From eb509f0c584ebae01834e773fb83584102a4f4da Mon Sep 17 00:00:00 2001 From: Guntupalli Venkata Sai Kalyan Date: Tue, 26 May 2020 15:44:10 -0700 Subject: [PATCH 001/707] move tensor's device from gpu to cpu (#2174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [X] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Partially Fixes https://github.com/pytorch/fairseq/issues/2173 . ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2174 Reviewed By: ngoyal2707 Differential Revision: D21725035 Pulled By: myleott fbshipit-source-id: 16e0a66a104a9713c5d98fee97d3d97261b72c94 --- fairseq/sequence_generator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 4e0ffca210..7ecdde869f 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -833,6 +833,11 @@ def generate(self, models, sample, **kwargs): for i in range(bsz * beam_size) ] + if src_tokens.device != "cpu": + src_tokens = src_tokens.to('cpu') + tgt_tokens = tgt_tokens.to('cpu') + attn = [i.to('cpu') for i in attn] + # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): alignment = utils.extract_hard_alignment( From 434eedd4ef4f09d352d44dbbf8eb53b4c0ac0a36 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 26 May 2020 15:49:13 -0700 Subject: [PATCH 002/707] correct link in quant_noise (#2184) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2184 Reviewed By: ngoyal2707 Differential Revision: D21725053 Pulled By: myleott fbshipit-source-id: bdb93af6695e96d4b44a58f104adb3c748a5fb40 --- examples/quant_noise/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 3341b98886..98d8c313ee 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -49,7 +49,7 @@ When evaluating a network, all quantized modules and activation hooks automatica #### Integration with your own code Looking to quantize your own models with Quant-Noise + Scalar Quantization? -- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar/utils) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. +- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. - Then, perform your training as usual. Note that in `eval()` mode, the network is always fully quantized (weights and activations) by default (`p=1`). From 145bc9de1278414812b2aef837be9ca0e9c1aebc Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 26 May 2020 15:55:53 -0700 Subject: [PATCH 003/707] Several small fixes (incl. set default --data-buffer-size=10) (#2163) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2163 Reviewed By: ngoyal2707 Differential Revision: D21665601 Pulled By: myleott fbshipit-source-id: 47673ff7f07acf0002c4e28380aa08ff917618ee --- README.md | 95 +++++++++++++------------- fairseq/benchmark/dummy_lm.py | 9 ++- fairseq/benchmark/dummy_masked_lm.py | 9 ++- fairseq/checkpoint_utils.py | 4 +- fairseq/data/encoders/fastbpe.py | 2 +- fairseq/models/fairseq_model.py | 8 +-- fairseq/models/huggingface/__init__.py | 16 ++++- fairseq/models/huggingface/hf_gpt2.py | 20 ++++++ fairseq/options.py | 2 +- fairseq_cli/train.py | 47 ++++++------- 10 files changed, 121 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index 8542791c2f..a3248d418a 100644 --- a/README.md +++ b/README.md @@ -13,30 +13,10 @@ Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. +We provide reference implementations of various sequence modeling papers: -### What's New: - - -- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -- April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) -- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -- February 2020: [mBART model and code released](examples/mbart/README.md) -- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) -- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -- November 2019: [CamemBERT model and code released](examples/camembert/README.md) -- November 2019: [BART model and code released](examples/bart/README.md) -- November 2019: [XLM-R models and code released](examples/xlmr/README.md) -- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -- August 2019: [WMT'19 models released](examples/wmt19/README.md) -- July 2019: fairseq relicensed under MIT license -- July 2019: [RoBERTa models and code released](examples/roberta/README.md) -- June 2019: [wav2vec models and code released](examples/wav2vec/README.md) - -### Features: +
List of implemented papers

-Fairseq provides reference implementations of various sequence-to-sequence models, including: - **Convolutional Neural Networks (CNN)** - [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) @@ -65,9 +45,35 @@ Fairseq provides reference implementations of various sequence-to-sequence model - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +

+ +### What's New: + +- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +- April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +
Previous updates

-**Additionally:** -- multi-GPU (distributed) training on one machine or across multiple machines +- February 2020: [mBART model and code released](examples/mbart/README.md) +- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) +- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +- November 2019: [CamemBERT model and code released](examples/camembert/README.md) +- November 2019: [BART model and code released](examples/bart/README.md) +- November 2019: [XLM-R models and code released](examples/xlmr/README.md) +- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +- August 2019: [WMT'19 models released](examples/wmt19/README.md) +- July 2019: fairseq relicensed under MIT license +- July 2019: [RoBERTa models and code released](examples/roberta/README.md) +- June 2019: [wav2vec models and code released](examples/wav2vec/README.md) + +

+ +### Features: + +- multi-GPU training on one machine or across multiple machines (data and model parallel) - fast generation on both CPU and GPU with multiple search algorithms implemented: - beam search - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) @@ -86,41 +92,32 @@ en2de.translate('Hello world', beam=5) See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. -![Model](fairseq.gif) - # Requirements and Installation * [PyTorch](http://pytorch.org/) version >= 1.4.0 * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: +* **To install fairseq** and develop locally: ```bash -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./ -``` +git clone https://github.com/pytorch/fairseq +cd fairseq +pip install --editable ./ -To install fairseq: -```bash -pip install fairseq +# on MacOS: +# CFLAGS="-stdlib=libc++" pip install --editable ./ ``` - -On MacOS: +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: ```bash -CFLAGS="-stdlib=libc++" pip install fairseq +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ + --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ + --global-option="--fast_multihead_attn" ./ ``` - -If you use Docker make sure to increase the shared memory size either with +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`. -**Installing from source** - -To install fairseq from source and develop locally: -```bash -git clone https://github.com/pytorch/fairseq -cd fairseq -pip install --editable . -``` # Getting Started @@ -135,11 +132,11 @@ as well as example training and evaluation commands. - [Translation](examples/translation/README.md): convolutional and transformer models are available - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available -- [wav2vec](examples/wav2vec/README.md): wav2vec large model is available We also have more detailed READMEs to reproduce results from specific papers: - [Training with Quantization Noise for Extreme Model Compression](examples/quant_noise/README.md) - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) @@ -156,10 +153,12 @@ We also have more detailed READMEs to reproduce results from specific papers: # Join the fairseq community +* Twitter: https://twitter.com/fairseq * Facebook page: https://www.facebook.com/groups/fairseq.users * Google group: https://groups.google.com/forum/#!forum/fairseq-users # License + fairseq(-py) is MIT-licensed. The license applies to the pre-trained models as well. diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 710eb67437..92e9dc8df5 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -21,7 +21,7 @@ class DummyLMTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--dict-size', default=50000, type=int) + parser.add_argument('--dict-size', default=49996, type=int) parser.add_argument('--dataset-size', default=100000, type=int) parser.add_argument('--tokens-per-sample', default=512, type=int, help='max number of total tokens over all segments ' @@ -32,6 +32,8 @@ def __init__(self, args, dictionary): self.dictionary = dictionary self.seed = args.seed + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 self.dummy_src = seq[:-1] @@ -51,7 +53,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - bsz = self.args.max_sentences + if self.args.max_sentences is not None: + bsz = self.args.max_sentences + else: + bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( { 'id': 1, diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index bacd0e8acc..f2e459caa2 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -21,7 +21,7 @@ class DummyMaskedLMTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--dict-size', default=50000, type=int) + parser.add_argument('--dict-size', default=49995, type=int) parser.add_argument('--dataset-size', default=100000, type=int) parser.add_argument('--tokens-per-sample', default=512, type=int, help='max number of total tokens over all segments ' @@ -34,7 +34,7 @@ def __init__(self, args, dictionary): # add mask token self.mask_idx = dictionary.add_symbol('') - assert len(dictionary) % 8 == 0 + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 mask_idx = 0 pad_idx = 1 @@ -62,7 +62,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - bsz = self.args.max_sentences + if self.args.max_sentences is not None: + bsz = self.args.max_sentences + else: + bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( { 'id': 1, diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ecda632264..fe25b0a9dd 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -400,7 +400,7 @@ def create_pruning_pass(layers_to_keep, layer_name): for i in range(len(keep_layers)): mapping_dict[str(keep_layers[i])] = str(i) - regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) return {"substitution_regex": regex, "mapping_dict": mapping_dict} pruning_passes = [] @@ -411,7 +411,7 @@ def create_pruning_pass(layers_to_keep, layer_name): new_state_dict = {} for layer_name in state_dict.keys(): - match = re.search("\.layers\.(\d+)\.", layer_name) + match = re.search(r"\.layers\.(\d+)\.", layer_name) # if layer has no number in it, it is a supporting layer, such as an # embedding if not match: diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py index 376e22cd85..ea0badd544 100644 --- a/fairseq/data/encoders/fastbpe.py +++ b/fairseq/data/encoders/fastbpe.py @@ -19,7 +19,7 @@ def add_args(parser): def __init__(self, args): if args.bpe_codes is None: - raise ValueError('--bpe-codes is required for --bpe=subword_nmt') + raise ValueError('--bpe-codes is required for --bpe=fastbpe') codes = file_utils.cached_path(args.bpe_codes) try: import fastBPE diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index c9590b87a3..7f9b731ef6 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -375,13 +375,7 @@ def build_shared_embeddings( return build_embedding(shared_dict, embed_dim, pretrained_embed_path) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): - decoder_outs = {} - for key in self.keys: - encoder_out = self.models[key].encoder(src_tokens, src_lengths, **kwargs) - decoder_outs[key] = self.models[key].decoder( - prev_output_tokens, encoder_out, **kwargs - ) - return decoder_outs + raise NotImplementedError def max_positions(self): """Maximum length supported by the model.""" diff --git a/fairseq/models/huggingface/__init__.py b/fairseq/models/huggingface/__init__.py index e186c4b196..633315f54d 100644 --- a/fairseq/models/huggingface/__init__.py +++ b/fairseq/models/huggingface/__init__.py @@ -3,4 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .hf_gpt2 import * # noqa +import importlib +import os + + +# automatically import any Python files in the models/huggingface/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith('_') + and not file.startswith('.') + and (file.endswith('.py') or os.path.isdir(path)) + ): + model_name = file[:file.find('.py')] if file.endswith('.py') else file + module = importlib.import_module('fairseq.models.huggingface.' + model_name) diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py index 4107113e81..6a03406ef6 100644 --- a/fairseq/models/huggingface/hf_gpt2.py +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -16,6 +16,18 @@ register_model_architecture, ) +try: + # Prepend the transformers submodule to the path, so that + # it's prioritized over other installations. This allows + # making local changes in the submodule. + sys.path.insert( + 0, os.path.join(os.path.dirname(__file__), 'transformers', 'src') + ) + from transformers import AutoModel, GPT2Config, GPT2LMHeadModel + has_hf = True +except ImportError: + has_hf = False + logger = logging.getLogger(__name__) @@ -28,6 +40,14 @@ class HuggingFaceGPT2LanguageModel(FairseqLanguageModel): def __init__(self, decoder): super().__init__(decoder) + if not has_hf: + raise ImportError( + '\n\nPlease install huggingface/transformers with:' + '\n\n pip install transformers' + '\n\nOr to make local edits, install the submodule:' + '\n\n git submodule update --init ' + 'fairseq/models/huggingface/transformers' + ) @staticmethod def add_args(parser): diff --git a/fairseq/options.py b/fairseq/options.py index ad7d85e54e..f3ea0bc52f 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -329,7 +329,7 @@ def add_dataset_args(parser, train=False, gen=False): parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') - group.add_argument('--data-buffer-size', default=0, type=int, metavar='N', + group.add_argument('--data-buffer-size', default=10, type=int, metavar='N', help='Number of batches to preload') if train: group.add_argument('--train-subset', default='train', metavar='SPLIT', diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ee08d6febd..869b913230 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -109,7 +109,6 @@ def main(args, init_distributed=False): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf - max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() @@ -118,8 +117,8 @@ def main(args, init_distributed=False): and epoch_itr.next_epoch_idx <= max_epoch ): # train for one epoch - valid_losses = train(args, trainer, task, epoch_itr, max_update) - if should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update: + valid_losses, should_stop = train(args, trainer, task, epoch_itr) + if should_stop: break # only use first validation loss to update the learning rate @@ -172,7 +171,7 @@ def tpu_data_loader(args, itr): @metrics.aggregate('train') -def train(args, trainer, task, epoch_itr, max_update=math.inf): +def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -201,6 +200,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') + should_stop = False for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) @@ -218,10 +218,10 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): metrics.reset_meters('train_inner') end_of_epoch = not itr.has_next() - valid_losses = validate_and_save( + valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) - if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: + if should_stop: break # log end-of-epoch stats @@ -230,7 +230,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): # reset epoch-level meters metrics.reset_meters('train') - return valid_losses + return valid_losses, should_stop def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): @@ -245,7 +245,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc ) do_validate = ( ( - do_save # saving requires validation + (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) ) and not args.disable_validation @@ -255,10 +255,19 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc valid_losses = [None] if do_validate: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - # Save - if do_save: + + # Stopping conditions + max_update = args.max_update or math.inf + should_stop = ( + should_stop_early(args, valid_losses[0]) + or trainer.get_num_updates() >= max_update + ) + + # Save checkpoint + if do_save or should_stop: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - return valid_losses + + return valid_losses, should_stop def get_training_stats(stats): @@ -276,21 +285,7 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: # Initialize data iterator - itr = task.get_batch_iterator( - dataset=task.dataset(subset), - max_tokens=args.max_tokens_valid, - max_sentences=args.max_sentences_valid, - max_positions=utils.resolve_max_positions( - task.max_positions(), - trainer.get_model().max_positions(), - ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - ).next_epoch_itr(shuffle=False) + itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( From 95294bfbb627c7ba140e73ac27c8e98012045916 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Wed, 27 May 2020 06:23:49 -0700 Subject: [PATCH 004/707] refactor superclass of MaskedLMModel (#2170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: masked_lm is actually encoder-only model # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Class Refactor ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2170 Reviewed By: ngoyal2707 Differential Revision: D21725071 Pulled By: myleott fbshipit-source-id: 75fd36008f3e3425f8f5180472734394046dfb77 --- fairseq/models/masked_lm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/fairseq/models/masked_lm.py b/fairseq/models/masked_lm.py index 0937059a4a..1cc8afcf23 100644 --- a/fairseq/models/masked_lm.py +++ b/fairseq/models/masked_lm.py @@ -11,7 +11,7 @@ from fairseq import utils from fairseq.models import ( - BaseFairseqModel, + FairseqEncoderModel, FairseqEncoder, register_model, register_model_architecture, @@ -28,15 +28,14 @@ @register_model('masked_lm') -class MaskedLMModel(BaseFairseqModel): +class MaskedLMModel(FairseqEncoderModel): """ Class for training a Masked Language Model. It also supports an additional sentence level prediction if the sent-loss argument is set. """ def __init__(self, args, encoder): - super().__init__() + super().__init__(encoder) self.args = args - self.encoder = encoder # if specified then apply bert initialization on the model. We need # to explictly call this to make sure that the output embeddings From 5453e4355b274645074d0068f668ac5bcea9905c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 27 May 2020 07:48:21 -0700 Subject: [PATCH 005/707] =?UTF-8?q?Avoid=20NaN=20in=20speech=5Frecognition?= =?UTF-8?q?=20with=20input=20having=20only=201=20spec=E2=80=A6=20(#1864)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …trogram # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1863. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/1864 Reviewed By: yqwangustc Differential Revision: D21663642 Pulled By: myleott fbshipit-source-id: f411c5c01c7505375bec6d47554e85fb70877e9c --- .../speech_recognition/data/data_utils.py | 4 ++++ tests/speech_recognition/test_data_utils.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 tests/speech_recognition/test_data_utils.py diff --git a/examples/speech_recognition/data/data_utils.py b/examples/speech_recognition/data/data_utils.py index 03c41f47d9..cc4729e63c 100644 --- a/examples/speech_recognition/data/data_utils.py +++ b/examples/speech_recognition/data/data_utils.py @@ -19,6 +19,10 @@ def calc_mean_invstddev(feature): def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features mean, invstddev = calc_mean_invstddev(features) res = (features - mean) * invstddev return res diff --git a/tests/speech_recognition/test_data_utils.py b/tests/speech_recognition/test_data_utils.py new file mode 100644 index 0000000000..5ca7c5c2a1 --- /dev/null +++ b/tests/speech_recognition/test_data_utils.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +import torch + +from examples.speech_recognition.data import data_utils + + +class DataUtilsTest(unittest.TestCase): + + def test_normalization(self): + sample_len1 = torch.tensor([[-0.7661, -1.3889, -2.0972, -0.9134, -0.7071, -0.9765, -0.8700, -0.8283, + 0.7512, 1.3211, 2.1532, 2.1174, 1.2800, 1.2633, 1.6147, 1.6322, + 2.0723, 3.1522, 3.2852, 2.2309, 2.5569, 2.2183, 2.2862, 1.5886, + 0.8773, 0.8725, 1.2662, 0.9899, 1.1069, 1.3926, 1.2795, 1.1199, + 1.1477, 1.2687, 1.3843, 1.1903, 0.8355, 1.1367, 1.2639, 1.4707]]) + out = data_utils.apply_mv_norm(sample_len1) + assert not torch.isnan(out).any() + assert (out == sample_len1).all() From be5313acc7856e69a4d91aa804bca254cc2886c2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 27 May 2020 09:56:08 -0700 Subject: [PATCH 006/707] Add bart.base to README (fixes #2189) (#2190) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2190 Reviewed By: ngoyal2707 Differential Revision: D21742525 Pulled By: myleott fbshipit-source-id: fa29e3e36eb136a337fb1277cad3996ae2b22546 --- examples/bart/README.md | 1 + fairseq/models/bart/model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/bart/README.md b/examples/bart/README.md index 9ca210a401..027e2f1ef1 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -10,6 +10,7 @@ BART is sequence-to-sequence model trained with denoising as pretraining objecti Model | Description | # params | Download ---|---|---|--- +`bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz) `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz) `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz) diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 389890a31c..62c495cb64 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -32,6 +32,7 @@ class BARTModel(TransformerModel): @classmethod def hub_models(cls): return { + 'bart.base': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz', 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', 'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz', From 2f7e3f33235b787de2e34123d25f659e34a21558 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 27 May 2020 10:21:49 -0700 Subject: [PATCH 007/707] Support multi-GPU validation in fairseq-validate (#2162) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2162 Reviewed By: ngoyal2707 Differential Revision: D21663181 Pulled By: myleott fbshipit-source-id: d01e64f97482f76bd601cd8b20232c0ef637bb8a --- fairseq/criterions/adaptive_loss.py | 2 +- fairseq/options.py | 1 + fairseq_cli/validate.py | 17 ++++++++++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 33e9317e84..1916131bb1 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -23,7 +23,7 @@ def __init__(self, task, sentence_avg): @classmethod def build_criterion(cls, args, task): - if args.ddp_backend == 'c10d': + if getattr(args, 'ddp_backend', None) == 'c10d': raise Exception( 'AdaptiveLoss is not compatible with the c10d ' 'version of DistributedDataParallel. Please use ' diff --git a/fairseq/options.py b/fairseq/options.py index f3ea0bc52f..52c8a96129 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -53,6 +53,7 @@ def get_eval_lm_parser(default_task="language_modeling"): def get_validation_parser(default_task=None): parser = get_parser("Validation", default_task) add_dataset_args(parser, train=True) + add_distributed_training_args(parser) group = parser.add_argument_group("Evaluation") add_common_eval_args(group) return parser diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 580ac7b8b3..b339a056a0 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -5,6 +5,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from itertools import chain import logging import sys @@ -12,7 +13,7 @@ from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq.logging import metrics, progress_bar -from fairseq.options import add_distributed_training_args + logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', @@ -32,6 +33,9 @@ def main(args, override_args=None): use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu + if use_cuda: + torch.cuda.set_device(args.device_id) + if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) @@ -80,6 +84,8 @@ def main(args, override_args=None): ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( @@ -97,6 +103,13 @@ def main(args, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) + if args.distributed_world_size > 1: + log_outputs = distributed_utils.all_gather_list( + log_outputs, + max_size=getattr(args, 'all_gather_list_size', 16384), + ) + log_outputs = list(chain.from_iterable(log_outputs)) + with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() @@ -106,12 +119,10 @@ def main(args, override_args=None): def cli_main(): parser = options.get_validation_parser() - add_distributed_training_args(parser) args = options.parse_args_and_arch(parser) # only override args that are explicitly given on the command line override_parser = options.get_validation_parser() - add_distributed_training_args(override_parser) override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) distributed_utils.call_main(args, main, override_args=override_args) From 8e48f45aa469bbff85613520ffc161c0850e4744 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 28 May 2020 07:23:22 -0700 Subject: [PATCH 008/707] Miscellaneous fixes (#2193) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2193 Reviewed By: ngoyal2707 Differential Revision: D21748548 Pulled By: myleott fbshipit-source-id: d9f64540b55b4d427b3da6ad04a35f7b988b049a --- .../README.md => README.adaptive_inputs.md} | 8 ++-- .../{conv_lm/README.md => README.conv.md} | 3 +- examples/language_model/README.md | 9 ++-- examples/layerdrop/README.md | 12 +++-- .../speech_recognition/criterions/ASG_loss.py | 2 +- .../speech_recognition/criterions/CTC_loss.py | 2 +- fairseq/data/iterators.py | 2 +- fairseq/models/roberta/model.py | 47 +++++++++++-------- fairseq/models/transformer_lm.py | 2 +- fairseq/trainer.py | 10 ++++ fairseq/utils.py | 1 - 11 files changed, 60 insertions(+), 38 deletions(-) rename examples/language_model/{transformer_lm/README.md => README.adaptive_inputs.md} (89%) rename examples/language_model/{conv_lm/README.md => README.conv.md} (97%) diff --git a/examples/language_model/transformer_lm/README.md b/examples/language_model/README.adaptive_inputs.md similarity index 89% rename from examples/language_model/transformer_lm/README.md rename to examples/language_model/README.adaptive_inputs.md index 0ca6482afc..6873467115 100644 --- a/examples/language_model/transformer_lm/README.md +++ b/examples/language_model/README.adaptive_inputs.md @@ -4,13 +4,13 @@ Description | Parameters | Dataset | Model and Test set(s) ---|---:|---|--- -Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) -Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) +Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2) +Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2) ## Training an LM with adaptive inputs -First, see the general [language modeling README](../README.md) for instructions -on preprocessing the WikiText-103 data. +First, see the general [language modeling README](README.md) for instructions on +preprocessing the WikiText-103 data. Then use the following training command to train a model with adaptive inputs using the `transformer_lm_wiki103` model architecture: diff --git a/examples/language_model/conv_lm/README.md b/examples/language_model/README.conv.md similarity index 97% rename from examples/language_model/conv_lm/README.md rename to examples/language_model/README.conv.md index 83ac0b454b..9fccfcc0ea 100644 --- a/examples/language_model/conv_lm/README.md +++ b/examples/language_model/README.conv.md @@ -2,8 +2,7 @@ ## Example usage -First download and preprocess the data following the main [language modeling -README](../README.md). +First download and preprocess the data following the main [language modeling README](README.md). Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103` architecture: diff --git a/examples/language_model/README.md b/examples/language_model/README.md index 43f3381a1f..3d5c3862bb 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -5,7 +5,7 @@ Model | Description | Dataset | Download ---|---|---|--- `transformer_lm.gbw.adaptive_huge` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2) -`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2) +`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2) `transformer_lm.wmt19.en` | English LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz) `transformer_lm.wmt19.de` | German LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz) `transformer_lm.wmt19.ru` | Russian LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz) @@ -72,8 +72,7 @@ fairseq-preprocess \ ### 2) Train a language model Next we'll train a basic transformer language model on wikitext-103. For more -advanced examples (e.g., using [adaptive inputs](https://arxiv.org/abs/1809.10853)), -please see the [Transformer LM README](transformer_lm/README.md). +advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md). To train a basic LM (assumes 2 GPUs): ``` @@ -120,5 +119,5 @@ dataset, but results in better (lower) perplexity. ## Convolutional language models -Please see the [convolutional LM README](conv_lm/README.md) for instructions to -train convolutional language models. +Please see the [convolutional LM README](README.conv.md) for instructions on +training convolutional language models. diff --git a/examples/layerdrop/README.md b/examples/layerdrop/README.md index d7ede13642..394e710b0f 100644 --- a/examples/layerdrop/README.md +++ b/examples/layerdrop/README.md @@ -26,7 +26,11 @@ Model | Description | Download Evaluate performance of these pre-trained models: ```bash # Example for Machine Translation -python generate.py /path/to/bped/wmt/data --path nmt_checkpoint.pt --lenpen 0.4 --batch-size 64 --remove-bpe --beam 8 --gen-subset test > wmt16_gen.txt +fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \ + --beam 8 --lenpen 0.4 \ + --batch-size 64 \ + --remove-bpe \ + --gen-subset test > wmt16_gen.txt bash scripts/compound_split_bleu.sh wmt16_gen.txt # prints BLEU4 = 30.17 ``` @@ -111,8 +115,10 @@ num. model params: 146163712 ``` If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling: -``` -python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}" +```bash +fairseq-eval-lm /path/to/wikitext-103 \ + --path /path/to/model/checkpoint.pt \ + --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}" ``` This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model. diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py index 39962baaf5..8f932bcd5b 100644 --- a/examples/speech_recognition/criterions/ASG_loss.py +++ b/examples/speech_recognition/criterions/ASG_loss.py @@ -70,7 +70,7 @@ def __init__( self.linseg_maximum = linseg_updates self.linseg_message_state = "none" if hide_linseg_messages else "start" - @staticmethod + @classmethod def build_criterion(cls, args, task): return cls( task, diff --git a/examples/speech_recognition/criterions/CTC_loss.py b/examples/speech_recognition/criterions/CTC_loss.py index 33ed0fb135..df516f0d6e 100644 --- a/examples/speech_recognition/criterions/CTC_loss.py +++ b/examples/speech_recognition/criterions/CTC_loss.py @@ -81,7 +81,7 @@ def __init__(self, task): super().__init__(task) self.blank_idx = task.target_dictionary.index("") - @staticmethod + @classmethod def build_criterion(cls, args, task): return cls(task) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 169a4e72f2..064567eee0 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -454,7 +454,7 @@ def __next__(self): logger.info( "Data loading buffer is empty or nearly empty. This may " "indicate a data loading bottleneck, and increasing the " - "number of workers may help." + "number of workers (--num-workers) may help." ) self.warning_time = time.time() diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index ce720db2d9..2303fbe26e 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -14,8 +14,8 @@ from fairseq import utils from fairseq.models import ( - FairseqDecoder, - FairseqLanguageModel, + FairseqEncoder, + FairseqEncoderModel, register_model, register_model_architecture, ) @@ -33,7 +33,7 @@ @register_model('roberta') -class RobertaModel(FairseqLanguageModel): +class RobertaModel(FairseqEncoderModel): @classmethod def hub_models(cls): @@ -116,12 +116,20 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla if classification_head_name is not None: features_only = True - x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) if classification_head_name is not None: x = self.classification_heads[classification_head_name](x) return x, extra + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + logits = net_output[0].float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): """Register a classification head.""" if name in self.classification_heads: @@ -163,13 +171,23 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na return RobertaHubInterface(x['args'], x['task'], x['models'][0]) def upgrade_state_dict_named(self, state_dict, name): - super().upgrade_state_dict_named(state_dict, name) - prefix = name + '.' if name != '' else '' - current_head_names = [] if not hasattr(self, 'classification_heads') else \ - self.classification_heads.keys() + + # rename decoder -> encoder before upgrading children modules + for k in list(state_dict.keys()): + if k.startswith(prefix + 'decoder'): + new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):] + state_dict[new_k] = state_dict[k] + del state_dict[k] + + # upgrade children modules + super().upgrade_state_dict_named(state_dict, name) # Handle new classification heads present in the state dict. + current_head_names = ( + [] if not hasattr(self, 'classification_heads') + else self.classification_heads.keys() + ) keys_to_delete = [] for k in state_dict.keys(): if not k.startswith(prefix + 'classification_heads.'): @@ -261,24 +279,15 @@ def forward(self, features, **kwargs): return x -class RobertaEncoder(FairseqDecoder): - """RoBERTa encoder. - - Implements the :class:`~fairseq.models.FairseqDecoder` interface required - by :class:`~fairseq.models.FairseqLanguageModel`. - """ +class RobertaEncoder(FairseqEncoder): + """RoBERTa encoder.""" def __init__(self, args, dictionary): super().__init__(dictionary) self.args = args - # RoBERTa is a sentence encoder model, so users will intuitively trim - # encoder layers. However, the implementation uses the fairseq decoder, - # so we fix here. if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - args.decoder_layers_to_keep = args.encoder_layers_to_keep - args.encoder_layers_to_keep = None self.sentence_encoder = TransformerSentenceEncoder( padding_idx=dictionary.pad(), diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 016e8c0743..dfc93d68d3 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -36,7 +36,7 @@ def moses_fastbpe(path): return { 'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2', - 'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2', + 'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2', 'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'), 'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'), 'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'), diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 377ce99b44..2d465c6f7a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -870,6 +870,16 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): self.task.reduce_metrics(logging_outputs, self.get_criterion()) del logging_outputs + # extra warning for criterions that don't properly log a loss value + if "loss" not in agg: + if "loss" not in self._warn_once: + self._warn_once.add("loss") + logger.warning( + "Criterion.reduce_metrics did not log a 'loss' value, " + "which may break some functionality" + ) + metrics.log_scalar("loss", -1) + # support legacy interface if self.tpu: logging_output = {} diff --git a/fairseq/utils.py b/fairseq/utils.py index 22008012f5..da27529268 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -356,7 +356,6 @@ def import_user_module(args): if module_name not in sys.modules: sys.path.insert(0, module_parent) importlib.import_module(module_name) - sys.path.pop(0) def softmax(x, dim: int, onnx_trace: bool = False): From 29b8a4deb58ca9798b61690a31de1ea57de92122 Mon Sep 17 00:00:00 2001 From: Yongqiang Wang Date: Fri, 29 May 2020 11:17:42 -0700 Subject: [PATCH 009/707] improving error logging for "gradients are inconsistent between workers" Summary: "gradients are inconsistent between workers" are becoming increasingly more annoying and very difficult to debug. It may be caused by: - grad_norm becomes NaN in some worker - all_reduce is inconsistent We now will explicitly raise NaN or Inf if grad_norm is NaN/Inf, and print out grad_norm values if they are inconsistent Reviewed By: myleott Differential Revision: D21716726 fbshipit-source-id: 70593f001bc16b9e1cda460169a29b2be6aaed2c --- fairseq/trainer.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 2d465c6f7a..269e43313d 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -837,16 +837,39 @@ def _fast_stat_sync_sum( logging_outputs = [] return logging_outputs, extra_stats_to_sum + def _is_grad_norms_consistent(self, grad_norm_buf): + """check whether a given tensor (shape (N,)) is consistent """ + """consistent means all the values are diff within a tolerate range""" + diff = grad_norm_buf - grad_norm_buf[0] + max_abs_diff = torch.max(torch.abs(diff)).item() + first_grad_norm = grad_norm_buf[0].item() + # TODO: make 1e-6 a configurable value + return max_abs_diff / (first_grad_norm + 1e-6) < 1e-6 + def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: self._grad_norm_buf.zero_() self._grad_norm_buf[self.data_parallel_rank] = grad_norm - distributed_utils.all_reduce(self._grad_norm_buf, group=self.data_parallel_process_group) - if not (self._grad_norm_buf == self._grad_norm_buf[0]).all(): + distributed_utils.all_reduce( + self._grad_norm_buf, + group=self.data_parallel_process_group + ) + + if not self._is_grad_norms_consistent(self._grad_norm_buf): + pretty_detail = "\n".join( + "rank {:3d} = {:.8f}".format(r, n) + for r, n in enumerate(self._grad_norm_buf.tolist()) + ) + error_detail = "grad_norm across the workers:\n{}\n".format(pretty_detail) raise RuntimeError( "Fatal error: gradients are inconsistent between workers. " - "Try --ddp-backend=no_c10d." + "Try --ddp-backend=no_c10d. " + "Or are you mixing up different generation of GPUs in training?" + + "\n" + + "-" * 80 + + "\n{}\n".format(error_detail) + + "-" * 80 ) def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): From 1e40a48037eefc5ceb7672ea0fa11db629a35113 Mon Sep 17 00:00:00 2001 From: Duc Le Date: Sat, 30 May 2020 15:23:43 -0700 Subject: [PATCH 010/707] Enforce max limit of buffer size Reviewed By: jay-mahadeokar Differential Revision: D21804332 fbshipit-source-id: 51997455560a6b67f66d1401ef7095d4a1de4027 --- fairseq/data/iterators.py | 4 +++- fairseq/options.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 064567eee0..ef7d5fbfa8 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -211,7 +211,9 @@ def __init__( self.num_shards = num_shards self.shard_id = shard_id self.num_workers = num_workers - self.buffer_size = buffer_size + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 5) self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = True diff --git a/fairseq/options.py b/fairseq/options.py index 52c8a96129..07390e25a5 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -330,7 +330,7 @@ def add_dataset_args(parser, train=False, gen=False): parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') - group.add_argument('--data-buffer-size', default=10, type=int, metavar='N', + group.add_argument('--data-buffer-size', default=2, type=int, metavar='N', help='Number of batches to preload') if train: group.add_argument('--train-subset', default='train', metavar='SPLIT', From fad3cf0769843e767155f4d0af18a61b9a804f59 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Wed, 3 Jun 2020 09:45:20 -0700 Subject: [PATCH 011/707] Updates floor division to use floor division operator Summary: Performing floor division with torch.div is deprecated and will soon throw a runtime error. This diff updates the floor division to use the floor division operator. Created from Diffusion's 'Open in Editor' feature. Reviewed By: myleott Differential Revision: D21848480 fbshipit-source-id: c9374e9406f4ba388f30315294eee7a2a4fcfecc --- fairseq/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/search.py b/fairseq/search.py index 667a151ab0..1ee1d7cb44 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -335,7 +335,7 @@ def step(self, step: int, lprobs, scores): k, ) - final_beams = torch.div(final_indices, k) + final_beams = final_indices // k for i in range(bsz): final_indices[i] = indices[i][final_indices[i]] From 0b462f899925a1da6c91749bce0e0ed347604607 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Wed, 3 Jun 2020 12:58:40 -0700 Subject: [PATCH 012/707] adding model parallel roberta code (#1181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1181 Reviewed By: myleott Differential Revision: D21862030 fbshipit-source-id: 532ef608652e63f5490d554af486b87364af100e --- fairseq/model_parallel/models/__init__.py | 11 +- .../model_parallel/models/roberta/__init__.py | 6 + .../model_parallel/models/roberta/model.py | 268 ++++++++++++++++++ fairseq/model_parallel/modules/__init__.py | 4 + .../modules/transformer_sentence_encoder.py | 64 +++++ .../transformer_sentence_encoder_layer.py | 79 ++++++ .../modules/transformer_sentence_encoder.py | 33 ++- .../transformer_sentence_encoder_layer.py | 36 ++- 8 files changed, 490 insertions(+), 11 deletions(-) create mode 100644 fairseq/model_parallel/models/roberta/__init__.py create mode 100644 fairseq/model_parallel/models/roberta/model.py create mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder.py create mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py diff --git a/fairseq/model_parallel/models/__init__.py b/fairseq/model_parallel/models/__init__.py index beae8afa80..a3207981ad 100644 --- a/fairseq/model_parallel/models/__init__.py +++ b/fairseq/model_parallel/models/__init__.py @@ -7,7 +7,10 @@ import os -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - model_name = file[:file.find('.py')] - importlib.import_module('fairseq.model_parallel.models.' + model_name) +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): + model_name = file[:file.find('.py')] if file.endswith('.py') else file + module = importlib.import_module('fairseq.model_parallel.models.' + model_name) diff --git a/fairseq/model_parallel/models/roberta/__init__.py b/fairseq/model_parallel/models/roberta/__init__.py new file mode 100644 index 0000000000..117827c3e9 --- /dev/null +++ b/fairseq/model_parallel/models/roberta/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py new file mode 100644 index 0000000000..1a5f5647b4 --- /dev/null +++ b/fairseq/model_parallel/models/roberta/model.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +RoBERTa: A Robustly Optimized BERT Pretraining Approach. +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + FairseqDecoder, + register_model, + register_model_architecture, +) +from fairseq.models.roberta import ( + RobertaModel, + RobertaEncoder, + RobertaLMHead, + RobertaClassificationHead, +) +from fairseq.modules import ( + LayerNorm, + TransformerSentenceEncoder, +) +from fairseq.model_parallel.modules import ( + ModelParallelTransformerSentenceEncoder, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params +try: + from fairseq.model_parallel.megatron.mpu import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + ColumnParallelLinear, + RowParallelLinear, + ) + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + +logger = logging.getLogger(__name__) + + +@register_model('model_parallel_roberta') +class ModelParallelRobertaModel(RobertaModel): + + + def __init__(self, args, encoder): + super().__init__(args, encoder) + + self.classification_heads = nn.ModuleDict() + + @staticmethod + def add_args(parser): + super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + if not hasattr(args, 'max_positions'): + args.max_positions = args.tokens_per_sample + + encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) + return cls(args, encoder) + + def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs): + if classification_head_name is not None: + features_only = True + + x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) + + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + return x, extra + + def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + 'and inner_dim {} (prev: {})'.format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = ModelParallelRobertaClassificationHead( + self.args.encoder_embed_dim, + inner_dim or self.args.encoder_embed_dim, + num_classes, + self.args.pooler_activation_fn, + self.args.pooler_dropout, + ) + + +class ModelParallelRobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the unmasked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + + features = copy_to_model_parallel_region(features) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + x = gather_from_model_parallel_region(x).contiguous() + x = x + self.bias + return x + + +class ModelParallelRobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout): + super().__init__() + self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ModelParallelRobertaEncoder(FairseqDecoder): + """RoBERTa encoder. + + Implements the :class:`~fairseq.models.FairseqDecoder` interface required + by :class:`~fairseq.models.FairseqLanguageModel`. + """ + + def __init__(self, args, dictionary): + super().__init__(dictionary) + self.args = args + + # RoBERTa is a sentence encoder model, so users will intuitively trim + # encoder layers. However, the implementation uses the fairseq decoder, + # so we fix here. + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + args.decoder_layers_to_keep = args.encoder_layers_to_keep + args.encoder_layers_to_keep = None + + self.sentence_encoder = ModelParallelTransformerSentenceEncoder( + padding_idx=dictionary.pad(), + vocab_size=len(dictionary), + num_encoder_layers=args.encoder_layers, + embedding_dim=args.encoder_embed_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + layerdrop=args.encoder_layerdrop, + max_seq_len=args.max_positions, + num_segments=0, + encoder_normalize_before=False, + apply_bert_init=False, + activation_fn=args.activation_fn, + ) + self.lm_head = ModelParallelRobertaLMHead( + embed_dim=args.encoder_embed_dim, + output_dim=len(dictionary), + activation_fn=args.activation_fn, + weight=self.sentence_encoder.embed_tokens.weight, + ) + + def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): + """ + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + features_only (bool, optional): skip LM head and just return + features. If True, the output will be of shape + `(batch, src_len, embed_dim)`. + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + + Returns: + tuple: + - the LM output of shape `(batch, src_len, vocab)` + - a dictionary of additional data, where 'inner_states' + is a list of hidden states. Note that the hidden + states have shape `(src_len, batch, vocab)`. + """ + x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) + if not features_only: + x = self.output_layer(x, masked_tokens=masked_tokens) + return x, extra + + def extract_features(self, src_tokens, return_all_hiddens=False, **unused): + inner_states, _ = self.sentence_encoder( + src_tokens, + last_state_only=not return_all_hiddens, + ) + features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C + return features, {'inner_states': inner_states if return_all_hiddens else None} + + def output_layer(self, features, masked_tokens=None, **unused): + return self.lm_head(features, masked_tokens) + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + +@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta') +def base_architecture(args): + args.encoder_layers = getattr(args, 'encoder_layers', 12) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) + + args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + + args.dropout = getattr(args, 'dropout', 0.1) + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.activation_dropout = getattr(args, 'activation_dropout', 0.0) + args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) + args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) + args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) + + +@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_base') +def roberta_base_architecture(args): + base_architecture(args) + + +@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_large') +def roberta_large_architecture(args): + args.encoder_layers = getattr(args, 'encoder_layers', 24) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + base_architecture(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index eb29b16bec..5c9431f92b 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -5,9 +5,13 @@ from .multihead_attention import ModelParallelMultiheadAttention from .transformer_layer import ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer +from .transformer_sentence_encoder_layer import ModelParallelTransformerSentenceEncoderLayer +from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder __all__ = [ 'ModelParallelMultiheadAttention', 'ModelParallelTransformerEncoderLayer', 'ModelParallelTransformerDecoderLayer', + 'ModelParallelTransformerSentenceEncoder', + 'ModelParallelTransformerSentenceEncoderLayer', ] diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py new file mode 100644 index 0000000000..101eca7bd4 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.modules import ( + LayerNorm, + MultiheadAttention, + PositionalEmbedding, + TransformerSentenceEncoder, +) + +from fairseq.model_parallel.modules import ( + ModelParallelTransformerSentenceEncoderLayer, +) + +try: + from fairseq.model_parallel.megatron.mpu import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + VocabParallelEmbedding, + ) + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + +import random + + +class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): + """ + Implementation for a Model Parallel Bi-directional Transformer based + Sentence Encoder used in BERT/XLM style pre-trained models. + """ + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + **unused, + ): + return ModelParallelTransformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py new file mode 100644 index 0000000000..116b89ef83 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.modules import ( + TransformerSentenceEncoderLayer +) +from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +try: + from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + ) + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): + """ + Implements a Model Parallel Transformer Encoder Layer used in + BERT/XLM style pre-trained models. + """ + def build_fc1(self, input_dim, output_dim): + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim): + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention( + self, + embed_dim, + num_attention_heads, + attention_dropout, + **kwargs, + ): + return ModelParallelMultiheadAttention( + embed_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + return x diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index f8c708a7b2..32ba1cecac 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -114,7 +114,7 @@ def __init__( self.traceable = traceable self.tpu = False # whether we're on TPU - self.embed_tokens = nn.Embedding( + self.embed_tokens = self.build_embedding( self.vocab_size, self.embedding_dim, self.padding_idx ) self.embed_scale = embed_scale @@ -150,7 +150,7 @@ def __init__( else: self.layers = nn.ModuleList([]) self.layers.extend([ - TransformerSentenceEncoderLayer( + self.build_transformer_sentence_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, @@ -188,6 +188,35 @@ def freeze_module_params(m): for layer in range(n_trans_layers_to_freeze): freeze_module_params(self.layers[layer]) + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + q_noise, + qn_block_size, + ): + return TransformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + def prepare_for_tpu_(self, **kwargs): self.tpu = True diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index c34ae0c724..429814c7d7 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -15,6 +15,7 @@ ) from fairseq.modules.quant_noise import quant_noise + class TransformerSentenceEncoderLayer(nn.Module): """ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained @@ -43,7 +44,7 @@ def __init__( # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) - self.self_attn = MultiheadAttention( + self.self_attn = self.build_self_attention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, @@ -54,15 +55,40 @@ def __init__( # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) - self.fc1 = quant_noise( + + self.fc1 = self.build_fc1(self.embedding_dim, ffn_embedding_dim, q_noise, qn_block_size) + self.fc2 = self.build_fc2(ffn_embedding_dim, self.embedding_dim, q_noise, qn_block_size) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( nn.Linear(self.embedding_dim, ffn_embedding_dim), q_noise, qn_block_size ) - self.fc2 = quant_noise( + + def build_fc2(self, input_dim, output_dim): + return quant_noise( nn.Linear(ffn_embedding_dim, self.embedding_dim), q_noise, qn_block_size ) - # layer norm associated with the position wise feed-forward NN - self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) + def build_self_attention( + self, + embed_dim, + num_attention_heads, + attention_dropout, + self_attention, + q_noise, + qn_block_size, + ): + return MultiheadAttention( + embed_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) def forward( self, From 2cc8f6e5f2b9c3c36c64bc775f6ed61d4b8d97e0 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Wed, 3 Jun 2020 14:04:07 -0700 Subject: [PATCH 013/707] some fixes for model parallel roberta (#1182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1182 Reviewed By: myleott Differential Revision: D21868455 fbshipit-source-id: c12f90701ec36e55a72da393ca85c1198f23af04 --- .../model_parallel/models/roberta/model.py | 6 ++--- .../transformer_sentence_encoder_layer.py | 10 ++++---- .../transformer_sentence_encoder_layer.py | 24 +++++++++++++------ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index 1a5f5647b4..e0ae4a2c8f 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -14,7 +14,7 @@ from fairseq import utils from fairseq.models import ( - FairseqDecoder, + FairseqEncoder, register_model, register_model_architecture, ) @@ -76,7 +76,7 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla if classification_head_name is not None: features_only = True - x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) if classification_head_name is not None: x = self.classification_heads[classification_head_name](x) @@ -155,7 +155,7 @@ def forward(self, features, **kwargs): return x -class ModelParallelRobertaEncoder(FairseqDecoder): +class ModelParallelRobertaEncoder(FairseqEncoder): """RoBERTa encoder. Implements the :class:`~fairseq.models.FairseqDecoder` interface required diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py index 116b89ef83..0e1ea2b7d7 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -26,23 +26,23 @@ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLay Implements a Model Parallel Transformer Encoder Layer used in BERT/XLM style pre-trained models. """ - def build_fc1(self, input_dim, output_dim): + def build_fc1(self, input_dim, output_dim, **unused): return ColumnParallelLinear(input_dim, output_dim, gather_output=False) - def build_fc2(self, input_dim, output_dim): + def build_fc2(self, input_dim, output_dim, **unused): return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) def build_self_attention( self, embed_dim, num_attention_heads, - attention_dropout, + dropout, **kwargs, ): return ModelParallelMultiheadAttention( embed_dim, num_attention_heads, - dropout=attention_dropout, + dropout=dropout, self_attention=True ) @@ -76,4 +76,4 @@ def forward( x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - return x + return x, None diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index 429814c7d7..2d4747d041 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -56,27 +56,37 @@ def __init__( # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) - self.fc1 = self.build_fc1(self.embedding_dim, ffn_embedding_dim, q_noise, qn_block_size) - self.fc2 = self.build_fc2(ffn_embedding_dim, self.embedding_dim, q_noise, qn_block_size) + self.fc1 = self.build_fc1( + self.embedding_dim, + ffn_embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + self.fc2 = self.build_fc2( + ffn_embedding_dim, + self.embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( - nn.Linear(self.embedding_dim, ffn_embedding_dim), q_noise, qn_block_size + nn.Linear(input_dim, output_dim), q_noise, qn_block_size ) - def build_fc2(self, input_dim, output_dim): + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( - nn.Linear(ffn_embedding_dim, self.embedding_dim), q_noise, qn_block_size + nn.Linear(input_dim, output_dim), q_noise, qn_block_size ) def build_self_attention( self, embed_dim, num_attention_heads, - attention_dropout, + dropout, self_attention, q_noise, qn_block_size, @@ -84,7 +94,7 @@ def build_self_attention( return MultiheadAttention( embed_dim, num_attention_heads, - dropout=attention_dropout, + dropout=dropout, self_attention=True, q_noise=q_noise, qn_block_size=qn_block_size, From ea092c2aa697eb7c362447e663922d2dfe2f6da1 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 3 Jun 2020 18:49:00 -0700 Subject: [PATCH 014/707] Split out fairseq GPU tests & add new deeplearning_fairseq_gpu contbuild using remote execution Reviewed By: myleott Differential Revision: D21472387 fbshipit-source-id: efde278baf6a05e8a81a9630b44c7e7e7c7fe7fc --- tests/gpu/__init__.py | 0 tests/gpu/test_binaries_gpu.py | 281 +++++++++++++++ .../transformer_quantization_config.yaml | 0 tests/test_binaries.py | 336 ++---------------- tests/utils.py | 155 +++++++- 5 files changed, 460 insertions(+), 312 deletions(-) create mode 100644 tests/gpu/__init__.py create mode 100644 tests/gpu/test_binaries_gpu.py rename tests/{ => gpu}/transformer_quantization_config.yaml (100%) diff --git a/tests/gpu/__init__.py b/tests/gpu/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py new file mode 100644 index 0000000000..5ccb84c551 --- /dev/null +++ b/tests/gpu/test_binaries_gpu.py @@ -0,0 +1,281 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from io import StringIO + +import torch +from fairseq import options +from fairseq_cli import train +from tests.utils import ( + create_dummy_data, + generate_main, + preprocess_lm_data, + preprocess_translation_data, + train_translation_model, +) + + +class TestTranslationGPU(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fp16") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, "fconv_iwslt_de_en", ["--fp16"]) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_memory_efficient_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_memory_efficient_fp16") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--memory-efficient-fp16"] + ) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_levenshtein_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_levenshtein_transformer" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "levenshtein_transformer", + [ + "--apply-bert-init", + "--early-exit", + "6,6,6", + "--criterion", + "nat_loss", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + +def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + ] + + (extra_flags or []), + ) + train.main(train_args) + + # try scalar quantization + scalar_quant_train_parser = options.get_training_parser() + scalar_quant_train_args = options.parse_args_and_arch( + scalar_quant_train_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-update", + "3", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + "--quant-noise-scalar", + "0.5", + ] + + (extra_flags or []), + ) + train.main(scalar_quant_train_args) + + # try iterative PQ quantization + quantize_parser = options.get_training_parser() + quantize_args = options.parse_args_and_arch( + quantize_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "50", + "--tokens-per-sample", + "50", + "--max-update", + "6", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + "--restore-file", + os.path.join(data_dir, "checkpoint_last.pt"), + "--reset-optimizer", + "--quantization-config-path", + os.path.join( + os.path.dirname(__file__), "transformer_quantization_config.yaml" + ), + ] + + (extra_flags or []), + ) + train.main(quantize_args) + + +class TestQuantization(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_quantization(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_quantization") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + # tests both scalar and iterative PQ quantization + _quantize_language_model(data_dir, "transformer_lm") + + +class TestOptimizersGPU(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_flat_grads(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_flat_grads") as data_dir: + # Use just a bit of data and tiny model to keep this test runtime reasonable + create_dummy_data(data_dir, num_examples=10, maxlen=5) + preprocess_translation_data(data_dir) + with self.assertRaises(RuntimeError): + # adafactor isn't compatible with flat grads, which + # are used by default with --fp16 + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + "adafactor", + "--fp16", + ], + ) + # but it should pass once we set --fp16-no-flatten-grads + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + "adafactor", + "--fp16", + "--fp16-no-flatten-grads", + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformer_quantization_config.yaml b/tests/gpu/transformer_quantization_config.yaml similarity index 100% rename from tests/transformer_quantization_config.yaml rename to tests/gpu/transformer_quantization_config.yaml diff --git a/tests/test_binaries.py b/tests/test_binaries.py index e1f037bcb0..8e8732b643 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -8,19 +8,22 @@ import logging import os import random -import sys import tempfile import unittest import torch from fairseq import options -from fairseq_cli import preprocess from fairseq_cli import train -from fairseq_cli import generate -from fairseq_cli import interactive from fairseq_cli import eval_lm from fairseq_cli import validate +from tests.utils import ( + create_dummy_data, + preprocess_lm_data, + preprocess_translation_data, + train_translation_model, + generate_main, +) class TestTranslation(unittest.TestCase): @@ -47,24 +50,6 @@ def test_raw(self): train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw']) generate_main(data_dir, ['--dataset-impl', 'raw']) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_fp16(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fp16') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16']) - generate_main(data_dir) - - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_memory_efficient_fp16(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16']) - generate_main(data_dir) - def test_update_freq(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_update_freq') as data_dir: @@ -184,19 +169,28 @@ def test_transformer(self): ], run_validation=True) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_transformer_fp16(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer') as data_dir: + with tempfile.TemporaryDirectory("test_transformer") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'transformer_iwslt_de_en', [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--fp16', - ], run_validation=True) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--fp16", + ], + run_validation=True, + ) generate_main(data_dir) def test_multilingual_transformer(self): @@ -296,23 +290,6 @@ def test_cmlm_transformer(self): '--print-step', ]) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_levenshtein_transformer(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'levenshtein_transformer', [ - '--apply-bert-init', '--early-exit', '6,6,6', - '--criterion', 'nat_loss' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '9', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) - def test_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: @@ -714,23 +691,6 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): train.main(train_args) -class TestQuantization(unittest.TestCase): - def setUp(self): - logging.disable(logging.CRITICAL) - - def tearDown(self): - logging.disable(logging.NOTSET) - - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_quantization(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_quantization') as data_dir: - create_dummy_data(data_dir) - preprocess_lm_data(data_dir) - # tests both scalar and iterative PQ quantization - quantize_language_model(data_dir, 'transformer_lm') - - class TestOptimizers(unittest.TestCase): def setUp(self): @@ -759,74 +719,6 @@ def test_optimizers(self): ]) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_flat_grads(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_flat_grads') as data_dir: - # Use just a bit of data and tiny model to keep this test runtime reasonable - create_dummy_data(data_dir, num_examples=10, maxlen=5) - preprocess_translation_data(data_dir) - with self.assertRaises(RuntimeError): - # adafactor isn't compatible with flat grads, which - # are used by default with --fp16 - train_translation_model(data_dir, 'lstm', [ - '--required-batch-size-multiple', '1', - '--encoder-layers', '1', - '--encoder-hidden-size', '32', - '--decoder-layers', '1', - '--optimizer', 'adafactor', - '--fp16', - ]) - # but it should pass once we set --fp16-no-flatten-grads - train_translation_model(data_dir, 'lstm', [ - '--required-batch-size-multiple', '1', - '--encoder-layers', '1', - '--encoder-hidden-size', '32', - '--decoder-layers', '1', - '--optimizer', 'adafactor', - '--fp16', - '--fp16-no-flatten-grads', - ]) - - -def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): - def _create_dummy_data(filename): - data = torch.rand(num_examples * maxlen) - data = 97 + torch.floor(26 * data).int() - with open(os.path.join(data_dir, filename), 'w') as h: - offset = 0 - for _ in range(num_examples): - ex_len = random.randint(1, maxlen) - ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) - print(ex_str, file=h) - offset += ex_len - - def _create_dummy_alignment_data(filename_src, filename_tgt, filename): - with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ - open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ - open(os.path.join(data_dir, filename), 'w') as h: - for src, tgt in zip(src_f, tgt_f): - src_len = len(src.split()) - tgt_len = len(tgt.split()) - avg_len = (src_len + tgt_len) // 2 - num_alignments = random.randint(avg_len // 2, 2 * avg_len) - src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() - tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() - ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) - print(ex_str, file=h) - - _create_dummy_data('train.in') - _create_dummy_data('train.out') - _create_dummy_data('valid.in') - _create_dummy_data('valid.out') - _create_dummy_data('test.in') - _create_dummy_data('test.out') - - if alignment: - _create_dummy_alignment_data('train.in', 'train.out', 'train.align') - _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') - _create_dummy_alignment_data('test.in', 'test.out', 'test.align') - def create_dummy_roberta_head_data(data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False): input_dir = 'input0' @@ -862,109 +754,6 @@ def _create_dummy_data(filename): _create_dummy_data('test') -def preprocess_translation_data(data_dir, extra_flags=None): - preprocess_parser = options.get_preprocessing_parser() - preprocess_args = preprocess_parser.parse_args( - [ - '--source-lang', 'in', - '--target-lang', 'out', - '--trainpref', os.path.join(data_dir, 'train'), - '--validpref', os.path.join(data_dir, 'valid'), - '--testpref', os.path.join(data_dir, 'test'), - '--thresholdtgt', '0', - '--thresholdsrc', '0', - '--destdir', data_dir, - ] + (extra_flags or []), - ) - preprocess.main(preprocess_args) - - -def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, - lang_flags=None, extra_valid_flags=None): - if lang_flags is None: - lang_flags = [ - '--source-lang', 'in', - '--target-lang', 'out', - ] - train_parser = options.get_training_parser() - train_args = options.parse_args_and_arch( - train_parser, - [ - '--task', task, - data_dir, - '--save-dir', data_dir, - '--arch', arch, - '--lr', '0.05', - '--max-tokens', '500', - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--num-workers', 0, - ] + lang_flags + (extra_flags or []), - ) - train.main(train_args) - - if run_validation: - # test validation - validate_parser = options.get_validation_parser() - validate_args = options.parse_args_and_arch( - validate_parser, - [ - '--task', task, - data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--valid-subset', 'valid', - '--max-tokens', '500', - '--no-progress-bar', - ] + lang_flags + (extra_valid_flags or []) - ) - validate.main(validate_args) - - -def generate_main(data_dir, extra_flags=None): - if extra_flags is None: - extra_flags = [ - '--print-alignment', - ] - generate_parser = options.get_generation_parser() - generate_args = options.parse_args_and_arch( - generate_parser, - [ - data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--beam', '3', - '--batch-size', '64', - '--max-len-b', '5', - '--gen-subset', 'valid', - '--no-progress-bar', - ] + (extra_flags or []), - ) - - # evaluate model in batch mode - generate.main(generate_args) - - # evaluate model interactively - generate_args.buffer_size = 0 - generate_args.input = '-' - generate_args.max_sentences = None - orig_stdin = sys.stdin - sys.stdin = StringIO('h e l l o\n') - interactive.main(generate_args) - sys.stdin = orig_stdin - - -def preprocess_lm_data(data_dir): - preprocess_parser = options.get_preprocessing_parser() - preprocess_args = preprocess_parser.parse_args([ - '--only-source', - '--trainpref', os.path.join(data_dir, 'train.out'), - '--validpref', os.path.join(data_dir, 'valid.out'), - '--testpref', os.path.join(data_dir, 'test.out'), - '--destdir', data_dir, - ]) - preprocess.main(preprocess_args) - - def train_masked_lm(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( @@ -1130,80 +919,5 @@ def train_masked_language_model(data_dir, arch, extra_args=()): train.main(train_args) -def quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False): - train_parser = options.get_training_parser() - train_args = options.parse_args_and_arch( - train_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '500', - '--tokens-per-sample', '500', - '--save-dir', data_dir, - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - ] + (extra_flags or []), - ) - train.main(train_args) - - # try scalar quantization - scalar_quant_train_parser = options.get_training_parser() - scalar_quant_train_args = options.parse_args_and_arch( - scalar_quant_train_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '500', - '--tokens-per-sample', '500', - '--save-dir', data_dir, - '--max-update', '3', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - '--quant-noise-scalar', '0.5', - ] + (extra_flags or []), - ) - train.main(scalar_quant_train_args) - - # try iterative PQ quantization - quantize_parser = options.get_training_parser() - quantize_args = options.parse_args_and_arch( - quantize_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '50', - '--tokens-per-sample', '50', - '--max-update', '6', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - '--restore-file', os.path.join(data_dir, 'checkpoint_last.pt'), - '--reset-optimizer', - '--quantization-config-path', os.path.join(os.path.dirname(__file__), 'transformer_quantization_config.yaml'), - ] + (extra_flags or []), - ) - train.main(quantize_args) - if __name__ == '__main__': unittest.main() diff --git a/tests/utils.py b/tests/utils.py index f908e5e74a..e207575d6f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,10 +4,14 @@ # LICENSE file in the root directory of this source tree. import argparse +import os +import random +import sys import torch import torch.nn.functional as F -from fairseq import utils +from io import StringIO +from fairseq import options, utils from fairseq.data import Dictionary from fairseq.data.language_pair_dataset import collate from fairseq.models import ( @@ -17,6 +21,13 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.tasks import FairseqTask +from fairseq_cli import ( + generate, + interactive, + preprocess, + train, + validate, +) def dummy_dictionary(vocab_size, prefix='token_'): @@ -116,6 +127,148 @@ def sequence_generator_setup(): return tgt_dict, w1, w2, src_tokens, src_lengths, model +def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): + def _create_dummy_data(filename): + data = torch.rand(num_examples * maxlen) + data = 97 + torch.floor(26 * data).int() + with open(os.path.join(data_dir, filename), 'w') as h: + offset = 0 + for _ in range(num_examples): + ex_len = random.randint(1, maxlen) + ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) + print(ex_str, file=h) + offset += ex_len + + def _create_dummy_alignment_data(filename_src, filename_tgt, filename): + with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ + open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ + open(os.path.join(data_dir, filename), 'w') as h: + for src, tgt in zip(src_f, tgt_f): + src_len = len(src.split()) + tgt_len = len(tgt.split()) + avg_len = (src_len + tgt_len) // 2 + num_alignments = random.randint(avg_len // 2, 2 * avg_len) + src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() + tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() + ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) + print(ex_str, file=h) + + _create_dummy_data('train.in') + _create_dummy_data('train.out') + _create_dummy_data('valid.in') + _create_dummy_data('valid.out') + _create_dummy_data('test.in') + _create_dummy_data('test.out') + + if alignment: + _create_dummy_alignment_data('train.in', 'train.out', 'train.align') + _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') + _create_dummy_alignment_data('test.in', 'test.out', 'test.align') + + +def preprocess_lm_data(data_dir): + preprocess_parser = options.get_preprocessing_parser() + preprocess_args = preprocess_parser.parse_args([ + '--only-source', + '--trainpref', os.path.join(data_dir, 'train.out'), + '--validpref', os.path.join(data_dir, 'valid.out'), + '--testpref', os.path.join(data_dir, 'test.out'), + '--destdir', data_dir, + ]) + preprocess.main(preprocess_args) + + +def preprocess_translation_data(data_dir, extra_flags=None): + preprocess_parser = options.get_preprocessing_parser() + preprocess_args = preprocess_parser.parse_args( + [ + '--source-lang', 'in', + '--target-lang', 'out', + '--trainpref', os.path.join(data_dir, 'train'), + '--validpref', os.path.join(data_dir, 'valid'), + '--testpref', os.path.join(data_dir, 'test'), + '--thresholdtgt', '0', + '--thresholdsrc', '0', + '--destdir', data_dir, + ] + (extra_flags or []), + ) + preprocess.main(preprocess_args) + + +def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, + lang_flags=None, extra_valid_flags=None): + if lang_flags is None: + lang_flags = [ + '--source-lang', 'in', + '--target-lang', 'out', + ] + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + '--task', task, + data_dir, + '--save-dir', data_dir, + '--arch', arch, + '--lr', '0.05', + '--max-tokens', '500', + '--max-epoch', '1', + '--no-progress-bar', + '--distributed-world-size', '1', + '--num-workers', 0, + ] + lang_flags + (extra_flags or []), + ) + train.main(train_args) + + if run_validation: + # test validation + validate_parser = options.get_validation_parser() + validate_args = options.parse_args_and_arch( + validate_parser, + [ + '--task', task, + data_dir, + '--path', os.path.join(data_dir, 'checkpoint_last.pt'), + '--valid-subset', 'valid', + '--max-tokens', '500', + '--no-progress-bar', + ] + lang_flags + (extra_valid_flags or []) + ) + validate.main(validate_args) + + +def generate_main(data_dir, extra_flags=None): + if extra_flags is None: + extra_flags = [ + '--print-alignment', + ] + generate_parser = options.get_generation_parser() + generate_args = options.parse_args_and_arch( + generate_parser, + [ + data_dir, + '--path', os.path.join(data_dir, 'checkpoint_last.pt'), + '--beam', '3', + '--batch-size', '64', + '--max-len-b', '5', + '--gen-subset', 'valid', + '--no-progress-bar', + ] + (extra_flags or []), + ) + + # evaluate model in batch mode + generate.main(generate_args) + + # evaluate model interactively + generate_args.buffer_size = 0 + generate_args.input = '-' + generate_args.max_sentences = None + orig_stdin = sys.stdin + sys.stdin = StringIO('h e l l o\n') + interactive.main(generate_args) + sys.stdin = orig_stdin + + class TestDataset(torch.utils.data.Dataset): def __init__(self, data): From 5abc774eead6a9b47b372cf5cde22aee49587edf Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 5 Jun 2020 06:04:31 -0700 Subject: [PATCH 015/707] Re-enable test_transformer_fp16 GPU test Reviewed By: theweiho Differential Revision: D21890628 fbshipit-source-id: 4088884dd2a82a831f1c129e675eb233c469242a --- fairseq/options.py | 2 +- tests/gpu/test_binaries_gpu.py | 24 ++++++++++++++++++++++++ tests/test_binaries.py | 24 ------------------------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index 07390e25a5..092e1e1cb3 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -53,7 +53,7 @@ def get_eval_lm_parser(default_task="language_modeling"): def get_validation_parser(default_task=None): parser = get_parser("Validation", default_task) add_dataset_args(parser, train=True) - add_distributed_training_args(parser) + add_distributed_training_args(parser, default_world_size=1) group = parser.add_argument_group("Evaluation") add_common_eval_args(group) return parser diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 5ccb84c551..b65b545a4e 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -49,6 +49,30 @@ def test_memory_efficient_fp16(self): ) generate_main(data_dir) + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_transformer_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--fp16", + ], + run_validation=True, + ) + generate_main(data_dir) + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_levenshtein_transformer(self): with contextlib.redirect_stdout(StringIO()): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 8e8732b643..73db3f6385 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -169,30 +169,6 @@ def test_transformer(self): ], run_validation=True) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") - def test_transformer_fp16(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_transformer") as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model( - data_dir, - "transformer_iwslt_de_en", - [ - "--encoder-layers", - "2", - "--decoder-layers", - "2", - "--encoder-embed-dim", - "8", - "--decoder-embed-dim", - "8", - "--fp16", - ], - run_validation=True, - ) - generate_main(data_dir) - def test_multilingual_transformer(self): # test with all combinations of encoder/decoder lang tokens encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']] From 023f7af21f0987814289b6f605821213f927bfc6 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Fri, 5 Jun 2020 09:28:22 -0700 Subject: [PATCH 016/707] enable choice of max-tokens in masked_lm (#1180) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Adds option to set max-positions in masked_lm model. This option is available in the RoBERTa model, so it should be available here too. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1180 Reviewed By: myleott Differential Revision: D21904170 Pulled By: joshim5 fbshipit-source-id: 37168dbf1a2758620d5e8e05c7e8a9ef8d09c765 --- fairseq/models/masked_lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairseq/models/masked_lm.py b/fairseq/models/masked_lm.py index 1cc8afcf23..35a6323ef2 100644 --- a/fairseq/models/masked_lm.py +++ b/fairseq/models/masked_lm.py @@ -78,6 +78,8 @@ def add_args(parser): ' (outside self attention)') parser.add_argument('--num-segment', type=int, metavar='N', help='num segment in the input') + parser.add_argument('--max-positions', type=int, + help='number of positional embeddings to learn') # Arguments related to sentence level prediction parser.add_argument('--sentence-class-num', type=int, metavar='N', From b1f9a8f5665ac4f8c8cd002e59beeb2837e320e4 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Fri, 5 Jun 2020 09:54:49 -0700 Subject: [PATCH 017/707] Updates div sometimes performing floor division to explicitly perform either true division or floor division Summary: torch.div will soon throw a runtime error when it would have performed floor division. This diff updates this instance of div to use either the true division or floor division operators as appropriate so the behavior doesn't change and the test won't throw a runtime error. Created from Diffusion's 'Open in Editor' feature. Reviewed By: myleott Differential Revision: D21900423 fbshipit-source-id: 363c3e64d25608a033cd2942dcbb039a73018596 --- scripts/average_checkpoints.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index 7890516154..edda69fb8f 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -65,7 +65,10 @@ def average_checkpoints(inputs): averaged_params = collections.OrderedDict() for k, v in params_dict.items(): averaged_params[k] = v - averaged_params[k].div_(num_models) + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models new_state['model'] = averaged_params return new_state From 152a3fe14322c3cb8a38a31ddbff59e7536ab689 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Fri, 5 Jun 2020 12:13:34 -0700 Subject: [PATCH 018/707] Support residual connections in LSTM models (#1103) Summary: Adds support for residual connections in LSTM models. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1103 Reviewed By: myleott Differential Revision: D21639942 Pulled By: joshim5 fbshipit-source-id: a02ddfe080a847fd91a9c6a5074cb6dc782f7727 --- fairseq/models/lstm.py | 9 ++++++--- fairseq/models/lstm_lm.py | 7 ++++++- tests/test_binaries.py | 14 ++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 1b470d064f..c2fbde33a4 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -176,7 +176,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): options.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), - max_target_positions=max_target_positions + max_target_positions=max_target_positions, + residuals=False ) return cls(encoder, decoder) @@ -346,7 +347,6 @@ def forward(self, input, source_hids, encoder_padding_mask): x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) return x, attn_scores - class LSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( @@ -354,7 +354,8 @@ def __init__( num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, - max_target_positions=DEFAULT_MAX_TARGET_POSITIONS + max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, + residuals=False ): super().__init__(dictionary) self.dropout_in = dropout_in @@ -363,6 +364,7 @@ def __init__( self.share_input_output_embed = share_input_output_embed self.need_attn = True self.max_target_positions = max_target_positions + self.residuals = residuals self.num_layers = num_layers self.adaptive_softmax = None @@ -501,6 +503,7 @@ def extract_features( # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) + if self.residuals: input = input + prev_hiddens[i] # save state for next time step prev_hiddens[i] = hidden diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py index d3b6972a38..9f6758a4bc 100644 --- a/fairseq/models/lstm_lm.py +++ b/fairseq/models/lstm_lm.py @@ -39,6 +39,9 @@ def add_args(parser): parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') + parser.add_argument('--residuals', default=False, + action='store_true', + help='applying residuals between LSTM layers') # Granular dropout settings (if not specified these default to --dropout) parser.add_argument('--decoder-dropout-in', type=float, metavar='D', @@ -104,7 +107,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): options.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), - max_target_positions=max_target_positions + max_target_positions=max_target_positions, + residuals=args.residuals ) return cls(decoder) @@ -123,3 +127,4 @@ def base_architecture(args): args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + args.residuals = getattr(args, 'residuals', False) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 73db3f6385..4eca1debf6 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -493,6 +493,20 @@ def test_lstm_lm(self): '--tokens-per-sample', '500', ]) + def test_lstm_lm_residuals(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_lstm_lm_residuals') as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, 'lstm_lm', ['--add-bos-token', '--residuals'], run_validation=True, + ) + eval_lm_main(data_dir) + generate_main(data_dir, [ + '--task', 'language_modeling', + '--sample-break-mode', 'eos', + '--tokens-per-sample', '500', + ]) class TestMaskedLanguageModel(unittest.TestCase): From 2699f4a28b70cb7a2ec5890f71b6d6f27fd0af92 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Fri, 5 Jun 2020 12:55:56 -0700 Subject: [PATCH 019/707] add random sequence truncation (#1173) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Allows taking random crops in language_modeling task. Discussed with myleott robert-verkuil tomsercu in meeting yesterday. Ultimately took a different, more general approach to implementing this. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1173 Reviewed By: myleott Differential Revision: D21904561 Pulled By: joshim5 fbshipit-source-id: 66e8dfb10a0d36b76acd2eb181e00db6fc2433fc --- .../roberta/README.custom_classification.md | 2 +- examples/roberta/README.race.md | 2 +- fairseq/data/__init__.py | 2 +- fairseq/data/mask_tokens_dataset.py | 1 + fairseq/data/shorten_dataset.py | 67 +++++++++++++++++++ fairseq/data/truncate_dataset.py | 31 --------- fairseq/tasks/language_modeling.py | 20 ++++-- fairseq/tasks/masked_lm.py | 16 +++++ fairseq/tasks/sentence_prediction.py | 21 ++++-- fairseq/tasks/sentence_ranking.py | 20 ++++-- 10 files changed, 134 insertions(+), 48 deletions(-) create mode 100644 fairseq/data/shorten_dataset.py delete mode 100644 fairseq/data/truncate_dataset.py diff --git a/examples/roberta/README.custom_classification.md b/examples/roberta/README.custom_classification.md index 1e53ce4c18..3b44aac027 100644 --- a/examples/roberta/README.custom_classification.md +++ b/examples/roberta/README.custom_classification.md @@ -123,7 +123,7 @@ CUDA_VISIBLE_DEVICES=0 python train.py IMDB-bin/ \ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ --max-epoch 10 \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ - --truncate-sequence \ + --shorten-method "truncate" \ --find-unused-parameters \ --update-freq 4 ``` diff --git a/examples/roberta/README.race.md b/examples/roberta/README.race.md index 365bcf248b..c2d1acaba6 100644 --- a/examples/roberta/README.race.md +++ b/examples/roberta/README.race.md @@ -28,7 +28,7 @@ CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \ --init-token 0 --separator-token 2 \ --max-option-length 128 \ --max-positions 512 \ - --truncate-sequence \ + --shorten-method "truncate" \ --arch roberta_large \ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ --criterion sentence_ranking \ diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index f844564dbc..30c6e88d82 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -44,7 +44,7 @@ from .token_block_dataset import TokenBlockDataset from .transform_eos_dataset import TransformEosDataset from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset -from .truncate_dataset import TruncateDataset +from .shorten_dataset import TruncateDataset, RandomCropDataset from .iterators import ( CountingIterator, diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 84d313a096..28bc3bc9cf 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -92,6 +92,7 @@ def __init__( self.epoch = 0 def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) self.epoch = epoch @lru_cache(maxsize=8) diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py new file mode 100644 index 0000000000..f95288a5c0 --- /dev/null +++ b/fairseq/data/shorten_dataset.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from fairseq.data import data_utils + +from . import BaseWrapperDataset + + +class TruncateDataset(BaseWrapperDataset): + """Truncate a sequence by returning the first truncation_length tokens + """ + + def __init__(self, dataset, truncation_length): + super().__init__(dataset) + assert truncation_length is not None + self.truncation_length = truncation_length + self.dataset = dataset + + def __getitem__(self, index): + item = self.dataset[index] + item_len = item.size(0) + if item_len > self.truncation_length: + item = item[:self.truncation_length] + return item + + @property + def sizes(self): + return np.minimum(self.dataset.sizes, self.truncation_length) + + def __len__(self): + return len(self.dataset) + + +class RandomCropDataset(TruncateDataset): + """Truncate a sequence by returning a random crop of truncation_length tokens + """ + + def __init__(self, dataset, truncation_length, seed=1): + super().__init__(dataset, truncation_length) + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + item_len = item.size(0) + excess = item_len - self.truncation_length + if excess > 0: + start_idx = np.random.randint(0, excess) + item = item[start_idx:start_idx+self.truncation_length] + return item + +def maybe_shorten_dataset(dataset, split, shorten_data_split_whitelist, shorten_method, tokens_per_sample, seed): + truncate_split = split in shorten_data_split_whitelist.split(',') \ + or len(shorten_data_split_whitelist) == 0 + if shorten_method == 'truncate' and truncate_split: + dataset = TruncateDataset(dataset, tokens_per_sample) + elif shorten_method == 'random_crop' and truncate_split: + dataset = RandomCropDataset(dataset, tokens_per_sample, seed) + return dataset diff --git a/fairseq/data/truncate_dataset.py b/fairseq/data/truncate_dataset.py deleted file mode 100644 index efd1c6d1cb..0000000000 --- a/fairseq/data/truncate_dataset.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np - -from . import BaseWrapperDataset - - -class TruncateDataset(BaseWrapperDataset): - - def __init__(self, dataset, truncation_length): - super().__init__(dataset) - assert truncation_length is not None - self.truncation_length = truncation_length - self.dataset = dataset - - def __getitem__(self, index): - item = self.dataset[index] - item_len = item.size(0) - if item_len > self.truncation_length: - item = item[:self.truncation_length] - return item - - @property - def sizes(self): - return np.minimum(self.dataset.sizes, self.truncation_length) - - def __len__(self): - return len(self.dataset) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 80ec26c1b7..82f41b2c73 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -23,9 +23,9 @@ StripTokenDataset, TokenBlockDataset, TransformEosDataset, - TruncateDataset, TruncatedDictionary, ) +from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import FairseqTask, register_task @@ -88,8 +88,12 @@ def add_args(parser): help='prepend beginning of sentence token ()') parser.add_argument('--max-target-positions', type=int, metavar='N', help='max number of tokens in the target sequence') - parser.add_argument('--truncate-sequence', action='store_true', default=False, - help='truncate sequences to --tokens-per-sample') + parser.add_argument('--shorten-method', default='none', + choices=['none', 'truncate', 'random_crop'], + help='if not none, shorten sequences that exceed --tokens-per-sample') + parser.add_argument('--shorten-data-split-whitelist', default='', + help='comma-separated list of dataset splits to apply shortening to, ' + 'e.g., "train,valid" (default: all dataset splits)') # fmt: on def __init__(self, args, dictionary, output_dictionary=None, targets=None): @@ -169,8 +173,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): "Dataset not found: {} ({})".format(split, split_path) ) - if self.args.truncate_sequence: - dataset = TruncateDataset(dataset, self.args.tokens_per_sample) + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_whitelist, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) dataset = TokenBlockDataset( dataset, diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 8089abf7d8..7f03e04fba 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -21,6 +21,7 @@ SortDataset, TokenBlockDataset, ) +from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import FairseqTask, register_task from fairseq.data.encoders.utils import get_whole_word_mask from fairseq import utils @@ -58,6 +59,12 @@ def add_args(parser): help='sample random replacement words based on word frequencies') parser.add_argument('--mask-whole-words', default=False, action='store_true', help='mask whole words; you may also want to set --bpe') + parser.add_argument('--shorten-method', default='none', + choices=['none', 'truncate', 'random_crop'], + help='if not none, shorten sequences that exceed --tokens-per-sample') + parser.add_argument('--shorten-data-split-whitelist', default='', + help='comma-separated list of dataset splits to apply shortening to, ' + 'e.g., "train,valid" (default: all dataset splits)') def __init__(self, args, dictionary): super().__init__(args) @@ -95,6 +102,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_whitelist, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 471baf428a..5cdfc97b7a 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -8,6 +8,7 @@ import numpy as np +from fairseq import utils from fairseq.data import ( ConcatSentencesDataset, data_utils, @@ -23,8 +24,8 @@ RollDataset, SortDataset, StripTokenDataset, - TruncateDataset, ) +from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import FairseqTask, register_task @@ -53,8 +54,12 @@ def add_args(parser): help='add separator token between inputs') parser.add_argument('--regression-target', action='store_true', default=False) parser.add_argument('--no-shuffle', action='store_true', default=False) - parser.add_argument('--truncate-sequence', action='store_true', default=False, - help='truncate sequence to max-positions') + parser.add_argument('--shorten-method', default='none', + choices=['none', 'truncate', 'random_crop'], + help='if not none, shorten sequences that exceed --tokens-per-sample') + parser.add_argument('--shorten-data-split-whitelist', default='', + help='comma-separated list of dataset splits to apply shortening to, ' + 'e.g., "train,valid" (default: all dataset splits)') parser.add_argument('--add-prev-output-tokens', action='store_true', default=False, help='add prev_output_tokens to sample, used for encoder-decoder arch') @@ -141,8 +146,14 @@ def make_dataset(type, dictionary): with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) - if self.args.truncate_sequence: - src_tokens = TruncateDataset(src_tokens, self.args.max_positions) + src_tokens = maybe_shorten_dataset( + src_tokens, + split, + self.args.shorten_data_split_whitelist, + self.args.shorten_method, + self.args.max_positions, + self.args.seed, + ) dataset = { 'id': IdDataset(), diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index 2438ab7404..7b667dfc86 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -8,6 +8,7 @@ import numpy as np +from fairseq import utils from fairseq.data import ( ConcatSentencesDataset, data_utils, @@ -20,8 +21,8 @@ RawLabelDataset, RightPadDataset, SortDataset, - TruncateDataset, ) +from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import FairseqTask, register_task @@ -49,8 +50,12 @@ def add_args(parser): parser.add_argument('--separator-token', type=int, help='add separator token between inputs') parser.add_argument('--no-shuffle', action='store_true') - parser.add_argument('--truncate-sequence', action='store_true', - help='Truncate sequence to max_positions') + parser.add_argument('--shorten-method', default='none', + choices=['none', 'truncate', 'random_crop'], + help='if not none, shorten sequences that exceed --tokens-per-sample') + parser.add_argument('--shorten-data-split-whitelist', default='', + help='comma-separated list of dataset splits to apply shortening to, ' + 'e.g., "train,valid" (default: all dataset splits)') parser.add_argument('--max-option-length', type=int, help='max length for each option') @@ -120,7 +125,14 @@ def make_dataset(type, dictionary): input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: - src_token = TruncateDataset(src_token, self.args.max_positions) + src_token = maybe_shorten_dataset( + src_token, + split, + self.args.shorten_data_split_whitelist, + self.args.shorten_method, + self.args.max_positions, + self.args.seed, + ) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): From e03bfd9bf447ff9b040887968cb6152e5bba3a4a Mon Sep 17 00:00:00 2001 From: Xilun Chen Date: Mon, 8 Jun 2020 10:14:23 -0700 Subject: [PATCH 020/707] Check if the checkpoint is from the latest version before updating the state_dict in TransformerDecoder.upgrade_state_dict_named() (#2222) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2222 When share_input_output_embed is set to True, the existing code always overrides output_projection.weight with embed_tokens.weight This is unncessary, and caused a very obscure bug in our custom BART model. Added a check to skip the update to state_dict if f"{name}.output_projection.weight" is already in the checkpoint. Reviewed By: myleott Differential Revision: D21915833 fbshipit-source-id: d298e24394be2ee85c8f686ba459b7e4cbd4298a --- fairseq/models/transformer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index b94a0db492..0275ecc10b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -819,14 +819,15 @@ def upgrade_state_dict_named(self, state_dict, name): "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) - if self.share_input_output_embed: - embed_out_key = f"{name}.embed_tokens.weight" - else: - embed_out_key = f"{name}.embed_out" - if embed_out_key in state_dict: - state_dict[f"{name}.output_projection.weight"] = state_dict[embed_out_key] - if not self.share_input_output_embed: - del state_dict[embed_out_key] + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[embed_out_key] + if not self.share_input_output_embed: + del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms From 5c4f0f89035c20dda080a6072ac56ae594815938 Mon Sep 17 00:00:00 2001 From: Gil Keren Date: Mon, 8 Jun 2020 10:50:33 -0700 Subject: [PATCH 021/707] Better exception handling for the data buffer thread Summary: When data-buffer-size != 0 was used, an exception happening in the data preparation (therefore in the buffer thread) was not raised properrly, and the main thread hanged on `queue.get`. This fixes it, by raising the error to the main thread. Reviewed By: myleott Differential Revision: D21917739 fbshipit-source-id: 8d3f875b663b37625f44a943fb3904e25216db06 --- fairseq/data/iterators.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index ef7d5fbfa8..b9d74e2946 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -423,11 +423,14 @@ def __init__(self, queue, source): self._source = source def run(self): - for item in self._source: - self._queue.put(item) + try: + for item in self._source: + self._queue.put(item) - # Signal the consumer we are done. - self._queue.put(_sentinel) + # Signal the consumer we are done. + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) class BufferedIterator(object): @@ -462,6 +465,8 @@ def __next__(self): # Get next example item = self._queue.get(True) + if isinstance(item, Exception): + raise item if item is _sentinel: raise StopIteration() return item From 2e1da09a9c374d0a909e623054a33e9907bc0d82 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Tue, 9 Jun 2020 13:01:31 -0700 Subject: [PATCH 022/707] Updates argument to np.arange to avoid performing floor division using torch.div Summary: Performing floor division with torch.div is deprecated and will soon throw a runtime error. Perhaps surprisingly, calling np.arange on a torch tensor can use torch.div to perform floor division. Taking the number from the tensor using .item() should prevent this issue and keep this code working. Created from Diffusion's 'Open in Editor' feature. Reviewed By: lematt1991 Differential Revision: D21941120 fbshipit-source-id: 4d76451d4b33d487946af1c2f9ed21eca858cb06 --- fairseq/data/noising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py index bd67e7336c..5801ae6eac 100644 --- a/fairseq/data/noising.py +++ b/fairseq/data/noising.py @@ -175,7 +175,7 @@ def noising(self, x, lengths, max_shuffle_distance=None): # generate a random permutation scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i] # ensure no reordering inside a word - scores += 1e-6 * np.arange(length_no_eos) + scores += 1e-6 * np.arange(length_no_eos.item()) permutation = scores.argsort() # shuffle words x2[:length_no_eos, i].copy_( From 242269d439dc9df346c8aaf7947aad4581d1894d Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Wed, 10 Jun 2020 11:56:43 -0700 Subject: [PATCH 023/707] Fix truncation in sentence_ranking (#1185) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes current breaking change on master. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1185 Reviewed By: myleott Differential Revision: D21924644 Pulled By: joshim5 fbshipit-source-id: 0eabd2393c76060dcf1568eba308878a90af7a87 --- fairseq/tasks/sentence_ranking.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index 7b667dfc86..ea2d22c181 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -21,6 +21,7 @@ RawLabelDataset, RightPadDataset, SortDataset, + TruncateDataset ) from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import FairseqTask, register_task @@ -124,15 +125,14 @@ def make_dataset(type, dictionary): if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) - if self.args.truncate_sequence: - src_token = maybe_shorten_dataset( - src_token, - split, - self.args.shorten_data_split_whitelist, - self.args.shorten_method, - self.args.max_positions, - self.args.seed, - ) + src_token = maybe_shorten_dataset( + src_token, + split, + self.args.shorten_data_split_whitelist, + self.args.shorten_method, + self.args.max_positions, + self.args.seed, + ) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): From 86edf989dd6a53827d509ad268e3f333261e2425 Mon Sep 17 00:00:00 2001 From: Yongqiang Wang Date: Fri, 12 Jun 2020 22:11:13 -0700 Subject: [PATCH 024/707] cast grad_norm to float in case fp16 training Summary: we found that grad_norm could become inf because it is accumulated in meter many times; and fp16 it becomes easy to overflow. Using fp32 for each `grad_norm` cost minimum memory Reviewed By: myleott Differential Revision: D22015643 fbshipit-source-id: 429d24bbb9c9a785edf0bfb06480497022f80418 --- fairseq/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/utils.py b/fairseq/utils.py index da27529268..c83770a593 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -262,9 +262,11 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: return torch.tensor(0.) if len(grads) == 1: - total_norm = torch.norm(grads[0]) + total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: - total_norm = torch.norm(torch.stack([torch.norm(g) for g in grads])) + total_norm = torch.norm( + torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) + ) if aggregate_norm_fn is not None: total_norm = aggregate_norm_fn(total_norm) From 8570277f91d6bed03d71cc9c8326f096cd06b0d2 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Mon, 15 Jun 2020 15:26:24 -0700 Subject: [PATCH 025/707] dataset sampling for minibatch training Summary: **Motivation:** We have 3 datasets: Portal, Video, and Messenger Voice Clips. We want to specify a distribution [p1, p2, p3] such that we sample utterances from Portal with prob p1, etc. Previously, D21675421 samples from datasets by **batches**. This is not acceptable for minibatch training, as we need to maintain LSTM states across consecutive batches. As a result, we need utterance level sampling, not batch level. **Design** 1. Created a new MultiCorpusDataset, similar to MultiCorpusSampledDataset, except it does sampling on utterance level. Specifically, everytime `ordered_indices` is called, a new sample of the multiple datasets is generated based on an input distribution. We ensure that the randomness of this is seeded by the input seed and the epoch, to enable reproducibility on loading from checkpoints. 2. Created MiniBatchMultiCorpusDataset, which adds minibatch specific logic to MultiCorpusDataset, mainly for handling things like the start frame and deleting cache. 3. Refactored different sampling strategies into a single `build_sampled_dataset` for easy re-use. 4. Added flag --reset-iterator, enabling us to reset the batch iterator every epoch, enabling a new `ordered_indices` to be generated every epoch 5. some minor refactoring of existing code **Usage** 1. In your data.json, include extra splits in addition to "train" (i.e. "portal", "video"), with whatever transforms/handle file you want. 2. In your flow, provide "--extra-splits portal 0.2 video 0.3" and "--reset-iterator" as flags 3. Enjoy WER improvements Differential Revision: D21887303 fbshipit-source-id: 6b377bed8a68a8e72e2528f8a5a28b675eebaadf --- fairseq/data/multi_corpus_dataset.py | 149 +++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 fairseq/data/multi_corpus_dataset.py diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py new file mode 100644 index 0000000000..02d269a17c --- /dev/null +++ b/fairseq/data/multi_corpus_dataset.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import OrderedDict +from typing import Dict, List + +import numpy as np +from fairseq.data import data_utils + +from . import FairseqDataset + + +logger = logging.getLogger(__name__) + + +class MultiCorpusDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together. Requires each instance + to be the same dataset, as the collate method needs to work on batches with + samples from each dataset. + + Allows specifying a distribution over the datasets to use. Note that unlike + MultiCorpusSampledDataset, this distribution allows sampling for each item, + rather than on a batch level. + + Each time ordered_indices() is called, a new sample is generated with + the specified distribution. + + Args: + datasets: a OrderedDict of FairseqDataset instances. + distribution: a List containing the probability of getting an utterance from + corresponding dataset + """ + + def __init__( + self, datasets: Dict[str, FairseqDataset], distribution: List[float], seed: int + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + assert len(datasets) == len(distribution) + self.datasets = datasets + self.distribution = distribution + self.seed = seed + + # Avoid repeated conversions to list later + self.dataset_list = list(datasets.values()) + self.total_num_instances = 0 + + first_dataset = list(self.datasets.values())[0] + + self.dataset_offsets = [] + for dataset in datasets.values(): + assert isinstance(dataset, FairseqDataset) + assert type(dataset) is type(first_dataset) + self.dataset_offsets.append(self.total_num_instances) + self.total_num_instances += len(dataset) + + def ordered_indices(self): + with data_utils.numpy_seed(self.seed, self.epoch): + # Used to store the order of indices of each dataset to use + indices = [ + np.random.permutation(len(dataset)) + for dataset in self.datasets.values() + ] + # Keep track of which samples we've used for each dataset + counters = [0 for _ in self.datasets] + + return np.array( + [ + self._sample(indices, counters) + for _ in range(self.total_num_instances) + ], + dtype=np.int64, + ) + + def _sample(self, indices, counters): + # First pick dataset + dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) + + # Then get dataset internal index + idx = indices[dataset_idx][counters[dataset_idx]] + + # Convert to multi-datasets index + idx += self.dataset_offsets[dataset_idx] + + counters[dataset_idx] += 1 + + # Reset if we reach end + if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): + counters[dataset_idx] = 0 + indices[dataset_idx] = np.random.permutation( + len(self.dataset_list[dataset_idx]) + ) + + return idx + + def _map_index(self, index: int): + """ + If dataset A has length N and dataset B has length M + then index 1 maps to index 1 of dataset A, and index N + 1 + maps to index 1 of B. + """ + counter = 0 + for key, dataset in self.datasets.items(): + if index < counter + len(dataset): + return index - counter, key + counter += len(dataset) + raise ValueError( + "Invalid index: {}, max: {}".format(index, self.total_num_instances) + ) + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + def __getitem__(self, index): + index, key = self._map_index(index) + return self.datasets[key][index] + + def collater(self, samples): + """ + Since we enforce all datsets to be the same, collating is just + picking the first one and doing collate. + """ + if len(samples) == 0: + return None + + return list(self.datasets.values())[0].collater(samples) + + def num_tokens(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].num_tokens(index) + + def size(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].size(index) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @property + def supports_prefetch(self): + return False From 14ee059a36092a4216ec38601ca32be11383eecb Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 16 Jun 2020 11:45:10 -0700 Subject: [PATCH 026/707] Dataloading fixes (#1189) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1189 Reviewed By: ngoyal2707 Differential Revision: D22052683 Pulled By: myleott fbshipit-source-id: afdfda291907ad4441af51cfc9e44f1bd01ea696 --- fairseq/data/iterators.py | 86 ++++++++++++++++++++++++++++++--------- fairseq/options.py | 2 +- fairseq/trainer.py | 29 ++++++++----- 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index b9d74e2946..8603085ec8 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -4,16 +4,18 @@ # LICENSE file in the root directory of this source tree. import itertools +import logging import math import operator import os +import queue import time +from threading import Thread + import numpy as np import torch -import queue -import logging -from threading import Thread -from . import data_utils + +from fairseq.data import data_utils logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -362,24 +364,68 @@ class GroupedIterator(CountingIterator): """ def __init__(self, iterable, chunk_size): - itr = _chunk_iterator(iterable, chunk_size) - super().__init__( - itr, - start=int(math.ceil(getattr(iterable, 'n', 0) / float(chunk_size))), - total=int(math.ceil(len(iterable) / float(chunk_size))), - ) self.chunk_size = chunk_size + n = getattr(iterable, 'n', 0) + itr = ichunked( + iterable, + chunk_size, + remaining=(len(iterable) - n), + ) + start = int(math.ceil(n / float(chunk_size))) + total = int(math.ceil(len(iterable) / float(chunk_size))) + super().__init__(itr, start=start, total=total) + + +class IndexableIterator(object): + + def __init__(self, iterable, length): + self.iterable = iterable + self.itr = iter(self) + self.n = length + self._cache = [] + + def __len__(self): + return self.n + + def __getitem__(self, index): + if index >= self.n: + raise IndexError + while len(self._cache) <= index: + self._cache.append(next(self.iterable)) + return self._cache[index] + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + def __next__(self): + return next(self.itr) + + def __eq__(self, other): + if len(self) != len(other): + return False + for i in range(len(self)): + if self[i] != other[i]: + return False + return True + + +def ichunked(iterable, n, remaining=None): + """Adapted from more_itertools.ichunked""" + if remaining is None: + remaining = len(iterable) + source = iter(iterable) + while remaining > 0: + item = next(source) + + # Clone the source and yield an n-length slice + source, it = itertools.tee(itertools.chain([item], source)) + yield IndexableIterator(itertools.islice(it, n), min(remaining, n)) -def _chunk_iterator(itr, chunk_size): - chunk = [] - for x in itr: - chunk.append(x) - if len(chunk) == chunk_size: - yield chunk - chunk = [] - if len(chunk) > 0: - yield chunk + # Advance the source iterable + next(itertools.islice(source, n, n), None) + remaining = max(0, remaining - n) class ShardedIterator(CountingIterator): @@ -453,7 +499,7 @@ def __len__(self): def __next__(self): # Notify the user if there is a data loading bottleneck - if self._queue.qsize() < 2: + if self._queue.qsize() < max(1, self._queue.maxsize // 2): if time.time() - self.start_time > 5 * 60: if self.warning_time is None or time.time() - self.warning_time > 15 * 60: logger.info( diff --git a/fairseq/options.py b/fairseq/options.py index 092e1e1cb3..fe92e35e70 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -331,7 +331,7 @@ def add_dataset_args(parser, train=False, gen=False): choices=get_available_dataset_impl(), help='output dataset implementation') group.add_argument('--data-buffer-size', default=2, type=int, metavar='N', - help='Number of batches to preload') + help='number of batches to preload') if train: group.add_argument('--train-subset', default='train', metavar='SPLIT', help='data subset to use for training (e.g. train, valid, test)') diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 269e43313d..342555f390 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -278,6 +278,17 @@ def load_checkpoint( ) ) + # handle changed world size + cpt_world_size = state["args"].distributed_world_size + if cpt_world_size != self.args.distributed_world_size: + logger.info("world size changed from checkpoint: {} -> {}".format( + cpt_world_size, self.args.distributed_world_size + )) + old_iters = extra_state["train_iterator"]["iterations_in_epoch"] + extra_state["train_iterator"]["iterations_in_epoch"] = int( + old_iters * cpt_world_size / self.args.distributed_world_size + ) + self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: @@ -485,7 +496,8 @@ def maybe_no_sync(): # take an optimization step self.optimizer.step() except FloatingPointError: - # re-run the forward and backward pass with hooks attached to print out where it fails + # re-run the forward and backward pass with hooks attached to print + # out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), @@ -837,15 +849,6 @@ def _fast_stat_sync_sum( logging_outputs = [] return logging_outputs, extra_stats_to_sum - def _is_grad_norms_consistent(self, grad_norm_buf): - """check whether a given tensor (shape (N,)) is consistent """ - """consistent means all the values are diff within a tolerate range""" - diff = grad_norm_buf - grad_norm_buf[0] - max_abs_diff = torch.max(torch.abs(diff)).item() - first_grad_norm = grad_norm_buf[0].item() - # TODO: make 1e-6 a configurable value - return max_abs_diff / (first_grad_norm + 1e-6) < 1e-6 - def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: @@ -856,7 +859,11 @@ def _check_grad_norms(self, grad_norm): group=self.data_parallel_process_group ) - if not self._is_grad_norms_consistent(self._grad_norm_buf): + def is_consistent(tensor): + max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) + return (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + + if not is_consistent(self._grad_norm_buf): pretty_detail = "\n".join( "rank {:3d} = {:.8f}".format(r, n) for r, n in enumerate(self._grad_norm_buf.tolist()) From 3c16b002b94f22a62cbc257b5d339b6d9f4d5a07 Mon Sep 17 00:00:00 2001 From: Rohit Kopparthy Date: Tue, 16 Jun 2020 12:15:00 -0700 Subject: [PATCH 027/707] Scripting ConvTransformer Summary: This diff is building off of D21986239 to script the ConvTransformer Model instead of the VggTransformer. The changes made in data_utils.py were copied over from D20443519. A new file called test_convtransformer.py was added to test scripting the model. The scripted model compiles and also produces the same output as before scripting. Reviewed By: myleott Differential Revision: D22022654 fbshipit-source-id: 8f5a36a9af391142b468818650be3af218235fc2 --- fairseq/models/transformer.py | 59 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 0275ecc10b..47ba77a503 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -267,9 +267,7 @@ def forward( which are not supported by TorchScript. """ encoder_out = self.encoder( - src_tokens, - src_lengths=src_lengths, - return_all_hiddens=return_all_hiddens, + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens ) decoder_out = self.decoder( prev_output_tokens, @@ -346,10 +344,9 @@ def __init__(self, args, dictionary, embed_tokens): self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) - self.layers.extend([ - self.build_encoder_layer(args) - for i in range(args.encoder_layers) - ]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) self.num_layers = len(self.layers) if args.encoder_normalize_before: @@ -376,12 +373,7 @@ def forward_embedding(self, src_tokens): x = self.quant_noise(x) return x, embed - def forward( - self, - src_tokens, - src_lengths, - return_all_hiddens: bool = False, - ): + def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape @@ -586,13 +578,17 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) - self.layers.extend([ - self.build_decoder_layer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ]) + self.layers.extend( + [ + self.build_decoder_layer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) self.num_layers = len(self.layers) - if args.decoder_normalize_before and not getattr(args, "no_decoder_final_norm", False): + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None @@ -679,6 +675,29 @@ def extract_features( full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + ''' + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. Aa copy of + this function is made to be used in the subclass instead. + ''' + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. @@ -825,7 +844,9 @@ def upgrade_state_dict_named(self, state_dict, name): else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: - state_dict[f"{name}.output_projection.weight"] = state_dict[embed_out_key] + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] if not self.share_input_output_embed: del state_dict[embed_out_key] From c294e2fcfb299290e53023d1e7cf3a53d27195a4 Mon Sep 17 00:00:00 2001 From: Yongqiang Wang Date: Wed, 17 Jun 2020 13:45:19 -0700 Subject: [PATCH 028/707] print out all the CUDA environment information (including name, memory size, Summary: Recently, we found there are more and more likely that different generations (V100 vs P100) / memory size (16GB, 32GB) GPUs are mixed up in training, while the users do not even know about this. Print out this message can be helpful for debugging Reviewed By: myleott Differential Revision: D21782630 fbshipit-source-id: 7e1075e1b928d969594bbee92275a819cf1a0877 --- fairseq/trainer.py | 13 +++++++++++++ fairseq/utils.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 342555f390..64fbb934d3 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -95,6 +95,19 @@ def __init__(self, args, task, model, criterion, quantizer=None): if self.quantizer is not None: self.quantizer.set_trainer(self) + # get detailed cuda environment + if self.cuda: + self.cuda_env = utils.CudaEnvironment() + if self.data_parallel_world_size > 1: + self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env) + else: + self.cuda_env_arr = [self.cuda_env] + if self.data_parallel_rank == 0: + utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) + else: + self.cuda_env = None + self.cuda_env_arr = None + metrics.log_start_time("wall", priority=790, round=0) def reinitialize(self): diff --git a/fairseq/utils.py b/fairseq/utils.py index c83770a593..aecc68d52d 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -517,3 +517,39 @@ def new_arange(x, *size): def get_tpu_device(args): import torch_xla.core.xla_model as xm return xm.xla_device() + + +def logging_multiple_line_messages(msg): + msg_arr = msg.split("\n") + for line in msg_arr: + logger.info(line) + + +class CudaEnvironment(object): + def __init__(self): + cur_device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) + self.name = prop.name + self.major = prop.major + self.minor = prop.minor + self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 + + @staticmethod + def pretty_print_cuda_env_list(cuda_env_list): + """ + Given a list of CudaEnviorments, pretty print them + """ + num_workers = len(cuda_env_list) + center = "CUDA enviroments for all {} workers".format(num_workers) + banner_len = 40 - len(center) // 2 + first_line = "*" * banner_len + center + "*" * banner_len + msg_arr = [first_line] + for r, env in enumerate(cuda_env_list): + msg_arr.append( + "rank {:3d}: ".format(r) + + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + + "name = {:40s}".format(env.name) + ) + msg_arr.append(first_line) + logging_multiple_line_messages("\n".join(msg_arr)) From 82f99df8e4f54cc91ec1dffa7e5cc506c91fb465 Mon Sep 17 00:00:00 2001 From: Gil Keren Date: Wed, 17 Jun 2020 18:48:29 -0700 Subject: [PATCH 029/707] Gradually releasing the restrictions on data-buffer-size Summary: the buffer was a suspect in creating some everstore overload, therefore was restricted in D21804332. But since it's part in those problems was inconclusive, and the everstore read limit was increased for the speech group, gradually increasing it back. Differential Revision: D22076534 fbshipit-source-id: cb01d50d4df5843b86f7d730e1805a88ea3f41d8 --- fairseq/data/iterators.py | 2 +- fairseq/options.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 8603085ec8..cf40ae7435 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -215,7 +215,7 @@ def __init__( self.num_workers = num_workers # This upper limit here is to prevent people from abusing this feature # in a shared computing environment. - self.buffer_size = min(buffer_size, 5) + self.buffer_size = min(buffer_size, 20) self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = True diff --git a/fairseq/options.py b/fairseq/options.py index fe92e35e70..c2b00b3a08 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -330,7 +330,7 @@ def add_dataset_args(parser, train=False, gen=False): parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') - group.add_argument('--data-buffer-size', default=2, type=int, metavar='N', + group.add_argument('--data-buffer-size', default=10, type=int, metavar='N', help='number of batches to preload') if train: group.add_argument('--train-subset', default='train', metavar='SPLIT', From d617c292f8101e8c62d6b0660bc321ef3e43a138 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 18 Jun 2020 18:36:34 -0700 Subject: [PATCH 030/707] Apply black formatter to fairseq_cli/train.py Differential Revision: D22125634 fbshipit-source-id: a05f483ac4b564f5d7a21f5ae3605615e7fcd263 --- fairseq_cli/train.py | 144 +++++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 869b913230..ca6ae798c8 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -15,7 +15,6 @@ import numpy as np import torch - from fairseq import ( checkpoint_utils, distributed_utils, @@ -26,28 +25,29 @@ ) from fairseq.data import iterators from fairseq.logging import meters, metrics, progress_bar -from fairseq.trainer import Trainer from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from fairseq.trainer import Trainer logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stdout, ) -logger = logging.getLogger('fairseq_cli.train') +logger = logging.getLogger("fairseq_cli.train") def main(args, init_distributed=False): utils.import_user_module(args) - assert args.max_tokens is not None or args.max_sentences is not None, \ - 'Must specify batch size either with --max-tokens or --max-sentences' + assert ( + args.max_tokens is not None or args.max_sentences is not None + ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() # Initialize CUDA and distributed training - if torch.cuda.is_available() and not args.cpu and not getattr(args, 'tpu', False): + if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): torch.cuda.set_device(args.device_id) np.random.seed(args.seed) utils.set_torch_seed(args.seed) @@ -64,18 +64,22 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(','): + for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) - logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) - logger.info('num. model params: {} (num. trained: {})'.format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - )) + logger.info( + "model {}, criterion {}".format(args.arch, criterion.__class__.__name__) + ) + logger.info( + "num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) + ) # (optionally) Configure quantization if args.quantization_config_path is not None: @@ -93,18 +97,22 @@ def main(args, init_distributed=False): else: trainer = MegatronTrainer(args, task, model, criterion) - logger.info('training on {} devices (GPUs/TPUs)'.format(args.distributed_world_size)) - logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format( - args.max_tokens, - args.max_sentences, - )) + logger.info( + "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) + ) + logger.info( + "max tokens per GPU = {} and max sentences per GPU = {}".format( + args.max_tokens, args.max_sentences + ) + ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) if args.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous('load_checkpoint') # wait for all workers + + xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # Train until the learning rate gets too small @@ -112,10 +120,7 @@ def main(args, init_distributed=False): lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while ( - lr > args.min_lr - and epoch_itr.next_epoch_idx <= max_epoch - ): + while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: @@ -127,10 +132,10 @@ def main(args, init_distributed=False): epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch - load_dataset=(os.pathsep in getattr(args, 'data', '')), + load_dataset=(os.pathsep in getattr(args, "data", "")), ) train_meter.stop() - logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) + logger.info("done training in {:.1f} seconds".format(train_meter.sum)) def should_stop_early(args, valid_loss): @@ -143,7 +148,7 @@ def should_stop_early(args, valid_loss): def is_better(a, b): return a > b if args.maximize_best_checkpoint_metric else a < b - prev_best = getattr(should_stop_early, 'best', None) + prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): should_stop_early.best = valid_loss should_stop_early.num_runs = 0 @@ -151,7 +156,11 @@ def is_better(a, b): else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= args.patience: - logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + args.patience + ) + ) return True else: return False @@ -160,17 +169,18 @@ def is_better(a, b): def tpu_data_loader(args, itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl - xm.rendezvous('tpu_data_loader') # wait for all workers + + xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), - start=getattr(itr, 'n', 0), + start=getattr(itr, "n", 0), total=len(itr), ) -@metrics.aggregate('train') +@metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator @@ -184,7 +194,7 @@ def train(args, trainer, task, epoch_itr): else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, @@ -194,15 +204,15 @@ def train(args, trainer, task, epoch_itr): tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), - default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) - valid_subsets = args.valid_subset.split(',') + valid_subsets = args.valid_subset.split(",") should_stop = False for samples in progress: - with metrics.aggregate('train_inner'): + with metrics.aggregate("train_inner"): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue @@ -210,12 +220,12 @@ def train(args, trainer, task, epoch_itr): # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: - stats = get_training_stats(metrics.get_smoothed_values('train_inner')) - progress.log(stats, tag='train_inner', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved - metrics.reset_meters('train_inner') + metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( @@ -225,31 +235,25 @@ def train(args, trainer, task, epoch_itr): break # log end-of-epoch stats - stats = get_training_stats(metrics.get_smoothed_values('train')) - progress.print(stats, tag='train', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values("train")) + progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters - metrics.reset_meters('train') + metrics.reset_meters("train") return valid_losses, should_stop def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): num_updates = trainer.get_num_updates() do_save = ( - ( - args.save_interval_updates > 0 - and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - ) - or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) - ) + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) do_validate = ( - ( - (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - ) - and not args.disable_validation - ) + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + ) and not args.disable_validation # Validate valid_losses = [None] @@ -271,7 +275,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc def get_training_stats(stats): - stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) + stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats @@ -286,7 +290,7 @@ def validate(args, trainer, task, epoch_itr, subsets): for subset in subsets: # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, @@ -297,7 +301,7 @@ def validate(args, trainer, task, epoch_itr, subsets): tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), - default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics @@ -315,13 +319,12 @@ def validate(args, trainer, task, epoch_itr, subsets): def get_valid_stats(args, trainer, stats): - stats['num_updates'] = trainer.get_num_updates() - if hasattr(checkpoint_utils.save_checkpoint, 'best'): - key = 'best_{0}'.format(args.best_checkpoint_metric) + stats["num_updates"] = trainer.get_num_updates() + if hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(args.best_checkpoint_metric) best_function = max if args.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, - stats[args.best_checkpoint_metric], + checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] ) return stats @@ -353,29 +356,26 @@ def cli_main(modify_parser=None): else: distributed_main(args.device_id, args) elif args.distributed_world_size > 1: - if not getattr(args, 'tpu', False): + if not getattr(args, "tpu", False): # fallback for single node with multiple GPUs assert args.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_init_method = "tcp://localhost:{port}".format(port=port) args.distributed_rank = None # set based on device id torch.multiprocessing.spawn( - fn=distributed_main, - args=(args, ), - nprocs=args.distributed_world_size, + fn=distributed_main, args=(args,), nprocs=args.distributed_world_size ) else: import torch_xla.distributed.xla_multiprocessing as xmp - torch.multiprocessing.set_sharing_strategy('file_system') + + torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( - fn=distributed_main, - args=(args, ), - nprocs=8, # use all 8 TPU cores + fn=distributed_main, args=(args,), nprocs=8 # use all 8 TPU cores ) else: # single GPU training main(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() From 8d8d773c5724f4a153216184821e512906c86149 Mon Sep 17 00:00:00 2001 From: Rohit Kopparthy Date: Fri, 19 Jun 2020 10:50:22 -0700 Subject: [PATCH 031/707] Set EncoderOut Attributes to None instead of torch.empty(0) Summary: The ConvTransformer model throws an error during training because of certain attributes having been changed to torch.empty(0) instead of None to meed torchscript type requirements. Existing assertion checks only check if these attributes are not None, rather than not torch.empty(0). To fix this, types have been modified to Optional types and allowed to stay as None like before. Reviewed By: zhengwy888 Differential Revision: D22115126 fbshipit-source-id: de3c7b64c5e7142c860a354f778b8b818a7b0bb8 --- fairseq/models/fairseq_encoder.py | 4 ++-- fairseq/models/transformer.py | 37 ++++++++++++++++++------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py index 62b0b04382..9c73633572 100644 --- a/fairseq/models/fairseq_encoder.py +++ b/fairseq/models/fairseq_encoder.py @@ -12,8 +12,8 @@ "EncoderOut", [ ("encoder_out", Tensor), # T x B x C - ("encoder_padding_mask", Tensor), # B x T - ("encoder_embedding", Tensor), # B x T x C + ("encoder_padding_mask", Optional[Tensor]), # B x T + ("encoder_embedding", Optional[Tensor]), # B x T x C ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] ("src_tokens", Optional[Tensor]), # B x T ("src_lengths", Optional[Tensor]), # B x 1 diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 47ba77a503..352db5a293 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -436,22 +436,28 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - new_encoder_out: Dict[str, Tensor] = {} + """ + Since encoder_padding_mask and encoder_embedding are both of type + Optional[Tensor] in EncoderOut, they need to be copied as local + variables for Torchscript Optional refinement + """ + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask + encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding - new_encoder_out["encoder_out"] = ( + new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) - new_encoder_out["encoder_padding_mask"] = ( - encoder_out.encoder_padding_mask - if encoder_out.encoder_padding_mask is None - else encoder_out.encoder_padding_mask.index_select(0, new_order) + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(0, new_order) ) - new_encoder_out["encoder_embedding"] = ( - encoder_out.encoder_embedding - if encoder_out.encoder_embedding is None - else encoder_out.encoder_embedding.index_select(0, new_order) + new_encoder_embedding = ( + encoder_embedding + if encoder_embedding is None + else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: @@ -467,9 +473,9 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( - encoder_out=new_encoder_out["encoder_out"], # T x B x C - encoder_padding_mask=new_encoder_out["encoder_padding_mask"], # B x T - encoder_embedding=new_encoder_out["encoder_embedding"], # B x T x C + encoder_out=new_encoder_out, # T x B x C + encoder_padding_mask=new_encoder_padding_mask, # B x T + encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 @@ -685,11 +691,12 @@ def extract_features( alignment_heads, ) - ''' + """ A scriptable subclass of this class has an extract_features method and calls super().extract_features, but super() is not supported in torchscript. Aa copy of this function is made to be used in the subclass instead. - ''' + """ + def extract_features_scriptable( self, prev_output_tokens, From 6f6461b81ac457b381669ebc8ea2d80ea798e53a Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Fri, 19 Jun 2020 16:21:53 -0700 Subject: [PATCH 032/707] Add tracepoints (#1192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: There is no overhead when the profiling is not enabled. When running using profile.py, I measure an overhead of 3%. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1192 Reviewed By: sidgoyal78 Differential Revision: D22102341 Pulled By: msbaines fbshipit-source-id: ffddb9cceb853df88db34195be18bae7723d4c98 --- fairseq/options.py | 1 + fairseq/tasks/fairseq_task.py | 6 ++++-- fairseq/trainer.py | 29 ++++++++++++++++------------- fairseq_cli/train.py | 12 ++++++++++-- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index c2b00b3a08..9ccd58941a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -250,6 +250,7 @@ def get_parser(desc, default_task="translation"): help='suffix to add to the checkpoint file name') parser.add_argument('--quantization-config-path', default=None, help='path to quantization config file') + parser.add_argument('--profile', action='store_true', help='enable autograd profiler emit_nvtx') from fairseq.registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 5036cfe293..bd9d75abd6 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -339,10 +339,12 @@ def train_step( """ model.train() model.set_num_updates(update_num) - loss, sample_size, logging_output = criterion(model, sample) + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, sample) if ignore_grad: loss *= 0 - optimizer.backward(loss) + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) return loss, sample_size, logging_output def valid_step(self, sample, model, criterion): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 64fbb934d3..b23de7ecd4 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -486,17 +486,19 @@ def maybe_no_sync(): gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) - # multiply gradients by (# GPUs / sample_size) since DDP - # already normalizes by the number of GPUs. Thus we get - # (sum_of_gradients / sample_size). - if not self.args.use_bmuf: - self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) - elif sample_size > 0: # BMUF needs to check sample size - num = self.data_parallel_world_size if self._sync_stats() else 1 - self.optimizer.multiply_grads(num / sample_size) - - # clip grads - grad_norm = self.clip_grad_norm(self.args.clip_norm) + with torch.autograd.profiler.record_function("multiply-grads"): + # multiply gradients by (# GPUs / sample_size) since DDP + # already normalizes by the number of GPUs. Thus we get + # (sum_of_gradients / sample_size). + if not self.args.use_bmuf: + self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) + elif sample_size > 0: # BMUF needs to check sample size + num = self.data_parallel_world_size if self._sync_stats() else 1 + self.optimizer.multiply_grads(num / sample_size) + + with torch.autograd.profiler.record_function("clip-grads"): + # clip grads + grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if ( @@ -506,8 +508,9 @@ def maybe_no_sync(): ): self._check_grad_norms(grad_norm) - # take an optimization step - self.optimizer.step() + with torch.autograd.profiler.record_function("optimizer"): + # take an optimization step + self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ca6ae798c8..fd9566b719 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -211,8 +211,8 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False - for samples in progress: - with metrics.aggregate("train_inner"): + for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue @@ -339,7 +339,15 @@ def distributed_main(i, args, start_rank=0): def cli_main(modify_parser=None): parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + if args.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + cli_main_helper(args) + else: + cli_main_helper(args) + +def cli_main_helper(args): if args.distributed_init_method is None: distributed_utils.infer_init_method(args) From 3ea511d89936caab6e7bf605366152b96cd95bcd Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Jun 2020 06:47:35 -0700 Subject: [PATCH 033/707] Revert Dataloader changes Summary: D22052683 may have introduced a memory leak, revert those parts for now The original motivation is described here: https://github.com/pytorch/fairseq/issues/2168. Previously I/O was bursty when training with large update frequency. This meant to even it out, but possibly introduced a memory leak. More context on the change can be found here: https://github.com/pytorch/fairseq/issues/2168 Reviewed By: yqwangustc Differential Revision: D22156157 fbshipit-source-id: 390ff39bc3e268d6312971768c34fe44d4bd84b7 --- fairseq/data/iterators.py | 75 +++++++++------------------------------ fairseq/trainer.py | 11 ------ 2 files changed, 16 insertions(+), 70 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index cf40ae7435..e909c44aa5 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -17,6 +17,7 @@ from fairseq.data import data_utils + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -364,68 +365,24 @@ class GroupedIterator(CountingIterator): """ def __init__(self, iterable, chunk_size): - self.chunk_size = chunk_size - - n = getattr(iterable, 'n', 0) - itr = ichunked( - iterable, - chunk_size, - remaining=(len(iterable) - n), + itr = _chunk_iterator(iterable, chunk_size) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, 'n', 0) / float(chunk_size))), + total=int(math.ceil(len(iterable) / float(chunk_size))), ) - start = int(math.ceil(n / float(chunk_size))) - total = int(math.ceil(len(iterable) / float(chunk_size))) - super().__init__(itr, start=start, total=total) - - -class IndexableIterator(object): - - def __init__(self, iterable, length): - self.iterable = iterable - self.itr = iter(self) - self.n = length - self._cache = [] - - def __len__(self): - return self.n - - def __getitem__(self, index): - if index >= self.n: - raise IndexError - while len(self._cache) <= index: - self._cache.append(next(self.iterable)) - return self._cache[index] - - def __iter__(self): - for i in range(len(self)): - yield self[i] + self.chunk_size = chunk_size - def __next__(self): - return next(self.itr) - def __eq__(self, other): - if len(self) != len(other): - return False - for i in range(len(self)): - if self[i] != other[i]: - return False - return True - - -def ichunked(iterable, n, remaining=None): - """Adapted from more_itertools.ichunked""" - if remaining is None: - remaining = len(iterable) - source = iter(iterable) - while remaining > 0: - item = next(source) - - # Clone the source and yield an n-length slice - source, it = itertools.tee(itertools.chain([item], source)) - yield IndexableIterator(itertools.islice(it, n), min(remaining, n)) - - # Advance the source iterable - next(itertools.islice(source, n, n), None) - remaining = max(0, remaining - n) +def _chunk_iterator(itr, chunk_size): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if len(chunk) > 0: + yield chunk class ShardedIterator(CountingIterator): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index b23de7ecd4..22edb4451e 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -291,17 +291,6 @@ def load_checkpoint( ) ) - # handle changed world size - cpt_world_size = state["args"].distributed_world_size - if cpt_world_size != self.args.distributed_world_size: - logger.info("world size changed from checkpoint: {} -> {}".format( - cpt_world_size, self.args.distributed_world_size - )) - old_iters = extra_state["train_iterator"]["iterations_in_epoch"] - extra_state["train_iterator"]["iterations_in_epoch"] = int( - old_iters * cpt_world_size / self.args.distributed_world_size - ) - self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: From 8eb9123f560d32940f96a01369d61c1684dce085 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Mon, 22 Jun 2020 10:01:21 -0700 Subject: [PATCH 034/707] Patch masked_lm memory leak on GPUs (#1195) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes memory leak in masked_lm criterion. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1195 Reviewed By: myleott Differential Revision: D22155285 Pulled By: joshim5 fbshipit-source-id: 9414e307e1e2d2a9225884dc94aae964a1627682 --- fairseq/criterions/masked_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index bb5b35b41f..80864693ec 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -62,7 +62,7 @@ def forward(self, model, sample, reduce=True): ) logging_output = { - 'loss': loss, + 'loss': loss if self.tpu else loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, From 320bf8cf963772fe09429b8fb9990371a0e22db3 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Mon, 22 Jun 2020 11:55:51 -0700 Subject: [PATCH 035/707] Updates full to no longer use deprecated integer fill_value type inference Summary: In PyTorch 1.5 using an integer fill_value and not setting the dtype or out kwarg with torch.full was deprecated, and soon will throw a runtime error. In the future, torch.full will infer its dtype from the fill_value, and these would produce integer, not float, tensors. This update maintains the current behavior. Created from Diffusion's 'Open in Editor' feature. Reviewed By: myleott Differential Revision: D22161456 fbshipit-source-id: b5d687e4de83dba6e76cae6e61b5106bf5b320db --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8d070a4758..35fb115dda 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -82,9 +82,9 @@ def test_clip_grad_norm_(self): params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)] for p in params: - p.grad = torch.full((5,), fill_value=2) + p.grad = torch.full((5,), fill_value=2.) grad_norm = utils.clip_grad_norm_(params, 1.0) - exp_grad_norm = torch.full((15,), fill_value=2).norm() + exp_grad_norm = torch.full((15,), fill_value=2.).norm() self.assertTrue(torch.is_tensor(grad_norm)) self.assertEqual(grad_norm, exp_grad_norm) From a9cb84df689d4e9343085d2434087c1b308a68a7 Mon Sep 17 00:00:00 2001 From: Tony Lekhtman Date: Mon, 22 Jun 2020 15:27:34 -0700 Subject: [PATCH 036/707] Update hub_utils.py (#2253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: fix bug for print_alignment # Before submitting - [ V] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ V] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? not relevant - [ ] Did you write any new necessary tests? not relevant ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1880 . ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2253 Reviewed By: huihuifan Differential Revision: D22162948 Pulled By: myleott fbshipit-source-id: 3ec5508506184a9effa330fbcd43ffe917b533c6 --- fairseq/hub_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 1b1091c881..c249eb23f5 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -189,7 +189,7 @@ def getarg(name, default): )) if hypo['alignment'] is not None and getarg('print_alignment', False): logger.info('A\t{}'.format( - ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) + ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in hypo['alignment']]) )) return outputs From e187f6e116c3926cfa693ee6440d277843d0972a Mon Sep 17 00:00:00 2001 From: Yi-Hsiu Liao Date: Mon, 22 Jun 2020 18:24:48 -0700 Subject: [PATCH 037/707] add maybe_no_sync for multilingual_translation task (#2238) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ x ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ x ] Did you write any new necessary tests? ## What does this PR do? This PR reduces unnecessary communication overhead between GPUs since we only need to sync up once for all lang-pairs. We see significant training speedup especially with large number of lang-pairs. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2238 Reviewed By: pipibjc Differential Revision: D22149086 Pulled By: myleott fbshipit-source-id: 6fff09e5a51b49bdcf5bc3986c0719b19d31c0a9 --- fairseq/tasks/multilingual_translation.py | 29 +++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 9a1315e1db..031a9c58fa 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -7,6 +7,7 @@ import logging import os +import contextlib import torch from fairseq import metrics, options @@ -265,13 +266,27 @@ def train_step(self, sample, model, criterion, optimizer, update_num, ignore_gra model.train() from collections import defaultdict agg_loss, agg_sample_size, agg_logging_output = 0., 0., defaultdict(float) - for lang_pair in self.model_lang_pairs: - if sample[lang_pair] is None or len(sample[lang_pair]) == 0: - continue - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) - if ignore_grad: - loss *= 0 - optimizer.backward(loss) + curr_lang_pairs = [ + lang_pair + for lang_pair in self.model_lang_pairs + if sample[lang_pair] is not None and len(sample[lang_pair]) != 0 + ] + + for idx, lang_pair in enumerate(curr_lang_pairs): + def maybe_no_sync(): + if ( + self.args.distributed_world_size > 1 + and hasattr(model, 'no_sync') + and idx < len(curr_lang_pairs) - 1 + ): + return model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + with maybe_no_sync(): + loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) agg_loss += loss.detach().item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size From 88c58b6718292be10311afc0aa7f829dc3fc0c27 Mon Sep 17 00:00:00 2001 From: gvskalyan Date: Mon, 22 Jun 2020 18:39:41 -0700 Subject: [PATCH 038/707] Preprocess dict number (#2228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2227 . ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2228 Reviewed By: huihuifan Differential Revision: D22163032 Pulled By: myleott fbshipit-source-id: a5afbfca2d9a11563026f47cd246654e131d92fb --- fairseq_cli/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index 5a60d2c611..b107b9fa18 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -108,7 +108,7 @@ def build_dictionary(filenames, src=False, tgt=False): tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): - logger.info("[{}] Dictionary: {} types".format(lang, len(vocab) - 1)) + logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) n_seq_tok = [0, 0] replaced = Counter() From d5d2cf3cd5f25c8a413328891c353c0379c22442 Mon Sep 17 00:00:00 2001 From: Ronan Riochet Date: Mon, 22 Jun 2020 18:40:50 -0700 Subject: [PATCH 039/707] Add timeout kwarg to EpochBatchIterator (#2261) Summary: Add an optional ```timeout``` argument to ```EpochBatchIterator```. I need it to fix this issue: https://github.com/pytorch/pytorch/issues/2474 I could do something more general, allowing one to pass ```**dataloader_kwargs``` to ```torch.utils.data.DataLoader```, if you think it's worth. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2261 Reviewed By: huihuifan Differential Revision: D22162936 Pulled By: myleott fbshipit-source-id: 959b408a53356c19c04fc5ae94aad5f164a32dcd --- fairseq/data/iterators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index e909c44aa5..23e4926fb9 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -200,11 +200,13 @@ class EpochBatchIterator(EpochBatchIterating): buffer_size (int, optional): the number of batches to keep ready in the queue. Helps speeding up dataloading. When buffer_size is zero, the default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, epoch=1, buffer_size=0 + num_workers=0, epoch=1, buffer_size=0, timeout=0, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -217,6 +219,7 @@ def __init__( # This upper limit here is to prevent people from abusing this feature # in a shared computing environment. self.buffer_size = min(buffer_size, 20) + self.timeout = timeout self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = True @@ -342,6 +345,7 @@ def shuffle_batches(batches, seed): collate_fn=self.collate_fn, batch_sampler=batches[offset:], num_workers=self.num_workers, + timeout=self.timeout, ) # Wrap with a BufferedIterator if needed From d0ccc3e02e1a9015d05cade8dfc61896948275c7 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Jun 2020 18:53:19 -0700 Subject: [PATCH 040/707] Add FairseqDecoder.reorder_incremental_state_scripting for TorchScript (#1190) Summary: The main changes are in fairseq_incremental_decoder.py. I made the base `reorder_incremental_state` implementation a no-op and instead we expect callers (e.g., SequenceGenerator) to call `reorder_incremental_state_scripting`. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1190 Test Plan: I ran unit tests both in PyTorch 1.5 and nightly (1.6). I also tested some of the pretrained translation models, but it'd be good to test with some prod runs. Reviewed By: jhcross Differential Revision: D22095614 Pulled By: myleott fbshipit-source-id: 484b8d47b4feda4efe52233a3d46a207d0816766 --- fairseq/models/fairseq_incremental_decoder.py | 34 +++++++--- fairseq/models/lstm.py | 63 +++++++++++-------- fairseq/models/transformer.py | 11 ---- fairseq/modules/transformer_layer.py | 12 ---- fairseq/sequence_generator.py | 2 +- 5 files changed, 63 insertions(+), 59 deletions(-) diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py index 51ab577288..68e583fea8 100644 --- a/fairseq/models/fairseq_incremental_decoder.py +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -2,11 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +import logging from typing import Dict, Optional +from torch import Tensor + from fairseq.models import FairseqDecoder from fairseq.incremental_decoding_utils import with_incremental_state -from torch import Tensor + + +logger = logging.getLogger(__name__) @with_incremental_state @@ -68,18 +74,28 @@ def reorder_incremental_state( ): """Reorder incremental state. - This should be called when the order of the input has changed from the + This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams. """ - seen: Dict[int, Optional[Tensor]] = {} - for _, module in self.named_modules(): + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): if hasattr(module, 'reorder_incremental_state'): - if id(module) not in seen and module is not self: - seen[id(module)] = None - result = module.reorder_incremental_state(incremental_state, new_order) - if result is not None: - incremental_state = result + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index c2fbde33a4..83baf7f065 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -347,6 +347,7 @@ def forward(self, input, source_hids, encoder_padding_mask): x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) return x, attn_scores + class LSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( @@ -410,18 +411,6 @@ def __init__( elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): - cached_state = self.get_incremental_state(incremental_state, 'cached_state') - assert cached_state is not None - prev_hiddens_ = cached_state["prev_hiddens"] - assert prev_hiddens_ is not None - prev_cells_ = cached_state["prev_cells"] - assert prev_cells_ is not None - prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] - prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models - return prev_hiddens, prev_cells, input_feed - def forward( self, prev_output_tokens, @@ -529,9 +518,13 @@ def extract_features( prev_cells_tensor = torch.stack(prev_cells) cache_state = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed}) - self.set_incremental_state( - incremental_state, 'cached_state', cache_state) + { + "prev_hiddens": prev_hiddens_tensor, + "prev_cells": prev_cells_tensor, + "input_feed": input_feed, + } + ) + self.set_incremental_state(incremental_state, 'cached_state', cache_state) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) @@ -559,23 +552,41 @@ def output_layer(self, x): x = self.fc_out(x) return x - def reorder_state(self, state: List[Tensor], new_order): - return [ - state_i.index_select(0, new_order) if state_i is not None else None - for state_i in state - ] + def get_cached_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: + cached_state = self.get_incremental_state(incremental_state, 'cached_state') + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state["input_feed"] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed - def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): if incremental_state is None or len(incremental_state) == 0: return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - cached_state = (prev_hiddens, prev_cells, [input_feed]) - new_state = [self.reorder_state(state, new_order) for state in cached_state] - prev_hiddens_tensor = torch.stack(new_state[0]) - prev_cells_tensor = torch.stack(new_state[1]) + prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] + prev_cells = [p.index_select(0, new_order) for p in prev_cells] + if input_feed is not None: + input_feed = input_feed.index_select(0, new_order) cached_state_new = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]}) + { + "prev_hiddens": torch.stack(prev_hiddens), + "prev_cells": torch.stack(prev_cells), + "input_feed": input_feed, + } + ) self.set_incremental_state(incremental_state, 'cached_state', cached_state_new), return diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 352db5a293..9171aaf4a2 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -882,17 +882,6 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict - # Overwrite the method to temporaily support JIT scripting in Transformer - @torch.jit.export - def reorder_incremental_state( - self, - incremental_state: Dict[str, Dict[str, Optional[Tensor]]], - new_order: Tensor, - ): - """Scriptable reorder incremental state in the transformer.""" - for layer in self.layers: - layer.reorder_incremental_state(incremental_state, new_order) - def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index cae8498315..8fb08b3aaf 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -379,18 +379,6 @@ def forward( def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn - @torch.jit.export - def reorder_incremental_state( - self, - incremental_state: Dict[str, Dict[str, Optional[Tensor]]], - new_order: Tensor, - ): - """Scriptable reorder incremental state in transformer layers.""" - self.self_attn.reorder_incremental_state(incremental_state, new_order) - - if self.encoder_attn is not None: - self.encoder_attn.reorder_incremental_state(incremental_state, new_order) - def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 7ecdde869f..a523b1ea64 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -794,7 +794,7 @@ def reorder_incremental_state(self, new_order): if not self.has_incremental_states(): return for i, model in enumerate(self.models): - model.decoder.reorder_incremental_state( + model.decoder.reorder_incremental_state_scripting( self.incremental_states[i], new_order ) From a12c5c5de896390775ac45addaa4f4f90534d9b7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 23 Jun 2020 06:46:50 -0700 Subject: [PATCH 041/707] Add max position params to speech recognition (#1783) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1782. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/1783 Reviewed By: okhonko Differential Revision: D21663633 Pulled By: myleott fbshipit-source-id: 5f3b4b7df83e27d866efb489daeffb3b38a66f38 --- examples/speech_recognition/tasks/speech_recognition.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index b555cfeefa..e5717c0ef8 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -6,6 +6,7 @@ import json import os import re +import sys import torch from fairseq.data import Dictionary @@ -77,6 +78,10 @@ def add_args(parser): parser.add_argument( "--silence-token", default="\u2581", help="token for silence (used by w2l)" ) + parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N', + help='max number of frames in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') def __init__(self, args, tgt_dict): super().__init__(args) @@ -132,3 +137,7 @@ def source_dictionary(self): """Return the source :class:`~fairseq.data.Dictionary` (if applicable for this task).""" return None + + def max_positions(self): + """Return the max speech and sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) From da94e58c703866236b29242ae413146be69fe94f Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 24 Jun 2020 09:54:46 -0700 Subject: [PATCH 042/707] TPU support for Translation (#2245) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2245 Reviewed By: ngoyal2707 Differential Revision: D22070745 Pulled By: myleott fbshipit-source-id: e43a96a585366b10d997a12522e8cd6496294ad2 --- docs/tutorial_classifying_names.rst | 2 - fairseq/data/__init__.py | 2 + fairseq/data/bucket_pad_length_dataset.py | 77 +++++++++++++++++ fairseq/data/data_utils.py | 26 ++++-- fairseq/data/data_utils_fast.pyx | 62 +++++++++++++- fairseq/data/fairseq_dataset.py | 60 ++++++++++++++ fairseq/data/language_pair_dataset.py | 91 ++++++++++++++++----- fairseq/modules/transformer_layer.py | 26 +++--- fairseq/options.py | 3 +- fairseq/tasks/fairseq_task.py | 3 +- fairseq/tasks/semisupervised_translation.py | 2 - fairseq/tasks/translation.py | 13 ++- 12 files changed, 314 insertions(+), 53 deletions(-) create mode 100644 fairseq/data/bucket_pad_length_dataset.py diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index b420d850bc..e2b5a67168 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -282,8 +282,6 @@ following contents:: tgt_sizes=torch.ones(len(labels)), # targets have length 1 tgt_dict=self.label_vocab, left_pad_source=False, - max_source_positions=self.args.max_positions, - max_target_positions=1, # Since our target is a single class label, there's no need for # teacher forcing. If we set this to ``True`` then our Model's # ``forward()`` method would receive an additional argument called diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 30c6e88d82..9bdb7a74ae 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -12,6 +12,7 @@ from .append_token_dataset import AppendTokenDataset from .audio.raw_audio_dataset import FileAudioDataset from .backtranslation_dataset import BacktranslationDataset +from .bucket_pad_length_dataset import BucketPadLengthDataset from .colorize_dataset import ColorizeDataset from .concat_dataset import ConcatDataset from .concat_sentences_dataset import ConcatSentencesDataset @@ -57,6 +58,7 @@ 'AppendTokenDataset', 'BacktranslationDataset', 'BaseWrapperDataset', + 'BucketPadLengthDataset', 'ColorizeDataset', 'ConcatDataset', 'ConcatSentencesDataset', diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py new file mode 100644 index 0000000000..6f53d01188 --- /dev/null +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn.functional as F + +from fairseq.data import BaseWrapperDataset + + +class BucketPadLengthDataset(BaseWrapperDataset): + """ + Bucket and pad item lengths to the nearest bucket size. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (int): padding symbol + left_pad (bool): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx, + left_pad, + ): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + assert num_buckets > 0 + self.buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + + def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes + + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + num_pad = bucket_size - item.size(-1) + return F.pad( + item, + (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_idx, + ) + + @property + def sizes(self): + return self._bucketed_sizes + + def num_tokens(self, index): + return self._bucketed_sizes[index] + + def size(self, index): + return self._bucketed_sizes[index] diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index ab82ea4594..3b8f1afd2b 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -199,7 +199,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, - required_batch_size_multiple=1, + required_batch_size_multiple=1, fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain @@ -214,10 +214,15 @@ def batch_by_size( max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to - be a multiple of N (default: 1). + be less than N or a multiple of N (default: 1). + fixed_shapes (List[Tuple[int, int]], optional): if given, batches will + only be created with the given shapes. *max_sentences* and + *required_batch_size_multiple* will be ignored (default: None). """ try: - from fairseq.data.data_utils_fast import batch_by_size_fast + from fairseq.data.data_utils_fast import ( + batch_by_size_fast, batch_fixed_shapes_fast, + ) except ImportError: raise ImportError( 'Please build Cython components with: `pip install --editable .` ' @@ -228,10 +233,21 @@ def batch_by_size( max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple - if isinstance(indices, types.GeneratorType): + if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) - return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) + if fixed_shapes is None: + return batch_by_size_fast( + indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, + ) + else: + fixed_shapes = np.array(fixed_shapes, dtype=np.int64) + sort_order = np.lexsort([ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ]) + fixed_shapes_sorted = fixed_shapes[sort_order] + return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) def process_bpe_symbol(sentence: str, bpe_symbol: str): diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index 6fa8acc09f..c1f97bf5b6 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -13,10 +13,10 @@ DTYPE = np.int64 ctypedef np.int64_t DTYPE_t -cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): - if len(batch) == 0: +cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences): + if num_sentences == 0: return 0 - if max_sentences > 0 and len(batch) == max_sentences: + if max_sentences > 0 and num_sentences == max_sentences: return 1 if max_tokens > 0 and num_tokens > max_tokens: return 1 @@ -53,7 +53,7 @@ cpdef list batch_by_size_fast( ) num_tokens = (len(batch) + 1) * sample_len - if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, @@ -66,3 +66,57 @@ cpdef list batch_by_size_fast( if len(batch) > 0: batches.append(batch) return batches + + +cdef _find_valid_shape( + DTYPE_t[:, :] shapes_view, + long num_sentences, + long num_tokens, +): + """Return index of first valid shape of -1 if none is found.""" + for i in range(shapes_view.shape[0]): + if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + return i + return -1 + + +@cython.cdivision(True) +cpdef list batch_fixed_shapes_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, +): + cdef long sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef long mod_len + cdef long i + cdef long idx + cdef long num_tokens + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + if shape_idx == -1: + batches.append(batch) + batch = [] + sample_lens = [] + sample_len = 0 + shapes_view = fixed_shapes_sorted + elif shape_idx > 0: + # small optimization for the next call to _find_valid_shape + shapes_view = shapes_view[shape_idx:] + + batch.append(idx) + + if len(batch) > 0: + batches.append(batch) + + return batches diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index fe5681be5a..b03c90ed43 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -62,6 +62,66 @@ def prefetch(self, indices): """Prefetch the data required for this epoch.""" raise NotImplementedError + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, 'Must specify --max-tokens' + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= (bsz % required_batch_size_multiple) + return bsz + + fixed_shapes = np.array([ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ]) + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """For datasets that need to be read sequentially, usually because the data diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index d18a92d786..63c95911fd 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -8,14 +8,18 @@ import numpy as np import torch -from . import data_utils, FairseqDataset +from fairseq.data import data_utils, FairseqDataset logger = logging.getLogger(__name__) def collate( - samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, input_feeding=True, ): if len(samples) == 0: @@ -52,7 +56,9 @@ def compute_alignment_weights(alignments): id = torch.LongTensor([s['id'] for s in samples]) src_tokens = merge('source', left_pad=left_pad_source) # sort by descending source length - src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) + src_lengths = torch.LongTensor([ + s['source'].ne(pad_idx).long().sum() for s in samples + ]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) src_tokens = src_tokens.index_select(0, sort_order) @@ -62,8 +68,10 @@ def compute_alignment_weights(alignments): if samples[0].get('target', None) is not None: target = merge('target', left_pad=left_pad_target) target = target.index_select(0, sort_order) - tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) - ntokens = sum(len(s['target']) for s in samples) + tgt_lengths = torch.LongTensor([ + s['target'].ne(pad_idx).long().sum() for s in samples + ]).index_select(0, sort_order) + ntokens = tgt_lengths.sum().item() if input_feeding: # we create a shifted version of targets for feeding the @@ -75,7 +83,7 @@ def compute_alignment_weights(alignments): ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: - ntokens = sum(len(s['source']) for s in samples) + ntokens = src_lengths.sum().item() batch = { 'id': id, @@ -133,10 +141,6 @@ class LanguagePairDataset(FairseqDataset): (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - max_target_positions (int, optional): max number of tokens in the - target sentence (default: 1024). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets @@ -149,17 +153,19 @@ class LanguagePairDataset(FairseqDataset): containing alignments. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, - max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, - append_bos=False, eos=None + append_bos=False, eos=None, + num_buckets=0, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -175,8 +181,6 @@ def __init__( self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source @@ -187,6 +191,42 @@ def __init__( self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + self.src = BucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=self.src_dict.pad(), + left_pad=self.left_pad_source, + ) + self.src_sizes = self.src.sizes + logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.tgt_dict.pad(), + left_pad=self.left_pad_target, + ) + self.tgt_sizes = self.tgt.sizes + logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to BucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) + for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + + def get_batch_shapes(self): + return self.buckets + def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] @@ -255,8 +295,11 @@ def collater(self, samples): on the left if *left_pad_target* is ``True``. """ return collate( - samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, - left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, + samples, + pad_idx=self.src_dict.pad(), + eos_idx=self.eos, + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, ) @@ -277,9 +320,19 @@ def ordered_indices(self): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - if self.tgt_sizes is not None: - indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] - return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[ + np.argsort(self.tgt_sizes[indices], kind='mergesort') + ] + return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + else: + # sort by bucketed_num_tokens, which is: + # max(padded_src_len, padded_tgt_len) + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') + ] @property def supports_prefetch(self): diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 8fb08b3aaf..854e2437c8 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -89,30 +89,28 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape - `(batch, src_len)` where padding elements are indicated by ``1``. - attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where - T_tgt is the length of query, while T_src is the length of key, - though here both query and key is x here, - attn_mask[t_tgt, t_src] = 1 means when calculating embedding - for t_tgt, t_src is excluded (or masked out), =0 means it is - included in attention + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ - residual = x - if self.normalize_before: - x = self.self_attn_layer_norm(x) - if attn_mask is not None: - attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters - # TODO: to formally solve this problem, we need to change fairseq's - # MultiheadAttention. We will do this later on. + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, diff --git a/fairseq/options.py b/fairseq/options.py index 9ccd58941a..1972d4b850 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -327,7 +327,8 @@ def add_dataset_args(parser, train=False, gen=False): group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', help='maximum number of sentences in a batch') group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N', - help='batch size will be a multiplier of this value') + help='batch size will either be less than this value, ' + 'or a multiple of this value') parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index bd9d75abd6..f58ccea8cc 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -173,9 +173,8 @@ def get_batch_iterator( ) # create mini-batches with given size constraints - batch_sampler = data_utils.batch_by_size( + batch_sampler = dataset.batch_by_size( indices, - dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index bf770bfe15..3f919be6f3 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -264,8 +264,6 @@ def language_pair_dataset(lang_pair): tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 6e6ea5596c..c3237aa968 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -39,7 +39,8 @@ def load_langpair_dataset( combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, - truncate_source=False, append_source_id=False + truncate_source=False, append_source_id=False, + num_buckets=0, ): def split_exists(split, src, tgt, lang, data_path): @@ -124,9 +125,8 @@ def split_exists(split, src, tgt, lang, data_path): tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, - max_source_positions=max_source_positions, - max_target_positions=max_target_positions, - align_dataset=align_dataset, eos=eos + align_dataset=align_dataset, eos=eos, + num_buckets=num_buckets, ) @@ -176,6 +176,10 @@ def add_args(parser): help='amount to upsample primary dataset') parser.add_argument('--truncate-source', action='store_true', default=False, help='truncate source to max-source-positions') + parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', + help='if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations') # options for reporting BLEU during validation parser.add_argument('--eval-bleu', action='store_true', @@ -255,6 +259,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, + num_buckets=self.args.num_batch_buckets, ) def build_dataset_for_inference(self, src_tokens, src_lengths): From f0a61a2774aff2efbc1adb0b5daee346a8401605 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 24 Jun 2020 10:03:35 -0700 Subject: [PATCH 043/707] Miscellaneous fixes (#1196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Incorporate several fixes, incl. from OSS contributors: - fix model argument in sequence generator in semisupervised_translation.py - fix aggregate logging in semisupervised_translation.py - Fix EOS token in multilingual_denoising - Handle missing eos_idx in data_utils.collate_tokens - Better OOM handling for single-GPU training - fix prepend_bos argument in translation_from_pretrained_bart.py … - Fix eos_idx in multilingual_denoising - Small logging fixes - Fix fb_hub on PyTorch 1.6 - Better variable names - Add support for model parallel to interactive.py - Use `//` operator to fix Integer division warning - Set default `--clip-norm=0.0` - Cleanup some binaries in root directory Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1196 Reviewed By: ngoyal2707 Differential Revision: D22162202 Pulled By: myleott fbshipit-source-id: 835b0c0ad9246827f9d915fdb4e89d7b5be2475d --- README.md | 2 +- fairseq.gif => docs/fairseq.gif | Bin fairseq_logo.png => docs/fairseq_logo.png | Bin eval_lm.py | 11 ---- examples/backtranslation/tokenized_bleu.sh | 4 +- examples/bart/README.glue.md | 2 +- examples/bart/README.summarization.md | 2 +- examples/mbart/README.md | 51 +++++++++++----- examples/megatron_11b/README.md | 2 +- examples/quant_noise/README.md | 2 +- .../roberta/README.custom_classification.md | 2 +- examples/roberta/README.glue.md | 2 +- .../speech_recognition/criterions/__init__.py | 6 +- examples/translation/README.md | 2 +- fairseq/data/data_utils.py | 6 +- fairseq/data/denoising_dataset.py | 13 ++-- fairseq/data/resampling_dataset.py | 8 ++- fairseq/data/shorten_dataset.py | 13 +++- .../model_parallel/models/transformer_lm.py | 57 +++++++++++++++--- fairseq/models/transformer_lm.py | 1 + fairseq/options.py | 3 +- fairseq/search.py | 5 +- fairseq/sequence_generator.py | 20 +++--- fairseq/tasks/language_modeling.py | 4 +- fairseq/tasks/masked_lm.py | 4 +- fairseq/tasks/multilingual_denoising.py | 27 +++++---- fairseq/tasks/semisupervised_translation.py | 22 ++----- fairseq/tasks/sentence_prediction.py | 4 +- fairseq/tasks/sentence_ranking.py | 4 +- .../tasks/translation_from_pretrained_bart.py | 14 +++-- fairseq/trainer.py | 4 ++ fairseq_cli/eval_lm.py | 1 - fairseq_cli/generate.py | 1 + fairseq_cli/interactive.py | 7 ++- generate.py | 11 ---- interactive.py | 11 ---- preprocess.py | 11 ---- score.py | 11 ---- train.py | 3 + validate.py | 11 ---- 40 files changed, 192 insertions(+), 172 deletions(-) rename fairseq.gif => docs/fairseq.gif (100%) rename fairseq_logo.png => docs/fairseq_logo.png (100%) delete mode 100644 eval_lm.py delete mode 100644 generate.py delete mode 100644 interactive.py delete mode 100644 preprocess.py delete mode 100644 score.py delete mode 100644 validate.py diff --git a/README.md b/README.md index a3248d418a..accea254b0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

- +

MIT License diff --git a/fairseq.gif b/docs/fairseq.gif similarity index 100% rename from fairseq.gif rename to docs/fairseq.gif diff --git a/fairseq_logo.png b/docs/fairseq_logo.png similarity index 100% rename from fairseq_logo.png rename to docs/fairseq_logo.png diff --git a/eval_lm.py b/eval_lm.py deleted file mode 100644 index b5e965a19f..0000000000 --- a/eval_lm.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.eval_lm import cli_main - - -if __name__ == '__main__': - cli_main() diff --git a/examples/backtranslation/tokenized_bleu.sh b/examples/backtranslation/tokenized_bleu.sh index 1589da334a..c6d6aaa193 100644 --- a/examples/backtranslation/tokenized_bleu.sh +++ b/examples/backtranslation/tokenized_bleu.sh @@ -37,10 +37,10 @@ sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \ | sacremoses normalize -l $SRCLANG -q \ | sacremoses tokenize -a -l $SRCLANG -q \ | python $BPEROOT/apply_bpe.py -c $BPECODE \ -| python interactive.py $DATABIN --path $MODEL \ +| fairseq-interactive $DATABIN --path $MODEL \ -s $SRCLANG -t $TGTLANG \ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \ | grep ^H- | cut -f 3- \ -| python score.py --ref $TMP_REF +| fairseq-score --ref $TMP_REF rm -f $TMP_REF diff --git a/examples/bart/README.glue.md b/examples/bart/README.glue.md index 797fdee31f..2948ff25ea 100644 --- a/examples/bart/README.glue.md +++ b/examples/bart/README.glue.md @@ -24,7 +24,7 @@ NUM_CLASSES=2 MAX_SENTENCES=16 # Batch size. BART_PATH=/path/to/bart/model.pt -CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \ +CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \ --restore-file $BART_PATH \ --max-sentences $MAX_SENTENCES \ --max-tokens 4400 \ diff --git a/examples/bart/README.summarization.md b/examples/bart/README.summarization.md index 4af7ab8d6a..d7fecc9ce6 100644 --- a/examples/bart/README.summarization.md +++ b/examples/bart/README.summarization.md @@ -52,7 +52,7 @@ MAX_TOKENS=2048 UPDATE_FREQ=4 BART_PATH=/path/to/bart/model.pt -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \ --restore-file $BART_PATH \ --max-tokens $MAX_TOKENS \ --task translation \ diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 13e54c435d..08a2c6bee0 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -45,18 +45,18 @@ ${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DATA}/${TEST}.spm.${TGT} & ```bash DICT=dict.txt -python preprocess.py \ ---source-lang ${SRC} \ ---target-lang ${TGT} \ ---trainpref ${DATA}/${TRAIN}.spm \ ---validpref ${DATA}/${VALID}.spm \ ---testpref ${DATA}/${TEST}.spm \ ---destdir ${DEST}/${NAME} \ ---thresholdtgt 0 \ ---thresholdsrc 0 \ ---srcdict ${DICT} \ ---tgtdict ${DICT} \ ---workers 70 +fairseq-preprocess \ + --source-lang ${SRC} \ + --target-lang ${TGT} \ + --trainpref ${DATA}/${TRAIN}.spm \ + --validpref ${DATA}/${VALID}.spm \ + --testpref ${DATA}/${TEST}.spm \ + --destdir ${DEST}/${NAME} \ + --thresholdtgt 0 \ + --thresholdsrc 0 \ + --srcdict ${DICT} \ + --tgtdict ${DICT} \ + --workers 70 ``` ## Finetune on EN-RO @@ -66,7 +66,23 @@ Finetune on mbart CC25 PRETRAIN=/path/to/model/mbart.cc25 langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN -python train.py path_2_data --encoder-normalize-before --decoder-normalize-before --arch mbart_large --task translation_from_pretrained_bart --source-lang en_XX --target-lang ro_RO --criterion label_smoothed_cross_entropy --label-smoothing 0.2 --dataset-impl mmap --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 --max-tokens 1024 --update-freq 2 --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints --seed 222 --log-format simple --log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --restore-file $PRETRAIN --langs $langs --layernorm-embedding --ddp-backend no_c10d +fairseq-train path_2_data \ + --encoder-normalize-before --decoder-normalize-before \ + --arch mbart_large --layernorm-embedding \ + --task translation_from_pretrained_bart \ + --source-lang en_XX --target-lang ro_RO \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --dataset-impl mmap \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 \ + --restore-file $PRETRAIN \ + --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \ + --langs $langs \ + --ddp-backend no_c10d ``` ## Generate on EN-RO Get sacrebleu on finetuned en-ro model @@ -77,7 +93,14 @@ tar -xzvf mbart.cc25.ft.enro.tar.gz ```bash model=model.pt -python generate.py path_2_data --path $model --task translation_from_pretrained_bart --gen-subset test -t ro_RO -s en_XX --bpe 'sentencepiece' --sentencepiece-vocab sentence.bpe.model --sacrebleu --remove-bpe 'sentencepiece' --max-sentences 32 --langs $langs > en_ro +fairseq-generate path_2_data \ + --path $model \ + --task translation_from_pretrained_bart \ + --gen-subset test \ + -t ro_RO -s en_XX \ + --bpe 'sentencepiece' --sentencepiece-vocab sentence.bpe.model \ + --sacrebleu --remove-bpe 'sentencepiece'\ + --max-sentences 32 --langs $langs > en_ro cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref diff --git a/examples/megatron_11b/README.md b/examples/megatron_11b/README.md index 3cf7aa3acd..d6b6cc0774 100644 --- a/examples/megatron_11b/README.md +++ b/examples/megatron_11b/README.md @@ -74,7 +74,7 @@ Note: Above was tested on `DGX-1` box, with `8xV100-32Gb` GPUs. **[Wikitext103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)** -Model | vValid perplexity | Test perplexity +Model | Valid perplexity | Test perplexity ---|---|--- `megatron_11b` | 10.64 | 10.54 diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 98d8c313ee..1dc7dc619c 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -219,7 +219,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ To **evaluate** this model, note you need to use the `eval.py` script. The following command can be used to evaluate: ```bash -python eval_lm.py /path/to/wikitext-103/data --path /path/to/model/checkpoint \ +fairseq-eval-lm /path/to/wikitext-103/data --path /path/to/model/checkpoint \ --sample-break-mode complete \ --max-tokens 3072 \ --context-window 2560 \ diff --git a/examples/roberta/README.custom_classification.md b/examples/roberta/README.custom_classification.md index 3b44aac027..72e490ddc7 100644 --- a/examples/roberta/README.custom_classification.md +++ b/examples/roberta/README.custom_classification.md @@ -103,7 +103,7 @@ NUM_CLASSES=2 # Number of classes for the classification task. MAX_SENTENCES=8 # Batch size. ROBERTA_PATH=/path/to/roberta.large/model.pt -CUDA_VISIBLE_DEVICES=0 python train.py IMDB-bin/ \ +CUDA_VISIBLE_DEVICES=0 fairseq-train IMDB-bin/ \ --restore-file $ROBERTA_PATH \ --max-positions 512 \ --max-sentences $MAX_SENTENCES \ diff --git a/examples/roberta/README.glue.md b/examples/roberta/README.glue.md index d0a266b868..db20360e2c 100644 --- a/examples/roberta/README.glue.md +++ b/examples/roberta/README.glue.md @@ -24,7 +24,7 @@ NUM_CLASSES=2 MAX_SENTENCES=16 # Batch size. ROBERTA_PATH=/path/to/roberta/model.pt -CUDA_VISIBLE_DEVICES=0 python train.py RTE-bin/ \ +CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \ --restore-file $ROBERTA_PATH \ --max-positions 512 \ --max-sentences $MAX_SENTENCES \ diff --git a/examples/speech_recognition/criterions/__init__.py b/examples/speech_recognition/criterions/__init__.py index e3a348afa4..88af9f340f 100644 --- a/examples/speech_recognition/criterions/__init__.py +++ b/examples/speech_recognition/criterions/__init__.py @@ -3,14 +3,14 @@ # ASG loss requires wav2letter -blacklist = set() +files_to_skip = set() try: import wav2letter except ImportError: - blacklist.add("ASG_loss.py") + files_to_skip.add("ASG_loss.py") for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_") and file not in blacklist: + if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip: criterion_name = file[: file.find(".py")] importlib.import_module( "examples.speech_recognition.criterions." + criterion_name diff --git a/examples/translation/README.md b/examples/translation/README.md index 67e99f6efd..e61d166c6e 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -225,7 +225,7 @@ train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets. Note that we use slightly different preprocessing here than for the IWSLT'14 En-De data above. In particular we learn a joint BPE code for all three -languages and use interactive.py and sacrebleu for scoring the test set. +languages and use fairseq-interactive and sacrebleu for scoring the test set. ```bash # First install sacrebleu and sentencepiece diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 3b8f1afd2b..b3eee1cb9c 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -38,7 +38,11 @@ def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_be def copy_tensor(src, dst): assert dst.numel() == src.numel() if move_eos_to_beginning: - dst[0] = eos_idx + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx dst[1:] = src[:-1] else: dst.copy_(src) diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index 8f96600c6d..ee3a03940d 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -26,7 +26,10 @@ def collate( def merge(key, left_pad, move_eos_to_beginning=False): return data_utils.collate_tokens( [s[key] for s in samples], - pad_idx, eos_idx, left_pad, move_eos_to_beginning, + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, ) id = torch.LongTensor([s['id'] for s in samples]) @@ -126,11 +129,11 @@ def __init__( self.replace_length = args.replace_length if not self.replace_length in [-1, 0, 1]: - raise (f'invalid arg: replace_length={self.replace_length}') + raise ValueError(f'invalid arg: replace_length={self.replace_length}') if not args.mask_length in ['subword', 'word', 'span-poisson']: - raise (f'invalid arg: mask-length={args.mask_length}') + raise ValueError(f'invalid arg: mask-length={args.mask_length}') if args.mask_length == 'subword' and not args.replace_length in [0, 1]: - raise (f'if using subwords, use replace-length=1 or 0') + raise ValueError(f'if using subwords, use replace-length=1 or 0') self.mask_span_distribution = None if args.mask_length == 'span-poisson': @@ -352,7 +355,7 @@ def collater(self, samples): Returns: dict: a mini-batch of data """ - return collate(samples, self.vocab.pad(), self.vocab.eos(), self.vocab) + return collate(samples, self.vocab.pad(), self.eos, self.vocab) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py index 2967916163..a2c9b31d79 100644 --- a/fairseq/data/resampling_dataset.py +++ b/fairseq/data/resampling_dataset.py @@ -3,9 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + import numpy as np -from . import BaseWrapperDataset, plasma_utils +from fairseq.data import BaseWrapperDataset, plasma_utils + + +logger = logging.getLogger(__name__) class ResamplingDataset(BaseWrapperDataset): @@ -103,6 +108,7 @@ def prefetch(self, indices): self.dataset.prefetch(self._cur_indices.array[indices]) def set_epoch(self, epoch): + logger.debug('ResamplingDataset.set_epoch: {}'.format(epoch)) super().set_epoch(epoch) if epoch == self._cur_epoch: diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py index f95288a5c0..9c84219dc7 100644 --- a/fairseq/data/shorten_dataset.py +++ b/fairseq/data/shorten_dataset.py @@ -57,9 +57,16 @@ def __getitem__(self, index): item = item[start_idx:start_idx+self.truncation_length] return item -def maybe_shorten_dataset(dataset, split, shorten_data_split_whitelist, shorten_method, tokens_per_sample, seed): - truncate_split = split in shorten_data_split_whitelist.split(',') \ - or len(shorten_data_split_whitelist) == 0 +def maybe_shorten_dataset( + dataset, + split, + shorten_data_split_list, + shorten_method, + tokens_per_sample, + seed, +): + truncate_split = split in shorten_data_split_list.split(',') \ + or len(shorten_data_split_list) == 0 if shorten_method == 'truncate' and truncate_split: dataset = TruncateDataset(dataset, tokens_per_sample) elif shorten_method == 'random_crop' and truncate_split: diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 507cb02d1a..9a8a4b0fdd 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -3,25 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn -from fairseq.models import ( - register_model, - register_model_architecture, -) +from fairseq import utils +from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer_lm import ( base_lm_architecture, TransformerLanguageModel, ) -from fairseq.model_parallel.models.transformer import ( - ModelParallelTransformerDecoder, -) +from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder try: + from fairseq.model_parallel.megatron.mpu import get_model_parallel_group + from fairseq.model_parallel.megatron.mpu import get_model_parallel_rank + from fairseq.model_parallel.megatron.mpu import get_model_parallel_world_size from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + from fairseq.model_parallel.megatron.mpu.utils import VocabUtility has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False + DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -67,6 +69,47 @@ def _vocab_init(tensor, **kwargs): embed_tokens = VocabParallelEmbedding(len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init) return embed_tokens + def get_normalized_probs( + self, + net_output, + log_probs, + sample, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output[0] + vocab_size = len(self.decoder.dictionary) + + if logits.size(-1) == vocab_size: + # we have the full set of logits + return super().get_normalized_probs(net_output, log_probs, sample) + # else: vocab-parallel logits, need to combine them + + assert logits.dim() == 3 + + # Get the partition's vocab indices + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = logits.size(-1) + rank = get_model_parallel_rank() + world_size = get_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, rank, world_size, + ) + + # Assemble full logits + full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size) + full_logits[:, :, vocab_start_index:vocab_end_index] = logits + torch.distributed.all_reduce( + full_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group(), + ) + + if log_probs: + return utils.log_softmax(full_logits, dim=-1) + else: + return utils.softmax(full_logits, dim=-1) + @register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron') def transformer_lm_megatron(args): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index dfc93d68d3..b59363900e 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -170,6 +170,7 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) return embed_tokens + @register_model_architecture('transformer_lm', 'transformer_lm') def base_lm_architecture(args): # backward compatibility for older model checkpoints diff --git a/fairseq/options.py b/fairseq/options.py index 1972d4b850..62aba383d3 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -32,6 +32,7 @@ def get_training_parser(default_task="translation"): def get_generation_parser(interactive=False, default_task="translation"): parser = get_parser("Generation", default_task) add_dataset_args(parser, gen=True) + add_distributed_training_args(parser) add_generation_args(parser) if interactive: add_interactive_args(parser) @@ -436,7 +437,7 @@ def add_optimization_args(parser): help='force stop training at specified epoch') group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', help='force stop training at specified update') - group.add_argument('--clip-norm', default=25, type=float, metavar='NORM', + group.add_argument('--clip-norm', default=0.0, type=float, metavar='NORM', help='clip threshold of gradients') group.add_argument('--sentence-avg', action='store_true', help='normalize gradients by the number of sentences in a batch' diff --git a/fairseq/search.py b/fairseq/search.py index 1ee1d7cb44..32e1450a1d 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -75,10 +75,7 @@ def step(self, step: int, lprobs, scores: Optional[Tensor]): ) scores_buf = top_prediction[0] indices_buf = top_prediction[1] - if torch.__version__ < '1.6.0': - beams_buf = torch.div(indices_buf, vocab_size) - else: - beams_buf = torch.floor_divide(indices_buf, vocab_size) + beams_buf = indices_buf // vocab_size indices_buf = indices_buf.fmod(vocab_size) return scores_buf, indices_buf, beams_buf diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index a523b1ea64..31daddc1e4 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -212,11 +212,11 @@ def _generate( tokens[:, 0] = self.eos if bos_token is None else bos_token attn: Optional[Tensor] = None - # The blacklist indicates candidates that should be ignored. + # A list that indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 - # samples. Then the blacklist would mark 2 positions as being ignored, + # samples. Then cands_to_ignore would mark 2 positions as being ignored, # so that we only finalize the remaining 3 samples. - blacklist = ( + cands_to_ignore = ( torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) ) # forward and backward-compatible False mask @@ -317,7 +317,7 @@ def _generate( # finalize hypotheses that end in eos eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) - eos_mask[:, :beam_size][blacklist] = torch.tensor(0).to(eos_mask) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) # only consider eos when it's among the top beam_size indices eos_bbsz_idx = torch.masked_select( @@ -369,7 +369,7 @@ def _generate( if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] - blacklist = blacklist[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) @@ -386,7 +386,7 @@ def _generate( # Rewrite the operator since the element wise or is not supported in torchscript. - eos_mask[:, :beam_size] = ~((~blacklist) & (~eos_mask[:, :beam_size])) + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) active_mask = torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[: eos_mask.size(1)], @@ -394,13 +394,13 @@ def _generate( # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask - new_blacklist, active_hypos = torch.topk( + new_cands_to_ignore, active_hypos = torch.topk( active_mask, k=beam_size, dim=1, largest=False ) - # update blacklist to ignore any finalized hypos - blacklist = new_blacklist.ge(cand_size)[:, :beam_size] - assert (~blacklist).any(dim=1).all() + # update cands_to_ignore to ignore any finalized hypos + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + assert (~cands_to_ignore).any(dim=1).all() active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 82f41b2c73..942477efb0 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -91,7 +91,7 @@ def add_args(parser): parser.add_argument('--shorten-method', default='none', choices=['none', 'truncate', 'random_crop'], help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-whitelist', default='', + parser.add_argument('--shorten-data-split-list', default='', help='comma-separated list of dataset splits to apply shortening to, ' 'e.g., "train,valid" (default: all dataset splits)') # fmt: on @@ -176,7 +176,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = maybe_shorten_dataset( dataset, split, - self.args.shorten_data_split_whitelist, + self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 7f03e04fba..4d7ea54b64 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -62,7 +62,7 @@ def add_args(parser): parser.add_argument('--shorten-method', default='none', choices=['none', 'truncate', 'random_crop'], help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-whitelist', default='', + parser.add_argument('--shorten-data-split-list', default='', help='comma-separated list of dataset splits to apply shortening to, ' 'e.g., "train,valid" (default: all dataset splits)') @@ -105,7 +105,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = maybe_shorten_dataset( dataset, split, - self.args.shorten_data_split_whitelist, + self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py index f1d7068832..18ee717fff 100644 --- a/fairseq/tasks/multilingual_denoising.py +++ b/fairseq/tasks/multilingual_denoising.py @@ -34,7 +34,7 @@ class MultilingualDenoisingTask(DenoisingTask): def add_args(parser): DenoisingTask.add_args(parser) parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0, - help='smoothing alpha for sample rations across multiple datasets') + help='smoothing alpha for sample ratios across multiple datasets') parser.add_argument('--add-lang-token', default=False, action='store_true') parser.add_argument('--langs', type=str, help="language ids we are considering", default=None) parser.add_argument('--no-whole-word-mask-langs', type=str, default='', metavar='N', @@ -61,7 +61,7 @@ def setup_task(cls, args, **kwargs): for lang in languages: dictionary.add_symbol('[{}]'.format(lang)) - logger.info("| dictionary: {} types".format(len(dictionary))) + logger.info("dictionary: {} types".format(len(dictionary))) if not hasattr(args, 'shuffle_instance'): args.shuffle_instance = False return cls(args, dictionary) @@ -105,10 +105,11 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): else: languages = self.langs.split(',') for name in languages: - assert os.path.exists(os.path.join(data_path, name)), "all the languages must exist" + p = os.path.join(data_path, name) + assert os.path.exists(p), "data not found: {}".format(p) - logger.info("| Training on {0} languages: {1}".format(len(languages), languages)) - logger.info("| Language to id mapping: ", { + logger.info("Training on {0} languages: {1}".format(len(languages), languages)) + logger.info("Language to id mapping: ", { lang: id for id, lang in enumerate(languages) } ) @@ -140,7 +141,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): eos=end_token, break_mode=self.args.sample_break_mode, ) - logger.info('| loaded {} blocks from: {}'.format(len(dataset), split_path)) + logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) @@ -165,23 +166,25 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dtype=float, ) logger.info( - '| loaded total {} blocks for all languages'.format( - dataset_lengths.sum(), + 'loaded total {} blocks for all languages'.format( + int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) - logger.info("| Sample probability by language: ", { + logger.info( + "Sample probability by language: {}".format({ lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) - } + }) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths - logger.info("| Up/Down Sampling ratio by language: ", { + logger.info( + "Up/Down Sampling ratio by language: {}".format({ lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) - } + }) ) resampled_lang_datasets = [ diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index 3f919be6f3..c81d362886 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -298,6 +298,7 @@ def build_model(self, args): src, tgt = lang_pair.split('-') key = '{}-{}'.format(tgt, src) self.sequence_generators[key] = SequenceGenerator( + [model.models[key]], tgt_dict=self.dicts[src], beam_size=args.bt_beam_size, max_len_a=args.bt_max_len_a, @@ -340,7 +341,9 @@ def forward_backward(model, samples, logging_output_key, weight): agg_loss += loss.detach().item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size - agg_logging_output[logging_output_key] = logging_output + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[logging_output_key] += logging_output[k] if self.lambda_parallel > 0.0: for lang_pair in self.lang_pairs: @@ -380,20 +383,3 @@ def lambda_step_func(config, n_iter): self.lambda_denoising = lambda_step_func(self.lambda_denoising_steps, num_updates) if self.lambda_otf_bt_steps is not None: self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) - - def aggregate_logging_outputs(self, logging_outputs, criterion): - # aggregate logging outputs for each language pair - logging_output_keys = { - key - for logging_output in logging_outputs - for key in logging_output - } - lang_pair_keys = set(self.lang_pairs + [ - _get_bt_dataset_key(lang_pair) - for lang_pair in self.lang_pairs - ] + [ - _get_denoising_dataset_key(lang_pair) - for lang_pair in self.lang_pairs - ]) - logging_output_keys = logging_output_keys.intersection(lang_pair_keys) - return super().aggregate_logging_outputs(logging_outputs, criterion, logging_output_keys) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 5cdfc97b7a..b50c9922cc 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -57,7 +57,7 @@ def add_args(parser): parser.add_argument('--shorten-method', default='none', choices=['none', 'truncate', 'random_crop'], help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-whitelist', default='', + parser.add_argument('--shorten-data-split-list', default='', help='comma-separated list of dataset splits to apply shortening to, ' 'e.g., "train,valid" (default: all dataset splits)') parser.add_argument('--add-prev-output-tokens', action='store_true', default=False, @@ -149,7 +149,7 @@ def make_dataset(type, dictionary): src_tokens = maybe_shorten_dataset( src_tokens, split, - self.args.shorten_data_split_whitelist, + self.args.shorten_data_split_list, self.args.shorten_method, self.args.max_positions, self.args.seed, diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index ea2d22c181..ea4b50a294 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -54,7 +54,7 @@ def add_args(parser): parser.add_argument('--shorten-method', default='none', choices=['none', 'truncate', 'random_crop'], help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-whitelist', default='', + parser.add_argument('--shorten-data-split-list', default='', help='comma-separated list of dataset splits to apply shortening to, ' 'e.g., "train,valid" (default: all dataset splits)') parser.add_argument('--max-option-length', type=int, @@ -128,7 +128,7 @@ def make_dataset(type, dictionary): src_token = maybe_shorten_dataset( src_token, split, - self.args.shorten_data_split_whitelist, + self.args.shorten_data_split_list, self.args.shorten_method, self.args.max_positions, self.args.seed, diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 00eaf4830b..181ad3c6ff 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -39,10 +39,14 @@ def add_args(parser): # fmt: off TranslationTask.add_args(parser) parser.add_argument('--langs', required=True, metavar='LANG', - help='comma-separated list of monolingual language, for example, "en,de,fr"' - 'be careful these langs are what you used for pretraining (the same order),' - 'not for finetuning.' - 'you should always add all pretraining language idx during finetuning.') + help='comma-separated list of monolingual language, ' + 'for example, "en,de,fr". These should match the ' + 'langs from pretraining (and be in the same order). ' + 'You should always add all pretraining language idx ' + 'during finetuning.') + parser.add_argument('--prepend-bos', action='store_true', + help='prepend bos token to each sentence, which matches ' + 'mBART pretraining') # fmt: on def __init__(self, args, src_dict, tgt_dict): @@ -75,7 +79,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): max_source_positions=getattr(self.args, 'max_source_positions', 1024), max_target_positions=getattr(self.args, 'max_target_positions', 1024), load_alignments=self.args.load_alignments, - prepend_bos=getattr(self.args, 'preprend_bos', False), + prepend_bos=getattr(self.args, 'prepend_bos', False), append_source_id=True ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 22edb4451e..d456bceecf 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -439,6 +439,10 @@ def maybe_no_sync(): ) ooms += 1 self.zero_grad() + if self.cuda: + torch.cuda.empty_cache() + if self.args.distributed_world_size == 1: + return None else: raise e diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index be529253a4..e538d7eeec 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -19,7 +19,6 @@ from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer -from fairseq.options import add_distributed_training_args from fairseq import distributed_utils diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 70fed512c8..81cf86b337 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -70,6 +70,7 @@ def _main(args, output_file): utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, + suffix=getattr(args, "checkpoint_suffix", ""), ) # Optimize ensemble for generation diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index cfcd2c535b..df6120a6cb 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -16,7 +16,7 @@ import torch -from fairseq import checkpoint_utils, options, tasks, utils +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders @@ -94,6 +94,7 @@ def main(args): args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, + suffix=getattr(args, "checkpoint_suffix", ""), ) # Set dictionaries @@ -208,9 +209,9 @@ def decode_fn(x): def cli_main(): - parser = options.get_generation_parser(interactive=True) + parser = options.get_interactive_generation_parser() args = options.parse_args_and_arch(parser) - main(args) + distributed_utils.call_main(args, main) if __name__ == '__main__': diff --git a/generate.py b/generate.py deleted file mode 100644 index 67109c51b7..0000000000 --- a/generate.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.generate import cli_main - - -if __name__ == '__main__': - cli_main() diff --git a/interactive.py b/interactive.py deleted file mode 100644 index 0dc3da3378..0000000000 --- a/interactive.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.interactive import cli_main - - -if __name__ == '__main__': - cli_main() diff --git a/preprocess.py b/preprocess.py deleted file mode 100644 index 6a64ce71e7..0000000000 --- a/preprocess.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.preprocess import cli_main - - -if __name__ == '__main__': - cli_main() diff --git a/score.py b/score.py deleted file mode 100644 index 153d037294..0000000000 --- a/score.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.score import cli_main - - -if __name__ == '__main__': - cli_main() diff --git a/train.py b/train.py index 8e7d1115fd..3967ef48f3 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +""" +Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. +""" from fairseq_cli.train import cli_main diff --git a/validate.py b/validate.py deleted file mode 100644 index 9c1c66bba5..0000000000 --- a/validate.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -u -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq_cli.validate import cli_main - - -if __name__ == '__main__': - cli_main() From d2b5265a60c4016d9fc3cbd55ad983dd86f1aa6f Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Wed, 24 Jun 2020 22:29:58 -0700 Subject: [PATCH 044/707] Merge FBSequenceGenerator & SequenceGenerator Summary: See discussion in D20995796 (https://github.com/pytorch/fairseq/commit/4725487bbc3bdee89c45ced0a8664cffd8e1ab01). Will merge 2 diffs if this looks good to you myleott jhcross Reviewed By: myleott Differential Revision: D21214974 fbshipit-source-id: ebb59b0491a8c209bed2420a0cd94e9c41d05f2e --- fairseq/search.py | 1 + fairseq/sequence_generator.py | 60 +++++++++++++++++------------------ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/fairseq/search.py b/fairseq/search.py index 32e1450a1d..9e18581a97 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -88,6 +88,7 @@ def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): self.max_len_a = max_len_a self.max_len_b = max_len_b self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True def step(self, step: int, lprobs, scores): min_lens = self.min_len_a * self.src_lengths + self.min_len_b diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 31daddc1e4..aa0d98ff49 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -86,6 +86,10 @@ def __init__( self.search = ( search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = hasattr(self.search, 'needs_src_lengths') and self.search.needs_src_lengths if not self.retain_dropout: self.model.eval() @@ -109,7 +113,6 @@ def forward( bos_token (int, optional): beginning of sentence token (default: self.eos) """ - self.model.reset_incremental_state() return self._generate(sample, prefix_tokens, bos_token) # TODO(myleott): unused, deprecate after pytorch-translate migration @@ -157,7 +160,6 @@ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): bos_token (int, optional): beginning of sentence token (default: self.eos) """ - self.model.reset_incremental_state() return self._generate(sample, **kwargs) def _generate( @@ -166,6 +168,13 @@ def _generate( prefix_tokens: Optional[Tensor] = None, bos_token: Optional[int] = None, ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) net_input = sample["net_input"] src_tokens = net_input["src_tokens"] # length of the source text being the character length except EndOfSentence and pad @@ -252,13 +261,16 @@ def _generate( reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size ) - self.model.reorder_incremental_state(reorder_state) + self.model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) lprobs, avg_attn_scores = self.model.forward_decoder( - tokens[:, : step + 1], encoder_outs, self.temperature + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, ) lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) @@ -299,7 +311,8 @@ def _generate( scores ) # scores of hypothesis ending with eos (finished sentences) - self.search.set_src_lengths(src_lengths) + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) if self.no_repeat_ngram_size > 0: lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step) @@ -653,8 +666,6 @@ def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): class EnsembleModel(nn.Module): """A wrapper around an ensemble of models.""" - incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]] - def __init__(self, models): super().__init__() self.models_size = len(models) @@ -662,13 +673,6 @@ def __init__(self, models): self.single_model = models[0] self.models = nn.ModuleList(models) - self.incremental_states = torch.jit.annotate( - List[Dict[str, Dict[str, Optional[Tensor]]]], - [ - torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) - for i in range(self.models_size) - ], - ) self.has_incremental: bool = False if all( hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) @@ -679,17 +683,6 @@ def __init__(self, models): def forward(self): pass - def reset_incremental_state(self): - if self.has_incremental_states(): - self.incremental_states = torch.jit.annotate( - List[Dict[str, Dict[str, Optional[Tensor]]]], - [ - torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) - for i in range(self.models_size) - ], - ) - return - def has_encoder(self): return hasattr(self.single_model, "encoder") @@ -710,7 +703,11 @@ def forward_encoder(self, net_input: Dict[str, Tensor]): @torch.jit.export def forward_decoder( - self, tokens, encoder_outs: List[EncoderOut], temperature: float = 1.0 + self, + tokens, + encoder_outs: List[EncoderOut], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, ): log_probs = [] avg_attn: Optional[Tensor] = None @@ -723,7 +720,7 @@ def forward_decoder( decoder_out = model.decoder.forward( tokens, encoder_out=encoder_out, - incremental_state=self.incremental_states[i], + incremental_state=incremental_states[i], ) else: decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) @@ -790,12 +787,16 @@ def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_orde return new_outs @torch.jit.export - def reorder_incremental_state(self, new_order): + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): if not self.has_incremental_states(): return for i, model in enumerate(self.models): model.decoder.reorder_incremental_state_scripting( - self.incremental_states[i], new_order + incremental_states[i], new_order ) @@ -816,7 +817,6 @@ def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): @torch.no_grad() def generate(self, models, sample, **kwargs): - self.model.reset_incremental_state() finalized = super()._generate(sample, **kwargs) src_tokens = sample["net_input"]["src_tokens"] From 894ae64858b62927d849c0fbc05e8f55d680a4f1 Mon Sep 17 00:00:00 2001 From: Belinda Li Date: Sat, 27 Jun 2020 16:10:14 -0700 Subject: [PATCH 045/707] Add Linformer to internal fairseq Summary: Adding linformer Reviewed By: myleott Differential Revision: D22253918 fbshipit-source-id: 0bb86dddae1be09450544cb25530400e914c640f --- fairseq/modules/transformer_sentence_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 32ba1cecac..414035f2bc 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -82,7 +82,7 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, - layerdrop : float = 0.0, + layerdrop: float = 0.0, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, From a87cafda718c7706e6f1694f0d39fc589ed2b264 Mon Sep 17 00:00:00 2001 From: Daniel Adkins Date: Tue, 30 Jun 2020 12:51:52 -0700 Subject: [PATCH 046/707] update fairseq binarizer to use PathManager Summary: Currently, fairseq binarizer does not work with Manifold files, making it incompatible with some internal procedures. This change preserves the old functionality while allowing Manifold files to be passed into binarizer functions. motivated by theweiho: "I think we should change Binarizer to use PathManager so that it can handle either Manifold path or POSIX path" (D22241626) Reviewed By: akinh Differential Revision: D22293525 fbshipit-source-id: d1bf4f8b50dda6a9214ee2fbe45e112ca9628f60 --- fairseq/binarizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 3a7ee0854f..ec3b90f211 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -8,7 +8,7 @@ from fairseq.tokenizer import tokenize_line import torch - +from fairseq.file_io import PathManager def safe_readline(f): pos = f.tell() @@ -40,7 +40,7 @@ def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) - with open(filename, "r", encoding="utf-8") as f: + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) @@ -79,7 +79,7 @@ def replaced_consumer(word, idx): def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): nseq = 0 - with open(filename, "r") as f: + with open(PathManager.get_local_path(filename), "r") as f: f.seek(offset) line = safe_readline(f) while line: @@ -93,7 +93,7 @@ def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): @staticmethod def find_offsets(filename, num_chunks): - with open(filename, "r", encoding="utf-8") as f: + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: size = os.fstat(f.fileno()).st_size chunk_size = size // num_chunks offsets = [0 for _ in range(num_chunks + 1)] From fc29aab2030055585a087727d93a2083edbb1678 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 6 Jul 2020 08:22:25 -0700 Subject: [PATCH 047/707] Fix model parallel training after quantization/interactive.py changes (#1202) Summary: - fix model parallel training after output_projection changes - fix training with non-vocab parallel criterions Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1202 Reviewed By: ngoyal2707 Differential Revision: D22266462 Pulled By: myleott fbshipit-source-id: c7bb9a95c01f5fdaf415a709a93bacb15336271c --- fairseq/model_parallel/models/transformer.py | 10 +++-- .../model_parallel/models/transformer_lm.py | 41 ------------------- 2 files changed, 6 insertions(+), 45 deletions(-) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index 0b194ad8c5..f5756ad898 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -97,13 +97,15 @@ def build_decoder_layer(self, args, no_encoder_attn=False): def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" + if not self.share_input_output_embed: + raise NotImplementedError( + 'Model parallel training currently requires --share-decoder-input-output-embed' + ) + features = copy_to_model_parallel_region(features) # project back to size of vocabulary - if self.share_input_output_embed: - x = F.linear(features, self.embed_tokens.weight) - else: - x = F.linear(features, self.embed_out) + x = self.output_projection(features) if getattr(self.args, 'criterion') != 'vocab_parallel_cross_entropy': x = gather_from_model_parallel_region(x).contiguous() diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 9a8a4b0fdd..37d3a26336 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -69,47 +69,6 @@ def _vocab_init(tensor, **kwargs): embed_tokens = VocabParallelEmbedding(len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init) return embed_tokens - def get_normalized_probs( - self, - net_output, - log_probs, - sample, - ): - """Get normalized probabilities (or log probs) from a net's output.""" - - logits = net_output[0] - vocab_size = len(self.decoder.dictionary) - - if logits.size(-1) == vocab_size: - # we have the full set of logits - return super().get_normalized_probs(net_output, log_probs, sample) - # else: vocab-parallel logits, need to combine them - - assert logits.dim() == 3 - - # Get the partition's vocab indices - get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = logits.size(-1) - rank = get_model_parallel_rank() - world_size = get_model_parallel_world_size() - vocab_start_index, vocab_end_index = get_vocab_range( - partition_vocab_size, rank, world_size, - ) - - # Assemble full logits - full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size) - full_logits[:, :, vocab_start_index:vocab_end_index] = logits - torch.distributed.all_reduce( - full_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_model_parallel_group(), - ) - - if log_probs: - return utils.log_softmax(full_logits, dim=-1) - else: - return utils.softmax(full_logits, dim=-1) - @register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron') def transformer_lm_megatron(args): From 97ca0c022c74232690aae1c67f8d535602d071be Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 7 Jul 2020 10:21:30 -0700 Subject: [PATCH 048/707] Fix data hang with buffered iterator (#1206) Summary: According to Tom Birch: "I think there's an issue with torch.utils.data.dataloader._MultiProcessingDataLoaderIter when next(...) is supposed to raise StopIteration it just blocks indefinitely instead." This PR is a workaround that fixes the issue. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1206 Reviewed By: froody Differential Revision: D22411150 Pulled By: myleott fbshipit-source-id: 7cdfa67cf55e9cff81cf7d4904f1d38bfa36a0d0 --- fairseq/data/iterators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 23e4926fb9..a86a3b6038 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -431,7 +431,9 @@ def __init__(self, queue, source): def run(self): try: - for item in self._source: + self._source_iter = iter(self._source) + for _ in range(len(self._source)): + item = next(self._source_iter) self._queue.put(item) # Signal the consumer we are done. @@ -439,6 +441,8 @@ def run(self): except Exception as e: self._queue.put(e) + del self._source_iter + class BufferedIterator(object): def __init__(self, size, iterable): From 578164a0ef642307bc1bf4e63dd29a8f70a176ea Mon Sep 17 00:00:00 2001 From: Siddharth Shah Date: Tue, 7 Jul 2020 16:14:11 -0700 Subject: [PATCH 049/707] 0 warmup in tri stage lr scheduler Summary: Current code fails due to division by zero. This diff allows for zero warmup in tri stage scheduler. Reviewed By: myleott Differential Revision: D22416482 fbshipit-source-id: dedb41ac141528314dc86cd73b8b67e699bf457b --- fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index b5f99c54c7..3460fa1226 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -64,7 +64,10 @@ def __init__(self, args, optimizer): self.hold_steps = args.hold_steps self.decay_steps = args.decay_steps - self.warmup_rate = (self.peak_lr - self.init_lr) / self.warmup_steps + self.warmup_rate = ( + (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 + else 0 + ) self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps # initial learning rate From 7816946ff92db0a0d0ec89faedd986e3b83d96d5 Mon Sep 17 00:00:00 2001 From: Gil Keren Date: Tue, 7 Jul 2020 16:41:05 -0700 Subject: [PATCH 050/707] Fix memory leak with small data-buffer-size Summary: As part of zhengwy888's debugging of a memory leak, he suggested that trimming the number of batches in pyspeech's train.py may cause the BufferedIterator to leave some batches in the queue, causing a memory leak. Therefore, propagating `take` to the buffered iterator, which should prevent the consumer thread from hanging on `queue.put`. Reviewed By: myleott Differential Revision: D22405263 fbshipit-source-id: 80f40a355652016af4ba8c386b623cb0552b1928 --- fairseq/data/iterators.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index a86a3b6038..cd53885d7d 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -82,6 +82,10 @@ def take(self, n): """ self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self.iterable, "take"): + self.iterable.take(n) + class EpochBatchIterating(object): def __len__(self) -> int: @@ -423,11 +427,13 @@ def __init__(self, iterable, num_shards, shard_id, fill_value=None): class BackgroundConsumer(Thread): - def __init__(self, queue, source): + def __init__(self, queue, source, max_len): Thread.__init__(self) self._queue = queue self._source = source + self._max_len = max_len + self.count = 0 def run(self): try: @@ -436,6 +442,11 @@ def run(self): item = next(self._source_iter) self._queue.put(item) + # Stop if we reached the maximum length + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + # Signal the consumer we are done. self._queue.put(_sentinel) except Exception as e: @@ -448,21 +459,35 @@ class BufferedIterator(object): def __init__(self, size, iterable): self._queue = queue.Queue(size) self._iterable = iterable - - self._consumer = BackgroundConsumer(self._queue, iterable) - self._consumer.daemon = True - self._consumer.start() + self.max_len = None + self._consumer = None self.start_time = time.time() self.warning_time = None + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.max_len + ) + self._consumer.daemon = True + self._consumer.start() + def __iter__(self): return self def __len__(self): return len(self._iterable) + def take(self, n): + self.max_len = n + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + # Notify the user if there is a data loading bottleneck if self._queue.qsize() < max(1, self._queue.maxsize // 2): if time.time() - self.start_time > 5 * 60: From 9f92b05e2a10f1c559d44dc1c264af31723d1d76 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 8 Jul 2020 00:23:48 -0700 Subject: [PATCH 051/707] TorchElastic for fairseq FBTranslate Summary: Use TorchElastic for multi-node, multi-GPU training Reviewed By: cndn Differential Revision: D22083634 fbshipit-source-id: 3673308671b0bc985b6012ee5327d604d995409f --- fairseq/distributed_utils.py | 3 --- fairseq_cli/train.py | 25 ++++++++++++++++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 84fcc683dc..62f70991e7 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -77,9 +77,6 @@ def infer_init_method(args): def distributed_init(args): - if args.distributed_world_size == 1: - raise ValueError('Cannot initialize distributed with distributed_world_size=1') - if not getattr(args, 'tpu', False): if torch.distributed.is_initialized(): warnings.warn('Distributed is already initialized, cannot initialize twice!') diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index fd9566b719..02ebeae6c1 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -7,11 +7,13 @@ Train a new model on one or across multiple GPUs. """ +import argparse import logging import math import os import random import sys +from typing import Callable, Optional import numpy as np import torch @@ -38,7 +40,13 @@ logger = logging.getLogger("fairseq_cli.train") -def main(args, init_distributed=False): +def main( + args, + init_distributed=False, + after_distributed_init_fn: Optional[ + Callable[[argparse.Namespace], argparse.Namespace] + ] = None, +): utils.import_user_module(args) assert ( @@ -53,6 +61,8 @@ def main(args, init_distributed=False): utils.set_torch_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) + if after_distributed_init_fn: + args = after_distributed_init_fn(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) @@ -329,11 +339,20 @@ def get_valid_stats(args, trainer, stats): return stats -def distributed_main(i, args, start_rank=0): +def distributed_main( + i, + args, + start_rank=0, + after_distributed_init_fn: Optional[ + Callable[[argparse.Namespace], argparse.Namespace] + ] = None, +): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn args.distributed_rank = start_rank + i - main(args, init_distributed=True) + main( + args, init_distributed=True, after_distributed_init_fn=after_distributed_init_fn + ) def cli_main(modify_parser=None): From d73e543e3853bb813d8f7955a06ce19359810707 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 8 Jul 2020 13:04:55 -0700 Subject: [PATCH 052/707] Update LinformerSentenceEncoder to inherit from TransformerSentenceEncoder Summary: It seems we can make this work by setting `compress_layer` in `build_transformer_sentence_encoder_layer` and adding an "init_fn" callback. Doing this refactoring now since the stacked diff (D22048889) broke Linformer training, so safer to inherit from TransformerSentenceEncoder directly. Reviewed By: ngoyal2707 Differential Revision: D22411012 fbshipit-source-id: d4ecb71eedd6ddf49abbb1e700d0f2af24e39e5a --- fairseq/modules/transformer_sentence_encoder_layer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index 2d4747d041..cadcc89981 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -2,7 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional + +from typing import Callable, Optional import torch import torch.nn as nn @@ -34,9 +35,13 @@ def __init__( export: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, + init_fn: Callable = None, ) -> None: - super().__init__() + + if init_fn is not None: + init_fn() + # Initialize parameters self.embedding_dim = embedding_dim self.dropout = dropout From 28876638114948711fd4bd4e350fdd6809013f1e Mon Sep 17 00:00:00 2001 From: m_fomicheva Date: Wed, 8 Jul 2020 13:04:55 -0700 Subject: [PATCH 053/707] Implemented applying dropout at inference time (#2308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2308 Implemented Monte Carlo dropout. Added README to reproduce the results from our paper that applies this idea for unsupervised quality estimation of NMT (joint work of Facebook AI and the University of Sheffield): Marina Fomicheva, Shuo Sun, Lisa Yankovskaya, Frédéric Blain, Francisco Guzmán, Mark Fishel, Nikolaos Aletras, Vishrav Chaudhary, Lucia Specia. Unsupervised Quality Estimation for Neural Machine Translation. Accepted to TACL Retaining dropout at test time is not possible in the current code base. The statement ``` if not self.retain_dropout: model.eval() ``` in `SequenceGenerator` does not have any effect, since model `training` attribute is already set to False by the method `make_generate_fast_`, which is applied before initializing `SequenceGenerator` in `generate.py`. `make_generate_fast_` throws an exception when trying to set `training` to True after its application. Also, if I am not mistaken `self.training=True` can have other effects, so setting it to True only for the purpose of retaining dropout at test time might be confusing. I propose an alternative implementation where `retain_dropout` is an attribute of FairseqModel class. # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [Y] Did you make sure to update the docs? - [Y] Did you write any new necessary tests? ## What does this PR do? New feature. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2151 Reviewed By: ngoyal2707 Differential Revision: D22048889 Pulled By: myleott fbshipit-source-id: 0d0d4784a7314fc7a45b76341fd3b8232b3e2cf0 --- examples/byte_level_bpe/gru_transformer.py | 4 +- .../models/transformer_monotonic_attention.py | 2 +- .../modules/monotonic_multihead_attention.py | 2 +- .../models/w2l_conv_glu_enc.py | 12 +- .../unsupervised_quality_estimation/README.md | 126 ++++++++++++++++++ .../aggregate_scores.py | 40 ++++++ .../unsupervised_quality_estimation/meteor.py | 97 ++++++++++++++ .../repeat_lines.py | 28 ++++ fairseq/hub_utils.py | 8 +- .../modules/multihead_attention.py | 11 +- fairseq/models/fairseq_model.py | 48 +++++-- fairseq/models/fconv.py | 28 ++-- fairseq/models/fconv_self_att.py | 25 ++-- fairseq/models/lightconv.py | 45 ++++--- fairseq/models/lstm.py | 39 +++--- fairseq/models/nat/levenshtein_transformer.py | 2 +- .../nat/nonautoregressive_transformer.py | 2 +- fairseq/models/transformer.py | 10 +- fairseq/modules/__init__.py | 2 + fairseq/modules/adaptive_softmax.py | 10 +- .../downsampled_multihead_attention.py | 10 +- fairseq/modules/dynamic_convolution.py | 13 +- .../dynamicconv_layer/dynamicconv_layer.py | 16 ++- fairseq/modules/fairseq_dropout.py | 52 ++++++++ .../lightconv_layer/lightconv_layer.py | 12 +- fairseq/modules/lightweight_convolution.py | 15 ++- fairseq/modules/multihead_attention.py | 20 +-- .../sparse_transformer_sentence_encoder.py | 2 +- fairseq/modules/transformer_layer.py | 37 ++--- .../modules/transformer_sentence_encoder.py | 11 +- .../transformer_sentence_encoder_layer.py | 13 +- fairseq/options.py | 13 +- fairseq/sequence_generator.py | 8 +- fairseq/tasks/multilingual_translation.py | 13 +- fairseq_cli/eval_lm.py | 2 +- fairseq_cli/generate.py | 12 +- fairseq_cli/interactive.py | 12 +- tests/test_binaries.py | 1 + tests/test_inference_dropout.py | 59 ++++++++ 39 files changed, 665 insertions(+), 197 deletions(-) create mode 100644 examples/unsupervised_quality_estimation/README.md create mode 100644 examples/unsupervised_quality_estimation/aggregate_scores.py create mode 100644 examples/unsupervised_quality_estimation/meteor.py create mode 100644 examples/unsupervised_quality_estimation/repeat_lines.py create mode 100644 fairseq/modules/fairseq_dropout.py create mode 100644 tests/test_inference_dropout.py diff --git a/examples/byte_level_bpe/gru_transformer.py b/examples/byte_level_bpe/gru_transformer.py index 79ebbda1fc..7ba8e4084f 100644 --- a/examples/byte_level_bpe/gru_transformer.py +++ b/examples/byte_level_bpe/gru_transformer.py @@ -36,13 +36,13 @@ def forward_embedding(self, src_tokens): # contextualize embeddings x = x.transpose(0, 1) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x, _ = self.emb_ctx.forward(x) x = x.transpose(0, 1) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) return x, embed diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index 24c3ba5353..759f195386 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -193,7 +193,7 @@ def pre_attention( if positions is not None: x += positions - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index cfa7fdcd16..d508b8cfba 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -533,7 +533,7 @@ def expected_attention(self, alpha, query, key, value, key_padding_mask, increme beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2]) - beta = F.dropout(beta, p=self.dropout, training=self.training) + beta = self.dropout_module(beta) assert not torch.isnan(beta).any(), "NaN detected in beta." diff --git a/examples/speech_recognition/models/w2l_conv_glu_enc.py b/examples/speech_recognition/models/w2l_conv_glu_enc.py index 31cf3401b3..26f27553d4 100644 --- a/examples/speech_recognition/models/w2l_conv_glu_enc.py +++ b/examples/speech_recognition/models/w2l_conv_glu_enc.py @@ -10,15 +10,17 @@ import torch import torch.nn as nn import torch.nn.functional as F + from fairseq.models import ( FairseqEncoder, FairseqEncoderModel, register_model, register_model_architecture, ) +from fairseq.modules.fairseq_dropout import FairseqDropout -default_conv_enc_config = """[ +default_conv_enc_config = """[ (400, 13, 170, 0.2), (440, 14, 0, 0.214), (484, 15, 0, 0.22898), @@ -106,7 +108,9 @@ def __init__( layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding) layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init self.conv_layers.append(nn.utils.weight_norm(layer)) - self.dropouts.append(dropout) + self.dropouts.append( + FairseqDropout(dropout, module_name=self.__class__.__name__) + ) if out_channels % 2 != 0: raise ValueError("odd # of out_channels is incompatible with GLU") cur_channels = out_channels // 2 # halved by GLU @@ -129,12 +133,12 @@ def forward(self, src_tokens, src_lengths, **kwargs): for layer_idx in range(len(self.conv_layers)): x = self.conv_layers[layer_idx](x) x = F.glu(x, dim=1) - x = F.dropout(x, p=self.dropouts[layer_idx], training=self.training) + x = self.dropouts[layer_idx](x) x = x.transpose(1, 2).contiguous() # (B, T, 908) x = self.linear_layers[0](x) x = F.glu(x, dim=2) - x = F.dropout(x, p=self.dropouts[-1]) + x = self.dropouts[-1](x) x = self.linear_layers[1](x) assert x.size(0) == B diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md new file mode 100644 index 0000000000..809a58e41b --- /dev/null +++ b/examples/unsupervised_quality_estimation/README.md @@ -0,0 +1,126 @@ +# Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020) + +This page includes instructions for reproducing results from the paper [Unsupervised Quality Estimation for Neural +Machine Translation (Fomicheva et al., 2020)](https://arxiv.org/abs/2005.10608) + +## Requirements: + +* mosesdecoder: https://github.com/moses-smt/mosesdecoder +* subword-nmt: https://github.com/rsennrich/subword-nmt +* flores: https://github.com/facebookresearch/flores + +## Download Models and Test Data + +Download translation models and test data from [MLQE dataset repository](https://github.com/facebookresearch/mlqe). + +## Set up: + +Given a testset consisting of source sentences and reference translations: + +* `SRC_LANG`: source language +* `TGT_LANG`: target language +* `INPUT`: input prefix, such that the file `$INPUT.$SRC_LANG` contains source sentences and `$INPUT.$TGT_LANG` +contains the reference sentences +* `OUTPUT_DIR`: output path to store results +* `MOSES_DECODER`: path to mosesdecoder installation +* `BPE_ROOT`: path to subword-nmt installation +* `BPE`: path to BPE model +* `MODEL_DIR`: directory containing the NMT model `.pt` file as well as the source and target vocabularies. +* `TMP`: directory for intermediate temporary files +* `GPU`: if translating with GPU, id of the GPU to use for inference +* `DROPOUT_N`: number of stochastic forward passes + +`$DROPOUT_N` is set to 30 in the experiments reported in the paper. However, we observed that increasing it beyond 10 +does not bring substantial improvements. + +## Translate the data using standard decoding + +Preprocess the input data: +``` +for LANG in $SRC_LANG $TGT_LANG; do + perl $MOSES_DECODER/scripts/tokenizer/tokenizer.perl -threads 80 -a -l $LANG < $INPUT.$LANG > $TMP/preprocessed.tok.$LANG + python $BPE_ROOT/apply_bpe.py -c ${BPE} < $TMP/preprocessed.tok.$LANG > $TMP/preprocessed.tok.bpe.$LANG +done +``` + +Binarize the data for faster translation: + +``` +fairseq-preprocess --srcdict $MODEL_DIR/dict.$SRC_LANG.txt --tgtdict $MODEL_DIR/dict.$TGT_LANG.txt +--source-lang ${SRC_LANG} --target-lang ${TGT_LANG} --testpref $TMP/preprocessed.tok.bpe --destdir $TMP/bin --workers 4 +``` + +Translate + +``` +CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 +--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out +grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out +``` + +Post-process + +``` +sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/mt.out | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl +-l $TGT_LANG > $OUTPUT_DIR/mt.out +``` + +## Produce uncertainty estimates + +### Scoring + +Make temporary files to store the translations repeated N times. + +``` +python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/preprocessed.tok.bpe.$SRC_LANG -n $DROPOUT_N +-o $TMP/repeated.$SRC_LANG +python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/mt.out -n $DROPOUT_N -o $TMP/repeated.$TGT_LANG + +fairseq-preprocess --srcdict ${MODEL_DIR}/dict.${SRC_LANG}.txt $TGT_DIC --source-lang ${SRC_LANG} +--target-lang ${TGT_LANG} --testpref ${TMP}/repeated --destdir ${TMP}/bin-repeated +``` + +Produce model scores for the generated translations using `--retain-dropout` option to apply dropout at inference time: + +``` +CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 + --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout + --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer + TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out + +grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores + +``` + +Use `--retain-dropout-modules` to specify the modules. By default, dropout is applied in the same places +as for training. + +Compute the mean of the resulting output distribution: + +``` +python $SCRIPTS/scripts/uncertainty/aggregate_scores.py -i $TMP/dropout.scores -o $OUTPUT_DIR/dropout.scores.mean +-n $DROPOUT_N +``` + +### Generation + +Produce multiple translation hypotheses for the same source using `--retain-dropout` option: + +``` +CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt + --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --retain-dropout + --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder +TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out + +grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_ + +sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl +-l $TGT_LANG > $TMP/dropout.hypotheses +``` + +Compute similarity between multiple hypotheses corresponding to the same source sentence using Meteor +evaluation metric: +``` +python meteor.py -i $TMP/dropout.hypotheses -m -n $DROPOUT_N -o +$OUTPUT_DIR/dropout.gen.sim.meteor +``` diff --git a/examples/unsupervised_quality_estimation/aggregate_scores.py b/examples/unsupervised_quality_estimation/aggregate_scores.py new file mode 100644 index 0000000000..35a6baf67d --- /dev/null +++ b/examples/unsupervised_quality_estimation/aggregate_scores.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import numpy as np +import sys + + +aggregate_funcs = { + 'std': np.std, + 'var': np.var, + 'median': np.median, + 'mean': np.mean, + 'min': np.min, + 'max': np.max, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_file', required=True, type=str) + parser.add_argument('-n', '--repeat_times', required=True, type=int) + parser.add_argument('-o', '--output_file', required=False) + parser.add_argument('-f', '--func', required=False, default='mean') + args = parser.parse_args() + + stream = open(args.output_file, 'w') if args.output_file else sys.stdout + + segment_scores = [] + for line in open(args.input_file): + segment_scores.append(float(line.strip())) + if len(segment_scores) == args.repeat_times: + stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores))) + segment_scores = [] + + +if __name__ == '__main__': + main() diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py new file mode 100644 index 0000000000..ed4ba4ec34 --- /dev/null +++ b/examples/unsupervised_quality_estimation/meteor.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import sys +import subprocess +import tempfile +import math + +from itertools import combinations +from collections import defaultdict + + +def read_translations(path, n_repeats): + segment_counter = 0 + segment_translations = [] + translations = defaultdict(list) + for line in open(path): + segment_translations.append(' '.join(line.split())) + if len(segment_translations) == n_repeats: + translations[segment_counter] = segment_translations + segment_translations = [] + segment_counter += 1 + return translations + + +def generate_input(translations, n_repeats): + _, ref_path = tempfile.mkstemp() + _, mt_path = tempfile.mkstemp() + ref_fh = open(ref_path, 'w') + mt_fh = open(mt_path, 'w') + for segid in sorted(translations.keys()): + assert len(translations[segid]) == n_repeats + indexes = combinations(range(n_repeats), 2) + for idx1, idx2 in indexes: + mt_fh.write(translations[segid][idx1].strip() + '\n') + ref_fh.write(translations[segid][idx2].strip() + '\n') + sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path)) + return ref_path, mt_path + + +def run_meteor(ref_path, mt_path, metric_path, lang='en'): + _, out_path = tempfile.mkstemp() + subprocess.call([ + 'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path, + '-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R + '-norm', + '-l', lang], stdout=open(out_path, 'w')) + os.remove(ref_path) + os.remove(mt_path) + sys.stderr.write('\nSaved Meteor output to %s' % out_path) + return out_path + + +def read_output(meteor_output_path, n_repeats): + n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2)) + raw_scores = [] + average_scores = [] + for line in open(meteor_output_path): + if not line.startswith('Segment '): + continue + score = float(line.strip().split('\t')[1]) + raw_scores.append(score) + if len(raw_scores) == n_combinations: + average_scores.append(sum(raw_scores)/n_combinations) + raw_scores = [] + os.remove(meteor_output_path) + return average_scores + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input') + parser.add_argument('-n', '--repeat_times', type=int) + parser.add_argument('-m', '--meteor') + parser.add_argument('-o', '--output') + args = parser.parse_args() + + translations = read_translations(args.infile, args.repetitions) + sys.stderr.write('\nGenerating input for Meteor...') + ref_path, mt_path = generate_input(translations, args.repetitions) + sys.stderr.write('\nRunning Meteor...') + out_path = run_meteor(ref_path, mt_path, args.meteor) + sys.stderr.write('\nReading output...') + scores = read_output(out_path, args.repetitions) + sys.stderr.write('\nWriting results...') + with open(args.output, 'w') as o: + for scr in scores: + o.write('{}\n'.format(scr)) + o.close() + + +if __name__ == '__main__': + main() diff --git a/examples/unsupervised_quality_estimation/repeat_lines.py b/examples/unsupervised_quality_estimation/repeat_lines.py new file mode 100644 index 0000000000..661ca17c1b --- /dev/null +++ b/examples/unsupervised_quality_estimation/repeat_lines.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + + +def _normalize_spaces(line): + return ' '.join(line.split()) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_file', required=True, type=str) + parser.add_argument('-n', '--repeat_times', required=True, type=int) + parser.add_argument('-o', '--output_file', required=False, type=str) + args = parser.parse_args() + stream = open(args.output_file, 'w') if args.output_file else sys.stdout + + for line in open(args.input_file): + for _ in range(args.repeat_times): + stream.write(_normalize_spaces(line) + '\n') + + +if __name__ == '__main__': + main() diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index c249eb23f5..9a4a28da15 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -96,13 +96,7 @@ def __init__(self, args, task, models): # optimize model for generation for model in self.models: - model.make_generation_fast_( - beamable_mm_beam_size=( - None if getattr(args, 'no_beamable_mm', False) - else getattr(args, 'beam', 5) - ), - need_attn=getattr(args, 'print_alignment', False), - ) + model.prepare_for_inference_(args) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index 06c7ef712a..e92a3f6a71 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -10,6 +10,7 @@ from fairseq import utils from torch import Tensor, nn from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout try: from fairseq.model_parallel.megatron.mpu import ( @@ -61,7 +62,9 @@ def __init__( self.num_heads_partition * self.model_parallel_size == num_heads ), "Number of heads must be divisble by model parallel size" - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -228,11 +231,7 @@ def forward( attn_weights = attn_weights_float.type_as(attn_weights) with get_cuda_rng_tracker().fork(): - attn_probs = F.dropout( - attn_weights_float.type_as(attn_weights), - p=self.dropout, - training=self.training, - ) + attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 7f9b731ef6..5cf6cba118 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -120,16 +120,33 @@ def do_upgrade(m, prefix): do_upgrade(self, name) def set_num_updates(self, num_updates): - """ State from trainer to pass along to model at every update """ + """State from trainer to pass along to model at every update.""" def _apply(m): if hasattr(m, 'set_num_updates') and m != self: m.set_num_updates(num_updates) self.apply(_apply) + def prepare_for_inference_(self, args): + """Prepare model for inference.""" + kwargs = {} + kwargs['beamable_mm_beam_size'] = ( + None if getattr(args, 'no_beamable_mm', False) + else getattr(args, 'beam', 5) + ) + kwargs['need_attn'] = getattr(args, 'print_alignment', False) + if hasattr(args, 'retain_dropout'): + kwargs['retain_dropout'] = args.retain_dropout + kwargs['retain_dropout_modules'] = getattr( + args, 'retain_dropout_modules', None + ) + self.make_generation_fast_(**kwargs) def make_generation_fast_(self, **kwargs): - """Optimize model for faster generation.""" + """ + Legacy entry point to optimize model for faster generation. + Prefer prepare_for_inference_. + """ if self._is_generation_fast: return # only apply once self._is_generation_fast = True @@ -143,18 +160,23 @@ def apply_remove_weight_norm(module): self.apply(apply_remove_weight_norm) - seen = set() - - def apply_make_generation_fast_(module): - if ( - module != self - and hasattr(module, "make_generation_fast_") - and module not in seen - ): - seen.add(module) - module.make_generation_fast_(**kwargs) + def apply_make_generation_fast_(module, prefix): + if len(prefix) > 0: + prefix += "." - self.apply(apply_make_generation_fast_) + base_func = BaseFairseqModel.make_generation_fast_ + for n, m in module.named_modules(): + if ( + m != self + and hasattr(m, "make_generation_fast_") + # don't call this implementation again, e.g., if + # children modules also inherit from BaseFairseqModel + and m.make_generation_fast_.__func__ is not base_func + ): + name = prefix + n + m.make_generation_fast_(name=name, **kwargs) + + apply_make_generation_fast_(self, "") def train(mode=True): if mode: diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py index 38e241feb5..c60a2f4e5f 100644 --- a/fairseq/models/fconv.py +++ b/fairseq/models/fconv.py @@ -18,7 +18,7 @@ register_model_architecture, ) from fairseq.modules import ( - AdaptiveSoftmax, BeamableMM, GradMultiply, LearnedPositionalEmbedding, + AdaptiveSoftmax, BeamableMM, FairseqDropout, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, ) @@ -151,7 +151,9 @@ def __init__( convolutions=((512, 3),) * 20, dropout=0.1, ): super().__init__(dictionary) - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.num_attention_layers = None num_embeddings = len(dictionary) @@ -214,7 +216,7 @@ def forward(self, src_tokens, src_lengths): """ # embed tokens and positions x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) input_embedding = x # project to size of convolution @@ -240,7 +242,7 @@ def forward(self, src_tokens, src_lengths): if encoder_padding_mask is not None: x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) if conv.kernel_size[0] % 2 == 1: # padding is implicit in the conv x = conv(x) @@ -351,11 +353,13 @@ def __init__( self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, max_positions=1024, convolutions=((512, 3),) * 20, attention=True, dropout=0.1, share_embed=False, positional_embeddings=True, - adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, + adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0., ): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([2])) - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.need_attn = True convolutions = extend_conv_spec(convolutions) @@ -440,7 +444,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, # embed tokens and combine with positional embeddings x += pos_embed - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) target_embedding = x # project to size of convolution @@ -461,7 +465,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, else: residual = None - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = conv(x, incremental_state) x = F.glu(x, dim=2) @@ -491,7 +495,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, # project back to size of vocabulary if not using adaptive softmax if self.fc2 is not None and self.fc3 is not None: x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = self.fc3(x) return x, avg_attn_scores @@ -581,7 +585,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): return m -def Linear(in_features, out_features, dropout=0): +def Linear(in_features, out_features, dropout=0.): """Weight-normalized Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features) nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features)) @@ -589,7 +593,7 @@ def Linear(in_features, out_features, dropout=0): return nn.utils.weight_norm(m) -def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwargs): """Weight-normalized Conv1d layer optimized for decoding""" m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) @@ -598,7 +602,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs return nn.utils.weight_norm(m, dim=2) -def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): """Weight-normalized Conv1d layer""" from fairseq.modules import ConvTBC m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py index 4f0c0a4cdc..c3582da96f 100644 --- a/fairseq/models/fconv_self_att.py +++ b/fairseq/models/fconv_self_att.py @@ -21,6 +21,7 @@ register_model_architecture, ) from fairseq.modules import ( + FairseqDropout, DownsampledMultiHeadAttention, GradMultiply, LayerNorm, @@ -137,7 +138,7 @@ def build_model(cls, args, task): dropout=args.dropout, max_positions=args.max_source_positions, attention=eval(args.encoder_attention), - attention_nheads=args.encoder_attention_nheads + attention_nheads=args.encoder_attention_nheads, ) decoder = FConvDecoder( @@ -155,7 +156,7 @@ def build_model(cls, args, task): gated_attention=eval(args.gated_attention), downsample=eval(args.downsample), pretrained=pretrained, - trained_decoder=trained_decoder + trained_decoder=trained_decoder, ) model = FConvModelSelfAtt(encoder, decoder, trained_encoder) @@ -174,7 +175,9 @@ def __init__( attention_nheads=1, ): super().__init__(dictionary) - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.num_attention_layers = None num_embeddings = len(dictionary) @@ -218,7 +221,7 @@ def expand_bool_array(val): def forward(self, src_tokens, src_lengths): # embed tokens and positions x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) input_embedding = x.transpose(0, 1) # project to size of convolution @@ -238,7 +241,7 @@ def forward(self, src_tokens, src_lengths): if encoder_padding_mask is not None: x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) padding_l = (conv.kernel_size[0] - 1) // 2 padding_r = conv.kernel_size[0] // 2 x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) @@ -305,7 +308,9 @@ def __init__( self.register_buffer('version', torch.Tensor([2])) self.pretrained = pretrained self.pretrained_decoder = trained_decoder - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.need_attn = True in_channels = convolutions[0][0] @@ -410,7 +415,7 @@ def forward(self, prev_output_tokens, encoder_out): # embed tokens and positions x = self.embed_tokens(prev_output_tokens) + positions - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) target_embedding = x.transpose(0, 1) # project to size of convolution @@ -426,7 +431,7 @@ def forward(self, prev_output_tokens, encoder_out): ): residual = x if proj is None else proj(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = conv(x) x = F.glu(x, dim=2) @@ -451,7 +456,7 @@ def forward(self, prev_output_tokens, encoder_out): # project back to size of vocabulary x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) if not self.pretrained: x = self.fc3(x) @@ -538,7 +543,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwarg return m -def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): """Weight-normalized Conv1d layer""" from fairseq.modules import ConvTBC m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py index 8100e37ec4..05939e1c75 100644 --- a/fairseq/models/lightconv.py +++ b/fairseq/models/lightconv.py @@ -20,6 +20,7 @@ from fairseq.modules import ( AdaptiveSoftmax, DynamicConv, + FairseqDropout, LayerNorm, PositionalEmbedding, LightweightConv, @@ -214,7 +215,7 @@ class LightConvEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx @@ -254,7 +255,7 @@ def forward(self, src_tokens, **unused): x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -317,7 +318,7 @@ class LightConvDecoder(FairseqIncrementalDecoder): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): super().__init__(dictionary) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim @@ -402,7 +403,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, if positions is not None: x += positions - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -487,9 +488,9 @@ def __init__(self, args, kernel_size=0): raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) - self.dropout = args.dropout - self.relu_dropout = args.relu_dropout - self.input_dropout = args.input_dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) + self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) @@ -507,7 +508,7 @@ def forward(self, x, encoder_padding_mask): """ residual = x x = self.maybe_layer_norm(0, x, before=True) - x = F.dropout(x, p=self.input_dropout, training=self.training) + x = self.input_dropout_module(x) x = self.linear1(x) if self.act is not None: x = self.act(x) @@ -515,16 +516,16 @@ def forward(self, x, encoder_padding_mask): x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0) x = self.conv(x) x = self.linear2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(0, x, after=True) residual = x x = self.maybe_layer_norm(1, x, before=True) x = F.relu(self.fc1(x)) - x = F.dropout(x, p=self.relu_dropout, training=self.training) + x = self.relu_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(1, x, after=True) return x @@ -538,7 +539,7 @@ def maybe_layer_norm(self, i, x, before=False, after=False): def extra_repr(self): return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( - self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before) + self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) class LightConvDecoderLayer(nn.Module): @@ -575,9 +576,9 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0): raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) - self.dropout = args.dropout - self.relu_dropout = args.relu_dropout - self.input_dropout = args.input_dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) + self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) self.normalize_before = args.decoder_normalize_before self.conv_layer_norm = LayerNorm(self.embed_dim) @@ -588,7 +589,7 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0): else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, - dropout=args.attention_dropout, encoder_decoder_attention=True + dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) @@ -616,13 +617,13 @@ def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, if incremental_state is None: incremental_state = {} self.conv._set_input_buffer(incremental_state, prev_conv_state) - x = F.dropout(x, p=self.input_dropout, training=self.training) + x = self.input_dropout_module(x) x = self.linear1(x) if self.act is not None: x = self.act(x) x = self.conv(x, incremental_state=incremental_state) x = self.linear2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True) @@ -645,16 +646,16 @@ def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) - x = F.dropout(x, p=self.relu_dropout, training=self.training) + x = self.relu_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) return x, attn @@ -671,7 +672,7 @@ def make_generation_fast_(self, need_attn=False, **kwargs): def extra_repr(self): return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( - self.dropout, self.relu_dropout, self.input_dropout, self.normalize_before) + self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) def Embedding(num_embeddings, embedding_dim, padding_idx): diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 83baf7f065..850428a32d 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -15,7 +15,7 @@ register_model, register_model_architecture, ) -from fairseq.modules import AdaptiveSoftmax +from fairseq.modules import AdaptiveSoftmax, FairseqDropout from torch import Tensor from typing import Dict, List, Optional, Tuple @@ -158,7 +158,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): dropout_out=args.encoder_dropout_out, bidirectional=args.encoder_bidirectional, pretrained_embed=pretrained_encoder_embed, - max_source_positions=max_source_positions + max_source_positions=max_source_positions, ) decoder = LSTMDecoder( dictionary=task.target_dictionary, @@ -177,7 +177,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.criterion == 'adaptive_loss' else None ), max_target_positions=max_target_positions, - residuals=False + residuals=False, ) return cls(encoder, decoder) @@ -201,12 +201,12 @@ def __init__( self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=True, pretrained_embed=None, padding_idx=None, - max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS + max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(dictionary) self.num_layers = num_layers - self.dropout_in = dropout_in - self.dropout_out = dropout_out + self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) + self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) self.bidirectional = bidirectional self.hidden_size = hidden_size self.max_source_positions = max_source_positions @@ -222,7 +222,7 @@ def __init__( input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, - dropout=self.dropout_out if num_layers > 1 else 0., + dropout=self.dropout_out_module.p if num_layers > 1 else 0., bidirectional=bidirectional, ) self.left_pad = left_pad @@ -261,7 +261,7 @@ def forward( # embed tokens x = self.embed_tokens(src_tokens) - x = F.dropout(x, p=self.dropout_in, training=self.training) + x = self.dropout_in_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -282,7 +282,7 @@ def forward( # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_idx*1.0) - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.dropout_out_module(x) assert list(x.size()) == [seqlen, bsz, self.output_units] if self.bidirectional: @@ -356,11 +356,11 @@ def __init__( encoder_output_units=512, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, - residuals=False + residuals=False, ): super().__init__(dictionary) - self.dropout_in = dropout_in - self.dropout_out = dropout_out + self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) + self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed self.need_attn = True @@ -406,7 +406,7 @@ def __init__( if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax( - num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out + num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out, ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -452,7 +452,7 @@ def extract_features( # embed tokens x = self.embed_tokens(prev_output_tokens) - x = F.dropout(x, p=self.dropout_in, training=self.training) + x = self.dropout_in_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -491,8 +491,9 @@ def extract_features( hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer - input = F.dropout(hidden, p=self.dropout_out, training=self.training) - if self.residuals: input = input + prev_hiddens[i] + input = self.dropout_out_module(hidden) + if self.residuals: + input = input + prev_hiddens[i] # save state for next time step prev_hiddens[i] = hidden @@ -504,7 +505,7 @@ def extract_features( out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask) else: out = hidden - out = F.dropout(out, p=self.dropout_out, training=self.training) + out = self.dropout_out_module(out) # input feeding if input_feed is not None: @@ -534,7 +535,7 @@ def extract_features( if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: x = self.additional_fc(x) - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.dropout_out_module(x) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.need_attn and self.attention is not None: assert attn_scores is not None @@ -621,7 +622,7 @@ def LSTMCell(input_size, hidden_size, **kwargs): return m -def Linear(in_features, out_features, bias=True, dropout=0): +def Linear(in_features, out_features, bias=True, dropout=0.): """Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features, bias=bias) m.weight.data.uniform_(-0.1, 0.1) diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index a015d9eb3b..e1748145c3 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -333,7 +333,7 @@ def extract_features( if positions is not None: x += positions - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py index 46459732af..050755c308 100644 --- a/fairseq/models/nat/nonautoregressive_transformer.py +++ b/fairseq/models/nat/nonautoregressive_transformer.py @@ -317,7 +317,7 @@ def forward_embedding(self, prev_output_tokens, states=None): if positions is not None: x += positions - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) return x, decoder_padding_mask diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 9171aaf4a2..82eaceb8bc 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fairseq import options, utils from fairseq.models import ( FairseqEncoder, @@ -20,6 +19,7 @@ from fairseq.models.fairseq_encoder import EncoderOut from fairseq.modules import ( AdaptiveSoftmax, + FairseqDropout, LayerDropModuleList, LayerNorm, PositionalEmbedding, @@ -309,7 +309,7 @@ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim @@ -368,7 +368,7 @@ def forward_embedding(self, src_tokens): x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed @@ -531,7 +531,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed @@ -757,7 +757,7 @@ def extract_features_scriptable( if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 69b9f6962b..94bb86880b 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -12,6 +12,7 @@ from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .dynamic_convolution import DynamicConv, DynamicConv1dTBC from .dynamic_crf_layer import DynamicCRF +from .fairseq_dropout import FairseqDropout from .fp32_group_norm import Fp32GroupNorm from .gelu import gelu, gelu_accurate from .grad_multiply import GradMultiply @@ -43,6 +44,7 @@ 'DynamicConv1dTBC', 'DynamicConv', 'DynamicCRF', + 'FairseqDropout', 'Fp32GroupNorm', 'Fp32LayerNorm', 'gelu', diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py index 1789e85d44..96f8b75ad3 100644 --- a/fairseq/modules/adaptive_softmax.py +++ b/fairseq/modules/adaptive_softmax.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F from fairseq.modules.quant_noise import quant_noise +from fairseq.modules.fairseq_dropout import FairseqDropout from torch import nn @@ -55,7 +56,8 @@ class AdaptiveSoftmax(nn.Module): approximation for GPUs" (http://arxiv.org/abs/1609.04309). """ - def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_inputs=None, tie_proj=False, q_noise=0, qn_block_size=8): + def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_inputs=None, tie_proj=False, + q_noise=0, qn_block_size=8): super().__init__() if vocab_size > cutoff[-1]: @@ -68,7 +70,7 @@ def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_i self.vocab_size = vocab_size self.cutoff = cutoff - self.dropout = dropout + self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) self.input_dim = input_dim self.factor = factor self.q_noise = q_noise @@ -114,7 +116,7 @@ def _make_tail(self, adaptive_inputs=None, tie_proj=False): m = nn.Sequential( proj, - nn.Dropout(self.dropout), + nn.Dropout(self.dropout_module.p), quant_noise(out_proj, self.q_noise, self.qn_block_size), ) @@ -160,7 +162,7 @@ def forward(self, input, target): """ input = input.contiguous().view(-1, input.size(-1)) - input = F.dropout(input, p=self.dropout, training=self.training) + input = self.dropout_module(input) new_target, target_idxs = self.adapt_target(target) output = [self.head(input)] diff --git a/fairseq/modules/downsampled_multihead_attention.py b/fairseq/modules/downsampled_multihead_attention.py index 5c401e4f8e..eeaf9bbdd3 100644 --- a/fairseq/modules/downsampled_multihead_attention.py +++ b/fairseq/modules/downsampled_multihead_attention.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from fairseq.modules.scalar_bias import scalar_bias +from fairseq.modules.fairseq_dropout import FairseqDropout class SingleHeadAttention(nn.Module): @@ -23,7 +24,7 @@ def __init__( ): super().__init__() self.embed_dim = embed_dim - self.dropout = dropout + self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) self.head_index = head_index self.head_dim = head_dim self.project_input = project_input @@ -134,7 +135,7 @@ def forward( ) attn_weights = attn_weights.view(size, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) - attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights = self.dropout_module(attn_weights) attn = torch.bmm(attn_weights, v) if self.downsample: @@ -157,7 +158,6 @@ def __init__( ): self.embed_dim = embed_dim self.num_heads = num_heads - self.dropout = dropout self.head_dim = embed_dim // num_heads self.downsample = downsample self.gated = gated @@ -170,7 +170,7 @@ def __init__( attention_heads.append( SingleHeadAttention( out_channels, self.embed_dim, self.head_dim, index, - self.dropout, bias, self.project_input, self.gated, + dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads, ) ) @@ -181,7 +181,7 @@ def __init__( # if not being downsampled, we can do the heads with one linear layer instead of separate ones super().__init__() self.attention_module = SingleHeadAttention( - out_channels, self.embed_dim, self.head_dim, 1, self.dropout, + out_channels, self.embed_dim, self.head_dim, 1, dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads, ) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py index 3de931f0de..5a8ecb99a8 100644 --- a/fairseq/modules/dynamic_convolution.py +++ b/fairseq/modules/dynamic_convolution.py @@ -10,6 +10,7 @@ from fairseq import utils from .unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, @@ -74,7 +75,7 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads - self.weight_dropout = weight_dropout + self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) self.weight_softmax = weight_softmax self.renorm_padding = renorm_padding @@ -166,7 +167,7 @@ def _forward_unfolded(self, x, incremental_state, query): if self.weight_softmax and self.renorm_padding: weight = F.softmax(weight, dim=1) - weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = self.weight_dropout_module(weight, inplace=False) output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 output = output.view(T, B, C) @@ -191,7 +192,7 @@ def _forward_expanded(self, x, incremental_stat, query): if not self.renorm_padding: if self.weight_softmax: weight = F.softmax(weight, dim=1) - weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = self.weight_dropout_module(weight, inplace=False) weight = weight.narrow(1, 0, K).contiguous() weight = weight.view(T, B*H, K).transpose(0, 1) @@ -203,7 +204,7 @@ def _forward_expanded(self, x, incremental_stat, query): weight_expanded = weight_expanded.narrow(2, self.padding_l, T) # normalize the weight over valid positions like self-attention weight_expanded = F.softmax(weight_expanded, dim=2) - weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) else: P = self.padding_l # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length @@ -239,6 +240,6 @@ def extra_repr(self): if self.query_size != self.input_size: s += ', query_size={}'.format(self.query_size) - if self.weight_dropout > 0.: - s += ', weight_dropout={}'.format(self.weight_dropout) + if self.weight_dropout_module.p > 0.: + s += ', weight_dropout={}'.format(self.weight_dropout_module.p) return s diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py index 975840fa73..52cc1e8118 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -12,6 +12,7 @@ from fairseq import utils from fairseq.modules.unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout class dynamicconvFunction(Function): @@ -47,7 +48,8 @@ def __init__( bias=False, renorm_padding=False, conv_bias=False, - query_size=None): + query_size=None, + ): super(DynamicconvLayer, self).__init__() self.input_size = input_size @@ -56,7 +58,7 @@ def __init__( self.padding_l = padding_l self.num_heads = num_heads self.weight_softmax = weight_softmax - self.weight_dropout = weight_dropout + self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) self.renorm_padding = renorm_padding self.bias = bias @@ -102,8 +104,8 @@ def forward(self, x, incremental_state=None, query=None, unfold=None): weight = self.weight_linear(x).view(T, B, H, K) if self.weight_softmax: weight = F.softmax(weight, dim=-1) - if self.weight_dropout: - weight = F.dropout(weight, self.weight_dropout, training=self.training) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) weight = weight.permute(1, 2, 3, 0).contiguous() self.filters = weight @@ -166,7 +168,7 @@ def _forward_unfolded(self, x, incremental_state, query): if self.weight_softmax and self.renorm_padding: weight = F.softmax(weight, dim=1) - weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = self.weight_dropout_module(weight, inplace=False) output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 output = output.view(T, B, C) @@ -186,7 +188,7 @@ def _forward_expanded(self, x, incremental_stat, query): if not self.renorm_padding: if self.weight_softmax: weight = F.softmax(weight, dim=1) - weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = self.weight_dropout_module(weight, inplace=False) weight = weight.narrow(1, 0, K).contiguous() weight = weight.view(T, B*H, K).transpose(0, 1) @@ -198,7 +200,7 @@ def _forward_expanded(self, x, incremental_stat, query): weight_expanded = weight_expanded.narrow(2, self.padding_l, T) # normalize the weight over valid positions like self-attention weight_expanded = F.softmax(weight_expanded, dim=2) - weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) else: P = self.padding_l # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py new file mode 100644 index 0000000000..cbfacf477f --- /dev/null +++ b/fairseq/modules/fairseq_dropout.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +class FairseqDropout(nn.Module): + + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.training or self.apply_during_inference: + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + logger.warning( + 'Cannot enable dropout during inference for module {} ' + 'because module_name was not set'.format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + logger.info( + 'Enabling dropout during inference for module: {}'.format(name) + ) + self.apply_during_inference = True + else: + logger.info('Disabling dropout for module: {}'.format(name)) diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py index 1b046a1611..9b4c9a951e 100644 --- a/fairseq/modules/lightconv_layer/lightconv_layer.py +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -11,6 +11,7 @@ import lightconv_cuda from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout class lightconvFunction(Function): @@ -43,14 +44,15 @@ def __init__( weight_softmax=False, num_heads=1, weight_dropout=0., - bias=False): + bias=False, + ): super(LightconvLayer, self).__init__() self.input_size = input_size self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads self.weight_softmax = weight_softmax - self.weight_dropout = weight_dropout + self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) if bias: @@ -96,7 +98,7 @@ def forward(self, x, incremental_state=None): weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) - weight = F.dropout(weight, self.weight_dropout, training=self.training) + weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output @@ -107,8 +109,8 @@ def forward(self, x, incremental_state=None): weight = self.weight if self.weight_softmax: weight = F.softmax(self.weight, -1) - if self.weight_dropout: - weight = F.dropout(weight, self.weight_dropout, training=self.training) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) def reorder_incremental_state(self, incremental_state, new_order): diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py index 5c77843ecf..3d4cddb134 100644 --- a/fairseq/modules/lightweight_convolution.py +++ b/fairseq/modules/lightweight_convolution.py @@ -10,6 +10,7 @@ from fairseq import utils from fairseq.modules.unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, @@ -66,7 +67,7 @@ def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1, self.bias = nn.Parameter(torch.Tensor(input_size)) else: self.bias = None - self.weight_dropout = weight_dropout + self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) self.reset_parameters() def reset_parameters(self): @@ -86,7 +87,7 @@ def forward(self, input): if self.weight_softmax: weight = F.softmax(weight, dim=-1) - weight = F.dropout(weight, self.weight_dropout, training=self.training) + weight = self.weight_dropout_module(weight) # Merge every C/H entries into the batch dimension (C = self.input_size) # B x C x T -> (B * C/H) x H x T # One can also expand the weight to C x 1 x K by a factor of C/H @@ -128,7 +129,7 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads - self.weight_dropout = weight_dropout + self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) self.weight_softmax = weight_softmax self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) @@ -197,7 +198,7 @@ def _forward_unfolded(self, x, incremental_state): weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) - weight = F.dropout(weight, self.weight_dropout, training=self.training) + weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output @@ -227,7 +228,7 @@ def _forward_expanded(self, x, incremental_state): weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) weight_expanded = weight_expanded.narrow(2, P, T) - weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training) + weight_expanded = self.weight_dropout_module(weight_expanded) output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) @@ -250,6 +251,6 @@ def extra_repr(self): self.input_size, self.kernel_size, self.padding_l, self.num_heads, self.weight_softmax, self.bias is not None ) - if self.weight_dropout > 0.: - s += ', weight_dropout={}'.format(self.weight_dropout) + if self.weight_dropout_module.p > 0.: + s += ', weight_dropout={}'.format(self.weight_dropout_module.p) return s diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 1b87881359..e33dd450ee 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -8,10 +8,12 @@ import torch import torch.nn.functional as F -from fairseq import utils from torch import Tensor, nn from torch.nn import Parameter + +from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise @@ -44,7 +46,10 @@ def __init__( self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads - self.dropout = dropout + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -161,10 +166,10 @@ def forward( self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, + self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, - self.training, + self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, @@ -343,11 +348,8 @@ def forward( attn_weights, dim=-1, onnx_trace=self.onnx_trace ) attn_weights = attn_weights_float.type_as(attn_weights) - attn_probs = F.dropout( - attn_weights, - p=self.dropout, - training=self.training, - ) + attn_probs = self.dropout_module(attn_weights) + assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] diff --git a/fairseq/modules/sparse_transformer_sentence_encoder.py b/fairseq/modules/sparse_transformer_sentence_encoder.py index 9589f2b41e..3d50d5a882 100644 --- a/fairseq/modules/sparse_transformer_sentence_encoder.py +++ b/fairseq/modules/sparse_transformer_sentence_encoder.py @@ -57,7 +57,7 @@ def __init__( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, - dropout=self.dropout, + dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 854e2437c8..037d8e88ae 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -7,10 +7,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fairseq import utils from fairseq.modules import LayerNorm, MultiheadAttention from fairseq.modules.quant_noise import quant_noise +from fairseq.modules.fairseq_dropout import FairseqDropout from torch import Tensor @@ -36,14 +36,17 @@ def __init__(self, args): self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) - self.activation_dropout = getattr(args, "activation_dropout", 0) - if self.activation_dropout == 0: + activation_dropout_p = getattr(args, "activation_dropout", 0) + if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - self.activation_dropout = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) self.normalize_before = args.encoder_normalize_before self.fc1 = self.build_fc1( self.embed_dim, args.encoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size @@ -118,7 +121,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): key_padding_mask=encoder_padding_mask, attn_mask=attn_mask, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -128,9 +131,9 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) - x = F.dropout(x, p=float(self.activation_dropout), training=self.training) + x = self.activation_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) @@ -159,7 +162,7 @@ def __init__( ): super().__init__() self.embed_dim = args.decoder_embed_dim - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.quant_noise = getattr(args, "quant_noise_pq", 0) self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) @@ -174,10 +177,12 @@ def __init__( self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) - self.activation_dropout = getattr(args, "activation_dropout", 0) - if self.activation_dropout == 0: + activation_dropout_p = getattr(args, "activation_dropout", 0) + if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - self.activation_dropout = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. @@ -314,7 +319,7 @@ def forward( need_weights=False, attn_mask=self_attn_mask, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -344,7 +349,7 @@ def forward( need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) @@ -354,9 +359,9 @@ def forward( x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) - x = F.dropout(x, p=float(self.activation_dropout), training=self.training) + x = self.activation_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 414035f2bc..8a6994181b 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -7,8 +7,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fairseq.modules import ( + FairseqDropout, LayerDropModuleList, LayerNorm, MultiheadAttention, @@ -16,7 +16,6 @@ TransformerSentenceEncoderLayer, ) from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ -import random def init_bert_params(module): @@ -103,7 +102,7 @@ def __init__( super().__init__() self.padding_idx = padding_idx self.vocab_size = vocab_size - self.dropout = dropout + self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim @@ -154,13 +153,13 @@ def __init__( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, - dropout=self.dropout, + dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, - qn_block_size=qn_block_size + qn_block_size=qn_block_size, ) for _ in range(num_encoder_layers) ]) @@ -250,7 +249,7 @@ def forward( if self.emb_layer_norm is not None: x = self.emb_layer_norm(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) # account for padding while computing the representation if padding_mask is not None: diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index cadcc89981..383938f68f 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fairseq import utils from fairseq.modules import ( @@ -15,6 +14,8 @@ MultiheadAttention, ) from fairseq.modules.quant_noise import quant_noise +from fairseq.modules.fairseq_dropout import FairseqDropout + class TransformerSentenceEncoderLayer(nn.Module): @@ -44,8 +45,8 @@ def __init__( # Initialize parameters self.embedding_dim = embedding_dim - self.dropout = dropout - self.activation_dropout = activation_dropout + self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) + self.activation_dropout_module = FairseqDropout(activation_dropout, module_name=self.__class__.__name__) # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) @@ -124,15 +125,15 @@ def forward( need_weights=False, attn_mask=self_attn_mask, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.self_attn_layer_norm(x) residual = x x = self.activation_fn(self.fc1(x)) - x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.activation_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x x = self.final_layer_norm(x) return x, attn diff --git a/fairseq/options.py b/fairseq/options.py index 62aba383d3..ec1b284307 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -188,6 +188,12 @@ def parse_args_and_arch( if args.tpu and args.fp16: raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") + if getattr(args, "seed", None) is None: + args.seed = 1 # default seed for training + args.no_seed_provided = True + else: + args.no_seed_provided = False + # Apply architecture configuration. if hasattr(args, "arch"): ARCH_CONFIG_REGISTRY[args.arch](args) @@ -216,7 +222,7 @@ def get_parser(desc, default_task="translation"): parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', help='path to save logs for tensorboard, should match --logdir ' 'of running tensorboard (default: no tensorboard logging)') - parser.add_argument('--seed', default=1, type=int, metavar='N', + parser.add_argument('--seed', default=None, type=int, metavar='N', help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA') @@ -607,6 +613,11 @@ def add_generation_args(parser): help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'), group.add_argument('--retain-iter-history', action='store_true', help='if set, decoding returns the whole history of iterative refinement') + group.add_argument('--retain-dropout', action='store_true', + help='Use dropout at inference time') + group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str, + help='if set, only retain dropout for the specified modules; ' + 'if not set, then dropout will be retained for all modules') # special decoding format for advanced decoding. group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index aa0d98ff49..70a1c93a88 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -27,7 +27,6 @@ def __init__( normalize_scores=True, len_penalty=1.0, unk_penalty=0.0, - retain_dropout=False, temperature=1.0, match_source_len=False, no_repeat_ngram_size=0, @@ -50,8 +49,6 @@ def __init__( shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) - retain_dropout (bool, optional): use dropout when generating - (default: False) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) @@ -77,7 +74,6 @@ def __init__( self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty - self.retain_dropout = retain_dropout self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size @@ -90,8 +86,8 @@ def __init__( # As a module attribute, setting it would break in multithread # settings when the model is shared. self.should_set_src_lengths = hasattr(self.search, 'needs_src_lengths') and self.search.needs_src_lengths - if not self.retain_dropout: - self.model.eval() + + self.model.eval() def cuda(self): self.model.cuda() diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 031a9c58fa..1dc51aa9ee 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -314,12 +314,15 @@ def valid_step(self, sample, model, criterion): def inference_step(self, generator, models, sample, prefix_tokens=None): with torch.no_grad(): + if self.args.decoder_langtok: + bos_token = _lang_token_index(self.target_dictionary, self.args.target_lang) + else: + bos_token = self.target_dictionary.eos() return generator.generate( - models, - sample, - prefix_tokens=prefix_tokens, - bos_token=_lang_token_index(self.target_dictionary, self.args.target_lang) - if self.args.decoder_langtok else self.target_dictionary.eos(), + models, + sample, + prefix_tokens=prefix_tokens, + bos_token=bos_token, ) def reduce_metrics(self, logging_outputs, criterion): diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index e538d7eeec..cf5b62c366 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -104,7 +104,7 @@ def main(parsed_args, **unused_kwargs): # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: - model.make_generation_fast_() + model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 81cf86b337..ab7deaa1b5 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -12,6 +12,8 @@ import os import sys +import numpy as np + import torch from fairseq import bleu, checkpoint_utils, options, tasks, utils @@ -51,6 +53,11 @@ def _main(args, output_file): args.max_tokens = 12000 logger.info(args) + # Fix seed for stochastic decoding + if args.seed is not None and not args.no_seed_provided: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits @@ -75,10 +82,7 @@ def _main(args, output_file): # Optimize ensemble for generation for model in models: - model.make_generation_fast_( - beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment, - ) + model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index df6120a6cb..40e07d46b2 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -14,6 +14,8 @@ import sys import os +import numpy as np + import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils @@ -83,6 +85,11 @@ def main(args): logger.info(args) + # Fix seed for stochastic decoding + if args.seed is not None and not args.no_seed_provided: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + use_cuda = torch.cuda.is_available() and not args.cpu # Setup task, e.g., translation @@ -103,10 +110,7 @@ def main(args): # Optimize ensemble for generation for model in models: - model.make_generation_fast_( - beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment, - ) + model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 4eca1debf6..f5e53fd000 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -112,6 +112,7 @@ def test_generation(self): '--match-source-len', ]) generate_main(data_dir, ['--prefix-size', '2']) + generate_main(data_dir, ['--retain-dropout']) def test_eval_bleu(self): with contextlib.redirect_stdout(StringIO()): diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py new file mode 100644 index 0000000000..89e05473f5 --- /dev/null +++ b/tests/test_inference_dropout.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from tests.test_sequence_generator import get_dummy_task_and_parser +from fairseq.models.transformer import TransformerModel + + +class TestInferenceDropout(unittest.TestCase): + + def setUp(self): + self.task, self.parser = get_dummy_task_and_parser() + TransformerModel.add_args(self.parser) + self.args = self.parser.parse_args([]) + self.args.encoder_layers = 2 + self.args.decoder_layers = 1 + + def test_sets_inference_dropout_to_true(self): + self.args.retain_dropout = True + self.transformer_model = TransformerModel.build_model(self.args, self.task) + self.transformer_model.prepare_for_inference_(self.args) + assert self.transformer_model.encoder.dropout_module.apply_during_inference + assert self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.encoder.layers: + assert layer.dropout_module.apply_during_inference + + def test_inference_dropout_false_by_default(self): + self.transformer_model = TransformerModel.build_model(self.args, self.task) + self.transformer_model.prepare_for_inference_(self.args) + assert not self.transformer_model.encoder.dropout_module.apply_during_inference + assert not self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.encoder.layers: + assert not layer.dropout_module.apply_during_inference + for layer in self.transformer_model.decoder.layers: + assert not layer.dropout_module.apply_during_inference + + def test_applies_training_mode(self): + self.transformer_model = TransformerModel.build_model(self.args, self.task) + assert self.transformer_model.encoder.dropout_module.training + for layer in self.transformer_model.encoder.layers: + assert layer.dropout_module.training + + self.transformer_model.eval() + assert not self.transformer_model.decoder.dropout_module.training + for layer in self.transformer_model.encoder.layers: + assert not layer.dropout_module.training + + def test_retain_modules(self): + self.args.retain_dropout = True + self.args.retain_dropout_modules = ['TransformerEncoder', 'TransformerEncoderLayer'] + self.transformer_model = TransformerModel.build_model(self.args, self.task) + self.transformer_model.prepare_for_inference_(self.args) + assert self.transformer_model.encoder.dropout_module.apply_during_inference + assert not self.transformer_model.decoder.dropout_module.apply_during_inference + for layer in self.transformer_model.decoder.layers: + assert not layer.dropout_module.apply_during_inference From 16e9661bd968cf66b02d7870c038d7219da3a5b9 Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Wed, 8 Jul 2020 14:46:02 -0700 Subject: [PATCH 054/707] avoid fp16 unscales and multiply_grads (#1201) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The object of this patch is to avoid fp16 unscale calls which can potentially under/over-flow by 1) scaling grad_narm instead of unscaling grads before calculating grad_norm and 2) using scale argument to step (if supported by optimizer). By letting the optimizer scale we avoid multiply_grads (saving on GPU compute/mem). We also get better precision since the unscale occurs in the kernel resulting in an FP32 unscaled grad instead of an FP16 unscaled grad. A side-effect of this patch is a noticeable WPS win due to a multi-tensor kernel being used for grad_norm and because we avoid multiply_grads. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1201 Test Plan: Verified grad_norm and loss before and after. Before: epoch 001 | loss 19.506 | ppl 744403 | wps 13966.7 | ups 0.21 | wpb 65536 | bsz 128 | num_updates 50 | lr 6.34875e-06 | gnorm 8.173 | loss_scale 10 | train_wall 250 | wall 259 After: epoch 001 | loss 19.506 | ppl 744363 | wps 14003 | ups 0.21 | wpb 65536 | bsz 128 | num_updates 50 | lr 6.34875e-06 | gnorm 8.173 | loss_scale 10 | train_wall 250 | wall 258 # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22251842 Pulled By: msbaines fbshipit-source-id: e6d82cdd3c95e7770835abe054db4b50e6ad569e --- fairseq/optim/fairseq_optimizer.py | 13 +++++++-- fairseq/optim/fp16_optimizer.py | 30 ++++++++++++++------- fairseq/optim/fused_adam.py | 4 +++ fairseq/utils.py | 43 +++++++++++++++++++++++++++--- 4 files changed, 76 insertions(+), 14 deletions(-) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index cc1daa82b6..3242a92a35 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -90,9 +90,12 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) - def step(self, closure=None): + def step(self, closure=None, scale=1.): """Performs a single optimization step.""" - self.optimizer.step(closure) + if self.supports_step_with_scale: + self.optimizer.step(closure, scale=scale) + else: + self.optimizer.step(closure) def zero_grad(self): """Clears the gradients of all optimized parameters.""" @@ -106,6 +109,12 @@ def supports_memory_efficient_fp16(self): return self.optimizer.supports_memory_efficient_fp16 return False + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, 'supports_step_with_scale'): + return self.optimizer.supports_step_with_scale + return False + @property def supports_flat_params(self): """ diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 277b21eb94..8154c72111 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -342,32 +342,37 @@ def backward(self, loss): if self.scaler is not None: loss = loss * self.scaler.loss_scale self._grads_are_scaled = True + self._multiply_factor = 1 loss.backward() - def _unscale_grads(self, multiply_grads=1.): + def _unscale_grads(self): if self._grads_are_scaled: self._grads_are_scaled = False # correct for dynamic loss scaler - self.wrapped_optimizer.multiply_grads(multiply_grads / self.scaler.loss_scale) + self.wrapped_optimizer.multiply_grads(self._multiply_factor / self.scaler.loss_scale) + self._multiply_factor = 1 else: - assert multiply_grads == 1. + assert self._multiply_factor == 1 def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" if self._grads_are_scaled: - self._unscale_grads(c) + self._multiply_factor *= c else: self.wrapped_optimizer.multiply_grads(c) def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" - self._unscale_grads() - grad_norm = self.wrapped_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) # detect overflow and adjust loss scale if self.scaler is not None: - overflow = DynamicLossScaler.has_overflow(grad_norm) + scale = self._multiply_factor / self.scaler.loss_scale + grad_norm = self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) * scale + grad_norm_cpu = float(grad_norm) + if grad_norm_cpu > max_norm: + self._multiply_factor *= max_norm / grad_norm_cpu + overflow = DynamicLossScaler.has_overflow(grad_norm_cpu) prev_scale = self.scaler.loss_scale self.scaler.update_scale(overflow) if overflow: @@ -381,13 +386,20 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): 'increasing the batch size.' ).format(self.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) + else: + self._unscale_grads() + grad_norm = self.wrapped_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) return grad_norm def step(self, closure=None): """Performs a single optimization step.""" - self._unscale_grads() - self.wrapped_optimizer.step(closure) + if self.supports_step_with_scale and self._grads_are_scaled: + scale = self._multiply_factor / self.scaler.loss_scale + self.wrapped_optimizer.step(closure, scale=scale) + else: + self._unscale_grads() + self.wrapped_optimizer.step(closure) def zero_grad(self): """Clears the gradients of all optimized parameters.""" diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index 7d6c0c4a5e..9024451aff 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -96,6 +96,10 @@ def supports_memory_efficient_fp16(self): def supports_flat_params(self): return True + @property + def supports_step_with_scale(self): + return True + def step(self, closure=None, grads=None, scale=1., grad_norms=None): """Performs a single optimization step. Arguments: diff --git a/fairseq/utils.py b/fairseq/utils.py index aecc68d52d..8a215e94c2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -23,6 +23,12 @@ from fairseq.modules.multihead_attention import MultiheadAttention from torch import Tensor +try: + from amp_C import multi_tensor_l2norm + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + logger = logging.getLogger(__name__) @@ -250,6 +256,30 @@ def item(tensor): return tensor +def multi_tensor_total_norm(grads, chunk_size=2048*32) -> torch.Tensor: + per_device_grads = {} + norms = [] + for grad in grads: + device = grad.device + cur_device_grads = per_device_grads.get(device) + if cur_device_grads is None: + cur_device_grads = [] + per_device_grads[device] = cur_device_grads + cur_device_grads.append(grad) + for device in per_device_grads.keys(): + cur_device_grads = per_device_grads[device] + if device.type == "cuda": + # TODO(msb) return has_inf + has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) + with torch.cuda.device(device): + norm = multi_tensor_l2norm(chunk_size, has_inf, [cur_device_grads], False) + norms.append(norm[0]) + else: + norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] + total_norm = torch.norm(torch.stack(norms)) + return total_norm + + def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if isinstance(params, torch.Tensor): params = [params] @@ -264,9 +294,16 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: - total_norm = torch.norm( - torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) - ) + if multi_tensor_l2norm_available: + total_norm = multi_tensor_total_norm(grads) + else: + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) + total_norm = torch.norm( + torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) + ) if aggregate_norm_fn is not None: total_norm = aggregate_norm_fn(total_norm) From 5d88d379cab13c8351b39341306d7a30a2a8acb8 Mon Sep 17 00:00:00 2001 From: Aditya Pillai Date: Thu, 9 Jul 2020 13:20:51 -0700 Subject: [PATCH 055/707] bug fix: use cls.load_dictionary for multilingual translation Summary: Currently, multilingual translation imports Dictionary and calls its load function. However, this does not permit extending the class with a different load_dictionary function to modify its behavior. Reviewed By: myleott, chtran Differential Revision: D22441356 fbshipit-source-id: b0ef159182b15adb479b117581ddcd2f65724980 --- fairseq/tasks/multilingual_translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 1dc51aa9ee..59634131fc 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -138,7 +138,7 @@ def prepare(cls, args, **kargs): for lang in sorted_langs: paths = utils.split_paths(args.data) assert len(paths) > 0 - dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) + dicts[lang] = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) if len(dicts) > 0: assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() From ffecb4e3496379edf5ecae1483df5b7e0886c264 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 14 Jul 2020 14:11:59 -0700 Subject: [PATCH 056/707] Small fixes (#1215) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1215 Reviewed By: ngoyal2707, msbaines Differential Revision: D22514719 Pulled By: myleott fbshipit-source-id: 5f15ba501fd66af1eb49b5702aff940f06c3d91f --- docs/command_line_tools.rst | 12 +++--- examples/byte_level_bpe/README.md | 6 +-- examples/byte_level_bpe/get_bitext.py | 4 +- examples/mbart/README.md | 2 +- examples/scaling_nmt/README.md | 28 +++++++++++-- examples/wav2vec/README.md | 2 +- fairseq/data/base_wrapper_dataset.py | 17 ++++++++ fairseq/data/encoders/hf_bert_bpe.py | 9 ++-- fairseq/data/encoders/sentencepiece_bpe.py | 8 ++-- fairseq/data/fairseq_dataset.py | 4 +- fairseq/data/iterators.py | 2 +- fairseq/data/language_pair_dataset.py | 4 +- fairseq/data/legacy/masked_lm_dictionary.py | 2 +- fairseq/hub_utils.py | 2 +- fairseq/logging/progress_bar.py | 10 ++++- .../model_parallel/models/transformer_lm.py | 6 --- fairseq/models/transformer.py | 9 ++-- fairseq/optim/fp16_optimizer.py | 41 ++++++++----------- fairseq/options.py | 2 +- fairseq/tasks/translation.py | 3 ++ fairseq/utils.py | 13 ++---- fairseq_cli/generate.py | 4 +- fairseq_cli/interactive.py | 2 +- fairseq_cli/train.py | 6 +++ scripts/average_checkpoints.py | 2 +- scripts/{sacrebleu_pregen.sh => sacrebleu.sh} | 13 +++--- 26 files changed, 121 insertions(+), 92 deletions(-) rename scripts/{sacrebleu_pregen.sh => sacrebleu.sh} (55%) diff --git a/docs/command_line_tools.rst b/docs/command_line_tools.rst index 28e011a4ef..c16300ff5c 100644 --- a/docs/command_line_tools.rst +++ b/docs/command_line_tools.rst @@ -17,7 +17,7 @@ Fairseq provides several command-line tools for training and evaluating models: fairseq-preprocess ~~~~~~~~~~~~~~~~~~ -.. automodule:: preprocess +.. automodule:: fairseq_cli.preprocess .. argparse:: :module: fairseq.options @@ -29,7 +29,7 @@ fairseq-preprocess fairseq-train ~~~~~~~~~~~~~ -.. automodule:: train +.. automodule:: fairseq_cli.train .. argparse:: :module: fairseq.options @@ -41,7 +41,7 @@ fairseq-train fairseq-generate ~~~~~~~~~~~~~~~~ -.. automodule:: generate +.. automodule:: fairseq_cli.generate .. argparse:: :module: fairseq.options @@ -53,7 +53,7 @@ fairseq-generate fairseq-interactive ~~~~~~~~~~~~~~~~~~~ -.. automodule:: interactive +.. automodule:: fairseq_cli.interactive .. argparse:: :module: fairseq.options @@ -65,7 +65,7 @@ fairseq-interactive fairseq-score ~~~~~~~~~~~~~ -.. automodule:: score +.. automodule:: fairseq_cli.score .. argparse:: :module: fairseq_cli.score @@ -77,7 +77,7 @@ fairseq-score fairseq-eval-lm ~~~~~~~~~~~~~~~ -.. automodule:: eval_lm +.. automodule:: fairseq_cli.eval_lm .. argparse:: :module: fairseq.options diff --git a/examples/byte_level_bpe/README.md b/examples/byte_level_bpe/README.md index fb44a87e50..d8c4cb6747 100644 --- a/examples/byte_level_bpe/README.md +++ b/examples/byte_level_bpe/README.md @@ -38,10 +38,10 @@ fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_le # BPE=--bpe bytes # BPE=--bpe characters BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model -# BPE=--bpe sentencepiece --sentencepiece-vocab data/spm_bpe2048.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model # BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model -# BPE=--bpe sentencepiece --sentencepiece-vocab data/spm_bpe4096.model -# BPE=--bpe sentencepiece --sentencepiece-vocab data/spm_bpe16384.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model ``` ```bash diff --git a/examples/byte_level_bpe/get_bitext.py b/examples/byte_level_bpe/get_bitext.py index a114391143..7770ea667b 100644 --- a/examples/byte_level_bpe/get_bitext.py +++ b/examples/byte_level_bpe/get_bitext.py @@ -95,8 +95,8 @@ def _apply_bbpe(model_path: str, in_path: str, out_path: str): def _apply_bpe(model_path: str, in_path: str, out_path: str): - Args = namedtuple('Args', ['sentencepiece_vocab']) - args = Args(sentencepiece_vocab=model_path) + Args = namedtuple('Args', ['sentencepiece_model']) + args = Args(sentencepiece_model=model_path) tokenizer = SentencepieceBPE(args) with open(in_path) as f, open(out_path, 'w') as f_o: for s in f: diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 08a2c6bee0..e68ba09c1c 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -98,7 +98,7 @@ fairseq-generate path_2_data \ --task translation_from_pretrained_bart \ --gen-subset test \ -t ro_RO -s en_XX \ - --bpe 'sentencepiece' --sentencepiece-vocab sentence.bpe.model \ + --bpe 'sentencepiece' --sentencepiece-model sentence.bpe.model \ --sacrebleu --remove-bpe 'sentencepiece'\ --max-sentences 32 --langs $langs > en_ro diff --git a/examples/scaling_nmt/README.md b/examples/scaling_nmt/README.md index 0782ef07ba..0cc3360c3b 100644 --- a/examples/scaling_nmt/README.md +++ b/examples/scaling_nmt/README.md @@ -70,16 +70,36 @@ good, but you may need to adjust this depending on how long you've trained: ```bash python scripts/average_checkpoints \ --inputs /path/to/checkpoints \ - --num-epoch-checkpoints 5 \ - --output checkpoint.avg5.pt + --num-epoch-checkpoints 10 \ + --output checkpoint.avg10.pt ``` Next, generate translations using a beam width of 4 and length penalty of 0.6: ```bash fairseq-generate \ data-bin/wmt16_en_de_bpe32k \ - --path checkpoint.avg5.pt \ - --beam 4 --lenpen 0.6 --remove-bpe + --path checkpoint.avg10.pt \ + --beam 4 --lenpen 0.6 --remove-bpe > gen.out +``` + +Finally, we apply the ["compound splitting" script](/scripts/compound_split_bleu.sh) to +add spaces around dashes. For example "Café-Liebhaber" would become three tokens: +"Café - Liebhaber". This typically results in larger BLEU scores, but it is not +appropriate to compare these inflated scores to work which does not include this trick. +This trick was used in the [original AIAYN code](https://github.com/tensorflow/tensor2tensor/blob/fc9335c0203685cbbfe2b30c92db4352d8f60779/tensor2tensor/utils/get_ende_bleu.sh), +so we used it in the Scaling NMT paper as well. That said, it's strongly advised to +report [sacrebleu](https://github.com/mjpost/sacrebleu) scores instead. + +To compute "compound split" tokenized BLEU (not recommended!): +```bash +bash scripts/compound_split_bleu.sh gen.out +# BLEU4 = 29.29, 60.3/35.0/22.8/15.3 (BP=1.000, ratio=1.004, syslen=64763, reflen=64496) +``` + +To compute detokenized BLEU with sacrebleu (preferred): +```bash +bash scripts/sacrebleu.sh wmt14/full en de gen.out +# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.4.3 = 28.6 59.3/34.3/22.1/14.9 (BP = 1.000 ratio = 1.016 hyp_len = 63666 ref_len = 62688) ``` ## Citation diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 963c133f83..1ea3b4fc2f 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -30,7 +30,7 @@ Given a directory containing wav files to be used for pretraining (we recommend ### Prepare training data manifest: ``` -$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav ``` ### Train a wav2vec model: diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py index d14c3c76f3..8b5326a635 100644 --- a/fairseq/data/base_wrapper_dataset.py +++ b/fairseq/data/base_wrapper_dataset.py @@ -46,6 +46,23 @@ def supports_prefetch(self): def prefetch(self, indices): self.dataset.prefetch(indices) + def get_batch_shapes(self): + return self.dataset.get_batch_shapes() + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + return self.dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.dataset, 'set_epoch'): diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py index 40c69d53c5..16adc45aee 100644 --- a/fairseq/data/encoders/hf_bert_bpe.py +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -21,12 +21,10 @@ def add_args(parser): def __init__(self, args): try: - from pytorch_transformers import BertTokenizer - from pytorch_transformers.tokenization_utils import clean_up_tokenization + from transformers import BertTokenizer except ImportError: raise ImportError( - 'Please install 1.0.0 version of pytorch_transformers' - 'with: pip install pytorch-transformers' + 'Please install transformers with: pip install transformers' ) if 'bpe_vocab_file' in args: @@ -37,13 +35,12 @@ def __init__(self, args): else: vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) - self.clean_up_tokenization = clean_up_tokenization def encode(self, x: str) -> str: return ' '.join(self.bert_tokenizer.tokenize(x)) def decode(self, x: str) -> str: - return self.clean_up_tokenization( + return self.bert_tokenizer.clean_up_tokenization( self.bert_tokenizer.convert_tokens_to_string(x.split(' ')) ) diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index 7ae4fc57d1..e5ff5db389 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -13,16 +13,16 @@ class SentencepieceBPE(object): @staticmethod def add_args(parser): # fmt: off - parser.add_argument('--sentencepiece-vocab', type=str, - help='path to sentencepiece vocab') + parser.add_argument('--sentencepiece-model', type=str, + help='path to sentencepiece model') # fmt: on def __init__(self, args): - vocab = file_utils.cached_path(args.sentencepiece_vocab) + sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) try: import sentencepiece as spm self.sp = spm.SentencePieceProcessor() - self.sp.Load(vocab) + self.sp.Load(sentencepiece_model) except ImportError: raise ImportError('Please install sentencepiece with: pip install sentencepiece') diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index b03c90ed43..5786d5c851 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -9,9 +9,9 @@ class EpochListening: """Mixin for receiving updates whenever the epoch increments.""" + def set_epoch(self, epoch): - """Will receive the updated epoch number at the beginning of the epoch. - """ + """Will receive the updated epoch number at the beginning of the epoch.""" pass diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index cd53885d7d..95aed8f295 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -489,7 +489,7 @@ def __next__(self): self._create_consumer() # Notify the user if there is a data loading bottleneck - if self._queue.qsize() < max(1, self._queue.maxsize // 2): + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): if time.time() - self.start_time > 5 * 60: if self.warning_time is None or time.time() - self.warning_time > 15 * 60: logger.info( diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 63c95911fd..d70c0c332d 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -73,7 +73,9 @@ def compute_alignment_weights(alignments): ]).index_select(0, sort_order) ntokens = tgt_lengths.sum().item() - if input_feeding: + if samples[0].get('prev_output_tokens', None) is not None: + prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target) + elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( diff --git a/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/data/legacy/masked_lm_dictionary.py index 68f9b83fbd..bff4bcb5ec 100644 --- a/fairseq/data/legacy/masked_lm_dictionary.py +++ b/fairseq/data/legacy/masked_lm_dictionary.py @@ -42,7 +42,7 @@ def __init__( cls='', sep='' ): - super().__init__(pad=pad, eos=eos, unk=unk) + super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) self.cls_word = cls self.sep_word = sep self.cls_index = self.add_symbol(cls) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 9a4a28da15..e040a8c3f3 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -59,7 +59,7 @@ def from_pretrained( for file, arg in { 'code': 'bpe_codes', 'bpecodes': 'bpe_codes', - 'sentencepiece.bpe.model': 'sentencepiece_vocab', + 'sentencepiece.bpe.model': 'sentencepiece_model', }.items(): path = os.path.join(model_path, file) if os.path.exists(path): diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 84c208f323..97e4162ea0 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -275,7 +275,12 @@ class TqdmProgressBar(BaseProgressBar): def __init__(self, iterable, epoch=None, prefix=None): super().__init__(iterable, epoch, prefix) from tqdm import tqdm - self.tqdm = tqdm(iterable, self.prefix, leave=False) + self.tqdm = tqdm( + iterable, + self.prefix, + leave=False, + disable=(logger.getEffectiveLevel() > logging.INFO), + ) def __iter__(self): return iter(self.tqdm) @@ -287,7 +292,8 @@ def log(self, stats, tag=None, step=None): def print(self, stats, tag=None, step=None): """Print end-of-epoch stats.""" postfix = self._str_pipes(self._format_stats(stats)) - self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix)) + with rename_logger(logger, tag): + logger.info('{} | {}'.format(self.prefix, postfix)) try: diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 37d3a26336..81bc93bc0a 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -3,10 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn -from fairseq import utils from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer_lm import ( base_lm_architecture, @@ -14,11 +12,7 @@ ) from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder try: - from fairseq.model_parallel.megatron.mpu import get_model_parallel_group - from fairseq.model_parallel.megatron.mpu import get_model_parallel_rank - from fairseq.model_parallel.megatron.mpu import get_model_parallel_world_size from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding - from fairseq.model_parallel.megatron.mpu.utils import VocabUtility has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 82eaceb8bc..6fd5c2bd05 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -331,6 +331,11 @@ def __init__(self, args, dictionary, embed_tokens): else None ) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), @@ -353,10 +358,6 @@ def __init__(self, args, dictionary, embed_tokens): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim) - else: - self.layernorm_embedding = None def build_encoder_layer(self, args): return TransformerEncoderLayer(args) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 8154c72111..abd4b5c0fc 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -286,7 +286,7 @@ def set_lr(self, lr): class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): - # forward __init__ call to the next class in mro(method resolution order) + # forward __init__ call to the next class in MRO (method resolution order) super().__init__(*args, **kwargs) @property @@ -341,37 +341,29 @@ def backward(self, loss): """ if self.scaler is not None: loss = loss * self.scaler.loss_scale - self._grads_are_scaled = True - self._multiply_factor = 1 + self._multiply_factor /= float(self.scaler.loss_scale) loss.backward() def _unscale_grads(self): - if self._grads_are_scaled: - self._grads_are_scaled = False - - # correct for dynamic loss scaler - self.wrapped_optimizer.multiply_grads(self._multiply_factor / self.scaler.loss_scale) - self._multiply_factor = 1 - else: - assert self._multiply_factor == 1 + if self._multiply_factor != 1.: + self.wrapped_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1. def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" - if self._grads_are_scaled: - self._multiply_factor *= c - else: - self.wrapped_optimizer.multiply_grads(c) + self._multiply_factor *= c def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" + max_norm = float(max_norm) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) - # detect overflow and adjust loss scale if self.scaler is not None: - scale = self._multiply_factor / self.scaler.loss_scale - grad_norm = self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) * scale grad_norm_cpu = float(grad_norm) - if grad_norm_cpu > max_norm: + if grad_norm_cpu > max_norm > 0.: self._multiply_factor *= max_norm / grad_norm_cpu + + # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm_cpu) prev_scale = self.scaler.loss_scale self.scaler.update_scale(overflow) @@ -387,16 +379,15 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): ).format(self.min_loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) else: - self._unscale_grads() - grad_norm = self.wrapped_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef return grad_norm def step(self, closure=None): """Performs a single optimization step.""" - if self.supports_step_with_scale and self._grads_are_scaled: - scale = self._multiply_factor / self.scaler.loss_scale - self.wrapped_optimizer.step(closure, scale=scale) + if self.supports_step_with_scale: + self.wrapped_optimizer.step(closure, scale=self._multiply_factor) else: self._unscale_grads() self.wrapped_optimizer.step(closure) @@ -404,7 +395,7 @@ def step(self, closure=None): def zero_grad(self): """Clears the gradients of all optimized parameters.""" self.wrapped_optimizer.zero_grad() - self._grads_are_scaled = False + self._multiply_factor = 1. class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): diff --git a/fairseq/options.py b/fairseq/options.py index ec1b284307..77af81c3e4 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -32,7 +32,7 @@ def get_training_parser(default_task="translation"): def get_generation_parser(interactive=False, default_task="translation"): parser = get_parser("Generation", default_task) add_dataset_args(parser, gen=True) - add_distributed_training_args(parser) + add_distributed_training_args(parser, default_world_size=1) add_generation_args(parser) if interactive: add_interactive_args(parser) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index c3237aa968..7077943c1e 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -41,6 +41,7 @@ def load_langpair_dataset( max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, + shuffle=True, ): def split_exists(split, src, tgt, lang, data_path): @@ -127,6 +128,7 @@ def split_exists(split, src, tgt, lang, data_path): left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, + shuffle=shuffle, ) @@ -260,6 +262,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, num_buckets=self.args.num_batch_buckets, + shuffle=(split != 'test'), ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/fairseq/utils.py b/fairseq/utils.py index 8a215e94c2..739ba49f9e 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -556,12 +556,6 @@ def get_tpu_device(args): return xm.xla_device() -def logging_multiple_line_messages(msg): - msg_arr = msg.split("\n") - for line in msg_arr: - logger.info(line) - - class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() @@ -580,13 +574,12 @@ def pretty_print_cuda_env_list(cuda_env_list): center = "CUDA enviroments for all {} workers".format(num_workers) banner_len = 40 - len(center) // 2 first_line = "*" * banner_len + center + "*" * banner_len - msg_arr = [first_line] + logger.info(first_line) for r, env in enumerate(cuda_env_list): - msg_arr.append( + logger.info( "rank {:3d}: ".format(r) + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + "name = {:40s}".format(env.name) ) - msg_arr.append(first_line) - logging_multiple_line_messages("\n".join(msg_arr)) + logger.info(first_line) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index ab7deaa1b5..ef259b6f2d 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -32,7 +32,7 @@ def main(args): if args.results_path is not None: os.makedirs(args.results_path, exist_ok=True) output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) - with open(output_path, 'w', buffering=1) as h: + with open(output_path, 'w', buffering=1, encoding='utf-8') as h: return _main(args, h) else: return _main(args, sys.stdout) @@ -56,7 +56,7 @@ def _main(args, output_file): # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) - torch.manual_seed(args.seed) + utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 40e07d46b2..24e9630d44 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -88,7 +88,7 @@ def main(args): # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) - torch.manual_seed(args.seed) + utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 02ebeae6c1..2def237e49 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -193,6 +193,8 @@ def tpu_data_loader(args, itr): @metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" + logger.info("begin training epoch {}".format(epoch_itr.epoch)) + # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, @@ -245,6 +247,7 @@ def train(args, trainer, task, epoch_itr): break # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) @@ -279,6 +282,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc # Save checkpoint if do_save or should_stop: + logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop @@ -298,6 +302,8 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: + logger.info("begin validation on \"{}\" subset".format(subset)) + # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index edda69fb8f..9d69671e7e 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -143,7 +143,7 @@ def main(): new_state = average_checkpoints(args.inputs) with PathManager.open(args.output, 'wb') as f: torch.save(new_state, f) - print('Finished writing averaged checkpoint to {}.'.format(args.output)) + print('Finished writing averaged checkpoint to {}'.format(args.output)) if __name__ == '__main__': diff --git a/scripts/sacrebleu_pregen.sh b/scripts/sacrebleu.sh similarity index 55% rename from scripts/sacrebleu_pregen.sh rename to scripts/sacrebleu.sh index 6fd3dd3c04..c10bf2b76e 100644 --- a/scripts/sacrebleu_pregen.sh +++ b/scripts/sacrebleu.sh @@ -11,18 +11,17 @@ TGTLANG=$3 GEN=$4 -echo 'Cloning Moses github repository (for tokenization scripts)...' -git clone https://github.com/moses-smt/mosesdecoder.git - -SCRIPTS=mosesdecoder/scripts -DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl +if ! command -v sacremoses &> /dev/null +then + echo "sacremoses could not be found, please install with: pip install sacremoses" + exit +fi grep ^H $GEN \ | sed 's/^H\-//' \ | sort -n -k 1 \ | cut -f 3 \ -| perl $DETOKENIZER -l $TGTLANG \ -| sed "s/ - /-/g" \ +| sacremoses detokenize \ > $GEN.sorted.detok sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok From a541b19d853cf4a5209d3b8f77d5d1261554a1d9 Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Wed, 15 Jul 2020 16:07:02 -0700 Subject: [PATCH 057/707] Add dummy task for translation benchmarking (#1212) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1212 Test Plan: python train.py \ -a transformer \ --clip-norm 0.4 --optimizer adam --lr 0.001 \ --dropout 0.0 \ --decoder-layers 7 \ --encoder-layers 7 \ --encoder-ffn-embed-dim 2048 \ --decoder-ffn-embed-dim 2048 \ --encoder-embed-dim 1024 \ --decoder-embed-dim 1024 \ --max-tokens 8192 \ --criterion cross_entropy --max-update 50 \ --attention-dropout 0.0 \ --adam-betas '(0.9, 0.98)' \ --disable-validation --no-save \ --task dummy_mt # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22484873 Pulled By: msbaines fbshipit-source-id: bc61165ab91290d0b6aa2077c968ab537bce8a6a --- fairseq/benchmark/__init__.py | 1 + fairseq/benchmark/dummy_mt.py | 120 ++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 fairseq/benchmark/dummy_mt.py diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py index ea67844ef7..926f3ce739 100644 --- a/fairseq/benchmark/__init__.py +++ b/fairseq/benchmark/__init__.py @@ -8,4 +8,5 @@ dummy_lm, dummy_masked_lm, dummy_model, + dummy_mt, ) diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py new file mode 100644 index 0000000000..09f2f0c119 --- /dev/null +++ b/fairseq/benchmark/dummy_mt.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch + +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import FairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task('dummy_mt') +class DummyMTTask(FairseqTask): + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('--dict-size', default=49996, type=int) + parser.add_argument('--dataset-size', default=100000, type=int) + parser.add_argument('--tokens-per-sample', default=512, type=int, + help='max number of total tokens over all segments ' + 'per sample for BERT dataset') + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 + + self.dummy_src = seq[:-1] + self.dummy_tgt = seq[1:] + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task. """ + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol('word{}'.format(i)) + logger.info('dictionary: {} types'.format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.args.max_sentences is not None: + bsz = self.args.max_sentences + else: + bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) + self.datasets[split] = DummyDataset( + { + 'id': 1, + 'net_input': { + 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), + 'src_lengths': torch.full( + (bsz, ), self.args.tokens_per_sample, dtype=torch.long + ), + 'prev_output_tokens': tgt.clone(), + }, + 'target': tgt, + 'nsentences': bsz, + 'ntokens': bsz * self.args.tokens_per_sample, + }, + num_items=self.args.dataset_size, + item_size=self.args.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False From 9c21a715d66e7833b0af80895c0550d555f2fd0d Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Wed, 15 Jul 2020 17:12:24 -0700 Subject: [PATCH 058/707] Fix regression in memory-efficient-fp16 (#1216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The fused_adam optimizer divides by the scale while our logic multiplies by the scale. I'm surprised this even worked. The first few iterations had nearly similar loss with the old code and even converged. However, Jun Ru noticed that the loss are very different after more iterations. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1216 Reviewed By: myleott, shruti-bh Differential Revision: D22536377 Pulled By: msbaines fbshipit-source-id: 9328a1764a1895572c18567f99bee3330f25179e --- fairseq/optim/fp16_optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index abd4b5c0fc..0b0132119f 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -387,7 +387,8 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): def step(self, closure=None): """Performs a single optimization step.""" if self.supports_step_with_scale: - self.wrapped_optimizer.step(closure, scale=self._multiply_factor) + # NOTE(msb) optimizer divides by scale factor + self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor)) else: self._unscale_grads() self.wrapped_optimizer.step(closure) From 84896af72c01f2b5f7e3c1c65ed28be697a3b32f Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Thu, 16 Jul 2020 08:08:37 -0700 Subject: [PATCH 059/707] Fix memory-efficient-fp16 when using update_freq other than 1 (#1219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Tested the following model and verified that gnorms and losses match the following commit: commit 3b7cf7558499cf70690913b76f35d0bc755e62ae Author: m_fomicheva Date: Wed Jul 8 13:04:55 2020 -0700 The loss and gnorm are identical to the number of digits reported in the logs and the ppl is very close to many signficant digits. Thanks again to Jun Ru for reporting. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1219 Test Plan: CUDA_VISIBLE_DEVICES=0 fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 Before (commit 3b7cf755): 2020-07-15 12:17:28 | INFO | train_inner | epoch 001: 45 / 3151 loss=19.083, ppl=555252, wps=7165.8, ups=0.22, wpb=32768, bsz=64, num_updates=41, lr=5.22398e-06, gnorm=6.895, loss_scale=8, train_wall=5, wall=208 2020-07-15 12:17:33 | INFO | train_inner | epoch 001: 46 / 3151 loss=19.042, ppl=539620, wps=7176.6, ups=0.22, wpb=32768, bsz=64, num_updates=42, lr=5.34895e-06, gnorm=6.662, loss_scale=8, train_wall=5, wall=213 2020-07-15 12:17:37 | INFO | train_inner | epoch 001: 47 / 3151 loss=18.908, ppl=492042, wps=7188.8, ups=0.22, wpb=32768, bsz=64, num_updates=43, lr=5.47393e-06, gnorm=6.231, loss_scale=8, train_wall=5, wall=217 2020-07-15 12:17:42 | INFO | train_inner | epoch 001: 48 / 3151 loss=18.894, ppl=487224, wps=7192, ups=0.22, wpb=32768, bsz=64, num_updates=44, lr=5.5989e-06, gnorm=6.078, loss_scale=8, train_wall=5, wall=222 2020-07-15 12:17:47 | INFO | train_inner | epoch 001: 49 / 3151 loss=18.829, ppl=465781, wps=7182.5, ups=0.22, wpb=32768, bsz=64, num_updates=45, lr=5.72388e-06, gnorm=5.819, loss_scale=8, train_wall=5, wall=226 2020-07-15 12:17:51 | INFO | train_inner | epoch 001: 50 / 3151 loss=18.752, ppl=441564, wps=7185.4, ups=0.22, wpb=32768, bsz=64, num_updates=46, lr=5.84885e-06, gnorm=5.521, loss_scale=8, train_wall=5, wall=231 After: 2020-07-15 15:13:10 | INFO | train_inner | epoch 001: 45 / 3151 loss=19.083, ppl=555249, wps=7220.5, ups=0.22, wpb=32768, bsz=64, num_updates=41, lr=5.22398e-06, gnorm=6.895, loss_scale=8, train_wall=5, wall=207 2020-07-15 15:13:14 | INFO | train_inner | epoch 001: 46 / 3151 loss=19.042, ppl=539617, wps=7216.3, ups=0.22, wpb=32768, bsz=64, num_updates=42, lr=5.34895e-06, gnorm=6.662, loss_scale=8, train_wall=5, wall=212 2020-07-15 15:13:19 | INFO | train_inner | epoch 001: 47 / 3151 loss=18.908, ppl=492041, wps=7220.8, ups=0.22, wpb=32768, bsz=64, num_updates=43, lr=5.47393e-06, gnorm=6.231, loss_scale=8, train_wall=5, wall=216 2020-07-15 15:13:24 | INFO | train_inner | epoch 001: 48 / 3151 loss=18.894, ppl=487228, wps=7229.4, ups=0.22, wpb=32768, bsz=64, num_updates=44, lr=5.5989e-06, gnorm=6.078, loss_scale=8, train_wall=5, wall=221 2020-07-15 15:13:28 | INFO | train_inner | epoch 001: 49 / 3151 loss=18.829, ppl=465783, wps=7231.2, ups=0.22, wpb=32768, bsz=64, num_updates=45, lr=5.72388e-06, gnorm=5.819, loss_scale=8, train_wall=5, wall=225 2020-07-15 15:13:33 | INFO | train_inner | epoch 001: 50 / 3151 loss=18.752, ppl=441559, wps=7224.5, ups=0.22, wpb=32768, bsz=64, num_updates=46, lr=5.84885e-06, gnorm=5.521, loss_scale=8, train_wall=5, wall=230 # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22560914 Pulled By: msbaines fbshipit-source-id: f2fdc3daa46de0b75f26cb4d5712e92d1a820d60 --- fairseq/optim/fp16_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 0b0132119f..58e8e9e0de 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -341,7 +341,6 @@ def backward(self, loss): """ if self.scaler is not None: loss = loss * self.scaler.loss_scale - self._multiply_factor /= float(self.scaler.loss_scale) loss.backward() def _unscale_grads(self): @@ -396,7 +395,8 @@ def step(self, closure=None): def zero_grad(self): """Clears the gradients of all optimized parameters.""" self.wrapped_optimizer.zero_grad() - self._multiply_factor = 1. + if self.scaler is not None: + self._multiply_factor = 1. / float(self.scaler.loss_scale) class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): From c0b52268539b861d93fde2a5931d909028a085b4 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 16 Jul 2020 09:32:44 -0700 Subject: [PATCH 060/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: new datasets (#1205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext - Major work is in fairseq/data/multilingual - fairseq/data/multilingual/sampled_multi_dataset.py to enable sampling and virtual data sizes - fairseq/data/multilingual/sampled_multi_epoch_dataset.py to enable virtual epoch data size to start training without going through the whole data (which reduces the loading from 1.5 hours into <30 seconds) - [next diff] fairseq/data/multilingual/multilingual_data_manager.py to support a few sophisticated multilingual data combinations - [next diff] fairseq/data/multilingual/sampling_method.py to support basic sampling functions - [next diff] A new task to glue all things together: fairseq/tasks/translation_multi_simple_epoch.py - Minor changes to - fairseq/data/language_pair_dataset.py to (1) have language IDs in the batch if they are set, (2) allow a preset max_size of batch; (2) corresponding changes to fairseq/data/data_utils.py - [next diff] fairseq/data/denoising_dataset.py to (1) allow additional transformation; (2) allow a preset max_size of batch; - [next diff] fairseq/data/iterators.py to allow dynamic batch sampler - [next diff] fairseq/checkpoint_utils.py to add finetuning option instead of using restore_file which will restore from original model when being requeued. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1205 Test Plan: buck test mode/dev //deeplearning/projects/fairseq-py:test_cpu -- 'test_translation_multi_simple_epoch \(tests\.test_binaries\.TestTranslation\)' https://our.intern.facebook.com/intern/testinfra/testrun/3659174727046259 Started new test run: https://our.intern.facebook.com/intern/testinfra/testrun/3659174727046259 ✓ deeplearning/projects/fairseq-py:test_cpu - test_translation_multi_simple_epoch (tests.test_binaries.TestTranslation) 331.967 1/1 (passed) Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/3659174727046259 Summary (total time 352.88s): PASS: 1 FAIL: 0 SKIP: 0 FATAL: 0 TIMEOUT: 0 OMIT: 0 Reviewed By: myleott Differential Revision: D22463947 Pulled By: tangyuq fbshipit-source-id: e430c040231035af73141dc736960bd972bd4b6e --- fairseq/data/__init__.py | 5 +- fairseq/data/data_utils.py | 8 +- fairseq/data/language_pair_dataset.py | 50 ++- fairseq/data/multilingual/__init__.py | 4 + .../multilingual/sampled_multi_dataset.py | 379 ++++++++++++++++++ .../sampled_multi_epoch_dataset.py | 239 +++++++++++ 6 files changed, 676 insertions(+), 9 deletions(-) create mode 100644 fairseq/data/multilingual/__init__.py create mode 100644 fairseq/data/multilingual/sampled_multi_dataset.py create mode 100644 fairseq/data/multilingual/sampled_multi_epoch_dataset.py diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 9bdb7a74ae..e35bb5646c 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -46,7 +46,8 @@ from .transform_eos_dataset import TransformEosDataset from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset from .shorten_dataset import TruncateDataset, RandomCropDataset - +from .multilingual.sampled_multi_dataset import SampledMultiDataset +from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset from .iterators import ( CountingIterator, EpochBatchIterator, @@ -97,6 +98,8 @@ 'ResamplingDataset', 'RightPadDataset', 'RoundRobinZipDatasets', + 'SampledMultiDataset', + 'SampledMultiEpochDataset', 'ShardedIterator', 'SortDataset', 'StripTokenDataset', diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index b3eee1cb9c..28410dc21e 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -30,9 +30,10 @@ def infer_language_pair(path): return src, dst -def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): +def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, pad_to_length=None): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) res = values[0].new(len(values), size).fill_(pad_idx) def copy_tensor(src, dst): @@ -132,6 +133,9 @@ def collect_filtered(function, iterable, filtered): def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): + def compare_leq(a, b): + return a <= b if not isinstance(a, tuple) else max(a) <= b + def check_size(idx): if isinstance(max_positions, float) or isinstance(max_positions, int): return size_fn(idx) <= max_positions @@ -148,7 +152,7 @@ def check_size(idx): # Hacky as heck, for the specific case of multilingual training with RoundRobin. if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple): return all( - a is None or b is None or a <= b + a is None or b is None or compare_leq(a, b) for a, b in zip(size_fn(idx).values(), max_positions) ) # For MultiCorpusSampledDataset, will generalize it later diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index d70c0c332d..df48a8052b 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -21,14 +21,16 @@ def collate( left_pad_source=True, left_pad_target=False, input_feeding=True, + pad_to_length=None, ): if len(samples) == 0: return {} - def merge(key, left_pad, move_eos_to_beginning=False): + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, + pad_to_length=pad_to_length, ) def check_alignment(alignment, src_len, tgt_len): @@ -54,7 +56,10 @@ def compute_alignment_weights(alignments): return 1. / align_weights.float() id = torch.LongTensor([s['id'] for s in samples]) - src_tokens = merge('source', left_pad=left_pad_source) + src_tokens = merge( + 'source', left_pad=left_pad_source, + pad_to_length=pad_to_length['source'] if pad_to_length is not None else None + ) # sort by descending source length src_lengths = torch.LongTensor([ s['source'].ne(pad_idx).long().sum() for s in samples @@ -66,7 +71,10 @@ def compute_alignment_weights(alignments): prev_output_tokens = None target = None if samples[0].get('target', None) is not None: - target = merge('target', left_pad=left_pad_target) + target = merge( + 'target', left_pad=left_pad_target, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + ) target = target.index_select(0, sort_order) tgt_lengths = torch.LongTensor([ s['target'].ne(pad_idx).long().sum() for s in samples @@ -82,6 +90,7 @@ def compute_alignment_weights(alignments): 'target', left_pad=left_pad_target, move_eos_to_beginning=True, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: @@ -157,6 +166,12 @@ class LanguagePairDataset(FairseqDataset): source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. + src_lang_id (int, optional): source language ID, if set, the collated batch + will contain a field 'src_lang_id' in 'net_input' which indicates the + source language of the samples. + tgt_lang_id (int, optional): target language ID, if set, the collated batch + will contain a field 'tgt_lang_id' which indicates the target language + of the samples. """ def __init__( @@ -168,6 +183,8 @@ def __init__( align_dataset=None, append_bos=False, eos=None, num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -192,7 +209,8 @@ def __init__( assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) - + self.src_lang_id = src_lang_id + self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( @@ -267,11 +285,14 @@ def __getitem__(self, index): def __len__(self): return len(self.src) - def collater(self, samples): + def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: @@ -291,19 +312,36 @@ def collater(self, samples): This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. + - `src_lang_id` (LongTensor): a long Tensor which contains source + language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. + - `tgt_lang_id` (LongTensor): a long Tensor which contains target language + IDs of each sample in the batch """ - return collate( + res = collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, + pad_to_length=pad_to_length, ) + if self.src_lang_id is not None or self.tgt_lang_id is not None: + src_tokens = res['net_input']['src_tokens'] + bsz = src_tokens.size(0) + if self.src_lang_id is not None: + res['net_input']['src_lang_id'] = torch.LongTensor( + [[self.src_lang_id]] + ).expand(bsz, 1).to(src_tokens) + if self.tgt_lang_id is not None: + res['tgt_lang_id'] = torch.LongTensor( + [[self.tgt_lang_id]] + ).expand(bsz, 1).to(src_tokens) + return res def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/data/multilingual/__init__.py b/fairseq/data/multilingual/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/fairseq/data/multilingual/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py new file mode 100644 index 0000000000..bf6051abc5 --- /dev/null +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -0,0 +1,379 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from enum import Enum +from collections import OrderedDict +from collections import defaultdict +from bisect import bisect_right +import hashlib +import logging +import datetime +import time + +import numpy as np +import torch + +from fairseq import distributed_utils +from fairseq.data import plasma_utils, FairseqDataset + + +def get_time_gap(s, e): + return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__() + + +logger = logging.getLogger(__name__) + + +def default_virtual_size_func(datasets, ratios, max_scale_up=1.5): + sizes = [len(d) for d in datasets] + if ratios is None: + return sum(sizes) + largest_idx = np.argmax(sizes) + largest_r = ratios[largest_idx] + largest_s = sizes[largest_idx] + # set virtual sizes relative to the largest dataset + virtual_sizes = [(r / largest_r) * largest_s for r in ratios] + vsize = sum(virtual_sizes) + max_size = sum(sizes) * max_scale_up + return int(vsize if vsize < max_size else max_size) + + +class CollateFormat(Enum): + single = 1 + ordered_dict = 2 + + +class SampledMultiDataset(FairseqDataset): + """Samples from multiple sub-datasets according to given sampling ratios. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concating all dataset together). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + """ + + def __init__( + self, + datasets, + sampling_ratios=None, + batch_by_size=False, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split='', + shared_collater=False, + ): + super().__init__() + self.batch_by_size = batch_by_size + self.shared_collater = shared_collater + + if isinstance(datasets, OrderedDict): + self.keys = list(datasets.keys()) + datasets = list(datasets.values()) + elif isinstance(datasets, List): + self.keys = list(range(len(datasets))) + else: + raise AssertionError() + self.datasets = datasets + self.split = split + + self.eval_key = eval_key + if self.eval_key is not None: + self.collate_format = CollateFormat.single + else: + self.collate_format = collate_format + + self.seed = seed + self._cur_epoch = None + self._cur_indices = None + self._sizes = None + self._ordered_indices = None + self.virtual_size_per_dataset = None + # caching properties + self._reset_cached_properties() + self.setup_sampling(sampling_ratios, virtual_size) + self.cumulated_sizes = None + self.virtual_size_per_dataset = None + self._size_cache = {} + self.set_epoch(epoch) + + def _reset_cached_properties(self): + self._sizes = None + self._ordered_indices = None + self._cur_indices = None + + def setup_sampling(self, sample_ratios, virtual_size): + sizes = [len(d) for d in self.datasets] + if sample_ratios is None: + # default back to concating datasets + self.sample_ratios = None + self.virtual_size = sum(sizes) + else: + if not isinstance(sample_ratios, np.ndarray): + sample_ratios = np.array(sample_ratios) + self.sample_ratios = plasma_utils.PlasmaArray(sample_ratios) + virtual_size = default_virtual_size_func if virtual_size is None else virtual_size + self.virtual_size = ( + virtual_size(self.datasets, self.sample_ratios.array) if callable(virtual_size) + else virtual_size) + + def adjust_sampling(self, epoch, sampling_ratios, virtual_size): + if sampling_ratios is not None: + sampling_ratios = self._sync_sample_ratios(sampling_ratios) + self.setup_sampling(sampling_ratios, virtual_size) + + def _sync_sample_ratios(self, ratios): + # in case the ratios are not precisely the same across processes + # also to ensure every procresses update the ratios in the same pace + ratios = torch.DoubleTensor(ratios) + if torch.distributed.is_initialized(): + if torch.cuda.is_available(): + distributed_utils.all_reduce(ratios.cuda()) + else: + distributed_utils.all_reduce(ratios) + ret = ratios.cpu() + ret = ret.numpy() + return ret + + def random_choice_in_dataset(self, rng, dataset, choice_size): + if hasattr(dataset, 'random_choice_in_dataset'): + return dataset.random_choice_in_dataset(rng, choice_size) + dataset_size = len(dataset) + return rng.choice(dataset_size, choice_size, replace=(choice_size > dataset_size)) + + def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size): + def get_counts(sample_ratios): + counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64) + diff = virtual_size - counts.sum() + assert diff >= 0 + # due to round-offs, the size might not match the desired sizes + if diff > 0: + dataset_indices = rng.choice(len(sample_ratios), size=diff, p=sample_ratios) + for i in dataset_indices: + counts[i] += 1 + return counts + + def get_in_dataset_indices(datasets, sizes, sample_ratios): + counts = get_counts(sample_ratios) + # uniformally sample desired counts for each dataset + # if the desired counts are large, sample with replacement: + indices = [ + self.random_choice_in_dataset(rng, d, c) + for c, d in zip(counts, datasets)] + return indices + + sizes = [len(d) for d in datasets] + if sample_ratios is None: + # default back to concating datasets + in_dataset_indices = [list(range(s)) for s in sizes] + virtual_sizes_per_dataset = sizes + else: + sample_ratios = sample_ratios.array + ratios = sample_ratios / sample_ratios.sum() + in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios) + virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices] + virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64) + cumulative_sizes = np.cumsum(virtual_sizes_per_dataset) + assert sum(virtual_sizes_per_dataset) == virtual_size + assert cumulative_sizes[-1] == virtual_size + if virtual_size < sum(sizes): + logger.warning( + f'virtual data size ({virtual_size}) is less than real data size ({sum(sizes)}).' + ' If virtual size << real data size, there could be data coverage issue.' + ) + in_dataset_indices = np.hstack(in_dataset_indices) + return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset + + def _get_dataset_and_index(self, index): + i = bisect_right(self.cumulated_sizes.array, index) + return i, self._cur_indices.array[index] + + def __getitem__(self, index): + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + return (ds_idx, self.datasets[ds_idx][ds_sample_idx]) + + def num_tokens(self, index): + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + return self.datasets[ds_idx].num_tokens(ds_sample_idx) + + def size(self, index): + if self._sizes is not None: + return self._sizes[index] + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + return self.datasets[ds_idx].size(ds_sample_idx) + + def __len__(self): + return self.virtual_size + + def collater(self, samples, **extra_args): + """Merge a list of samples to form a mini-batch.""" + if len(samples) == 0: + return None + if self.collate_format == 'ordered_dict': + collect_samples = [[] for _ in range(len(self.datasets))] + for (i, sample) in samples: + collect_samples[i].append(sample) + return OrderedDict([ + (self.keys[i], dataset.collater(collect_samples[i])) + for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) + if len(collect_samples[i]) > 0 + ]) + elif self.shared_collater: + return self.datasets[0].collater( + [s for _, s in samples] + ) + else: + samples_dict = defaultdict(list) + max_size = defaultdict(int) + for ds_idx, s in samples: + max_size['source'] = max(max_size['source'], s['source'].size(0)) + if s['target'] is not None: + max_size['target'] = max(max_size['target'], s['target'].size(0)) + samples_dict[ds_idx].append(s) + batches = [ + self.datasets[i].collater(samples_dict[i], max_size=max_size) + for i in range(len(self.datasets)) + if len(samples_dict[i]) > 0 + ] + + def straight_data(tensors): + batch = torch.cat(tensors, dim=0) + return batch + + src_lengths = straight_data([b['net_input']['src_lengths'] for b in batches]) + src_lengths, sort_order = src_lengths.sort(descending=True) + + def straight_order(tensors): + batch = straight_data(tensors) + return batch.index_select(0, sort_order) + + batch = { + 'id': straight_order([b['id'] for b in batches]), + 'nsentences': sum(b['nsentences'] for b in batches), + 'ntokens': sum(b['ntokens'] for b in batches), + 'net_input': { + 'src_tokens': straight_order([b['net_input']['src_tokens'] for b in batches]), + 'src_lengths': src_lengths, + }, + 'target': straight_order([b['target'] for b in batches]) if batches[0]['target'] is not None else None, + } + if 'prev_output_tokens' in batches[0]['net_input']: + batch['net_input']['prev_output_tokens'] = straight_order( + [b['net_input']['prev_output_tokens'] for b in batches]) + if 'src_lang_id' in batches[0]['net_input']: + batch['net_input']['src_lang_id'] = straight_order([b['net_input']['src_lang_id'] for b in batches]) + if 'tgt_lang_id' in batches[0]: + batch['tgt_lang_id'] = straight_order([b['tgt_lang_id'] for b in batches]) + return batch + + @property + def sizes(self): + if self._sizes is not None: + return self._sizes + start_time = time.time() + size_cache = self._size_cache + ret = [] + for i in range(len(self)): + ds_idx, ds_sample_idx = self._get_dataset_and_index(i) + if (ds_idx, ds_sample_idx) in size_cache: + ret.append(size_cache[(ds_idx, ds_sample_idx)]) + else: + s = self.datasets[ds_idx].size(ds_sample_idx) + size_cache[(ds_idx, ds_sample_idx)] = s + ret.append(s) + logger.debug(f'sizes() calling time: {get_time_gap(start_time, time.time())}') + self._sizes = np.array(ret, np.int64) + return self._sizes + + def ordered_indices(self): + if self._ordered_indices is not None: + return self._ordered_indices + + if self.batch_by_size: + # No need to do shuffle as the data items are already randomized + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[ + np.argsort(tgt_sizes[indices], kind='mergesort') + ] + sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] + else: + sort_indices = np.arange(len(self)) + self._ordered_indices = sort_indices + return sort_indices + + def prefetch(self, indices): + prefetch_indices = [[] for _ in range(len(self.datasets))] + for i in indices: + ds_idx, ds_sample_idx = self._get_dataset_and_index(i) + prefetch_indices[ds_idx].append(ds_sample_idx) + for i in range(len(prefetch_indices)): + self.datasets[i].prefetch(prefetch_indices[i]) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if epoch == self._cur_epoch: + # re-enter so return + return + for d in self.datasets: + if hasattr(d, 'set_epoch'): + d.set_epoch(epoch) + self._cur_epoch = epoch + self._establish_virtual_datasets() + + def _establish_virtual_datasets(self): + if self.sample_ratios is None and self._cur_indices is not None: + # not a samping dataset, no need to resample if indices are already established + return + self._reset_cached_properties() + + start_time = time.time() + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + rng = np.random.RandomState( + [ + int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32), + self.seed % (2 ** 32), # global seed + self._cur_epoch, # epoch index, + ] + ) + indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( + rng, self.datasets, self.sample_ratios, self.virtual_size) + self._cur_indices = plasma_utils.PlasmaArray(indices) + self.cumulated_sizes = plasma_utils.PlasmaArray(cumulated_sizes) + self.virtual_size_per_dataset = plasma_utils.PlasmaArray(virtual_size_per_dataset) + + logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, [len(d) for d in self.datasets])))}') + logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, self.virtual_size_per_dataset.array)))}') + if self.sample_ratios is not None: + logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios.array)))}') + else: + logger.info(f'[{self.split}] A concat dataset') + logger.debug(f'[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}') diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py new file mode 100644 index 0000000000..65660407bd --- /dev/null +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -0,0 +1,239 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import math +import logging +import time + +import numpy as np +import torch +from fairseq import distributed_utils +from fairseq.data import plasma_utils, SampledMultiDataset +from .sampled_multi_dataset import default_virtual_size_func, get_time_gap, CollateFormat + + +logger = logging.getLogger(__name__) + + +class SampledMultiEpochDataset(SampledMultiDataset): + """Samples from multiple sub-datasets according to sampling ratios + using virtual epoch sizes to speed up dataloading. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concating all dataset together). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by + this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering + can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + shard_epoch (int): the real epoch number for shard selection. + """ + def __init__( + self, + datasets, + sampling_ratios=None, + batch_by_size=False, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split='', + virtual_epoch_size=None, + shared_collater=False, + shard_epoch=1, + ): + self.virtual_epoch_size = virtual_epoch_size + self._current_epoch_start_index = None + self._epoch_sizes = None + self._epoch_ordered_indices = None + self._random_globa_indices = None + self.shard_epoch = shard_epoch if shard_epoch is not None else 1 + self.load_next_shard = None + super().__init__( + datasets=datasets, + sampling_ratios=sampling_ratios, + batch_by_size=batch_by_size, + seed=seed, + epoch=epoch, + eval_key=eval_key, + collate_format=collate_format, + virtual_size=virtual_size, + split=split, + shared_collater=shared_collater, + ) + + def _setup(self, epoch): + self.virtual_epoch_size = self.virtual_epoch_size if self.virtual_epoch_size is not None else self.virtual_size + if self.virtual_epoch_size > self.virtual_size: + logger.warning(f'virtual epoch size {self.virtual_epoch_size} ' + f'is greater than virtual dataset size {self.virtual_size}') + self.virtual_epoch_size = self.virtual_size + self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) + self._current_epoch_start_index = self._get_epoch_start_index(epoch) + logger.info(f'virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}') + + def _map_epoch_index_to_global(self, index): + index = self._current_epoch_start_index + index + # add randomness + return self._random_globa_indices.array[index] + + def __getitem__(self, index): + i = self._map_epoch_index_to_global(index) + return super().__getitem__(i) + + def num_tokens(self, index): + i = self._map_epoch_index_to_global(index) + return super().num_tokens(i) + + def size(self, index): + if self._epoch_sizes is not None: + return self._epoch_sizes[index] + index = self._map_epoch_index_to_global(index) + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + return self.datasets[ds_idx].size(ds_sample_idx) + + def __len__(self): + return ( + self.virtual_epoch_size + if self._current_epoch_start_index + self.virtual_epoch_size < self.virtual_size + else self.virtual_size - self._current_epoch_start_index + ) + + @property + def sizes(self): + if self._epoch_sizes is not None: + return self._epoch_sizes + start_time = time.time() + + size_cache = self._size_cache + ret = [] + for i in range(len(self)): + index = self._map_epoch_index_to_global(i) + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + + if (ds_idx, ds_sample_idx) in size_cache: + ret.append(size_cache[(ds_idx, ds_sample_idx)]) + else: + s = self.datasets[ds_idx].size(ds_sample_idx) + s = (s, s) if not isinstance(s, tuple) else s + size_cache[(ds_idx, ds_sample_idx)] = s + ret.append(s) + logger.debug(f'sizes() calling time: {get_time_gap(start_time, time.time())}') + self._epoch_sizes = np.array(ret, np.int64) + return self._epoch_sizes + + def ordered_indices(self): + if self._epoch_ordered_indices is not None: + return self._epoch_ordered_indices + + if self.batch_by_size: + # No need to do shuffle as the data items are already randomized + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[ + np.argsort(tgt_sizes[indices], kind='mergesort') + ] + sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] + else: + sort_indices = np.arange(len(self)) + self._epoch_ordered_indices = sort_indices + return sort_indices + + def prefetch(self, indices): + prefetch_indices = [[] for _ in range(len(self.datasets))] + for i in indices: + index = self._map_epoch_index_to_global(i) + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + prefetch_indices[ds_idx].append(ds_sample_idx) + for i in range(len(prefetch_indices)): + self.datasets[i].prefetch(prefetch_indices[i]) + + def set_epoch(self, epoch): + if self._current_epoch_start_index is None: + self._setup(epoch) + self._next_virtual_epoch(epoch) + if epoch == self._cur_epoch: + # re-enter so return + return + self._next_virtual_epoch(epoch) + + def _get_epoch_start_index(self, epoch): + assert epoch >= 1 # fairseq is using 1-based epoch everywhere + return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size + + def _next_global_indices(self, epoch): + rng = np.random.RandomState( + [ + int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32), + self.seed % (2 ** 32), # global seed + epoch, # epoch index, + ] + ) + self._random_globa_indices = plasma_utils.PlasmaArray( + rng.choice(self.virtual_size, self.virtual_size, replace=False)) + if self.load_next_shard is None: + self.load_next_shard = False + else: + # increase shard epoch for next loading + self.shard_epoch += 1 + self.load_next_shard = True + # a hack to avoid possible out of sync of shard epoch number + # TODO: to confirm whether this is needed; without it, CUDA event error is occassionally observed + synced_shard_epoch = self._sync_shard_epoch(self.shard_epoch) + logger.info('to load next epoch/shard in next load_dataset: ' + f'epoch={epoch}/shard_epoch={self.shard_epoch}[synced={synced_shard_epoch}]') + + def _sync_shard_epoch(self, shard_epoch): + # in case the ratios are not precisely the same across processes + # also to ensure every procresses update the ratios in the same pace + shard_epoch = torch.DoubleTensor([shard_epoch]) + if torch.distributed.is_initialized(): + if torch.cuda.is_available(): + distributed_utils.all_reduce(shard_epoch.cuda()) + else: + distributed_utils.all_reduce(shard_epoch) + ret = shard_epoch.cpu() + ret = ret.numpy() + return ret + + def _next_virtual_epoch(self, epoch): + index = self._get_epoch_start_index(epoch) + if index == 0 or self._random_globa_indices is None: + # need to start from the beginning, + # so call super().set_epoch(epoch) to establish the global virtual indices + logger.info('establishing a new set of global virtual indices for ' + f'epoch={epoch}/shard_epoch={self.shard_epoch}') + super().set_epoch(epoch) + self._next_global_indices(epoch) + else: + self._cur_epoch = epoch + # reset cache sizes and ordered_indices for the epoch after moving to a new epoch + self._epoch_sizes = None + self._epoch_ordered_indices = None + self._current_epoch_start_index = index From 033daef0fc4fd8b44ae350a4ce2c7da299bbeb7f Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 16 Jul 2020 09:32:44 -0700 Subject: [PATCH 061/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: multiligual dataset manager Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext - Major work is in fairseq/data/multilingual - fairseq/data/multilingual/multilingual_data_manager.py to support a few sophisticated multilingual data combinations - fairseq/data/multilingual/sampling_method.py to support basic sampling functions Reviewed By: pipibjc Differential Revision: D22483471 fbshipit-source-id: 3d9d2643877a29333915975020e419508887b3ae --- .../multilingual/multilingual_data_manager.py | 749 ++++++++++++++++++ fairseq/data/multilingual/sampling_method.py | 66 ++ 2 files changed, 815 insertions(+) create mode 100644 fairseq/data/multilingual/multilingual_data_manager.py create mode 100644 fairseq/data/multilingual/sampling_method.py diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py new file mode 100644 index 0000000000..c313a7be6c --- /dev/null +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -0,0 +1,749 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os + +import numpy as np +from collections import OrderedDict + +import json +from fairseq import options +from fairseq.options import eval_str_dict, eval_str_list + +from fairseq.data import ( + Dictionary, + AppendTokenDataset, + ConcatDataset, + data_utils, + indexed_dataset, + LanguagePairDataset, + PrependTokenDataset, + StripTokenDataset, + TruncateDataset, + SampledMultiDataset, + TransformEosLangPairDataset, + SampledMultiEpochDataset, +) +from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat +from fairseq.file_io import PathManager + +logger = logging.getLogger(__name__) + + +def _lang_token(lang: str, style='__{}__'): + return style.format(lang) + + +def _lang_token_index(dic: Dictionary, lang: str, style='__{}__'): + """Return language token index.""" + idx = dic.index(_lang_token(lang, style)) + assert idx != dic.unk_index, \ + 'cannot find language token for lang {}'.format(lang) + return idx + + +def _lang_id(dic: Dictionary, lang: str): + """Return language ID index.""" + idx = dic.index(lang) + assert idx != dic.unk_index, \ + 'cannot find language ID for lang {}'.format(lang) + return idx + + +def load_sampling_weights(from_file): + with open(from_file) as f: + weights = json.load(f) + return weights + + +class MultilingualDatasetManager(object): + def __init__(self, args, lang_pairs, langs, dicts, sampling_method): + super().__init__() + self.args = args + self.seed = args.seed + self.lang_pairs = lang_pairs + self.langs = langs + self.dicts = dicts + self.lang_dict = self.create_lang_dictionary(self.langs) + self.sampling_method = sampling_method + self.sampling_scheduler = None + self._has_sharded_data = False + self._num_shards = {} + + @classmethod + def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): + return MultilingualDatasetManager(args, lang_pairs, langs, dicts, sampling_method) + + @staticmethod + def add_args(parser): + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner') + parser.add_argument('--lang-dict', default=None, type=str, + help='language dictionary path with a list of ' + 'languages which can appear in lang-pairs') + parser.add_argument('--lang-tok-style', default='multilingual', + type=str, choices=['multilingual', 'mbart'], + help='language token styles') + + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') + parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', + help='pad the source on the left') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left') + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + parser.add_argument('--truncate-source', action='store_true', default=False, + help='truncate source to max-source-positions') + parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], + metavar='SRCTGT', + help='prepend to the beginning of source sentence the source or target ' + 'language token. (src/tgt)') + parser.add_argument('--decoder-langtok', action='store_true', + help='prepend to the beginning of target sentence the target language token') + parser.add_argument('--lang-tok-replacing-bos-eos', action='store_true', default=False) + parser.add_argument('--enable-lang-ids', default=False, action='store_true', + help='whether to include language IDs in samples') + parser.add_argument('--enable-reservsed-directions-shared-datasets', default=False, action='store_true', + help='whether to allow datasets be used in reversed directions') + + parser.add_argument('--extra-data', help='a dictionary of data name to this path, \ + e.g. {"mined", path_to_mined_data, "denoised": path_to_denoised_data}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None) + parser.add_argument('--extra-lang-pairs', help='a dictionary of data name to the language pairs they serve, \ + e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None) + parser.add_argument('--langtoks-specs', help='a list of comma separated language tokens specifictions', + default='main', + type=lambda uf: eval_str_list(uf, type=str), + ) + parser.add_argument('--langtoks', help='a dictionary of how to add language tokens, \ + e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \ + ("src", "tgt")}, or {"mined": ("src.mined", "tgt")}', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument('--sampling-weights-from-file', + help='a file contain a python dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, type=str, + ) + parser.add_argument('--sampling-weights', help='a dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument('--virtual-epoch-size', default=1000000, type=int, + help='virtual epoch size to speed up data loading') + parser.add_argument('--virtual-data-size', default=None, type=int, + help='virtual data size of the whole joint dataset to speed' + 'up data loading and have specific dynamic sampling strategy interval') + + @classmethod + def load_langs(cls, args, **kwargs): + if args.lang_dict is None: + logger.warning( + 'External language dictionary is not provided; ' + 'use lang-pairs to infer the set of supported languages. ' + 'The language ordering is not stable which might cause ' + 'misalignment in pretraining and finetuning.') + # infer from lang_pairs as it is + langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}) + langs = sorted(langs) + logger.info(f'inferred language list: {langs}') + else: + with PathManager.open(args.lang_dict, "r", encoding="utf-8") as f: + langs = [lang.strip() for lang in f.readlines() if lang.strip()] + logger.info(f'loaded language list from {args.lang_dict} as they are ordered in file') + return langs + + def has_sharded_data(self, split): + return split == 'train' and self._has_sharded_data + + def _shared_collater(self): + return ( + not (self.args.extra_data and 'mono_dae' in self.args.extra_data) + and (not self.args.lang_tok_replacing_bos_eos) + ) + + @classmethod + def prepare(cls, args, **kargs): + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + if not hasattr(args, 'shuffle_instance'): + args.shuffle_instance = False + if args.langtoks is None: + args.langtoks = {} + if 'main' not in args.langtoks: + src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None + tgt_langtok_spec = 'tgt' if args.decoder_langtok else None + args.langtoks['main'] = (src_langtok_spec, tgt_langtok_spec) + + def check_langs(langs, pairs): + messages = [] + for src, tgt in pairs: + if src not in langs or tgt not in langs: + messages.append(f'language pair {src}-{tgt} contains languages ' + 'that are not in the language dictionary') + if len(messages) > 0: + raise ValueError(' '.join(messages) + f"; langs: {langs}") + + if args.lang_pairs is None: + raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.') + if isinstance(args.lang_pairs, str): + args.lang_pairs = args.lang_pairs.split(',') + if args.source_lang is not None or args.target_lang is not None: + training = False + else: + training = True + sorted_langs = cls.load_langs(args, **kargs) + check_langs( + sorted_langs, + ([p.split('-') for p in args.lang_pairs] if training + else [(args.source_lang, args.target_lang)]) + ) + + # load dictionaries + if training: + extra_lang_pairs = ( + list({p for _, v in args.extra_lang_pairs.items() for p in v.split(',')}) + if args.extra_lang_pairs else [] + ) + langs_to_load_dicts = sorted({x for p in args.lang_pairs + extra_lang_pairs for x in p.split('-')}) + else: + langs_to_load_dicts = sorted([args.source_lang, args.target_lang]) + + dicts = OrderedDict() + supported_langtok_specs = args.langtoks_specs + for lang in langs_to_load_dicts: + paths = args.data.split(os.pathsep) + assert len(paths) > 0 + dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) + if len(dicts) > 0: + assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() + assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() + assert dicts[lang].unk() == dicts[langs_to_load_dicts[0]].unk() + + # keep the langs consistent for all experiments with the same lang dict + # for finetuning regardless of whether lang_tok is required or not just add the tokens to the dicts + for spec in supported_langtok_specs: + for lang_to_add in sorted_langs: + dicts[lang].add_symbol( + MultilingualDatasetManager.get_lang_tok(lang_to_add, args, spec) + ) + if args.lang_tok_style == 'mbart' or (args.extra_data and 'mono_dae' in args.extra_data): + dicts[lang].add_symbol('') + logger.info('[{}] dictionary: {} types'.format(lang, len(dicts[lang]))) + return sorted_langs, dicts, training + + TOKEN_STYLES = { + 'mbart': '[{}]', + 'multilingual': '__{}__' + } + + @classmethod + def create_lang_dictionary(cls, langs): + unk = '' + # hack to remove symbols other than unk as they are not needed by lang dict + lang_dict = Dictionary( + pad=unk, + eos=unk, + unk=unk, + bos=unk, + ) + for lang in langs: + lang_dict.add_symbol(lang) + return lang_dict + + @classmethod + def get_lang_tok_style(cls, args): + return cls.TOKEN_STYLES[args.lang_tok_style] + + @classmethod + def get_lang_tok(cls, lang, args, spec=''): + if spec is None: + return None + if spec.endswith('dae'): + lang = f'{lang}_dae' + elif spec.endswith('mined'): + lang = f'{lang}_mined' + return _lang_token(lang, cls.get_lang_tok_style(args)) + + @classmethod + def get_langtok_index(cls, lang_tok, dic): + idx = dic.index(lang_tok) + assert idx != dic.unk_index, \ + 'cannot find language token {} in the dictionary'.format(lang_tok) + return idx + + def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): + if spec is None: + return None + if spec and spec.startswith('src'): + if src_lang is None: + return None + langtok = self.get_lang_tok(src_lang, self.args, spec) + else: + if tgt_lang is None: + return None + langtok = self.get_lang_tok(tgt_lang, self.args, spec) + return self.get_langtok_index(langtok, self.dicts[src_lang if src_lang else tgt_lang]) + + def get_decoder_langtok(self, tgt_lang, spec=None): + if spec is None: + return None + langtok = self.get_lang_tok(tgt_lang, self.args, spec) + return self.get_langtok_index(langtok, self.dicts[tgt_lang]) + + @classmethod + def load_data(cls, path, vdict, impl): + dataset = data_utils.load_indexed_dataset(path, vdict, impl) + return dataset + + @classmethod + def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl): + filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + @classmethod + def mono_split_exists(cls, split, lang, data_path, dataset_impl): + filename = os.path.join(data_path, '{}.{}'.format(split, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + def load_lang_dataset( + self, + data_path, split, + src, src_dict, + tgt, tgt_dict, + combine, dataset_impl, upsample_primary, + max_source_positions, + prepend_bos=False, load_alignments=False, + truncate_source=False, + ): + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else '') + + # infer langcode + if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) + elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) + else: + if k > 0: + break + else: + logger.error(f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}") + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + + src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) + if truncate_source: + src_dataset = AppendTokenDataset( + TruncateDataset( + StripTokenDataset(src_dataset, src_dict.eos()), + max_source_positions - 1, + ), + src_dict.eos(), + ) + src_datasets.append(src_dataset) + tgt_datasets.append( + self.load_data(prefix + tgt, tgt_dict, dataset_impl) + ) + + logger.info('{} {} {}-{} {} examples'.format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + )) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) + + if len(src_datasets) == 1: + src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + + align_dataset = None + if load_alignments: + align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + + return src_dataset, tgt_dataset, align_dataset + + def load_langpair_dataset( + self, + data_path, split, + src, src_dict, + tgt, tgt_dict, + combine, dataset_impl, upsample_primary, + left_pad_source, left_pad_target, max_source_positions, + max_target_positions, prepend_bos=False, load_alignments=False, + truncate_source=False, + src_dataset_transform_func=lambda dataset: dataset, + tgt_dataset_transform_func=lambda dataset: dataset, + src_lang_id=None, + tgt_lang_id=None, + langpairs_sharing_datasets=None, + ): + if langpairs_sharing_datasets is not None: + src_dataset = langpairs_sharing_datasets.get((data_path, split, src), 'NotInCache') + tgt_dataset = langpairs_sharing_datasets.get((data_path, split, tgt), 'NotInCache') + align_dataset = langpairs_sharing_datasets.get((data_path, split, src, tgt), 'NotInCache') + + # a hack: any one is not in cache, we need to reload them + if ( + langpairs_sharing_datasets is None + or src_dataset == 'NotInCache' + or tgt_dataset == 'NotInCache' + or align_dataset == 'NotInCache' + or split != 'train' + ): + # source and target datasets can be reused in reversed directions to save memory + # reversed directions of valid and test data will not share source and target datasets + src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( + data_path, split, + src, src_dict, + tgt, tgt_dict, + combine, dataset_impl, upsample_primary, + max_source_positions=max_source_positions, + prepend_bos=prepend_bos, load_alignments=load_alignments, + truncate_source=truncate_source, + ) + src_dataset = src_dataset_transform_func(src_dataset) + tgt_dataset = tgt_dataset_transform_func(tgt_dataset) + if langpairs_sharing_datasets is not None: + langpairs_sharing_datasets[(data_path, split, src)] = src_dataset + langpairs_sharing_datasets[(data_path, split, tgt)] = tgt_dataset + langpairs_sharing_datasets[(data_path, split, src, tgt)] = align_dataset + + return LanguagePairDataset( + src_dataset, src_dataset.sizes, src_dict, + tgt_dataset, tgt_dataset.sizes, tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + align_dataset=align_dataset, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + ) + + def src_dataset_tranform_func(self, src_lang, tgt_lang, dataset, spec=None): + if self.args.lang_tok_replacing_bos_eos: + # it is handled by self.alter_dataset_langtok + # TODO: Unifiy with alter_dataset_langtok + return dataset + if spec is None: + return dataset + tok = self.get_encoder_langtok(src_lang, tgt_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None): + if self.args.lang_tok_replacing_bos_eos: + # TODO: Unifiy with alter_dataset_langtok + # It is handled by self.alter_dataset_langtok. + # The complication in self.alter_dataset_langtok + # makes a unified framework difficult. + return dataset + # if not self.args.decoder_langtok: + if not spec: + return dataset + tok = self.get_decoder_langtok(target_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def alter_dataset_langtok(self, lang_pair_dataset, + src_eos=None, src_lang=None, + tgt_eos=None, tgt_lang=None, + src_langtok_spec=None, tgt_langtok_spec=None, + ): + if src_langtok_spec is None and tgt_langtok_spec is None: + return lang_pair_dataset + + new_src_eos = None + if src_langtok_spec is not None and src_eos is not None \ + and (src_lang is not None or tgt_lang is not None): + new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec) + else: + src_eos = None + + new_tgt_bos = None + if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None: + new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec) + else: + tgt_eos = None + + return TransformEosLangPairDataset( + lang_pair_dataset, + src_eos=src_eos, + new_src_eos=new_src_eos, + tgt_bos=tgt_eos, + new_tgt_bos=new_tgt_bos, + ) + + def load_a_dataset( + self, + split, + data_path, + src, src_dict, + tgt, tgt_dict, + combine, + prepend_bos=False, + langpairs_sharing_datasets=None, + data_category=None, + **extra_kwargs, + ): + dataset_impl = self.args.dataset_impl + upsample_primary = self.args.upsample_primary + left_pad_source = self.args.left_pad_source + left_pad_target = self.args.left_pad_target + max_source_positions = self.args.max_source_positions + max_target_positions = self.args.max_target_positions + load_alignments = self.args.load_alignments + truncate_source = self.args.truncate_source + src_dataset_transform_func = self.src_dataset_tranform_func + tgt_dataset_transform_func = self.tgt_dataset_tranform_func + enable_lang_ids = self.args.enable_lang_ids + lang_dictionary = self.lang_dict + src_langtok_spec, tgt_langtok_spec = extra_kwargs['langtok_spec'] + + src_langtok = self.get_encoder_langtok(src, tgt, src_langtok_spec) + tgt_langtok = self.get_decoder_langtok(tgt, tgt_langtok_spec) + logger.info(f'{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}') + + langpair_ds = self.load_langpair_dataset( + data_path, split, + src, src_dict, + tgt, tgt_dict, + combine, dataset_impl, upsample_primary, + left_pad_source, left_pad_target, max_source_positions, + max_target_positions, prepend_bos, load_alignments, + truncate_source, + src_dataset_transform_func=lambda dataset: src_dataset_transform_func(src, tgt, dataset, src_langtok_spec), + tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func(src, tgt, dataset, tgt_langtok_spec), + src_lang_id=_lang_id(lang_dictionary, src) if enable_lang_ids and lang_dictionary is not None else None, + tgt_lang_id=_lang_id(lang_dictionary, tgt) if enable_lang_ids and lang_dictionary is not None else None, + langpairs_sharing_datasets=langpairs_sharing_datasets, + ) + if langpair_ds.tgt_sizes is None: + # hack to use src_sizes as the sizes for the whole pair dataset for ConcatDataset + langpair_ds.sizes = langpair_ds.src_sizes + else: + # use the max of two sides to define the size to help max positions filtering + langpair_ds.sizes = np.vstack([langpair_ds.src_sizes, langpair_ds.tgt_sizes]).max(axis=0) + assert langpair_ds.sizes.shape == langpair_ds.src_sizes.shape + # TODO: handle modified lang toks for mined data and dae data + if self.args.lang_tok_replacing_bos_eos: + ds = self.alter_dataset_langtok( + langpair_ds, + src_eos=self.dicts[src if src else tgt].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + ds = langpair_ds + return ds + + def load_split_langpair_datasets( + self, + split, + data_param_list, + ): + datasets = [] + langpairs_sharing_datasets = {} if self.args.enable_reservsed_directions_shared_datasets else None + for param in data_param_list: + ds = self.load_a_dataset(split=split, langpairs_sharing_datasets=langpairs_sharing_datasets, **param) + datasets.append(ds) + return datasets + + def get_data_paths_and_lang_pairs(self, split): + datapaths = { + 'main': self.args.data, + } + lang_pairs = { + 'main': self.lang_pairs + } + if split == 'train': + # only training data can have extra data and extra language pairs + if self.args.extra_data: + extra_datapaths = self.args.extra_data + datapaths.update(extra_datapaths) + if self.args.extra_lang_pairs: + extra_lang_pairs = {k: v.split(',') for k, v in self.args.extra_lang_pairs.items()} + lang_pairs.update(extra_lang_pairs) + return datapaths, lang_pairs + + def get_split_data_param_list(self, split, epoch, shard_epoch=None): + def get_epoch(epoch, shard_epoch): + return epoch if shard_epoch is None else shard_epoch + # TODO: to extend with extra datasets and keys and loop over different shard data paths + param_list = [] + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + logger.info(f'langtoks settings: {self.args.langtoks}') + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + # paths = self.args.data.split(os.pathsep) + paths = paths.split(os.pathsep) + assert len(paths) > 0 + if len(paths) > 1: + self._has_sharded_data = True + self._num_shards[data_category] = len(paths) + # epoch starts with 1 now: + data_path = paths[(get_epoch(epoch, shard_epoch) - 1) % len(paths)] + if data_category in self.args.langtoks: + lang_tok_spec = self.args.langtoks[data_category] + else: + # default to None + lang_tok_spec = (None, None) + + # infer langcode + lang_dirs = [lang_pair.split('-') for lang_pair in lang_pairs[data_category]] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + assert src is not None or data_category == 'mono_dae', (f'error: src={src}, ' + 'tgt={tgt} for data_category={data_category}') + # logger.info(f"preparing param for {data_category}: {src} - {tgt}") + param_list.append( + { + 'key': f'{data_category}:{src}-{tgt}', + 'data_path': data_path, + 'split': split, + 'src': src, + 'src_dict': self.dicts[src] if src and data_category != 'mono_dae' else None, + 'tgt': tgt, + 'tgt_dict': self.dicts[tgt], + 'data_category': data_category, + 'langtok_spec': lang_tok_spec, + } + ) + return param_list + + def get_train_sampling_ratios(self, datasets, epoch=1): + data_sizes = [len(d) for _, d in datasets] + sampling_func = self.sampling_method.sampling_method_selector() + sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None + return sample_ratios + + def get_sampling_ratios(self, data_param_list, datasets, epoch): + if self.args.sampling_weights_from_file: + weights = load_sampling_weights(self.args.sampling_weights_from_file) + sample_ratios = [weights[k] for k, _ in datasets] + logger.info('| ignoring --sampling-weights when loadding sampling weights ' + f'from file {self.args.sampling_weights_from_file}') + elif self.args.sampling_weights: + sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] + else: + # TODO: modify to provide sampling function more information other than sizes + sample_ratios = self.get_train_sampling_ratios(datasets, epoch) + + if sample_ratios is not None: + logger.info('| Upsample ratios: {}'.format( + list(zip(map(lambda x: x['key'], data_param_list), sample_ratios)) + )) + assert len(sample_ratios) == len(datasets) + return sample_ratios + + def load_split_datasets( + self, + split, + training, + epoch=1, combine=False, shard_epoch=None, **kwargs, + ): + data_param_list = self.get_split_data_param_list( + split, epoch, shard_epoch=shard_epoch, + ) + langpairs_sharing_datasets = {} if self.args.enable_reservsed_directions_shared_datasets else None + datasets = [ + ( + param['key'], + self.load_a_dataset( + combine=combine, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param + ), + ) + for param in data_param_list + ] + return datasets, data_param_list + + def load_into_sampled_multi_epoch_dataset( + self, split, datasets, data_param_list, + epoch, shard_epoch=None + ): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiEpochDataset( + OrderedDict(datasets), + epoch=epoch, + shard_epoch=shard_epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + batch_by_size=True, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + virtual_epoch_size=self.args.virtual_epoch_size, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + + def load_into_concat_dataset(self, split, datasets, data_param_list): + if self.args.lang_tok_replacing_bos_eos: + # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset + return SampledMultiDataset( + OrderedDict(datasets), + sampling_ratios=None, + eval_key=None, + batch_by_size=True, + collate_format=CollateFormat.single, + virtual_size=None, + split=split, + ) + return ConcatDataset([d for _, d in datasets]) + + def load_sampled_multi_epoch_dataset( + self, + split, + training, + epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, + epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == 'train': + return self.load_into_sampled_multi_epoch_dataset( + split, datasets, data_param_list, epoch, shard_epoch=shard_epoch) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) diff --git a/fairseq/data/multilingual/sampling_method.py b/fairseq/data/multilingual/sampling_method.py new file mode 100644 index 0000000000..6a9d39f7a6 --- /dev/null +++ b/fairseq/data/multilingual/sampling_method.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +import logging + + +logger = logging.getLogger(__name__) + + +def uniform(dataset_sizes: List[int]): + return [1.0] * len(dataset_sizes) + + +def temperature_sampling(dataset_sizes, temp): + total_size = sum(dataset_sizes) + return [(size / total_size) ** (1.0/temp) for size in dataset_sizes] + + +def make_temperature_sampling(temp=1.0): + def sampling_func(dataset_sizes): + return temperature_sampling(dataset_sizes, temp) + return sampling_func + + +def make_ratio_sampling(ratios): + def sampling_func(dataset_sizes): + return ratios + return sampling_func + + +class SamplingMethod: + @staticmethod + def add_arguments(parser): + parser.add_argument( + '--sampling-method', + choices=['uniform', 'temperature', 'concat', 'RoundRobin', ], + type=str, + default='concat', + help='The method to sample data per language pairs') + parser.add_argument('--sampling-temperature', default=1.5, type=float, + help='only work with --sampling-method temperature') + + @staticmethod + def build_sampler(args, task): + return SamplingMethod(args, task) + + def __init__(self, args, task): + self.args = args + self.task = task + + def is_adaptive(self): + return False + + def sampling_method_selector(self): + args = self.args + logger.info(f'selected sampler: {args.sampling_method}') + if args.sampling_method == 'uniform': + return uniform + elif args.sampling_method == 'temperature' or self.is_adaptive(): + return make_temperature_sampling(float(args.sampling_temperature)) + else: + # default to concating all data set together + return None From e52d071ee8b24fd371b5235938abeb1d0ae8e1ec Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 16 Jul 2020 09:32:44 -0700 Subject: [PATCH 062/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: new multiligual task Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext - A new task to glue all things together: fairseq/tasks/translation_multi_simple_epoch.py - Minor changes to - fairseq/data/iterators.py to allow dynamic batch sampler - fairseq/checkpoint_utils.py to add finetuning option instead of using restore_file which will restore from original model when being requeued. Reviewed By: pipibjc Differential Revision: D22483484 fbshipit-source-id: 283b67e538508f330b0968609b7dae64d26bea05 --- fairseq/data/iterators.py | 20 +- .../multilingual/multilingual_data_manager.py | 10 +- fairseq/options.py | 12 + fairseq/tasks/fairseq_task.py | 4 + .../tasks/translation_multi_simple_epoch.py | 300 ++++++++++++++++++ fairseq_cli/train.py | 4 +- tests/test_binaries.py | 39 +++ 7 files changed, 380 insertions(+), 9 deletions(-) create mode 100644 fairseq/tasks/translation_multi_simple_epoch.py diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 95aed8f295..5f4a616c65 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -188,8 +188,10 @@ class EpochBatchIterator(EpochBatchIterating): Args: dataset (~torch.utils.data.Dataset): dataset from which to load the data collate_fn (callable): merges a list of samples to form a mini-batch - batch_sampler (~torch.utils.data.Sampler): an iterator over batches of - indices + batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of + indices, or a callable to create such an iterator (~torch.utils.data.Sampler). + A callable batch_sampler will be called for each epoch to enable per epoch dynamic + batch iterators defined by this callable batch_sampler. seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N @@ -215,7 +217,8 @@ def __init__( assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset self.collate_fn = collate_fn - self.frozen_batches = tuple(batch_sampler) + self.batch_sampler = batch_sampler + self._frozen_batches = tuple(batch_sampler) if not callable(batch_sampler) else None self.seed = seed self.num_shards = num_shards self.shard_id = shard_id @@ -231,6 +234,12 @@ def __init__( self._next_epoch_itr = None self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + def __len__(self): return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) @@ -259,14 +268,17 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): that :attr:`dataset` supports prefetching (default: False). """ self.epoch = self.next_epoch_idx + self.dataset.set_epoch(self.epoch) if self._next_epoch_itr is not None: self._cur_epoch_itr = self._next_epoch_itr self._next_epoch_itr = None else: + if callable(self.batch_sampler): + # reset _frozen_batches to refresh the next epoch + self._frozen_batches = None self._cur_epoch_itr = self._get_iterator_for_epoch( self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, ) - self.dataset.set_epoch(self.epoch) self.shuffle = shuffle return self._cur_epoch_itr diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index c313a7be6c..14695a2879 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -12,7 +12,7 @@ import json from fairseq import options -from fairseq.options import eval_str_dict, eval_str_list +from fairseq.options import eval_str_dict, csv_str_list from fairseq.data import ( Dictionary, @@ -123,9 +123,13 @@ def add_args(parser): e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}', type=lambda uf: eval_str_dict(uf, type=str), default=None) - parser.add_argument('--langtoks-specs', help='a list of comma separated language tokens specifictions', + parser.add_argument('--langtoks-specs', + help='a list of comma separated data types that a set of language tokens to be specialized for, \ + e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \ + distinguish languages in different training data types. If not specified, default language \ + tokens per languages will be added', default='main', - type=lambda uf: eval_str_list(uf, type=str), + type=csv_str_list, ) parser.add_argument('--langtoks', help='a dictionary of how to add language tokens, \ e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \ diff --git a/fairseq/options.py b/fairseq/options.py index 77af81c3e4..b4dbf3a7e5 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -60,6 +60,10 @@ def get_validation_parser(default_task=None): return parser +def csv_str_list(x): + return x.split(',') + + def eval_str_list(x, type=float): if x is None: return None @@ -71,6 +75,14 @@ def eval_str_list(x, type=float): return [type(x)] +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + def eval_bool(x, default=False): if x is None: return default diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index f58ccea8cc..d5578fcc32 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import warnings +import os import torch @@ -78,6 +79,9 @@ def setup_task(cls, args, **kwargs): """ return cls(args, **kwargs) + def has_sharded_data(self, split): + return (os.pathsep in getattr(self.args, 'data', '')) + def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py new file mode 100644 index 0000000000..d19e4f6222 --- /dev/null +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -0,0 +1,300 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import datetime +import time + +import torch +from fairseq.data import ( + data_utils, + FairseqDataset, + iterators, + LanguagePairDataset, + ListDataset, +) + +from fairseq.tasks import FairseqTask, register_task +from fairseq.data.multilingual.sampling_method import SamplingMethod +from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager + + +### +def get_time_gap(s, e): + return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__() +### + + +logger = logging.getLogger(__name__) + + +@register_task('translation_multi_simple_epoch') +class TranslationMultiSimpleEpochTask(FairseqTask): + """ + Translate from one (source) language to another (target) language. + + Args: + langs (List[str]): a list of languages that are being supported + dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries + training (bool): whether the task should be configured for training or not + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + + The translation task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.translation_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', + help='inference source language') + parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', + help='inference target language') + parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') + + SamplingMethod.add_arguments(parser) + MultilingualDatasetManager.add_args(parser) + # fmt: on + + def __init__(self, args, langs, dicts, training): + super().__init__(args) + self.langs = langs + self.dicts = dicts + self.training = training + if training: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] + # eval_lang_pairs for multilingual translation is usually all of the + # lang_pairs. However for other multitask settings or when we want to + # optimize for certain languages we want to use a different subset. Thus + # the eval_lang_pairs class variable is provided for classes that extend + # this class. + self.eval_lang_pairs = self.lang_pairs + # model_lang_pairs will be used to build encoder-decoder model pairs in + # models.build_model(). This allows multitask type of sub-class can + # build models other than the input lang_pairs + self.model_lang_pairs = self.lang_pairs + self.sampling_method = SamplingMethod.build_sampler(args, self) + self.data_manager = MultilingualDatasetManager.setup_data_manager( + args, self.lang_pairs, langs, dicts, self.sampling_method) + + @classmethod + def setup_task(cls, args, **kwargs): + langs, dicts, training = MultilingualDatasetManager.prepare(args, **kwargs) + return cls(args, langs, dicts, training) + + def has_sharded_data(self, split): + return self.data_manager.has_sharded_data(split) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split in self.datasets: + dataset = self.datasets[split] + if self.has_sharded_data(split) and dataset.load_next_shard: + shard_epoch = dataset.shard_epoch + else: + # no need to load next shard so skip loading + # also this avoid always loading from beginning of the data + return + else: + shard_epoch = None + self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( + split, + self.training, + epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths): + src_data = ListDataset(src_tokens, src_lengths) + dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) + src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main'] + if self.args.lang_tok_replacing_bos_eos: + dataset = self.data_manager.alter_dataset_langtok( + dataset, + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + dataset.src = self.data_manager.src_dataset_tranform_func( + self.args.source_lang, + self.args.target_lang, + dataset=dataset.src, + spec=src_langtok_spec, + ) + return dataset + + def build_model(self, args): + return super().build_model(args) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + return loss, sample_size, logging_output + + def inference_step(self, generator, models, sample, prefix_tokens=None): + with torch.no_grad(): + _, tgt_langtok_spec = self.args.langtoks['main'] + if not self.args.lang_tok_replacing_bos_eos: + if prefix_tokens is None and tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) + src_tokens = sample['net_input']['src_tokens'] + bsz = src_tokens.size(0) + prefix_tokens = torch.LongTensor( + [[tgt_lang_tok]] + ).expand(bsz, 1).to(src_tokens) + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + ) + else: + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + bos_token=self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) + if tgt_langtok_spec else self.target_dictionary.eos(), + ) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def source_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.source_lang] + + @property + def target_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.target_lang] + + def create_batch_sampler_func( + self, max_positions, ignore_invalid_inputs, + max_tokens, max_sentences + ): + def construct_batch_sampler( + dataset, epoch + ): + splits = [s for s, _ in self.datasets.items() if self.datasets[s] == dataset] + split = splits[0] if len(splits) > 0 else None + + if epoch is not None: + dataset.set_epoch(epoch) + start_time = time.time() + # get indices ordered by example size + indices = dataset.ordered_indices() + logger.debug(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}') + + # filter examples that are too large + if max_positions is not None: + my_time = time.time() + indices = data_utils.filter_by_size( + indices, dataset, max_positions, raise_exception=(not ignore_invalid_inputs), + ) + logger.debug(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}') + + # create mini-batches with given size constraints + my_time = time.time() + batch_sampler = data_utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, + ) + logger.debug(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}') + logger.debug(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}') + return batch_sampler + return construct_batch_sampler + + # we need to override get_batch_iterator because we want to reset the epoch iterator each time + def get_batch_iterator( + self, dataset, max_tokens=None, max_sentences=None, max_positions=None, + ignore_invalid_inputs=False, required_batch_size_multiple=1, + seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 0). + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + # initialize the dataset with the correct starting epoch + assert isinstance(dataset, FairseqDataset) + if dataset in self.dataset_to_epoch_iter: + return self.dataset_to_epoch_iter[dataset] + if ( + self.args.sampling_method == 'RoundRobin' + ): + batch_iter = super().get_batch_iterator( + dataset, max_tokens=max_tokens, max_sentences=max_sentences, max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, required_batch_size_multiple=required_batch_size_multiple, + seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, + ) + self.dataset_to_epoch_iter[dataset] = batch_iter + return batch_iter + + construct_batch_sampler = self.create_batch_sampler_func( + max_positions, ignore_invalid_inputs, + max_tokens, max_sentences) + + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=construct_batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + return epoch_iter diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 2def237e49..a6b73088b3 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -10,7 +10,6 @@ import argparse import logging import math -import os import random import sys from typing import Callable, Optional @@ -130,6 +129,7 @@ def main( lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() + while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) @@ -142,7 +142,7 @@ def main( epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch - load_dataset=(os.pathsep in getattr(args, "data", "")), + load_dataset=task.has_sharded_data('train'), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index f5e53fd000..a6133b1b41 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -206,6 +206,45 @@ def test_multilingual_transformer(self): ] + enc_ltok_flag + dec_ltok_flag, ) + def test_translation_multi_simple_epoch(self): + # test with all combinations of encoder/decoder lang tokens + encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']] + decoder_langtok_flags = [[], ['--decoder-langtok']] + with contextlib.redirect_stdout(StringIO()): + for i in range(len(encoder_langtok_flags)): + for j in range(len(decoder_langtok_flags)): + enc_ltok_flag = encoder_langtok_flags[i] + dec_ltok_flag = decoder_langtok_flags[j] + with tempfile.TemporaryDirectory(f'test_translation_multi_simple_epoch_{i}_{j}') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + arch='transformer', + task='translation_multi_simple_epoch', + extra_flags=[ + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--sampling-method', 'temperature', + '--sampling-temperature', '1.5', + '--virtual-epoch-size', '1000', + ] + enc_ltok_flag + dec_ltok_flag, + lang_flags=['--lang-pairs', 'in-out,out-in'], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + '--task', 'translation_multi_simple_epoch', + '--lang-pairs', 'in-out,out-in', + '--source-lang', 'in', + '--target-lang', 'out', + ] + enc_ltok_flag + dec_ltok_flag, + ) + def test_transformer_cross_self_attention(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir: From 77df83ab6e07468dbc8b31a5b84766f7fedd250d Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 16 Jul 2020 09:58:43 -0700 Subject: [PATCH 063/707] Consolidate distributed init code into distributed_utils.call_main (#1218) Summary: We use `distributed_utils.call_main` in most of the other CLI tools (e.g., generate.py, eval_lm.py), but not train.py. The only place where they're different is that train.py supports TPUs and the `after_distributed_init_fn` hook. We can add that support to `distributed_utils.call_main` and merge them. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1218 Reviewed By: jhcross Differential Revision: D22556771 Pulled By: myleott fbshipit-source-id: 4f7110155f5f5d96905ef0bd17a4aa243ec8c443 --- fairseq/distributed_utils.py | 45 +++++++++++++--------- fairseq_cli/train.py | 74 ++---------------------------------- 2 files changed, 31 insertions(+), 88 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 62f70991e7..86ec49075e 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -27,7 +27,7 @@ def is_master(args): return args.distributed_rank == 0 -def infer_init_method(args): +def infer_init_method(args, force_distributed=False): if args.distributed_init_method is not None or getattr(args, 'tpu', False): return @@ -75,6 +75,12 @@ def infer_init_method(args): except FileNotFoundError: # Slurm is not installed pass + elif args.distributed_world_size > 1 or force_distributed: + # fallback for single node with multiple GPUs + assert args.distributed_world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + def distributed_init(args): if not getattr(args, 'tpu', False): @@ -132,14 +138,19 @@ def distributed_init(args): return args.distributed_rank -def _distributed_main(i, main, args, kwargs): +def distributed_main(i, main, args, kwargs): args.device_id = i - if torch.cuda.is_available() and not args.cpu: + if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): torch.cuda.set_device(args.device_id) if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = kwargs.get('start_rank', 0) + i + args.distributed_rank = kwargs.pop('start_rank', 0) + i args.distributed_rank = distributed_init(args) + + after_distributed_init_fn = kwargs.pop('after_distributed_init_fn', None) + if after_distributed_init_fn: + args = after_distributed_init_fn(args) + main(args, **kwargs) @@ -149,27 +160,27 @@ def call_main(args, main, **kwargs): if args.distributed_init_method is not None: # distributed main - if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: + if not args.distributed_no_spawn: start_rank = args.distributed_rank args.distributed_rank = None # assign automatically kwargs['start_rank'] = start_rank torch.multiprocessing.spawn( - fn=_distributed_main, + fn=distributed_main, args=(main, args, kwargs), - nprocs=torch.cuda.device_count(), + nprocs=min( + torch.cuda.device_count(), + args.distributed_world_size, + ), ) else: - _distributed_main(args.device_id, main, args, kwargs) - elif args.distributed_world_size > 1: - # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) - args.distributed_rank = None # set based on device id - torch.multiprocessing.spawn( - fn=_distributed_main, + distributed_main(args.device_id, main, args, kwargs) + elif getattr(args, "tpu", False): + import torch_xla.distributed.xla_multiprocessing as xmp + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, args=(main, args, kwargs), - nprocs=args.distributed_world_size, + nprocs=8, # use all 8 TPU cores ) else: # single GPU main diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index a6b73088b3..88460c30a9 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -12,7 +12,6 @@ import math import random import sys -from typing import Callable, Optional import numpy as np import torch @@ -39,13 +38,7 @@ logger = logging.getLogger("fairseq_cli.train") -def main( - args, - init_distributed=False, - after_distributed_init_fn: Optional[ - Callable[[argparse.Namespace], argparse.Namespace] - ] = None, -): +def main(args): utils.import_user_module(args) assert ( @@ -53,15 +46,8 @@ def main( ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() - # Initialize CUDA and distributed training - if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): - torch.cuda.set_device(args.device_id) np.random.seed(args.seed) utils.set_torch_seed(args.seed) - if init_distributed: - args.distributed_rank = distributed_utils.distributed_init(args) - if after_distributed_init_fn: - args = after_distributed_init_fn(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) @@ -345,69 +331,15 @@ def get_valid_stats(args, trainer, stats): return stats -def distributed_main( - i, - args, - start_rank=0, - after_distributed_init_fn: Optional[ - Callable[[argparse.Namespace], argparse.Namespace] - ] = None, -): - args.device_id = i - if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = start_rank + i - main( - args, init_distributed=True, after_distributed_init_fn=after_distributed_init_fn - ) - - def cli_main(modify_parser=None): parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - cli_main_helper(args) - else: - cli_main_helper(args) - - -def cli_main_helper(args): - if args.distributed_init_method is None: - distributed_utils.infer_init_method(args) - - if args.distributed_init_method is not None: - # distributed training - if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: - start_rank = args.distributed_rank - args.distributed_rank = None # assign automatically - torch.multiprocessing.spawn( - fn=distributed_main, - args=(args, start_rank), - nprocs=torch.cuda.device_count(), - ) - else: - distributed_main(args.device_id, args) - elif args.distributed_world_size > 1: - if not getattr(args, "tpu", False): - # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - args.distributed_init_method = "tcp://localhost:{port}".format(port=port) - args.distributed_rank = None # set based on device id - torch.multiprocessing.spawn( - fn=distributed_main, args=(args,), nprocs=args.distributed_world_size - ) - else: - import torch_xla.distributed.xla_multiprocessing as xmp - - torch.multiprocessing.set_sharing_strategy("file_system") - xmp.spawn( - fn=distributed_main, args=(args,), nprocs=8 # use all 8 TPU cores - ) + distributed_utils.call_main(args, main) else: - # single GPU training - main(args) + distributed_utils.call_main(args, main) if __name__ == "__main__": From 75d354c92ba6d96d45a8d6fb6f28183817efe203 Mon Sep 17 00:00:00 2001 From: Duc Le Date: Thu, 16 Jul 2020 10:55:49 -0700 Subject: [PATCH 064/707] NNLM training in PySpeech Summary: Enable support for NNLM training in PySpeech. This implementation slightly modifies Fairseq's `LanguageModelingTask` in a few ways: 1. `source` and `input` used during training are slightly different (see `_maybe_add_bos` under `PySpeechLMDataset`). 2. The underlying model is `PySpeechEncoderModel` instead of `FairseqDecoder`. This lets us interface more easily with PySpeech, and the jitted model can easily be used in C++. Reviewed By: jay-mahadeokar Differential Revision: D22077479 fbshipit-source-id: 4918b26ba78de8786870060ada0bc3d3a28d64b0 --- fairseq/tasks/language_modeling.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 942477efb0..a4a98e07bc 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -106,12 +106,7 @@ def __init__(self, args, dictionary, output_dictionary=None, targets=None): self.targets = targets @classmethod - def setup_task(cls, args, **kwargs): - """Setup the task (e.g., load dictionaries). - - Args: - args (argparse.Namespace): parsed command-line arguments - """ + def setup_dictionary(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: @@ -124,6 +119,16 @@ def setup_task(cls, args, **kwargs): output_dictionary = TruncatedDictionary( dictionary, args.output_dictionary_size ) + return (dictionary, output_dictionary) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) # upgrade old checkpoints if hasattr(args, "exclude_self_target"): @@ -197,17 +202,20 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): and self.args.sample_break_mode != "none" ) - self.datasets[split] = MonolingualDataset( - dataset, - dataset.sizes, - self.dictionary, - self.output_dictionary, + self.datasets[split] = self._initialize_dataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=self.dictionary, + tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, ) + def _initialize_dataset(self, **kwargs): + return MonolingualDataset(**kwargs) + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens From 3655cf266e32a2272d6deac6069a594977880084 Mon Sep 17 00:00:00 2001 From: James Cross Date: Thu, 16 Jul 2020 17:43:42 -0700 Subject: [PATCH 065/707] optional limit on total training time (#2333) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2333 This change adds a new option (`--stop-time-hours`) which if specified limits the total training time to that number of hours. In order to stop training within the inner training loop (after the first update exceeding the time limit) the starting time is stored on the trainer. In addition, in order to persist the training time when when restoring from checkpoints (important because training runs are sometimes killed due to resource constraints), training time already completed is stored as extra state in the checkpoints (though this change is backward compatible with existing checkpoints). Reviewed By: myleott Differential Revision: D22573166 fbshipit-source-id: 01c59274a1c196acc8a3a0243814167e1d368b1a --- fairseq/options.py | 2 ++ fairseq/trainer.py | 27 +++++++++++++++++++++++++-- fairseq_cli/train.py | 6 ++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index b4dbf3a7e5..88c0389eba 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -455,6 +455,8 @@ def add_optimization_args(parser): help='force stop training at specified epoch') group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', help='force stop training at specified update') + group.add_argument('--stop-time-hours', default=0, type=float, metavar='N', + help='force stop training after specified cumulative time (if >0)') group.add_argument('--clip-norm', default=0.0, type=float, metavar='NORM', help='clip threshold of gradients') group.add_argument('--sentence-avg', action='store_true', diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d456bceecf..a91d12fdc2 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -11,6 +11,7 @@ from itertools import chain import logging import sys +import time from typing import Any, Dict, List import torch @@ -110,6 +111,10 @@ def __init__(self, args, task, model, criterion, quantizer=None): metrics.log_start_time("wall", priority=790, round=0) + self._start_time = time.time() + self._previous_training_time = 0 + self._cumulative_training_time = None + def reinitialize(self): """Reinitialize the Trainer, typically after model params change.""" self._lr_scheduler = None @@ -218,6 +223,7 @@ def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint extra_state["metrics"] = metrics.state_dict() + extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( filename, self.args, @@ -291,6 +297,10 @@ def load_checkpoint( ) ) + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() + self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: @@ -468,9 +478,11 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): - logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( - logging_outputs, sample_size, ooms, ignore=is_dummy_batch, + train_time = self._local_cumulative_training_time() + logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( + logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) + self._cumulative_training_time = total_train_time / self.data_parallel_world_size overflow = False try: @@ -716,6 +728,17 @@ def set_num_updates(self, num_updates): def clip_grad_norm(self, clip_norm): return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + def cumulative_training_time(self): + if self._cumulative_training_time is None: + # single GPU + return self._local_cumulative_training_time() + else: + return self._cumulative_training_time + + def _local_cumulative_training_time(self): + """Aggregate training time in seconds.""" + return time.time() - self._start_time + self._previous_training_time + def _prepare_sample(self, sample): if sample == "DUMMY": raise Exception( diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 88460c30a9..770cbb97e5 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -229,6 +229,12 @@ def train(args, trainer, task, epoch_itr): valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) + + if args.stop_time_hours > 0: + elapsed_hours = trainer.cumulative_training_time() / (60 * 60) + if elapsed_hours > args.stop_time_hours: + should_stop = True + if should_stop: break From 7ea5e3b341d9ccf0005ba7695f1b9c39a11a145e Mon Sep 17 00:00:00 2001 From: Stanislau Hlebik Date: Fri, 17 Jul 2020 17:07:23 -0700 Subject: [PATCH 066/707] remediation of S205607 fbshipit-source-id: 5113fe0c527595e4227ff827253b7414abbdf7ac --- fairseq/data/audio/__init__.py | 1 + fairseq/logging/__init__.py | 1 + fairseq/modules/quantization/__init__.py | 1 + fairseq_cli/__init__.py | 1 + scripts/__init__.py | 1 + tests/__init__.py | 1 + tests/gpu/__init__.py | 1 + tests/speech_recognition/__init__.py | 1 + 8 files changed, 8 insertions(+) diff --git a/fairseq/data/audio/__init__.py b/fairseq/data/audio/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/fairseq/data/audio/__init__.py +++ b/fairseq/data/audio/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/fairseq/logging/__init__.py +++ b/fairseq/logging/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/fairseq/modules/quantization/__init__.py b/fairseq/modules/quantization/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/fairseq/modules/quantization/__init__.py +++ b/fairseq/modules/quantization/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/fairseq_cli/__init__.py b/fairseq_cli/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/fairseq_cli/__init__.py +++ b/fairseq_cli/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/scripts/__init__.py b/scripts/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/scripts/__init__.py +++ b/scripts/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/tests/gpu/__init__.py b/tests/gpu/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/tests/gpu/__init__.py +++ b/tests/gpu/__init__.py @@ -0,0 +1 @@ +S205607 diff --git a/tests/speech_recognition/__init__.py b/tests/speech_recognition/__init__.py index e69de29bb2..56de9c5ee1 100644 --- a/tests/speech_recognition/__init__.py +++ b/tests/speech_recognition/__init__.py @@ -0,0 +1 @@ +S205607 From 698e3b91ffa832c286c48035bdff78238b0de8ae Mon Sep 17 00:00:00 2001 From: Stanislau Hlebik Date: Fri, 17 Jul 2020 17:07:23 -0700 Subject: [PATCH 067/707] remediation of S205607 fbshipit-source-id: 798decc90db4f13770e97cdce3c0df7d5421b2a3 --- fairseq/data/audio/__init__.py | 1 - fairseq/logging/__init__.py | 1 - fairseq/modules/quantization/__init__.py | 1 - fairseq_cli/__init__.py | 1 - scripts/__init__.py | 1 - tests/__init__.py | 1 - tests/gpu/__init__.py | 1 - tests/speech_recognition/__init__.py | 1 - 8 files changed, 8 deletions(-) diff --git a/fairseq/data/audio/__init__.py b/fairseq/data/audio/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/fairseq/data/audio/__init__.py +++ b/fairseq/data/audio/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/fairseq/logging/__init__.py +++ b/fairseq/logging/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/fairseq/modules/quantization/__init__.py b/fairseq/modules/quantization/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/fairseq/modules/quantization/__init__.py +++ b/fairseq/modules/quantization/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/fairseq_cli/__init__.py b/fairseq_cli/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/fairseq_cli/__init__.py +++ b/fairseq_cli/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/scripts/__init__.py b/scripts/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/scripts/__init__.py +++ b/scripts/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/tests/__init__.py b/tests/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/tests/gpu/__init__.py b/tests/gpu/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/tests/gpu/__init__.py +++ b/tests/gpu/__init__.py @@ -1 +0,0 @@ -S205607 diff --git a/tests/speech_recognition/__init__.py b/tests/speech_recognition/__init__.py index 56de9c5ee1..e69de29bb2 100644 --- a/tests/speech_recognition/__init__.py +++ b/tests/speech_recognition/__init__.py @@ -1 +0,0 @@ -S205607 From 93f5128509278f425afb6bcf0da574c0af0e0c16 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 20 Jul 2020 08:24:30 -0700 Subject: [PATCH 068/707] Misc fixes (#2342) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2342 Reviewed By: ngoyal2707 Differential Revision: D22601110 Pulled By: myleott fbshipit-source-id: 7a704c07d507692f274c31ec74b090134fa9dee3 --- examples/translation_moe/score.py | 9 ++++++--- fairseq/distributed_utils.py | 4 +++- fairseq/models/bart/hub_interface.py | 2 +- fairseq/optim/fairseq_optimizer.py | 12 ++++++++---- fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py | 9 ++++++--- fairseq/tasks/translation_lev.py | 7 +++++++ fairseq/utils.py | 9 +++++---- 7 files changed, 36 insertions(+), 16 deletions(-) diff --git a/examples/translation_moe/score.py b/examples/translation_moe/score.py index 8e207093db..b68cc828a7 100644 --- a/examples/translation_moe/score.py +++ b/examples/translation_moe/score.py @@ -55,13 +55,16 @@ def load_sys(paths): with open(path) as f: for line in f: line = line.rstrip() - if line.startswith(('S-', 'T-', 'H-')): + # S: source + # T: target + # D: detokenized system output + if line.startswith(('S-', 'T-', 'D-')): i = int(line[line.find('-')+1:line.find('\t')]) if line.startswith('S-'): src[i] = line.split('\t')[1] if line.startswith('T-'): tgt[i] = line.split('\t')[1] - if line.startswith('H-'): + if line.startswith('D-'): if i not in hypos: hypos[i] = [] log_probs[i] = [] @@ -115,7 +118,7 @@ def sentence_bleu(hypothesis, reference): bleu = compute_bleu( bleu.counts, bleu.totals, bleu.sys_len, bleu.ref_len, - smooth='exp', smooth_floor=0.0, + smooth_method='exp', ) return bleu.score diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 86ec49075e..7ee89adce9 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -38,6 +38,8 @@ def infer_init_method(args, force_distributed=False): args.distributed_init_method = 'env://' args.distributed_world_size = int(os.environ['WORLD_SIZE']) args.distributed_rank = int(os.environ['RANK']) + # processes are created by torch.distributed.launch + args.distributed_no_spawn = True # we can determine the init method automatically for Slurm elif args.distributed_port > 0: @@ -159,7 +161,7 @@ def call_main(args, main, **kwargs): infer_init_method(args) if args.distributed_init_method is not None: - # distributed main + # distributed training if not args.distributed_no_spawn: start_rank = args.distributed_rank args.distributed_rank = None # assign automatically diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index f87291bfbd..48c59cb91d 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -23,7 +23,7 @@ class BARTHubInterface(nn.Module): """A simple PyTorch Hub interface to BART. - Usage: https://github.com/pytorch/fairseq/tree/master/examples/BART + Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart """ def __init__(self, args, task, model): diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 3242a92a35..b1b9c76edb 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -41,20 +41,24 @@ def optimizer_config(self): @property def params(self): """Return an iterable of the parameters held by the optimizer.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: for p in param_group['params']: yield p + @property + def param_groups(self): + return self.optimizer.param_groups + def __getstate__(self): return self._optimizer.__getstate__() def get_lr(self): """Return the current learning rate.""" - return self.optimizer.param_groups[0]['lr'] + return self.param_groups[0]['lr'] def set_lr(self, lr): """Set the learning rate.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: param_group['lr'] = lr def state_dict(self): @@ -73,7 +77,7 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): if optimizer_overrides is not None and len(optimizer_overrides) > 0: # override learning rate, momentum, etc. with latest values - for group in self.optimizer.param_groups: + for group in self.param_groups: group.update(optimizer_overrides) def backward(self, loss): diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index ef9ef26195..431e784de6 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -34,7 +34,7 @@ def __init__(self, args, optimizer): ' Consider --lr-scheduler=fixed instead.' ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, patience=0, factor=args.lr_shrink, + self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, threshold=args.lr_threshold) warmup_end_lr = args.lr[0] # if no warm up, sets initial lr to be args.lr[0] @@ -59,8 +59,11 @@ def add_args(parser): parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', help='shrink factor for annealing, lr_new = (lr * lr_shrink)') parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', - help='Threshold for measuring the new optimum, \ - to only focus on significant changes') + help='threshold for measuring the new optimum, ' + 'to only focus on significant changes') + parser.add_argument('--lr-patience', default=0, type=int, + help='number of epochs with no improvement after which ' + 'learning rate will be reduced') parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', help='warmup the learning rate linearly for the first N updates') parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index d07c271569..845dd81644 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -7,6 +7,8 @@ import torch +from fairseq.data import LanguagePairDataset + from fairseq.utils import new_arange from fairseq.tasks import register_task from fairseq.tasks.translation import TranslationTask, load_langpair_dataset @@ -139,6 +141,11 @@ def build_generator(self, models, args): adaptive=not getattr(args, 'iter_decode_force_max_iter', False), retain_history=getattr(args, 'retain_iter_history', False)) + def build_dataset_for_inference(self, src_tokens, src_lengths): + return LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary, append_bos=True + ) + def train_step(self, sample, model, diff --git a/fairseq/utils.py b/fairseq/utils.py index 739ba49f9e..d4ed89e357 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -297,10 +297,11 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if multi_tensor_l2norm_available: total_norm = multi_tensor_total_norm(grads) else: - warnings.warn( - "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " - "you may get better performance by installing NVIDIA's apex library" - ) + if torch.cuda.is_available(): + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) total_norm = torch.norm( torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) ) From c1e734b2dd7024044c8dee551620146e4f872ad4 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Mon, 20 Jul 2020 19:57:29 -0700 Subject: [PATCH 069/707] Ensure checkpoints are properly saved when hitting stop time limit Summary: Context: https://fburl.com/tasks/3nhdm1rv and https://fburl.com/test/4i6icfbd This should fix test_e2e_base_training_wo_prepare_data breakage due to D22598579: when stop time limit https://fburl.com/diffusion/rrp4jst5 gets triggered, a checkpoint isn't saved as it should in https://fburl.com/diffusion/hbn0c6o5. The error isn't related to Manifold usage in export_ensemble - that's just the first place where we notice that no checkpoint was written. So for toy tests that finish quickly, it's possible for training to end before any checkpoint has been saved. It was affecting dev-nosan but not opt probably because opt training finishes quickly enough to not hit the stop time limit? Reviewed By: akinh Differential Revision: D22639199 fbshipit-source-id: ec4da15bb14e14c2066af6946d7a34db333178eb --- fairseq_cli/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 770cbb97e5..4184210ae7 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -128,7 +128,7 @@ def main(args): epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch - load_dataset=task.has_sharded_data('train'), + load_dataset=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) @@ -210,7 +210,9 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): - with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue @@ -230,11 +232,6 @@ def train(args, trainer, task, epoch_itr): args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) - if args.stop_time_hours > 0: - elapsed_hours = trainer.cumulative_training_time() / (60 * 60) - if elapsed_hours > args.stop_time_hours: - should_stop = True - if should_stop: break @@ -270,6 +267,10 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc should_stop = ( should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update + or ( + args.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + ) ) # Save checkpoint @@ -294,7 +295,7 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: - logger.info("begin validation on \"{}\" subset".format(subset)) + logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) From 84e32315af7b4a479450a1e65b9e277226d75576 Mon Sep 17 00:00:00 2001 From: Jun Ru Anderson Date: Thu, 23 Jul 2020 17:05:40 -0700 Subject: [PATCH 070/707] =?UTF-8?q?Move=20DynamicLossScaler=20into=20its?= =?UTF-8?q?=20own=20file=20and=20separate=20opti=E2=80=A6=20(#1221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …mizer and scaling logic # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Separates (some of) scaling logic from optimizer logic. This increases readability in its own right and also makes the addition of new scalers more straightforward ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1221 Reviewed By: myleott Differential Revision: D22696726 Pulled By: andersonic fbshipit-source-id: c3a19184de0f17e75766894286c20feacdb3e010 --- fairseq/optim/dynamic_loss_scaler.py | 63 ++++++++++++++ fairseq/optim/fp16_optimizer.py | 119 +++++++-------------------- 2 files changed, 93 insertions(+), 89 deletions(-) create mode 100644 fairseq/optim/dynamic_loss_scaler.py diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py new file mode 100644 index 0000000000..9d1f0b2c05 --- /dev/null +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +class DynamicLossScaler(object): + + def __init__( + self, init_scale=2.**15, scale_factor=2., scale_window=2000, + tolerance=0.05, threshold=None, min_loss_scale=1e-4 + ): + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self.tolerance = tolerance + self.threshold = threshold + self._iter = 0 + self._last_overflow_iter = -1 + self._last_rescale_iter = -1 + self._overflows_since_rescale = 0 + self.min_loss_scale = min_loss_scale + + def scale(self, outputs): + return self.loss_scale * outputs + + def update(self): + if (self._iter - self._last_overflow_iter) % self.scale_window == 0: + self.loss_scale *= self.scale_factor + self._last_rescale_iter = self._iter + self._iter += 1 + + def _decrease_loss_scale(self): + self.loss_scale /= self.scale_factor + if self.threshold is not None: + self.loss_scale = max(self.loss_scale, self.threshold) + + def check_overflow(self, grad_norm): + # detect inf and nan + if grad_norm == float('inf') or grad_norm != grad_norm: + # overflow has occured + prev_scale = self.loss_scale + iter_since_rescale = self._iter - self._last_rescale_iter + + self._last_overflow_iter = self._iter + self._overflows_since_rescale += 1 + pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) + if pct_overflow >= self.tolerance: + self._decrease_loss_scale() + self._last_rescale_iter = self._iter + self._overflows_since_rescale = 0 + + if self.loss_scale <= self.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + self.loss_scale = prev_scale + raise FloatingPointError(( + 'Minimum loss scale reached ({}). Your loss is probably exploding. ' + 'Try lowering the learning rate, using gradient clipping or ' + 'increasing the batch size.' + ).format(self.min_loss_scale)) + + self._iter += 1 + raise OverflowError('setting loss scale to: ' + str(self.loss_scale)) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 58e8e9e0de..37e94965bb 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -9,49 +9,7 @@ from fairseq import optim, utils - -class DynamicLossScaler(object): - - def __init__( - self, init_scale=2.**15, scale_factor=2., scale_window=2000, - tolerance=0.05, threshold=None, - ): - self.loss_scale = init_scale - self.scale_factor = scale_factor - self.scale_window = scale_window - self.tolerance = tolerance - self.threshold = threshold - self._iter = 0 - self._last_overflow_iter = -1 - self._last_rescale_iter = -1 - self._overflows_since_rescale = 0 - - def update_scale(self, overflow): - iter_since_rescale = self._iter - self._last_rescale_iter - if overflow: - self._last_overflow_iter = self._iter - self._overflows_since_rescale += 1 - pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) - if pct_overflow >= self.tolerance: - self._decrease_loss_scale() - self._last_rescale_iter = self._iter - self._overflows_since_rescale = 0 - elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: - self.loss_scale *= self.scale_factor - self._last_rescale_iter = self._iter - self._iter += 1 - - def _decrease_loss_scale(self): - self.loss_scale /= self.scale_factor - if self.threshold is not None: - self.loss_scale = max(self.loss_scale, self.threshold) - - @staticmethod - def has_overflow(grad_norm): - # detect inf and nan - if grad_norm == float('inf') or grad_norm != grad_norm: - return True - return False +from .dynamic_loss_scaler import DynamicLossScaler class _FP16OptimizerMixin(object): @@ -113,7 +71,7 @@ def backward(self, loss): underflow. """ if self.scaler is not None: - loss = loss * self.scaler.loss_scale + loss = self.scaler.scale(loss) loss.backward() self._needs_sync = True @@ -146,6 +104,22 @@ def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): self._needs_sync = False + def _sync_fp32_grads_to_fp16(self): + # copy FP32 params back into FP16 model + if self.has_flat_params: + offset = 0 + for p in self.fp16_params: + if not p.requires_grad: + continue + numel = p.data.numel() + p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + def multiply_grads(self, c): """Multiplies grads by a constant ``c``.""" if self._needs_sync: @@ -163,20 +137,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): # detect overflow and adjust loss scale if self.scaler is not None: - overflow = DynamicLossScaler.has_overflow(grad_norm) - prev_scale = self.scaler.loss_scale - self.scaler.update_scale(overflow) - if overflow: - if self.scaler.loss_scale <= self.min_loss_scale: - # Use FloatingPointError as an uncommon error that parent - # functions can safely catch to stop training. - self.scaler.loss_scale = prev_scale - raise FloatingPointError(( - 'Minimum loss scale reached ({}). Your loss is probably exploding. ' - 'Try lowering the learning rate, using gradient clipping or ' - 'increasing the batch size.' - ).format(self.min_loss_scale)) - raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) + self.scaler.check_overflow(grad_norm) return grad_norm @@ -185,20 +146,10 @@ def step(self, closure=None): self._sync_fp16_grads_to_fp32() self.fp32_optimizer.step(closure) - # copy FP32 params back into FP16 model - if self.has_flat_params: - offset = 0 - for p in self.fp16_params: - if not p.requires_grad: - continue - numel = p.data.numel() - p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) - offset += numel - else: - for p, p32 in zip(self.fp16_params, self.fp32_params): - if not p.requires_grad: - continue - p.data.copy_(p32.data) + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_grads_to_fp16() def zero_grad(self): """Clears the gradients of all optimized parameters.""" @@ -240,8 +191,8 @@ def __init__(self, args, params, fp32_optimizer, fp32_params): scale_window=scale_window, tolerance=args.fp16_scale_tolerance, threshold=args.threshold_loss_scale, + min_loss_scale=args.min_loss_scale ) - self.min_loss_scale = self.args.min_loss_scale else: # disable loss scaling for bfloat16 self.scaler = None @@ -340,7 +291,7 @@ def backward(self, loss): underflow. """ if self.scaler is not None: - loss = loss * self.scaler.loss_scale + loss = self.scaler.scale(loss) loss.backward() def _unscale_grads(self): @@ -363,20 +314,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): self._multiply_factor *= max_norm / grad_norm_cpu # detect overflow and adjust loss scale - overflow = DynamicLossScaler.has_overflow(grad_norm_cpu) - prev_scale = self.scaler.loss_scale - self.scaler.update_scale(overflow) - if overflow: - if self.scaler.loss_scale <= self.min_loss_scale: - # Use FloatingPointError as an uncommon error that parent - # functions can safely catch to stop training. - self.scaler.loss_scale = prev_scale - raise FloatingPointError(( - 'Minimum loss scale reached ({}). Your loss is probably exploding. ' - 'Try lowering the learning rate, using gradient clipping or ' - 'increasing the batch size.' - ).format(self.min_loss_scale)) - raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) + self.scaler.check_overflow(grad_norm_cpu) else: clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) self._multiply_factor *= clip_coef @@ -392,6 +330,9 @@ def step(self, closure=None): self._unscale_grads() self.wrapped_optimizer.step(closure) + if self.scaler is not None: + self.scaler.update() + def zero_grad(self): """Clears the gradients of all optimized parameters.""" self.wrapped_optimizer.zero_grad() @@ -441,8 +382,8 @@ def __init__(self, args, params, optimizer): scale_window=scale_window, tolerance=args.fp16_scale_tolerance, threshold=args.threshold_loss_scale, + min_loss_scale=args.min_loss_scale ) - self.min_loss_scale = self.args.min_loss_scale else: # disable loss scaling for bfloat16 self.scaler = None From f448f36462a77036b814e740a44990f8c3ba8760 Mon Sep 17 00:00:00 2001 From: Chau Tran Date: Fri, 24 Jul 2020 16:50:00 -0700 Subject: [PATCH 071/707] fix sentencepiece vocab processor behavior with unknown token Summary: We don't want to use INVALID_ID=-1 for tokens not found in dictionary, use UNKNOWN=3 instead Reviewed By: cndn Differential Revision: D22480966 fbshipit-source-id: 79b232e70efe7a0336a149ae494b24590afe5ea0 --- fairseq/tasks/translation_from_pretrained_bart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 181ad3c6ff..2b7d589cee 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -6,6 +6,7 @@ import torch from fairseq.data import LanguagePairDataset +from fairseq import utils from .translation import load_langpair_dataset, TranslationTask from . import register_task @@ -63,7 +64,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = self.args.data.split(':') + paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] From 4f618a758ccd6b1924508ccbfb32eaacc3ea11c5 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 24 Jul 2020 22:28:34 -0700 Subject: [PATCH 072/707] A more general solution to strip symbols from generation output Summary: An attempt to get more general output stripping in generator. Context: In the mBART pull request: https://github.com/fairinternal/fairseq-py/commit/7fbd17a3a67e1e4950855d646d6f4dd9db76bc63, eos was introduced in sequence_generator.py for removing the the lang tokens. However it only serves the special case where eos is modified to be the language token ID. Reviewed By: myleott Differential Revision: D22668233 fbshipit-source-id: 9bdaaefb28508cb74d8b2f3bd99160646b442959 --- fairseq/sequence_generator.py | 4 ++++ fairseq/sequence_scorer.py | 8 +++++++- fairseq/tasks/fairseq_task.py | 11 ++++++----- fairseq_cli/generate.py | 15 +++++++++------ fairseq_cli/interactive.py | 2 ++ 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 70a1c93a88..ed65bd86b6 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -32,6 +32,7 @@ def __init__( no_repeat_ngram_size=0, search_strategy=None, eos=None, + symbols_to_strip_from_output=None, ): """Generates translations of a given source sentence. @@ -63,6 +64,9 @@ def __init__( self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None else {self.eos}) self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index deaf86f3ce..343c29acc2 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -12,12 +12,18 @@ class SequenceScorer(object): """Scores the target for a given source sentence.""" - def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None): + def __init__( + self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None, + symbols_to_strip_from_output=None, + ): self.pad = tgt_dict.pad() self.eos = tgt_dict.eos() if eos is None else eos self.softmax_batch = softmax_batch or sys.maxsize assert self.softmax_batch > 0 self.compute_alignment = compute_alignment + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None else {self.eos}) @torch.no_grad() def generate(self, models, sample, **kwargs): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index d5578fcc32..ca187b73d8 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -232,7 +232,7 @@ def build_criterion(self, args): return criterions.build_criterion(args, self) - def build_generator(self, models, args): + def build_generator(self, models, args, seq_gen_cls=None): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer @@ -296,10 +296,11 @@ def build_generator(self, models, args): else: search_strategy = search.BeamSearch(self.target_dictionary) - if getattr(args, "print_alignment", False): - seq_gen_cls = SequenceGeneratorWithAlignment - else: - seq_gen_cls = SequenceGenerator + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + else: + seq_gen_cls = SequenceGenerator return seq_gen_cls( models, diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index ef259b6f2d..9f67c0271d 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -38,6 +38,13 @@ def main(args): return _main(args, sys.stdout) +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, 'symbols_to_strip_from_output'): + return generator.symbols_to_strip_from_output + else: + return {generator.eos} + + def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', @@ -174,9 +181,7 @@ def decode_fn(x): target_tokens, args.remove_bpe, escape_unk=True, - extra_symbols_to_ignore={ - generator.eos, - } + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) src_str = decode_fn(src_str) @@ -198,9 +203,7 @@ def decode_fn(x): align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, - extra_symbols_to_ignore={ - generator.eos, - } + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 24e9630d44..2258f18326 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -20,6 +20,7 @@ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders +from .generate import get_symbols_to_strip_from_output logging.basicConfig( @@ -186,6 +187,7 @@ def decode_fn(x): align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) score = hypo['score'] / math.log(2) # convert to base 2 From 108bb2560b1ec01524ba723bc7c69186875afa0a Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 27 Jul 2020 18:09:22 -0700 Subject: [PATCH 073/707] Multilingual v1: [improvement] Reset size cache per epoch in sampled_multi_epoch_dataset.py to save memory Summary: Reset size cache per epoch in sampled_multi_epoch_dataset.py to save memory Reviewed By: akinh Differential Revision: D22754902 fbshipit-source-id: 1001cf37f0f47a90ffd10295b48c3e5a77283bc8 --- fairseq/data/multilingual/sampled_multi_epoch_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index 65660407bd..ff7f7fa18b 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -237,3 +237,4 @@ def _next_virtual_epoch(self, epoch): self._epoch_sizes = None self._epoch_ordered_indices = None self._current_epoch_start_index = index + self._size_cache = {} From df5bf427348f8a9ccd8dbfbd70362648b82b6ae5 Mon Sep 17 00:00:00 2001 From: Aditya Pillai Date: Tue, 28 Jul 2020 14:53:45 -0700 Subject: [PATCH 074/707] Create safe num_updates access on retrying failed training Summary: The num_updates reference is unsafe since it is only defined in the for-loop without any assertion that the loop will execute at least once. We could either try to add this assertion to make it more clear as to what is happening, or we could set num_updates to some default value if the loop is never executed. I'm open to doing either one. Reviewed By: chtran Differential Revision: D22598029 fbshipit-source-id: 4d20b6a415a783fe8c450f3998b7767861da8c06 --- fairseq_cli/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 4184210ae7..750a54bfee 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -209,6 +209,7 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False + num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i From 578ee3d456034f12af13e446bba9447844374356 Mon Sep 17 00:00:00 2001 From: Aditya Pillai Date: Tue, 28 Jul 2020 14:53:45 -0700 Subject: [PATCH 075/707] Use cls.load_dictionary for multilingual data manager Summary: In order to dynamically change the dictionary loaded, we extend a class and override the `load_dictionary` class function. To preserve this behavior with `translation_multi_simple_epoch`, we can pass in the `load_dictionary` function from the task to the multilingual data manager. Once the dictionary is loaded and instantiated, there are only non-static class calls, which go through the custom dictionary object created with `load_dictionary`. Reviewed By: chtran Differential Revision: D22598008 fbshipit-source-id: 23d7a510fb695df81ebfe4f991e5b5e3db13a1bd --- fairseq/data/multilingual/multilingual_data_manager.py | 4 ++-- fairseq/tasks/translation_multi_simple_epoch.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 14695a2879..de6b16cc98 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -183,7 +183,7 @@ def _shared_collater(self): ) @classmethod - def prepare(cls, args, **kargs): + def prepare(cls, load_dictionary, args, **kargs): args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) @@ -235,7 +235,7 @@ def check_langs(langs, pairs): for lang in langs_to_load_dicts: paths = args.data.split(os.pathsep) assert len(paths) > 0 - dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) + dicts[lang] = load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) if len(dicts) > 0: assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index d19e4f6222..a7c3067d0b 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -93,7 +93,9 @@ def __init__(self, args, langs, dicts, training): @classmethod def setup_task(cls, args, **kwargs): - langs, dicts, training = MultilingualDatasetManager.prepare(args, **kwargs) + langs, dicts, training = MultilingualDatasetManager.prepare( + cls.load_dictionary, args, **kwargs + ) return cls(args, langs, dicts, training) def has_sharded_data(self, split): From ef8d62a8cd005f36cf26ac2d10baa12b2f8b6754 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 31 Jul 2020 22:40:05 -0700 Subject: [PATCH 076/707] Enable fairseq fblearner sweep backend to use manifold for checkpoint savings Summary: Deffault fairseq fblearner sweep backend uses gluser to save checkoints which is slow for large models. This diff enables fairseq fblearner sweep to use manifold for checkpoint savings. Reviewed By: akinh Differential Revision: D22770343 fbshipit-source-id: ec5174e490c35f2c7d11ccf69f3b6e9adcd8ac7b --- fairseq/file_io.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index b15030c3c8..b57373f8b5 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -95,6 +95,11 @@ def rm(path: str) -> None: return FVCorePathManager.rm(path) os.remove(path) + @staticmethod + def chmod(path: str, mode: int) -> None: + if "manifold" not in path: + os.chmod(path, mode) + @staticmethod def register_handler(handler) -> None: if FVCorePathManager: From 8b9eaacf6b2d502cd7886dd7bf702a46ab37f058 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 31 Jul 2020 22:40:05 -0700 Subject: [PATCH 077/707] Fixed multilingual data manager in handling manifold sharding data Summary: multilingual data manager did not handle manifold sharding data; this is a fix. Reviewed By: akinh Differential Revision: D22823795 fbshipit-source-id: 2f3293ff9a5c1db22ef20e29168d35c1518fef1b --- .../data/multilingual/multilingual_data_manager.py | 13 ++++++------- fairseq/tasks/translation_multi_simple_epoch.py | 1 + fairseq/utils.py | 5 ++++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index de6b16cc98..098fc6fcba 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -11,7 +11,7 @@ from collections import OrderedDict import json -from fairseq import options +from fairseq import options, utils from fairseq.options import eval_str_dict, csv_str_list from fairseq.data import ( @@ -174,7 +174,7 @@ def load_langs(cls, args, **kwargs): return langs def has_sharded_data(self, split): - return split == 'train' and self._has_sharded_data + return self._has_sharded_data and split == getattr(self.args, "train_subset", None) def _shared_collater(self): return ( @@ -426,7 +426,7 @@ def load_langpair_dataset( or src_dataset == 'NotInCache' or tgt_dataset == 'NotInCache' or align_dataset == 'NotInCache' - or split != 'train' + or split != getattr(self.args, "train_subset", None) ): # source and target datasets can be reused in reversed directions to save memory # reversed directions of valid and test data will not share source and target datasets @@ -597,7 +597,7 @@ def get_data_paths_and_lang_pairs(self, split): lang_pairs = { 'main': self.lang_pairs } - if split == 'train': + if split == getattr(self.args, "train_subset", None): # only training data can have extra data and extra language pairs if self.args.extra_data: extra_datapaths = self.args.extra_data @@ -617,8 +617,7 @@ def get_epoch(epoch, shard_epoch): for data_category, paths in data_paths.items(): if data_category not in lang_pairs: continue - # paths = self.args.data.split(os.pathsep) - paths = paths.split(os.pathsep) + paths = utils.split_paths(paths) assert len(paths) > 0 if len(paths) > 1: self._has_sharded_data = True @@ -746,7 +745,7 @@ def load_sampled_multi_epoch_dataset( split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs ) - if training and split == 'train': + if training and split == getattr(self.args, "train_subset", None): return self.load_into_sampled_multi_epoch_dataset( split, datasets, data_param_list, epoch, shard_epoch=shard_epoch) else: diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index a7c3067d0b..4a1a757ec0 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -117,6 +117,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): return else: shard_epoch = None + logger.info(f'loading data for {split} epoch={epoch}/{shard_epoch}') self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( split, self.training, diff --git a/fairseq/utils.py b/fairseq/utils.py index d4ed89e357..f68860330c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -33,8 +33,11 @@ logger = logging.getLogger(__name__) +MANIFOLD_PATH_SEP = "|" + + def split_paths(paths: str) -> List[str]: - return paths.split(os.pathsep) if "://" not in paths else paths.split("|") + return paths.split(os.pathsep) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): From 488254c88030d4e1fbfb85dbb8a90c97256bf491 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 3 Aug 2020 15:29:02 -0700 Subject: [PATCH 078/707] Add total data size logging to sampled_multi_dataset Summary: Add total data size logging to help debugging Reviewed By: akinh Differential Revision: D22889901 fbshipit-source-id: 37bd0e5cae29398be44874721f84b9670578695f --- fairseq/data/multilingual/sampled_multi_dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index bf6051abc5..d6c104031e 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -370,8 +370,12 @@ def _establish_virtual_datasets(self): self.cumulated_sizes = plasma_utils.PlasmaArray(cumulated_sizes) self.virtual_size_per_dataset = plasma_utils.PlasmaArray(virtual_size_per_dataset) - logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, [len(d) for d in self.datasets])))}') - logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, self.virtual_size_per_dataset.array)))}') + raw_sizes = [len(d) for d in self.datasets] + sampled_sizes = self.virtual_size_per_dataset.array + logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; ' + f'raw total size: {sum(raw_sizes)}') + logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; ' + f'resampled total size: {sum(sampled_sizes)}') if self.sample_ratios is not None: logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios.array)))}') else: From e0cd7f98f3a2f441783fb5f5eaff1c8b69b14404 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 3 Aug 2020 16:45:30 -0700 Subject: [PATCH 079/707] Enable multilingual task to strip language tokens in generation outputs Summary: Enable multilingual task to strip language tokens in generation outputs by default Reviewed By: pipibjc Differential Revision: D22673703 fbshipit-source-id: df235b54e296265fe9c4fc07ea202ed5fa6713cb --- fairseq/tasks/fairseq_task.py | 8 ++++++-- .../tasks/translation_multi_simple_epoch.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index ca187b73d8..503a008b51 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -232,7 +232,10 @@ def build_criterion(self, args): return criterions.build_criterion(args, self) - def build_generator(self, models, args, seq_gen_cls=None): + def build_generator( + self, models, args, + seq_gen_cls=None, extra_gen_cls_kwargs=None + ): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer @@ -301,7 +304,7 @@ def build_generator(self, models, args, seq_gen_cls=None): seq_gen_cls = SequenceGeneratorWithAlignment else: seq_gen_cls = SequenceGenerator - + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} return seq_gen_cls( models, self.target_dictionary, @@ -316,6 +319,7 @@ def build_generator(self, models, args, seq_gen_cls=None): match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, + **extra_gen_cls_kwargs, ) def train_step( diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 4a1a757ec0..eba32f1759 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -63,6 +63,8 @@ def add_args(parser): help='inference target language') parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') + parser.add_argument('--keep-inference-langtok', action='store_true', + help='keep language tokens in inference output (e.g. for analysis or debugging)') SamplingMethod.add_arguments(parser) MultilingualDatasetManager.add_args(parser) @@ -147,6 +149,23 @@ def build_dataset_for_inference(self, src_tokens, src_lengths): ) return dataset + def build_generator( + self, models, args, + seq_gen_cls=None, extra_gen_cls_kwargs=None, + ): + if not getattr(args, 'keep_inference_langtok', False): + _, tgt_langtok_spec = self.args.langtoks['main'] + if tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + extra_gen_cls_kwargs['symbols_to_strip_from_output'] = {tgt_lang_tok} + + return super().build_generator( + models, args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + def build_model(self, args): return super().build_model(args) From 8449c5f4e85d7658e533ffae3dac716d04cb2f0e Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 3 Aug 2020 16:45:30 -0700 Subject: [PATCH 080/707] Add fairseq evaluate flows for multilingual models Summary: Mutlingual needs to evaluate multiple directions together with extra arguments. (1) evaluate_multi_flow to evaluate bleu without tokenization (assuming models are no raw texts and depending on sacreblue) Reviewed By: pipibjc Differential Revision: D22748035 fbshipit-source-id: aa3d21af104965398ec945055c027f2826d5cab4 --- fairseq_cli/generate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 9f67c0271d..fda3cdc9c0 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -266,7 +266,10 @@ def decode_fn(x): logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") - logger.info('Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) + # use print to be consistent with other main outputs: S-, H-, T-, D- and so on + print( + 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), + file=output_file) return scorer From bc3ea11d4da39d618d8c6abea4b8445d4932c1fd Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 4 Aug 2020 08:13:54 -0700 Subject: [PATCH 081/707] Update language_pair_dataset.py (#2367) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? add sort_order for `prev_output_tokens` Pull Request resolved: https://github.com/pytorch/fairseq/pull/2367 Reviewed By: pipibjc Differential Revision: D22727492 Pulled By: myleott fbshipit-source-id: 932f073b3b938682d2189e6c072a26bee7169a98 --- fairseq/data/language_pair_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index df48a8052b..b8e71be2ea 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -92,7 +92,6 @@ def compute_alignment_weights(alignments): move_eos_to_beginning=True, pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, ) - prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: ntokens = src_lengths.sum().item() @@ -107,7 +106,7 @@ def compute_alignment_weights(alignments): 'target': target, } if prev_output_tokens is not None: - batch['net_input']['prev_output_tokens'] = prev_output_tokens + batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order) if samples[0].get('alignment', None) is not None: bsz, tgt_sz = batch['target'].shape From b689b6ff3ab7b806217b8aa41821bb8fc85f7cd8 Mon Sep 17 00:00:00 2001 From: Valentin Malykh Date: Tue, 4 Aug 2020 08:18:53 -0700 Subject: [PATCH 082/707] FIX: bos is always at 0th element (#2369) Summary: small bug fix in dataset creation Pull Request resolved: https://github.com/pytorch/fairseq/pull/2369 Reviewed By: pipibjc Differential Revision: D22727599 Pulled By: myleott fbshipit-source-id: bb7f18d85b72a19667e2a6844bbe172a3397bafb --- fairseq/data/language_pair_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index b8e71be2ea..5c9e09edcf 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -264,7 +264,7 @@ def __getitem__(self, index): tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) bos = self.src_dict.bos() - if self.src[index][-1] != bos: + if self.src[index][0] != bos: src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) if self.remove_eos_from_source: From bddb25ab712e54f0c1eb2d82947e6a1dedd35f42 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 4 Aug 2020 08:20:51 -0700 Subject: [PATCH 083/707] Update README.md (#2381) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes link ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2381 Reviewed By: pipibjc Differential Revision: D22900368 Pulled By: myleott fbshipit-source-id: c442246662ddfddef9b8fbe616bf480b5c41c21a --- examples/nonautoregressive_translation/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md index bb0f4b8b06..dfc592f0a0 100644 --- a/examples/nonautoregressive_translation/README.md +++ b/examples/nonautoregressive_translation/README.md @@ -13,7 +13,7 @@ We also provided our own implementations for several popular non-autoregressive- ## Dataset -First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#prepare-wmt14en2desh). +First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#wmt14-english-to-german-convolutional). Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`. ### Knowledge Distillation @@ -143,4 +143,4 @@ Note that we need to make sure the autoregressive model shares the same vocabula journal={arXiv preprint arXiv:1911.02727}, year={2019} } -``` \ No newline at end of file +``` From 33cefe372812f42eb6b1fb5dcc07f3f7f810c5ea Mon Sep 17 00:00:00 2001 From: Oren Amsalem Date: Tue, 4 Aug 2020 08:22:04 -0700 Subject: [PATCH 084/707] Update README.md - spelling (#2360) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2360 Reviewed By: pipibjc Differential Revision: D22727472 Pulled By: myleott fbshipit-source-id: 8f4276edae48bfe6bbf103255bc899c93912312e --- examples/bart/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/bart/README.md b/examples/bart/README.md index 027e2f1ef1..394503f29f 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -4,7 +4,7 @@ ## Introduction -BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details. +BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details. ## Pre-trained models From b040dae714ff33172731f876e2120848acf10c64 Mon Sep 17 00:00:00 2001 From: Rakesh Chada Date: Tue, 4 Aug 2020 08:24:25 -0700 Subject: [PATCH 085/707] Fixes checkpoint_path while loading a model-parallel checkpoint (#2365) Summary: Fixes https://github.com/pytorch/fairseq/issues/2351 Pull Request resolved: https://github.com/pytorch/fairseq/pull/2365 Reviewed By: pipibjc Differential Revision: D22727384 Pulled By: myleott fbshipit-source-id: e2ff703181a6b8f10df9b4ee7aa3f9e128c04b4e --- fairseq/checkpoint_utils.py | 2 ++ tests/test_train.py | 1 + 2 files changed, 3 insertions(+) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index fe25b0a9dd..2d37e3fc31 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -126,6 +126,8 @@ def load_checkpoint(args, trainer, **passthrough_args): suffix = getattr(args, "checkpoint_suffix", "") if args.restore_file == "checkpoint_last.pt": checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) + elif getattr(args, "model_parallel_size", 1) > 1: + checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = args.restore_file diff --git a/tests/test_train.py b/tests/test_train.py index 734d0e8601..fb935461c8 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -57,6 +57,7 @@ def setUp(self): self.args_mock.reset_dataloader = False self.args_mock.reset_meters = False self.args_mock.reset_optimizer = False + self.args_mock.model_parallel_size = 1 self.patches = { 'os.makedirs': MagicMock(), 'os.path.join': MagicMock(), From c0aefe8fdd87aaaa5c044cbe58f19243b8e48d95 Mon Sep 17 00:00:00 2001 From: Oren Amsalem Date: Tue, 4 Aug 2020 08:25:02 -0700 Subject: [PATCH 086/707] pep 8 - use "x not in y" rather than "not x in y" (#2388) Summary: https://stackoverflow.com/questions/8738388/x-not-in-y-or-not-x-in-y https://www.python.org/dev/peps/pep-0008/#programming-recommendations E713: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes Pull Request resolved: https://github.com/pytorch/fairseq/pull/2388 Reviewed By: pipibjc Differential Revision: D22900393 Pulled By: myleott fbshipit-source-id: 0d89645922a9689e84b9c1a92e2616b08b63d6c8 --- fairseq/data/denoising_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index ee3a03940d..9d7c318702 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -128,11 +128,11 @@ def __init__( self.full_stop_index = self.vocab.index('13') self.replace_length = args.replace_length - if not self.replace_length in [-1, 0, 1]: + if self.replace_length not in [-1, 0, 1]: raise ValueError(f'invalid arg: replace_length={self.replace_length}') - if not args.mask_length in ['subword', 'word', 'span-poisson']: + if args.mask_length not in ['subword', 'word', 'span-poisson']: raise ValueError(f'invalid arg: mask-length={args.mask_length}') - if args.mask_length == 'subword' and not args.replace_length in [0, 1]: + if args.mask_length == 'subword' and args.replace_length not in [0, 1]: raise ValueError(f'if using subwords, use replace-length=1 or 0') self.mask_span_distribution = None From e3a5eafe97276fc48ecf6311dc8dcdf98390a774 Mon Sep 17 00:00:00 2001 From: "michal.pietruszka" Date: Tue, 4 Aug 2020 08:26:48 -0700 Subject: [PATCH 087/707] =?UTF-8?q?Fix=20`mode`=20in=20ReduceLROnPlateau?= =?UTF-8?q?=20scheduler=20to=20follow=20`args.maximize=5Fbe=E2=80=A6=20(#2?= =?UTF-8?q?354)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …st_checkpoint_metric` # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes #(2205). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Lots of fun! Pull Request resolved: https://github.com/pytorch/fairseq/pull/2354 Reviewed By: pipibjc Differential Revision: D22636665 Pulled By: myleott fbshipit-source-id: c5159db124099ee980bdba46ec23008abad20751 --- fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 431e784de6..8128cf0eb8 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -35,6 +35,7 @@ def __init__(self, args, optimizer): ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, + mode='max' if args.maximize_best_checkpoint_metric else 'min', threshold=args.lr_threshold) warmup_end_lr = args.lr[0] # if no warm up, sets initial lr to be args.lr[0] From 2ae88d01acd2a9116667db9a00fb3bdf0962d100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=A4=86=E0=A4=B2=E0=A5=8B=E0=A4=95?= Date: Tue, 4 Aug 2020 08:29:45 -0700 Subject: [PATCH 088/707] typo: Sanskrit* (#2394) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2394 Reviewed By: pipibjc Differential Revision: D22900433 Pulled By: myleott fbshipit-source-id: 82b85003ff8df7ad5033da8ae873234ecdd9ef87 --- examples/xlmr/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/xlmr/README.md b/examples/xlmr/README.md index 9e0dfc1254..65d4be13de 100644 --- a/examples/xlmr/README.md +++ b/examples/xlmr/README.md @@ -21,7 +21,7 @@ Kyrgyz | Lao | Latin | Latvian | Lithuanian Macedonian | Malagasy | Malay | Malayalam | Marathi Mongolian | Nepali | Norwegian | Oriya | Oromo Pashto | Persian | Polish | Portuguese | Punjabi -Romanian | Russian | Sanskri | Scottish Gaelic | Serbian +Romanian | Russian | Sanskrit | Scottish Gaelic | Serbian Sindhi | Sinhala | Slovak | Slovenian | Somali Spanish | Sundanese | Swahili | Swedish | Tamil Tamil Romanize | Telugu | Telugu Romanize | Thai | Turkish From 627ccc83700782139a02c429fce87b5c11894b18 Mon Sep 17 00:00:00 2001 From: Fady Essam Date: Tue, 4 Aug 2020 08:40:25 -0700 Subject: [PATCH 089/707] Fix fairseq-generate score printing (issue #2355) (#2356) Summary: I think .format() should be added to the return line as the latest sacrebleu.corpus_bleu() now returns an object not a string Pull Request resolved: https://github.com/pytorch/fairseq/pull/2356 Reviewed By: pipibjc Differential Revision: D22636662 Pulled By: myleott fbshipit-source-id: 3ec7f963069622b0f7b792610244dc7f6fd28800 --- fairseq/bleu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/bleu.py b/fairseq/bleu.py index 36b15286fc..7f053bb853 100644 --- a/fairseq/bleu.py +++ b/fairseq/bleu.py @@ -55,7 +55,7 @@ def score(self, order=4): def result_string(self, order=4): if order != 4: raise NotImplementedError - return self.sacrebleu.corpus_bleu(self.sys, [self.ref]) + return self.sacrebleu.corpus_bleu(self.sys, [self.ref]).format() class Scorer(object): From cf87f759b9f09769dc416761738cee72382252df Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 4 Aug 2020 10:13:44 -0700 Subject: [PATCH 090/707] cleanup mBART doc (#2391) Summary: - uses downloaded paths in mbart commands, as defaults. - corrects path to `sentencepiece.bpe.model` Pull Request resolved: https://github.com/pytorch/fairseq/pull/2391 Reviewed By: pipibjc Differential Revision: D22900413 Pulled By: myleott fbshipit-source-id: f04df08350257742dd263a48dca960114598059c --- examples/mbart/README.md | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/mbart/README.md b/examples/mbart/README.md index e68ba09c1c..f22d43dba4 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -63,7 +63,7 @@ fairseq-preprocess \ Finetune on mbart CC25 ```bash -PRETRAIN=/path/to/model/mbart.cc25 +PRETRAIN=mbart.cc25 # fix if you moved the downloaded checkpoint langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN fairseq-train path_2_data \ @@ -72,7 +72,6 @@ fairseq-train path_2_data \ --task translation_from_pretrained_bart \ --source-lang en_XX --target-lang ro_RO \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ - --dataset-impl mmap \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ @@ -87,19 +86,22 @@ fairseq-train path_2_data \ ## Generate on EN-RO Get sacrebleu on finetuned en-ro model -get tokenizer [here](https://github.com/rsennrich/wmt16-scripts) +get tokenizer [here](https://github.com/rsennrich/wmt16-scripts) +```bash wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz tar -xzvf mbart.cc25.ft.enro.tar.gz +``` ```bash -model=model.pt +model_dir=MBART_finetuned_enro # fix if you moved the checkpoint + fairseq-generate path_2_data \ - --path $model \ + --path $model_dir/model.pt \ --task translation_from_pretrained_bart \ --gen-subset test \ -t ro_RO -s en_XX \ - --bpe 'sentencepiece' --sentencepiece-model sentence.bpe.model \ - --sacrebleu --remove-bpe 'sentencepiece'\ + --bpe 'sentencepiece' --sentencepiece-model $model_dir/sentence.bpe.model \ + --sacrebleu --remove-bpe 'sentencepiece' \ --max-sentences 32 --langs $langs > en_ro cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp From 621e834103b13318cb48d41fc713b580f0da6b24 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 4 Aug 2020 14:18:03 -0700 Subject: [PATCH 091/707] wav2vec 2.0 (#1220) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1220 Test Plan: Please see examples/wav2vec/README.md for instructions Reviewed By: edunov Differential Revision: D22707565 Pulled By: alexeib fbshipit-source-id: 0c0d4ca7acc933ef7c0062f8dce550b94e414680 --- README.md | 3 + examples/noisychannel/rerank.py | 3 +- examples/noisychannel/rerank_options.py | 2 +- .../speech_recognition/criterions/CTC_loss.py | 197 ---- examples/speech_recognition/infer.py | 365 ++++-- examples/speech_recognition/w2l_decoder.py | 302 ++++- examples/wav2vec/README.md | 149 ++- examples/wav2vec/libri_labels.py | 56 + examples/wav2vec/vq-wav2vec_featurize.py | 2 +- examples/wav2vec/wav2vec_featurize.py | 2 +- fairseq/criterions/ctc.py | 246 ++++ ..._cross_entropy.py => wav2vec_criterion.py} | 51 +- fairseq/data/__init__.py | 2 + fairseq/data/add_target_dataset.py | 49 + fairseq/data/audio/raw_audio_dataset.py | 69 +- fairseq/data/data_utils.py | 145 ++- fairseq/data/dictionary.py | 2 +- fairseq/models/fairseq_encoder.py | 8 + fairseq/models/wav2vec/__init__.py | 8 + fairseq/models/{ => wav2vec}/wav2vec.py | 11 +- fairseq/models/wav2vec/wav2vec2.py | 1017 +++++++++++++++++ fairseq/models/wav2vec/wav2vec2_asr.py | 679 +++++++++++ fairseq/modules/__init__.py | 4 + fairseq/modules/gumbel_vector_quantizer.py | 34 +- fairseq/modules/same_pad.py | 18 + fairseq/modules/transpose_last.py | 20 + fairseq/options.py | 8 +- fairseq/scoring/__init__.py | 22 + fairseq/{ => scoring}/bleu.py | 62 +- fairseq/scoring/scoring_utils.py | 22 + fairseq/scoring/wer.py | 32 + fairseq/sequence_generator.py | 16 +- fairseq/tasks/audio_pretraining.py | 123 +- fairseq_cli/generate.py | 17 +- fairseq_cli/score.py | 2 +- fairseq_cli/train.py | 2 + setup.py | 1 + 37 files changed, 3292 insertions(+), 459 deletions(-) delete mode 100644 examples/speech_recognition/criterions/CTC_loss.py create mode 100644 examples/wav2vec/libri_labels.py create mode 100644 fairseq/criterions/ctc.py rename fairseq/criterions/{binary_cross_entropy.py => wav2vec_criterion.py} (79%) create mode 100644 fairseq/data/add_target_dataset.py create mode 100644 fairseq/models/wav2vec/__init__.py rename fairseq/models/{ => wav2vec}/wav2vec.py (98%) create mode 100644 fairseq/models/wav2vec/wav2vec2.py create mode 100644 fairseq/models/wav2vec/wav2vec2_asr.py create mode 100644 fairseq/modules/same_pad.py create mode 100644 fairseq/modules/transpose_last.py create mode 100644 fairseq/scoring/__init__.py rename fairseq/{ => scoring}/bleu.py (67%) create mode 100644 fairseq/scoring/scoring_utils.py create mode 100644 fairseq/scoring/wer.py diff --git a/README.md b/README.md index accea254b0..cea586d4f5 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ We provide reference implementations of various sequence modeling papers: - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](examples/wav2vec/README.md) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -49,6 +50,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) - May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) - April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) - April 2020: [Quant-Noise code released](examples/quant_noise/README.md) @@ -142,6 +144,7 @@ We also have more detailed READMEs to reproduce results from specific papers: - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](examples/wav2vec/README.md) - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index fd2aef36d6..a5927a53b3 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -8,8 +8,9 @@ import numpy as np -from fairseq import bleu, options +from fairseq import options from fairseq.data import dictionary +from fairseq.scoring import bleu from . import ( rerank_generate, diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py index 41a80d88d1..55c57051ff 100644 --- a/examples/noisychannel/rerank_options.py +++ b/examples/noisychannel/rerank_options.py @@ -64,7 +64,7 @@ def add_reranking_args(parser): help='whether the first model group is a right to left model') group.add_argument('--right-to-left2', action='store_true', help='whether the second model group is a right to left model') - group.add_argument('--remove-bpe', default='@@ ', + group.add_argument('--remove-bpe', '--post-process', default='@@ ', help='the bpe symbol, used for the bitext and LM') group.add_argument('--prefix-len', default=None, type=int, help='the length of the target prefix to use in rescoring (in terms of words wo bpe)') diff --git a/examples/speech_recognition/criterions/CTC_loss.py b/examples/speech_recognition/criterions/CTC_loss.py deleted file mode 100644 index df516f0d6e..0000000000 --- a/examples/speech_recognition/criterions/CTC_loss.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import math -from itertools import groupby - -import torch -import torch.nn.functional as F -from fairseq import utils -from fairseq.criterions import FairseqCriterion, register_criterion -from examples.speech_recognition.data.data_utils import encoder_padding_mask_to_lengths -from examples.speech_recognition.utils.wer_utils import Code, EditDistance, Token - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -def arr_to_toks(arr): - toks = [] - for a in arr: - toks.append(Token(str(a), 0.0, 0.0)) - return toks - - -def compute_ctc_uer(logprobs, targets, input_lengths, target_lengths, blank_idx): - """ - Computes utterance error rate for CTC outputs - - Args: - logprobs: (Torch.tensor) N, T1, D tensor of log probabilities out - of the encoder - targets: (Torch.tensor) N, T2 tensor of targets - input_lengths: (Torch.tensor) lengths of inputs for each sample - target_lengths: (Torch.tensor) lengths of targets for each sample - blank_idx: (integer) id of blank symbol in target dictionary - - Returns: - batch_errors: (float) errors in the batch - batch_total: (float) total number of valid samples in batch - """ - batch_errors = 0.0 - batch_total = 0.0 - for b in range(logprobs.shape[0]): - predicted = logprobs[b][: input_lengths[b]].argmax(1).tolist() - target = targets[b][: target_lengths[b]].tolist() - # dedup predictions - predicted = [p[0] for p in groupby(predicted)] - # remove blanks - nonblanks = [] - for p in predicted: - if p != blank_idx: - nonblanks.append(p) - predicted = nonblanks - - # compute the alignment based on EditDistance - alignment = EditDistance(False).align( - arr_to_toks(predicted), arr_to_toks(target) - ) - - # compute the number of errors - # note that alignment.codes can also be used for computing - # deletion, insersion and substitution error breakdowns in future - for a in alignment.codes: - if a != Code.match: - batch_errors += 1 - batch_total += len(target) - - return batch_errors, batch_total - - -@register_criterion("ctc_loss") -class CTCCriterion(FairseqCriterion): - def __init__(self, task): - assert hasattr(task, "target_dictionary") - super().__init__(task) - self.blank_idx = task.target_dictionary.index("") - - @classmethod - def build_criterion(cls, args, task): - return cls(task) - - @staticmethod - def add_args(parser): - parser.add_argument( - "--use-source-side-sample-size", - action="store_true", - default=False, - help=( - "when compute average loss, using number of source tokens " - + "as denominator. " - + "This argument will be no-op if sentence-avg is used." - ), - ) - - def forward(self, model, sample, reduce=True, log_probs=True): - """Compute the loss for the given sample. - - Returns a tuple with three elements: - 1) the loss - 2) the sample size, which is used as the denominator for the gradient - 3) logging outputs to display while training - """ - net_output = model(**sample["net_input"]) - lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) - if not hasattr(lprobs, "batch_first"): - logging.warning( - "ERROR: we need to know whether " - "batch first for the encoder output; " - "you need to set batch_first attribute for the return value of " - "model.get_normalized_probs. Now, we assume this is true, but " - "in the future, we will raise exception instead. " - ) - - batch_first = getattr(lprobs, "batch_first", True) - - if not batch_first: - max_seq_len = lprobs.size(0) - bsz = lprobs.size(1) - else: - max_seq_len = lprobs.size(1) - bsz = lprobs.size(0) - device = net_output["encoder_out"].device - - input_lengths = encoder_padding_mask_to_lengths( - net_output["encoder_padding_mask"], max_seq_len, bsz, device - ) - target_lengths = sample["target_lengths"] - targets = sample["target"] - - if batch_first: - # N T D -> T N D (F.ctc_loss expects this) - lprobs = lprobs.transpose(0, 1) - - pad_mask = sample["target"] != self.padding_idx - targets_flat = targets.masked_select(pad_mask) - - loss = F.ctc_loss( - lprobs, - targets_flat, - input_lengths, - target_lengths, - blank=self.blank_idx, - reduction="sum", - zero_infinity=True, - ) - - lprobs = lprobs.transpose(0, 1) # T N D -> N T D - errors, total = compute_ctc_uer( - lprobs, targets, input_lengths, target_lengths, self.blank_idx - ) - - if self.args.sentence_avg: - sample_size = sample["target"].size(0) - else: - if self.args.use_source_side_sample_size: - sample_size = torch.sum(input_lengths).item() - else: - sample_size = sample["ntokens"] - - logging_output = { - "loss": utils.item(loss.data) if reduce else loss.data, - "ntokens": sample["ntokens"], - "nsentences": sample["target"].size(0), - "sample_size": sample_size, - "errors": errors, - "total": total, - "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), - } - return loss, sample_size, logging_output - - @staticmethod - def aggregate_logging_outputs(logging_outputs): - """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get("loss", 0) for log in logging_outputs) - ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) - nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) - sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - errors = sum(log.get("errors", 0) for log in logging_outputs) - total = sum(log.get("total", 0) for log in logging_outputs) - nframes = sum(log.get("nframes", 0) for log in logging_outputs) - agg_output = { - "loss": loss_sum / sample_size / math.log(2), - "ntokens": ntokens, - "nsentences": nsentences, - "nframes": nframes, - "sample_size": sample_size, - "acc": 100.0 - min(errors * 100.0 / total, 100.0), - } - if sample_size != ntokens: - agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) - return agg_output diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index ffa0f1e753..d22acc9c3b 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -8,19 +8,23 @@ Run inference for pre-processed data with a trained model. """ +import editdistance import logging import math import os +import sys -import sentencepiece as spm +import numpy as np import torch -from fairseq import checkpoint_utils, options, utils, tasks -from fairseq.logging import meters, progress_bar -from fairseq.utils import import_user_module +from fairseq import checkpoint_utils, options, progress_bar, utils, tasks +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.data.data_utils import post_process +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) def add_asr_eval_argument(parser): @@ -45,34 +49,54 @@ def add_asr_eval_argument(parser): "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" ) parser.add_argument( - "--w2l-decoder", choices=["viterbi", "kenlm"], help="use a w2l decoder" + "--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder" ) parser.add_argument("--lexicon", help="lexicon for w2l decoder") - parser.add_argument("--kenlm-model", help="kenlm model for w2l decoder") + parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm") + parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") parser.add_argument("--beam-threshold", type=float, default=25.0) + parser.add_argument("--beam-size-token", type=float, default=100) parser.add_argument("--word-score", type=float, default=1.0) parser.add_argument("--unk-weight", type=float, default=-math.inf) parser.add_argument("--sil-weight", type=float, default=0.0) + parser.add_argument( + "--dump-emissions", + type=str, + default=None, + help="if present, dumps emissions into this file and exits", + ) + parser.add_argument( + "--dump-features", + type=str, + default=None, + help="if present, dumps features into this file and exits", + ) + parser.add_argument( + "--load-emissions", + type=str, + default=None, + help="if present, loads emissions from this file", + ) return parser def check_args(args): - assert args.path is not None, "--path required for generation!" - assert args.results_path is not None, "--results_path required for generation!" + # assert args.path is not None, "--path required for generation!" + # assert args.results_path is not None, "--results_path required for generation!" assert ( - not args.sampling or args.nbest == args.beam + not args.sampling or args.nbest == args.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - args.replace_unk is None or args.dataset_impl == "raw" - ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" -def get_dataset_itr(args, task): +def get_dataset_itr(args, task, models): return task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, - max_positions=(1000000.0, 1000000.0), + max_positions=(sys.maxsize, sys.maxsize), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, @@ -82,29 +106,43 @@ def get_dataset_itr(args, task): def process_predictions( - args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id + args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id ): for hypo in hypos[: min(len(hypos), args.nbest)]: hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) - hyp_words = sp.DecodePieces(hyp_pieces.split()) - print( - "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"] - ) - print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) + + if "words" in hypo: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, args.remove_bpe) + + if res_files is not None: + print( + "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"] + ) + print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) tgt_pieces = tgt_dict.string(target_tokens) - tgt_words = sp.DecodePieces(tgt_pieces.split()) - print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) - print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) - # only score top hypothesis - if not args.quiet: - logger.debug("HYPO:" + hyp_words) - logger.debug("TARGET:" + tgt_words) - logger.debug("___________________") + tgt_words = post_process(tgt_pieces, args.remove_bpe) + + if res_files is not None: + print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) + print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) + # only score top hypothesis + if not args.quiet: + logger.debug("HYPO:" + hyp_words) + logger.debug("TARGET:" + tgt_words) + logger.debug("___________________") + + hyp_words = hyp_words.split() + tgt_words = tgt_words.split() + return editdistance.eval(hyp_words, tgt_words), len(tgt_words) def prepare_result_files(args): def get_res_file(file_prefix): + if args.num_shards > 1: + file_prefix = f'{args.shard_id}_{file_prefix}' path = os.path.join( args.results_path, "{}-{}-{}.txt".format( @@ -113,6 +151,9 @@ def get_res_file(file_prefix): ) return open(path, "w", buffering=1) + if not args.results_path: + return None + return { "hypo.words": get_res_file("hypo.word"), "hypo.units": get_res_file("hypo.units"), @@ -121,19 +162,33 @@ def get_res_file(file_prefix): } -def load_models_and_criterions(filenames, arg_overrides=None, task=None): +def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): models = [] criterions = [] + + if arg_overrides is None: + arg_overrides = {} + + arg_overrides['wer_args'] = None + arg_overrides['data'] = data_path + + if filenames is None: + assert model_state is not None + filenames = [0] + else: + filenames = filenames.split(":") + for filename in filenames: - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) + if model_state is None: + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) + else: + state = model_state args = state["args"] if task is None: task = tasks.setup_task(args) - - # build model for ensemble model = task.build_model(args) model.load_state_dict(state["model"], strict=True) models.append(model) @@ -159,24 +214,40 @@ def optimize_models(args, use_cuda, models): model.cuda() -def main(args): +class ExistingEmissionsDecoder(object): + def __init__(self, decoder, emissions): + self.decoder = decoder + self.emissions = emissions + + def generate(self, models, sample, prefix_tokens=None): + ids = sample["id"].cpu().numpy() + try: + emissions = np.stack(self.emissions[ids]) + except: + print([x.shape for x in self.emissions[ids]]) + raise Exception('invalid sizes') + emissions = torch.from_numpy(emissions) + return self.decoder.decode(emissions) + + +def main(args, task=None, model_state=None): check_args(args) - import_user_module(args) if args.max_tokens is None and args.max_sentences is None: - args.max_tokens = 30000 + args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu - # Load dataset splits - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) - logger.info( - "| {} {} {} examples".format( - args.data, args.gen_subset, len(task.dataset(args.gen_subset)) + if task is None: + # Load dataset splits + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + logger.info( + "| {} {} {} examples".format( + args.data, args.gen_subset, len(task.dataset(args.gen_subset)) + ) ) - ) # Set dictionary tgt_dict = task.target_dictionary @@ -184,13 +255,19 @@ def main(args): logger.info("| decoding with criterion {}".format(args.criterion)) # Load ensemble - logger.info("| loading model(s) from {}".format(args.path)) - models, criterions, _model_args = load_models_and_criterions( - args.path.split(os.pathsep), - arg_overrides=eval(args.model_overrides), # noqa - task=task, - ) - optimize_models(args, use_cuda, models) + + if args.load_emissions: + models, criterions = [], [] + else: + logger.info("| loading model(s) from {}".format(args.path)) + models, criterions, _ = load_models_and_criterions( + args.path, + data_path=args.data, + arg_overrides=eval(args.model_overrides), # noqa + task=task, + model_state=model_state, + ) + optimize_models(args, use_cuda, models) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": @@ -198,73 +275,151 @@ def main(args): args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) - itr = get_dataset_itr(args, task) - progress = progress_bar.progress_bar( - itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=('tqdm' if not args.no_progress_bar else 'none'), - ) + itr = get_dataset_itr(args, task, models) # Initialize generator - gen_timer = meters.StopwatchMeter() - generator = task.build_generator(models, args) + gen_timer = StopwatchMeter() + + def build_generator(args): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, task.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, task.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, task.target_dictionary) + else: + return super().build_generator(args) + + generator = build_generator(args) + + if args.load_emissions: + generator = ExistingEmissionsDecoder( + generator, np.load(args.load_emissions, allow_pickle=True) + ) + logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 - if not os.path.exists(args.results_path): + if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) - sp = spm.SentencePieceProcessor() - sp.Load(os.path.join(args.data, "spm.model")) - - res_files = prepare_result_files(args) - wps_meter = meters.TimeMeter() - for sample in progress: - sample = utils.move_to_cuda(sample) if use_cuda else sample - if "net_input" not in sample: - continue - - prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, : args.prefix_size] - - gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) - num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) - gen_timer.stop(num_generated_tokens) - - for i, sample_id in enumerate(sample["id"].tolist()): - speaker = task.dataset(args.gen_subset).speakers[int(sample_id)] - id = task.dataset(args.gen_subset).ids[int(sample_id)] - target_tokens = ( - utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() - ) - # Process top predictions - process_predictions( - args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id - ) + max_source_pos = ( + utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ), + ) - wps_meter.update(num_generated_tokens) - progress.log({"wps": round(wps_meter.avg)}) - num_sentences += sample["nsentences"] - - logger.info( - "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" - "sentences/s, {:.2f} tokens/s)".format( - num_sentences, - gen_timer.n, - gen_timer.sum, - num_sentences / gen_timer.sum, - 1.0 / gen_timer.avg, + if max_source_pos is not None: + max_source_pos = max_source_pos[0] + if max_source_pos is not None: + max_source_pos = max_source_pos[0] - 1 + + if args.dump_emissions: + emissions = {} + if args.dump_features: + features = {} + models[0].bert.proj = None + else: + res_files = prepare_result_files(args) + errs_t = 0 + lengths_t = 0 + with progress_bar.build_progress_bar(args, itr) as t: + wps_meter = TimeMeter() + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample["target"][:, : args.prefix_size] + + gen_timer.start() + if args.dump_emissions: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + emm = models[0].get_normalized_probs(encoder_out, log_probs=True) + emm = emm.transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + emissions[id.item()] = emm[i] + continue + elif args.dump_features: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None + features[id.item()] = (feat[i], padding) + continue + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + speaker = None + # id = task.dataset(args.gen_subset).ids[int(sample_id)] + id = sample_id + toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :] + target_tokens = ( + utils.strip_pad(toks, tgt_dict.pad()).int().cpu() + ) + # Process top predictions + errs, length = process_predictions( + args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id + ) + errs_t += errs + lengths_t += length + + wps_meter.update(num_generated_tokens) + t.log({"wps": round(wps_meter.avg)}) + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + + wer = None + if args.dump_emissions: + emm_arr = [] + for i in range(len(emissions)): + emm_arr.append(emissions[i]) + np.save(args.dump_emissions, emm_arr) + logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") + elif args.dump_features: + feat_arr = [] + for i in range(len(features)): + feat_arr.append(features[i]) + np.save(args.dump_features, feat_arr) + logger.info(f"saved {len(features)} emissions to {args.dump_features}") + else: + if lengths_t > 0: + wer = errs_t * 100.0 / lengths_t + logger.info(f"WER: {wer}") + + logger.info( + "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" + "sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) ) - ) - logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) + logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) + return task, wer -def cli_main(): +def make_parser(): parser = options.get_generation_parser() parser = add_asr_eval_argument(parser) + return parser + +def cli_main(): + parser = make_parser() args = options.parse_args_and_arch(parser) main(args) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index c56448ce3d..149bec0c49 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -8,10 +8,16 @@ """ Wav2letter decoders. """ -import math + +from collections import namedtuple, deque +import gc import itertools as it +import numpy as np import torch -from fairseq import utils +import os.path as osp +import warnings +from fairseq import tasks +from fairseq.utils import apply_to_sample from examples.speech_recognition.data.replabels import unpack_replabels try: @@ -21,14 +27,19 @@ CriterionType, DecoderOptions, KenLM, + LM, + LMState, SmearingMode, Trie, - WordLMDecoder, + LexiconDecoder, + LexiconFreeDecoder, + ) +except: + warnings.warn( + "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings" ) -except ImportError: - # wav2letter is a required dependency for the speech_recognition - # example, but don't break on import - pass + LM = object + LMState = object class W2lDecoder(object): @@ -38,9 +49,13 @@ def __init__(self, args, tgt_dict): self.nbest = args.nbest # criterion-specific init - if args.criterion == "ctc_loss": + if args.criterion == "ctc": self.criterion_type = CriterionType.CTC - self.blank = tgt_dict.index("") + self.blank = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) self.asg_transitions = None elif args.criterion == "asg_loss": self.criterion_type = CriterionType.ASG @@ -63,7 +78,8 @@ def generate(self, models, sample, prefix_tokens=None): def get_emissions(self, models, encoder_input): """Run encoder and normalize emissions""" - encoder_out = models[0].encoder(**encoder_input) + # encoder_out = models[0].encoder(**encoder_input) + encoder_out = models[0](**encoder_input) if self.criterion_type == CriterionType.CTC: emissions = models[0].get_normalized_probs(encoder_out, log_probs=True) elif self.criterion_type == CriterionType.ASG: @@ -73,10 +89,10 @@ def get_emissions(self, models, encoder_input): def get_tokens(self, idxs): """Normalize tokens by handling CTC blank, ASG replabels, etc.""" idxs = (g[0] for g in it.groupby(idxs)) - idxs = filter(lambda x: x >= 0, idxs) if self.criterion_type == CriterionType.CTC: idxs = filter(lambda x: x != self.blank, idxs) elif self.criterion_type == CriterionType.ASG: + idxs = filter(lambda x: x >= 0, idxs) idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel) return torch.LongTensor(list(idxs)) @@ -113,8 +129,11 @@ class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) - self.silence = tgt_dict.index(args.silence_token) - + self.silence = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("") @@ -123,26 +142,36 @@ def __init__(self, args, tgt_dict): self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) - for word, spellings in self.lexicon.items(): + for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, + int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, - False, args.sil_weight, + 0, + False, self.criterion_type, ) - self.decoder = WordLMDecoder( + if self.asg_transitions is None: + N = 768 + # self.asg_transitions = torch.FloatTensor(N, N).zero_() + self.asg_transitions = [] + + self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, @@ -150,6 +179,7 @@ def __init__(self, args, tgt_dict): self.blank, self.unk_word, self.asg_transitions, + False, ) def decode(self, emissions): @@ -157,11 +187,247 @@ def decode(self, emissions): hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) - nbest_results = self.decoder.decode(emissions_ptr, T, N)[: self.nbest] + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] hypos.append( [ - {"tokens": self.get_tokens(result.tokens), "score": result.score} + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "words": [ + self.word_dict.get_entry(x) for x in result.words if x >= 0 + ], + } for result in nbest_results ] ) return hypos + + +FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"]) + + +class FairseqLM(LM): + def __init__(self, dictionary, model): + LM.__init__(self) + self.dictionary = dictionary + self.model = model + self.unk = self.dictionary.unk() + + self.save_incremental = False # this currently does not work properly + self.max_cache = 20_000 + + model.cuda() + model.eval() + model.make_generation_fast_() + + self.states = {} + self.stateq = deque() + + def start(self, start_with_nothing): + state = LMState() + prefix = torch.LongTensor([[self.dictionary.eos()]]) + incremental_state = {} if self.save_incremental else None + with torch.no_grad(): + res = self.model(prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) + + if incremental_state is not None: + incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) + self.states[state] = FairseqLMState( + prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() + ) + self.stateq.append(state) + + return state + + def score(self, state: LMState, token_index: int, no_cache: bool = False): + """ + Evaluate language model based on the current lm state and new word + Parameters: + ----------- + state: current lm state + token_index: index of the word + (can be lexicon index then you should store inside LM the + mapping between indices of lexicon and lm, or lm index of a word) + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + curr_state = self.states[state] + + def trim_cache(targ_size): + while len(self.stateq) > targ_size: + rem_k = self.stateq.popleft() + rem_st = self.states[rem_k] + rem_st = FairseqLMState(rem_st.prefix, None, None) + self.states[rem_k] = rem_st + + if curr_state.probs is None: + new_incremental_state = ( + curr_state.incremental_state.copy() + if curr_state.incremental_state is not None + else None + ) + with torch.no_grad(): + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cuda(), new_incremental_state + ) + elif self.save_incremental: + new_incremental_state = {} + + res = self.model( + torch.from_numpy(curr_state.prefix).cuda(), + incremental_state=new_incremental_state, + ) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None + ) + + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cpu(), new_incremental_state + ) + + curr_state = FairseqLMState( + curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() + ) + + if not no_cache: + self.states[state] = curr_state + self.stateq.append(state) + + score = curr_state.probs[token_index].item() + + trim_cache(self.max_cache) + + outstate = state.child(token_index) + if outstate not in self.states and not no_cache: + prefix = np.concatenate( + [curr_state.prefix, torch.LongTensor([[token_index]])], -1 + ) + incr_state = curr_state.incremental_state + + self.states[outstate] = FairseqLMState(prefix, incr_state, None) + + if token_index == self.unk: + score = float("-inf") + + return outstate, score + + def finish(self, state: LMState): + """ + Evaluate eos for language model based on the current lm state + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + return self.score(state, self.dictionary.eos()) + + def empty_cache(self): + self.states = {} + self.stateq = deque() + gc.collect() + + +class W2lFairseqLMDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + self.silence = tgt_dict.bos() + + self.unit_lm = getattr(args, "unit_lm", False) + + self.lexicon = load_words(args.lexicon) if args.lexicon else None + self.idx_to_wrd = {} + + checkpoint = torch.load(args.kenlm_model, map_location="cpu") + lm_args = checkpoint["args"] + lm_args.data = osp.dirname(args.kenlm_model) + print(lm_args) + task = tasks.setup_task(lm_args) + model = task.build_model(lm_args) + model.load_state_dict(checkpoint["model"], strict=False) + + self.trie = Trie(self.vocab_size, self.silence) + + self.word_dict = task.dictionary + self.unk_word = self.word_dict.unk() + self.lm = FairseqLM(self.word_dict, model) + + self.decoder_opts = DecoderOptions( + args.beam, + int(getattr(args, "beam_size_token", len(tgt_dict))), + args.beam_threshold, + args.lm_weight, + args.word_score, + args.unk_weight, + args.sil_weight, + 0, + False, + self.criterion_type, + ) + + if self.lexicon: + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + if self.unit_lm: + word_idx = i + self.idx_to_wrd[i] = word + score = 0 + else: + word_idx = self.word_dict.index(word) + _, score = self.lm.score(start_state, word_idx, no_cache=True) + + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + self.unit_lm, + ) + else: + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + + def idx_to_word(idx): + if self.unit_lm: + return self.idx_to_wrd[idx] + else: + return self.word_dict[idx] + + def make_hypo(result): + hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} + if self.lexicon: + hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] + return hypo + + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append([make_hypo(result) for result in nbest_results]) + self.lm.empty_cache() + + return hypos diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 1ea3b4fc2f..d90dedf22e 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -1,12 +1,141 @@ +# wav2vec 2.0 + +wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/2006.11477). + +## Pre-trained models + +Model | Finetuning split | Dataset | Model +|---|---|---|--- +Wav2Vec 2.0 Base | - | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) +Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) +Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) +Wav2Vec 2.0 Large | - | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) +Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) +Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) +Wav2Vec 2.0 Large (LV-60) | - | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) +Wav2Vec 2.0 Large (LV-60) | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m.pt) +Wav2Vec 2.0 Large (LV-60) | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h.pt) +Wav2Vec 2.0 Large (LV-60) | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h.pt) + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +$ext should be set t flac or wav, or whatever format your dataset happens to use that soundfile can read + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + +```shell script +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + +### Train a wav2vec 2.0 base model: + +Note that this was tested with pytorch 1.4.0 and the input is expected to be single channel, sampled at 16 kHz + +```shell script +$ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \ +--save-dir /model/path fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ +--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ +--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \ +--latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \ +--adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 400000 \ +--lr 0.0005 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ +--encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ +--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 \ +--max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ +--max-tokens 1400000 --max-update 400000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +``` + +Note: you can simulate 64 GPUs by using k GPUs and setting --update-freq 64/k + +### Train a wav2vec 2.0 large model: + +This configuration was used for model trained on the Libri-light dataset in the paper wav2vec 2.0 paper + +```shell script +$ python train.py --distributed-world-size 128 --distributed-port $PORT /manifest/path \ +--save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ +--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ +--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 768 --latent-vars 320 \ +--latent-groups 2 --latent-temp '(2.0,0.1,0.999995)' --infonce --optimizer adam \ +--adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 600000 \ +--lr 0.0003 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ +--encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.03 \ +--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --encoder-layers 24 --encoder-embed-dim 1024 \ +--encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --num-negatives 100 --cross-sample-negatives 0 \ +--max-sample-size 320000 --min-sample-size 32000 --dropout 0.0 --attention-dropout 0.1 --weight-decay 0.01 \ +--max-tokens 1200000 --max-update 600000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +``` + +Note: you can simulate 128 GPUs by using k GPUs and setting --update-freq 128/k + +### Fine-tune a pre-trained model with CTC: + +Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. +A letter vocabulary is can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example script that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: + +```shell script +split=train +$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split +``` + +Fine-tuning on 100h of Librispeech with letter targets: +```shell script +valid_subset=dev_other +python train.py --distributed-world-size 24 --distributed-port $PORT /path/to/training_data --save-dir /model/path --fp16 \ +--wer-args '("/path/to/lm/4-gram.bin","/path/to/lexicon",2,-1)' \ +--post-process letter --valid-subset $valid_subset --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 \ +--max-update 80000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /path/to/pretrained/model \ +--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ +--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 --zero-infinity \ +--feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 --optimizer adam \ +--adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 --hold-steps 32000 \ +--decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 --activation-dropout 0.1 --criterion ctc \ +--attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json --log-interval 500 --ddp-backend no_c10d +``` + +Note: you can simulate 24 GPUs by using k GPUs and setting --update-freq 24/k + +Note that decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). +Alternatively, simply omit the --wer-args flag. + +### Evaluating a CTC model: + +Evaluating a CTC model with a language model requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings) to be installed. + +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/models/sota/2019). +Be sure to upper-case the language model vocab after downloading it. + +Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). + +Next run the evaluation command: + +```shell script +$subset=dev_other +python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ +--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ +--lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter +``` + +To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. + # wav2vec Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). ## Pre-trained models -Description | Parameters | Dataset | Model ----|---:|---|--- -Wav2Vec large
([(Schneider et al., 2019)](https://arxiv.org/abs/1904.05862)) | 32.5M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) +Description | Dataset | Model +---|---|--- +Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) #### Example usage: ```python @@ -25,7 +154,7 @@ c = model.feature_aggregator(z) ## Training a new model with the CLI tools -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) ### Prepare training data manifest: @@ -55,13 +184,15 @@ $ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --inp Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). +These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). + ## Pre-trained models -Description | Parameters | Dataset | Model ----|---:|---|--- -vq-wav2vec Gumbel
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 34.1M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) -vq-wav2vec K-means
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 33.0M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) -Roberta on K-means codes
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 123.6M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) +Description | Dataset | Model +---|---|--- +vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) +vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) +Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) #### Example usage: ```python diff --git a/examples/wav2vec/libri_labels.py b/examples/wav2vec/libri_labels.py new file mode 100644 index 0000000000..4feced0a02 --- /dev/null +++ b/examples/wav2vec/libri_labels.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helper script to pre-compute embeddings for a wav2letter++ dataset +""" + +import argparse +import os + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("tsv") + parser.add_argument("--output-dir", required=True) + parser.add_argument("--output-name", required=True) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + transcriptions = {} + + with open(args.tsv, "r") as tsv, open( + os.path.join(args.output_dir, args.output_name + ".ltr.txt"), "w" + ) as ltr_out, open( + os.path.join(args.output_dir, args.output_name + ".wrd.txt"), "w" + ) as wrd_out: + root = next(tsv).strip() + for line in tsv: + line = line.strip() + dir = os.path.dirname(line) + if dir not in transcriptions: + parts = dir.split("/") + trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt" + path = os.path.join(root, dir, trans_path) + assert os.path.exists(path) + texts = {} + with open(path, "r") as trans_f: + for tline in trans_f: + items = tline.strip().split() + texts[items[0]] = " ".join(items[1:]) + transcriptions[dir] = texts + part = os.path.basename(line).split(".")[0] + assert part in transcriptions[dir] + print(transcriptions[dir][part], file=wrd_out) + print( + " ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |", + file=ltr_out, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py index ce9afd2431..0d658c07ca 100644 --- a/examples/wav2vec/vq-wav2vec_featurize.py +++ b/examples/wav2vec/vq-wav2vec_featurize.py @@ -19,7 +19,7 @@ except: print("Install tqdm to use --log-format=tqdm") -from fairseq.models.wav2vec import Wav2VecModel +from fairseq.models.wav2vec.wav2vec import Wav2VecModel import tqdm import soundfile as sf diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py index 31e12433f9..445a5d0213 100644 --- a/examples/wav2vec/wav2vec_featurize.py +++ b/examples/wav2vec/wav2vec_featurize.py @@ -20,7 +20,7 @@ from torch import nn import tqdm -from fairseq.models.wav2vec import Wav2VecModel +from fairseq.models.wav2vec.wav2vec import Wav2VecModel def read_audio(fname): diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py new file mode 100644 index 0000000000..3b8d974387 --- /dev/null +++ b/fairseq/criterions/ctc.py @@ -0,0 +1,246 @@ +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from argparse import Namespace +import math + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.data.data_utils import post_process +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("ctc") +class CtcCriterion(FairseqCriterion): + def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): + super().__init__(task) + self.blank_idx = task.target_dictionary.bos() + self.pad_idx = task.target_dictionary.pad() + self.eos_idx = task.target_dictionary.eos() + self.post_process = remove_bpe if remove_bpe else "letter" + + if wer_args is not None: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args) + + dec_args = Namespace() + dec_args.nbest = 1 + dec_args.criterion = "ctc" + dec_args.kenlm_model = wer_compute_kenlm + dec_args.lexicon = wer_lexicon + dec_args.beam = 50 + dec_args.beam_size_token = min(50, len(task.target_dictionary)) + dec_args.beam_threshold = min(50, len(task.target_dictionary)) + dec_args.lm_weight = lm_w + dec_args.word_score = ws_w + dec_args.unk_weight = -math.inf + dec_args.sil_weight = 0 + + self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary) + else: + self.w2l_decoder = None + + self.zero_infinity = zero_infinity + self.sentence_avg = sentence_avg + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + parser.add_argument( + "--zero-infinity", action="store_true", help="zero inf loss" + ) + try: + parser.add_argument( + "--remove-bpe", + "--post-process", + default="letter", + help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)", + ) + except: + pass # this option might have been added from eval args + parser.add_argument( + "--wer-args", + type=str, + default=None, + help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \ + path to lexicon, lm score, word score", + ) + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + lprobs = model.get_normalized_probs( + net_output, log_probs=True + ).contiguous() # (T, B, C) from the encoder + + if "src_lengths" in sample["net_input"]: + input_lengths = sample["net_input"]["src_lengths"] + else: + non_padding_mask = ~net_output["padding_mask"] + input_lengths = non_padding_mask.long().sum(-1) + + pad_mask = (sample["target"] != self.pad_idx) & ( + sample["target"] != self.eos_idx + ) + targets_flat = sample["target"].masked_select(pad_mask) + target_lengths = sample["target_lengths"] + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction="sum", + zero_infinity=self.zero_infinity, + ) + + ntokens = ( + sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item() + ) + + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": ntokens, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + if not model.training: + import editdistance + + with torch.no_grad(): + lprobs_t = lprobs.transpose(0, 1).float().cpu() + + c_err = 0 + c_len = 0 + w_errs = 0 + w_len = 0 + wv_errs = 0 + for lp, t, inp_l in zip( + lprobs_t, + sample["target_label"] + if "target_label" in sample + else sample["target"], + input_lengths, + ): + lp = lp[:inp_l].unsqueeze(0) + + decoded = None + if self.w2l_decoder is not None: + decoded = self.w2l_decoder.decode(lp) + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + + p = (t != self.task.target_dictionary.pad()) & ( + t != self.task.target_dictionary.eos() + ) + targ = t[p] + targ_units = self.task.target_dictionary.string(targ) + targ_units_arr = targ.tolist() + + toks = lp.argmax(dim=-1).unique_consecutive() + pred_units_arr = toks[toks != self.blank_idx].tolist() + + c_err += editdistance.eval(pred_units_arr, targ_units_arr) + c_len += len(targ_units_arr) + + targ_words = post_process(targ_units, self.post_process).split() + + pred_units = self.task.target_dictionary.string(pred_units_arr) + pred_words_raw = post_process(pred_units, self.post_process).split() + + if decoded is not None and "words" in decoded: + pred_words = decoded["words"] + w_errs += editdistance.eval(pred_words, targ_words) + wv_errs += editdistance.eval(pred_words_raw, targ_words) + else: + dist = editdistance.eval(pred_words_raw, targ_words) + w_errs += dist + wv_errs += dist + + w_len += len(targ_words) + + logging_output["wv_errors"] = wv_errs + logging_output["w_errors"] = w_errs + logging_output["w_total"] = w_len + logging_output["c_errors"] = c_err + logging_output["c_total"] = c_len + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/binary_cross_entropy.py b/fairseq/criterions/wav2vec_criterion.py similarity index 79% rename from fairseq/criterions/binary_cross_entropy.py rename to fairseq/criterions/wav2vec_criterion.py index 557f50bd90..019db62249 100644 --- a/fairseq/criterions/binary_cross_entropy.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -8,12 +8,12 @@ import torch import torch.nn.functional as F -from fairseq import utils +from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -@register_criterion('binary_cross_entropy') -class BinaryCrossEntropyCriterion(FairseqCriterion): +@register_criterion('wav2vec') +class Wav2vecCriterion(FairseqCriterion): def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): super().__init__(task) @@ -60,7 +60,8 @@ def forward(self, model, sample, reduce=True, log_pred=False): sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss) - if self.loss_weights is not None and hasattr(model, "get_extra_losses"): + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") extra_losses = model.get_extra_losses(net_output) if torch.is_tensor(extra_losses): extra_losses = [extra_losses] @@ -76,7 +77,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output = { 'loss': loss.item() if reduce else loss, 'ntokens': sample_size, - 'nsentences': logits.size(0), + 'nsentences': sample['id'].numel(), 'sample_size': sample_size, } @@ -110,25 +111,31 @@ def forward(self, model, sample, reduce=True, log_pred=False): return loss, sample_size, logging_output @staticmethod - def aggregate_logging_outputs(logging_outputs): + def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) - agg_output = { - 'loss': loss_sum / sample_size / math.log(2), - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, - } - if sample_size != ntokens: - agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) + + metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + metrics.log_scalar('ntokens', ntokens) + metrics.log_scalar('nsentences', nsentences) correct = sum(log.get("correct", 0) for log in logging_outputs) + metrics.log_scalar("_correct", correct) + total = sum(log.get("count", 0) for log in logging_outputs) + metrics.log_scalar("_total", total) + + if total > 0: - agg_output['accuracy'] = correct / total + metrics.log_derived( + "accuracy", + lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) + if meters["_total"].sum > 0 + else float("nan"), + ) builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'} @@ -136,7 +143,15 @@ def aggregate_logging_outputs(logging_outputs): if k not in builtin_keys: val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs) if k.startswith('loss'): - val = val / ntokens if ntokens > 0 else float('nan') - agg_output[k] = val + metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) + else: + metrics.log_scalar(k, val, round=3) - return agg_output + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index e35bb5646c..a99d9280fa 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -9,6 +9,7 @@ from .base_wrapper_dataset import BaseWrapperDataset +from .add_target_dataset import AddTargetDataset from .append_token_dataset import AppendTokenDataset from .audio.raw_audio_dataset import FileAudioDataset from .backtranslation_dataset import BacktranslationDataset @@ -56,6 +57,7 @@ ) __all__ = [ + 'AddTargetDataset', 'AppendTokenDataset', 'BacktranslationDataset', 'BaseWrapperDataset', diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py new file mode 100644 index 0000000000..91cf1a51c4 --- /dev/null +++ b/fairseq/data/add_target_dataset.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset +from . import data_utils + + +class AddTargetDataset(BaseWrapperDataset): + def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None, add_to_input=False): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.add_to_input = add_to_input + + def get_label(self, index): + return self.labels[index] if self.process_label is None else self.process_label(self.labels[index]) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = len(self.get_label(index)) + return (sz, own_sz) + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["target"] = target + if self.add_to_input: + eos = target.new_full((target.size(0), 1), self.eos) + collated["target"] = torch.cat([target, eos], dim=-1) + collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1) + return collated \ No newline at end of file diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 40cbc20680..675b095647 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -5,6 +5,7 @@ import os +import logging import numpy as np import sys @@ -13,6 +14,8 @@ from .. import FairseqDataset +logger = logging.getLogger(__name__) + class RawAudioDataset(FairseqDataset): def __init__( @@ -22,6 +25,8 @@ def __init__( min_sample_size=None, shuffle=True, min_length=0, + pad=False, + normalize=False, ): super().__init__() @@ -30,11 +35,11 @@ def __init__( self.max_sample_size = ( max_sample_size if max_sample_size is not None else sys.maxsize ) - self.min_sample_size = ( - min_sample_size if min_sample_size is not None else self.max_sample_size - ) + self.min_sample_size = min_sample_size self.min_length = min_length + self.pad = pad self.shuffle = shuffle + self.normalize = normalize def __getitem__(self, index): raise NotImplementedError() @@ -43,17 +48,17 @@ def __len__(self): return len(self.sizes) def postprocess(self, feats, curr_sample_rate): - def resample(x, factor): - return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze() - if feats.dim() == 2: feats = feats.mean(-1) if curr_sample_rate != self.sample_rate: - factor = self.sample_rate / curr_sample_rate - feats = resample(feats, factor) + raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") assert feats.dim() == 1, feats.dim() + + if self.normalize: + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) return feats def crop_to_max_size(self, wav, target_size): @@ -68,34 +73,42 @@ def crop_to_max_size(self, wav, target_size): def collater(self, samples): samples = [ - s for s in samples if s["source"] is not None and len(s["source"]) > 0 + s + for s in samples + if s["source"] is not None ] if len(samples) == 0: return {} sources = [s["source"] for s in samples] sizes = [len(s) for s in sources] - target_size = min(min(sizes), self.max_sample_size) - if target_size < self.min_length: - return {} - - if self.min_sample_size < target_size: - target_size = np.random.randint(self.min_sample_size, target_size + 1) + if self.pad: + target_size = min(max(sizes), self.max_sample_size) + else: + target_size = min(min(sizes), self.max_sample_size) collated_sources = sources[0].new(len(sources), target_size) + padding_mask = ( + torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None + ) for i, (source, size) in enumerate(zip(sources, sizes)): diff = size - target_size - assert diff >= 0 if diff == 0: collated_sources[i] = source + elif diff < 0: + assert self.pad + collated_sources[i] = torch.cat( + [source, source.new_full((-diff,), 0.0)] + ) + padding_mask[i, diff:] = True else: collated_sources[i] = self.crop_to_max_size(source, target_size) - return { - "id": torch.LongTensor([s["id"] for s in samples]), - "net_input": {"source": collated_sources}, - } + input = {"source": collated_sources} + if self.pad: + input["padding_mask"] = padding_mask + return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} def num_tokens(self, index): return self.size(index) @@ -103,6 +116,8 @@ def num_tokens(self, index): def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" + if self.pad: + return self.sizes[index] return min(self.sizes[index], self.max_sample_size) def ordered_indices(self): @@ -115,7 +130,7 @@ def ordered_indices(self): order = [np.arange(len(self))] order.append(self.sizes) - return np.lexsort(order) + return np.lexsort(order)[::-1] class FileAudioDataset(RawAudioDataset): @@ -127,6 +142,8 @@ def __init__( min_sample_size=None, shuffle=True, min_length=0, + pad=False, + normalize=False, ): super().__init__( sample_rate=sample_rate, @@ -134,17 +151,25 @@ def __init__( min_sample_size=min_sample_size, shuffle=shuffle, min_length=min_length, + pad=pad, + normalize=normalize, ) self.fnames = [] + skipped = 0 with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for line in f: items = line.strip().split("\t") assert len(items) == 2, line + sz = int(items[1]) + if min_length is not None and sz < min_length: + skipped += 1 + continue self.fnames.append(items[0]) - self.sizes.append(int(items[1])) + self.sizes.append(sz) + logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 28410dc21e..57991a8802 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -11,10 +11,11 @@ import itertools import logging import os -import sys -import types + +from typing import Tuple, Optional import numpy as np +import torch logger = logging.getLogger(__name__) @@ -258,11 +259,137 @@ def batch_by_size( return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) -def process_bpe_symbol(sentence: str, bpe_symbol: str): - if bpe_symbol == 'sentencepiece': - sentence = sentence.replace(' ', '').replace('\u2581', ' ').strip() - elif bpe_symbol == '_EOW': - sentence = sentence.replace(' ', '').replace('_EOW', ' ').strip() - elif bpe_symbol is not None: - sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip() +def post_process(sentence: str, symbol: str): + if symbol == "sentencepiece": + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() + elif symbol == 'wordpiece': + sentence = sentence.replace(" ", "").replace("_", " ").strip() + elif symbol == 'letter': + sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "_EOW": + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() + elif symbol is not None and symbol != 'none': + sentence = (sentence + " ").replace(symbol, "").rstrip() return sentence + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e-length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start-min_space+1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter((e - s if e-s >= length+min_space else 0 for s, e in parts), np.int) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 5ff010d5e2..01a6a81486 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -101,7 +101,7 @@ def token_string(i): if utils.item(i) not in extra_symbols_to_ignore ) - return data_utils.process_bpe_symbol(sent, bpe_symbol) + return data_utils.post_process(sent, bpe_symbol) def unk_string(self, escape=False): """Return unknown string, optionally escaped as: <>""" diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py index 9c73633572..7ddc0fba01 100644 --- a/fairseq/models/fairseq_encoder.py +++ b/fairseq/models/fairseq_encoder.py @@ -81,3 +81,11 @@ def max_positions(self): def upgrade_state_dict(self, state_dict): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, 'set_num_updates') and m != self: + m.set_num_updates(num_updates) + self.apply(_apply) diff --git a/fairseq/models/wav2vec/__init__.py b/fairseq/models/wav2vec/__init__.py new file mode 100644 index 0000000000..06cec18183 --- /dev/null +++ b/fairseq/models/wav2vec/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .wav2vec import * # noqa +from .wav2vec2 import * # noqa +from .wav2vec2_asr import * # noqa diff --git a/fairseq/models/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py similarity index 98% rename from fairseq/models/wav2vec.py rename to fairseq/models/wav2vec/wav2vec.py index 38a25db0a8..905df824f3 100644 --- a/fairseq/models/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -17,6 +17,7 @@ Fp32LayerNorm, GumbelVectorQuantizer, KmeansVectorQuantizer, + TransposeLast, ) from fairseq.utils import buffered_arange @@ -402,16 +403,6 @@ def get_extra_losses(self, net_output): return loss -class TransposeLast(nn.Module): - def __init__(self, deconstruct_idx=None): - super().__init__() - self.deconstruct_idx = deconstruct_idx - - def forward(self, x): - if self.deconstruct_idx is not None: - x = x[self.deconstruct_idx] - return x.transpose(-2, -1) - def norm_block(is_layer_norm, dim, affine=True): if is_layer_norm: mod = nn.Sequential( diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py new file mode 100644 index 0000000000..226f035ba8 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -0,0 +1,1017 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import List, Tuple + +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + GumbelVectorQuantizer, + LayerNorm, + MultiheadAttention, + SamePad, + TransposeLast, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import buffered_arange + + +@register_model("wav2vec2") +class Wav2Vec2Model(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + + parser.add_argument( + "--extractor-mode", + choices=["default", "layer_norm"], + help="mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", + ) + + parser.add_argument( + "--encoder-layers", + type=int, + metavar="L", + help="num encoder layers in the transformer", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability for the transformer", + ) + + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + + parser.add_argument( + "--final-dim", + type=int, + metavar="D", + help="project final representations and targets to this many dimensions", + ) + + parser.add_argument( + "--layer-norm-first", + action="store_true", + help="apply layernorm first in the transformer", + ) + + parser.add_argument( + "--encoder-layerdrop", + type=float, + help="probability of dropping a tarnsformer layer", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + metavar="EXPR", + help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--logit-temp", type=float, help="temperature to divide logits by" + ) + + parser.add_argument( + "--quantize-targets", action="store_true", help="use quantized targets" + ) + + parser.add_argument( + "--quantize-input", action="store_true", help="use quantized inputs" + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + help="multiply feature extractor var grads by this", + ) + + parser.add_argument( + "--latent-vars", + type=int, + metavar="N", + help="number of latent variables V in each group of the codebook", + ) + + parser.add_argument( + "--latent-groups", + type=int, + metavar="N", + help="number of groups G of latent variables in the codebook", + ) + + parser.add_argument( + "--latent-dim", + type=int, + metavar="N", + help="if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", + ) + + parser.add_argument("--mask-length", type=int, help="mask length") + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--mask-channel-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-negatives", type=int, metavar="N", help="number of negative examples" + ) + + parser.add_argument( + "--negatives-from-everywhere", + action="store_true", + help="sample negatives from everywhere, not just masked states", + ) + + parser.add_argument( + "--cross-sample-negatives", + type=int, + metavar="N", + help="num of cross sampled negatives", + ) + + parser.add_argument( + "--codebook-negatives", + type=int, + metavar="N", + help="num of codebook sampled negatives", + ) + + parser.add_argument( + "--conv-pos", + type=int, + metavar="N", + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + metavar="N", + help="number of groups for convolutional positional embedding", + ) + + parser.add_argument( + "--latent-temp", + type=str, + metavar="D", + help="temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", + ) + + parser.add_argument( + "--target-glu", action="store_true", help="adds projection + glu to targets" + ) + + parser.add_argument( + "--conv-bias", action="store_true", help="include bias in conv encoder" + ) + + def __init__(self, args): + super().__init__() + self.args = args + + feature_enc_layers = eval(args.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=args.extractor_mode, + conv_bias=args.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, args.encoder_embed_dim) + if self.embed != args.encoder_embed_dim and not args.quantize_input + else None + ) + + self.mask_prob = args.mask_prob + self.mask_selection = args.mask_selection + self.mask_other = args.mask_other + self.mask_length = args.mask_length + self.no_mask_overlap = args.no_mask_overlap + self.mask_min_space = args.mask_min_space + + self.mask_channel_prob = args.mask_channel_prob + self.mask_channel_selection = args.mask_channel_selection + self.mask_channel_other = args.mask_channel_other + self.mask_channel_length = args.mask_channel_length + self.no_mask_channel_overlap = args.no_mask_channel_overlap + self.mask_channel_min_space = args.mask_channel_min_space + + self.dropout_input = nn.Dropout(args.dropout_input) + self.dropout_features = nn.Dropout(args.dropout_features) + + self.feature_grad_mult = args.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = args.num_negatives + self.cross_sample_negatives = args.cross_sample_negatives + self.codebook_negatives = args.codebook_negatives + self.negatives_from_everywhere = args.negatives_from_everywhere + + self.logit_temp = args.logit_temp + + if args.quantize_input: + vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim + self.input_quantizer = ( + GumbelVectorQuantizer( + dim=args.encoder_embed_dim, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + if not args.same_quantizer + else self.quantizer + ) + self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + + final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + + if args.quantize_targets: + vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(args) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if args.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, args, task=None): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + return cls(args) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + + logits /= self.logit_temp + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + return logits + + def forward(self, source, padding_mask=None, mask=True, features_only=False): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features + + if padding_mask is not None: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + if mask_indices is not None: + y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) + else: + y = unmasked_features + else: + x = features + y = unmasked_features + mask_indices = None + + x = self.encoder(x, padding_mask=padding_mask) + + if features_only: + return {"x": x, "padding_mask": padding_mask} + + if self.quantizer: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + if self.negatives_from_everywhere: + neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) + negs, _ = self.sample_negatives(neg_cands, y.size(1)) + negs = self.project_q(negs) + + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + x = self.final_proj(x) + x = self.compute_preds(x, y, negs) + + result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} + + if prob_ppl is not None: + result["prob_perplexity"] = prob_ppl + result["code_perplexity"] = code_ppl + result["num_vars"] = num_vars + result["temp"] = curr_temp + + return result + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose(1, 2) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features(self, source, padding_mask, mask=False): + res = self.forward(source, padding_mask, mask=mask, features_only=True) + return res["x"], res["padding_mask"] + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + for _ in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None): + x = self.extract_features(x, padding_mask) + + if self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def extract_features(self, x, padding_mask=None): + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + layer_results.append(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn + + +@register_model_architecture("wav2vec2", "wav2vec2") +def base_architecture(args): + args.extractor_mode = getattr(args, "extractor_mode", "default") + + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + + args.final_dim = getattr(args, "final_dim", 0) + + args.layer_norm_first = getattr(args, "layer_norm_first", False) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + conv_feature_layers = "[(512, 10, 5)]" + conv_feature_layers += " + [(512, 8, 4)]" + conv_feature_layers += " + [(512, 4, 2)] * 3" + conv_feature_layers += " + [(512, 1, 1)]" + args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) + + args.logit_temp = getattr(args, "logit_temp", 0.1) + + args.quantize_targets = getattr(args, "quantize_targets", False) + args.quantize_input = getattr(args, "quantize_input", False) + + args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) + + args.latent_vars = getattr(args, "latent_vars", 320) + args.latent_groups = getattr(args, "latent_groups", 2) + args.latent_dim = getattr(args, "latent_dim", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.65) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_min_space = getattr(args, "mask_min_space", 1) + + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.65) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) + + args.dropout_input = getattr(args, "dropout_input", 0) + args.dropout_features = getattr(args, "dropout_features", 0) + + args.num_negatives = getattr(args, "num_negatives", 100) + args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False) + args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) + args.codebook_negatives = getattr(args, "codebook_negatives", 0) + + args.conv_pos = getattr(args, "conv_pos", 128) + args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) + + args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)") + + args.target_glu = getattr(args, "target_glu", False) + + args.conv_bias = getattr(args, "conv_bias", False) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py new file mode 100644 index 0000000000..f50af255a5 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -0,0 +1,679 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, tasks, utils + +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqEncoderDecoderModel, + BaseFairseqModel, + register_model, + register_model_architecture, +) +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer + + +def add_common_args(parser): + parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") + parser.add_argument( + "--no-pretrained-weights", + action="store_true", + help="if true, does not load pretrained weights", + ) + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--final-dropout", + type=float, + metavar="D", + help="dropout after transformer and before final projection", + ) + parser.add_argument( + "--apply-mask", action="store_true", help="apply masking during fine-tuning" + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability inside wav2vec 2.0 model", + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside wav2vec 2.0 model", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside wav2vec 2.0 model", + ) + + parser.add_argument( + "--mask-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--freeze-finetune-updates", + default=0, + type=int, + help="dont finetune wav2vec for this many updates", + ) + + parser.add_argument( + "--feature-grad-mult", + default=None, + type=float, + help="reset feature grad mult in wav2vec 2.0 to this", + ) + + parser.add_argument( + "--layerdrop", + default=0.0, + type=float, + help="probability of dropping a layer in wav2vec 2.0", + ) + + +@register_model("wav2vec_ctc") +class Wav2VecCtc(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + add_common_args(parser) + + def __init__(self, w2v_encoder, args): + super().__init__() + self.w2v_encoder = w2v_encoder + self.args = args + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + base_architecture(args) + w2v_encoder = Wav2VecEncoder(args, task.target_dictionary) + return cls(w2v_encoder, args) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + # def max_positions(self): + # return None + + +@register_model("wav2vec_seq2seq") +class TransformerModel(FairseqEncoderDecoderModel): + def __init__(self, args, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + add_common_args(parser) + + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-layerdrop", + type=float, + metavar="D", + help="decoder layerdrop chance", + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--no-token-positional-embeddings", + default=False, + action="store_true", + help="if set, disables positional embeddings (outside self attention)", + ) + + parser.add_argument( + "--decoder-dropout", + type=float, + metavar="D", + help="dropout probability in the decoder", + ) + parser.add_argument( + "--decoder-attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside the decoder", + ) + parser.add_argument( + "--decoder-activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside the decoder", + ) + + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, "max_source_positions"): + args.max_source_positions = 2048 + if not hasattr(args, "max_target_positions"): + args.max_target_positions = 2048 + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) + + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return TransformerModel(args, encoder, decoder) + + @classmethod + def build_encoder(cls, args): + return Wav2VecEncoder(args) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + def forward(self, **kwargs): + encoder_out = self.encoder(tbc=False, **kwargs) + decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) + return decoder_out + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + +class Wav2VecEncoder(FairseqEncoder): + def __init__(self, args, tgt_dict=None): + self.apply_mask = args.apply_mask + + arg_overrides = { + "dropout": args.dropout, + "activation_dropout": args.activation_dropout, + "dropout_input": args.dropout_input, + "attention_dropout": args.attention_dropout, + "mask_length": args.mask_length, + "mask_prob": args.mask_prob, + "mask_selection": args.mask_selection, + "mask_other": args.mask_other, + "no_mask_overlap": args.no_mask_overlap, + "mask_channel_length": args.mask_channel_length, + "mask_channel_prob": args.mask_channel_prob, + "mask_channel_selection": args.mask_channel_selection, + "mask_channel_other": args.mask_channel_other, + "no_mask_channel_overlap": args.no_mask_channel_overlap, + "encoder_layerdrop": args.layerdrop, + "feature_grad_mult": args.feature_grad_mult, + } + + if getattr(args, "w2v_args", None) is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + args.w2v_path, arg_overrides + ) + w2v_args = state["args"] + else: + state = None + w2v_args = args.w2v_args + + assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same' + + w2v_args.data = args.data + task = tasks.setup_task(w2v_args) + model = task.build_model(w2v_args) + + if state is not None and not args.no_pretrained_weights: + model.load_state_dict(state["model"], strict=True) + + model.remove_pretraining_modules() + + super().__init__(task.source_dictionary) + + d = w2v_args.encoder_embed_dim + + self.w2v_model = model + + self.final_dropout = nn.Dropout(args.final_dropout) + self.freeze_finetune_updates = args.freeze_finetune_updates + self.num_updates = 0 + + if tgt_dict is not None: + self.proj = Linear(d, len(tgt_dict)) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_features(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(dictionary) + + self.dropout = args.decoder_dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_output_dim + args.encoder_embed_dim = embed_dim + + self.layerdrop = args.decoder_layerdrop + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + args = copy.deepcopy(args) + args.dropout = args.decoder_dropout + args.attention_dropout = args.decoder_attention_dropout + args.activation_dropout = args.decoder_activation_dropout + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] + if encoder_out is not None + else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@register_model_architecture("wav2vec_ctc", "wav2vec_ctc") +def base_architecture(args): + args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) + args.dropout_input = getattr(args, "dropout_input", 0) + args.final_dropout = getattr(args, "final_dropout", 0) + args.apply_mask = getattr(args, "apply_mask", False) + args.dropout = getattr(args, "dropout", 0) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.5) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + + args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) + args.layerdrop = getattr(args, "layerdrop", 0.0) + + +@register_model_architecture("wav2vec_seq2seq", "wav2vec_seq2seq") +def seq2seq_architecture(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_layers = getattr(args, "decoder_layers", 10) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.decoder_dropout = getattr(args, "decoder_dropout", 0) + args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) + args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) + + base_architecture(args) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 94bb86880b..d526d4a92e 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -25,10 +25,12 @@ from .linearized_convolution import LinearizedConvolution from .multihead_attention import MultiheadAttention from .positional_embedding import PositionalEmbedding +from .same_pad import SamePad from .scalar_bias import ScalarBias from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer from .transformer_sentence_encoder import TransformerSentenceEncoder +from .transpose_last import TransposeLast from .unfold import unfold1d from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .vggblock import VGGBlock @@ -60,12 +62,14 @@ 'LinearizedConvolution', 'MultiheadAttention', 'PositionalEmbedding', + 'SamePad', 'ScalarBias', 'SinusoidalPositionalEmbedding', 'TransformerSentenceEncoderLayer', 'TransformerSentenceEncoder', 'TransformerDecoderLayer', 'TransformerEncoderLayer', + 'TransposeLast', 'VGGBlock', 'unfold1d', ] diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 2efc10e74b..01ddd2298b 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -53,7 +53,7 @@ def __init__( num_groups = groups if not combine_groups else 1 self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) - nn.init.xavier_normal_(self.vars) + nn.init.uniform_(self.vars) if weight_proj_depth > 1: @@ -70,6 +70,8 @@ def block(input_dim, output_dim): ) else: self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) assert len(temp) == 3, temp @@ -81,8 +83,7 @@ def set_num_updates(self, num_updates): self.curr_temp = max( self.max_temp * self.temp_decay ** num_updates, self.min_temp ) - - def codebook(self): + def get_codebook_indices(self): if self.codebook_indices is None: from itertools import product @@ -99,13 +100,36 @@ def codebook(self): for b in range(1, self.groups): self.codebook_indices[:, b] += self.num_vars * b self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + def codebook(self): + indices = self.get_codebook_indices() return ( self.vars.squeeze(0) - .index_select(0, self.codebook_indices) - .view(self.num_vars ** self.groups, -1) + .index_select(0, indices) + .view(self.num_vars ** self.groups, -1) ) + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars ** exponent) + return res + def forward_idx(self, x): res = self.forward(x, produce_targets=True) return res["x"], res["targets"] diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py new file mode 100644 index 0000000000..b46f94d635 --- /dev/null +++ b/fairseq/modules/same_pad.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from torch import nn + + +class SamePad(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = kernel_size % 2 == 0 + + def forward(self, x): + if self.remove: + x = x[:, :, :-1] + return x diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py new file mode 100644 index 0000000000..e578b3ec50 --- /dev/null +++ b/fairseq/modules/transpose_last.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +transpose last 2 dimensions of the input +""" + +import torch.nn as nn + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) diff --git a/fairseq/options.py b/fairseq/options.py index 88c0389eba..171c67966d 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -9,7 +9,7 @@ import torch -from fairseq import utils +from fairseq import scoring, utils from fairseq.data.indexed_dataset import get_available_dataset_impl @@ -361,6 +361,10 @@ def add_dataset_args(parser, train=False, gen=False): ' (e.g. train, valid, test)') group.add_argument('--validate-interval', type=int, default=1, metavar='N', help='validate every N epochs') + group.add_argument('--validate-interval-updates', type=int, default=0, metavar='N', + help='validate every N updates') + group.add_argument('--validate-after-updates', type=int, default=0, metavar='N', + help='dont validate until reaching this many updates') group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N', help='specified random seed for validation') group.add_argument('--disable-validation', action='store_true', @@ -529,7 +533,7 @@ def add_common_eval_args(group): # fmt: off group.add_argument('--path', metavar='FILE', help='path(s) to model file(s), colon separated') - group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, + group.add_argument('--remove-bpe', '--post-process', nargs='?', const='@@ ', default=None, help='remove BPE tokens before scoring (can be set to sentencepiece)') group.add_argument('--quiet', action='store_true', help='only print final scores') diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py new file mode 100644 index 0000000000..6e5cc287ba --- /dev/null +++ b/fairseq/scoring/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os + +from fairseq import registry + + +build_scoring, register_scoring, SCORING_REGISTRY = registry.setup_registry( + "--scoring", default="bleu" +) + + +# automatically import any Python files in the current directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.scoring." + module) diff --git a/fairseq/bleu.py b/fairseq/scoring/bleu.py similarity index 67% rename from fairseq/bleu.py rename to fairseq/scoring/bleu.py index 7f053bb853..40f3440d82 100644 --- a/fairseq/bleu.py +++ b/fairseq/scoring/bleu.py @@ -7,11 +7,14 @@ import math import torch +from fairseq.scoring import register_scoring + try: from fairseq import libbleu except ImportError as e: import sys - sys.stderr.write('ERROR: missing libbleu.so. run `pip install --editable .`\n') + + sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") raise e @@ -20,22 +23,24 @@ class BleuStat(ctypes.Structure): _fields_ = [ - ('reflen', ctypes.c_size_t), - ('predlen', ctypes.c_size_t), - ('match1', ctypes.c_size_t), - ('count1', ctypes.c_size_t), - ('match2', ctypes.c_size_t), - ('count2', ctypes.c_size_t), - ('match3', ctypes.c_size_t), - ('count3', ctypes.c_size_t), - ('match4', ctypes.c_size_t), - ('count4', ctypes.c_size_t), + ("reflen", ctypes.c_size_t), + ("predlen", ctypes.c_size_t), + ("match1", ctypes.c_size_t), + ("count1", ctypes.c_size_t), + ("match2", ctypes.c_size_t), + ("count2", ctypes.c_size_t), + ("match3", ctypes.c_size_t), + ("count3", ctypes.c_size_t), + ("match4", ctypes.c_size_t), + ("count4", ctypes.c_size_t), ] +@register_scoring("sacrebleu") class SacrebleuScorer(object): - def __init__(self): + def __init__(self, *unused): import sacrebleu + self.sacrebleu = sacrebleu self.reset() @@ -58,6 +63,7 @@ def result_string(self, order=4): return self.sacrebleu.corpus_bleu(self.sys, [self.ref]).format() +@register_scoring("bleu") class Scorer(object): def __init__(self, pad, eos, unk): self.stat = BleuStat() @@ -74,11 +80,9 @@ def reset(self, one_init=False): def add(self, ref, pred): if not isinstance(ref, torch.IntTensor): - raise TypeError('ref must be a torch.IntTensor (got {})' - .format(type(ref))) + raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) if not isinstance(pred, torch.IntTensor): - raise TypeError('pred must be a torch.IntTensor(got {})' - .format(type(pred))) + raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) # don't match unknown words rref = ref.clone() @@ -95,11 +99,13 @@ def add(self, ref, pred): ctypes.c_size_t(pred.size(0)), ctypes.c_void_p(pred.data_ptr()), ctypes.c_int(self.pad), - ctypes.c_int(self.eos)) + ctypes.c_int(self.eos), + ) def score(self, order=4): - psum = sum(math.log(p) if p > 0 else float('-Inf') - for p in self.precision()[:order]) + psum = sum( + math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] + ) return self.brevity() * math.exp(psum / order) * 100 def precision(self): @@ -119,11 +125,17 @@ def brevity(self): def result_string(self, order=4): assert order <= 4, "BLEU scores for order > 4 aren't supported" - fmt = 'BLEU{} = {:2.2f}, {:2.1f}' + fmt = "BLEU{} = {:2.2f}, {:2.1f}" for _ in range(1, order): - fmt += '/{:2.1f}' - fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' + fmt += "/{:2.1f}" + fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" bleup = [p * 100 for p in self.precision()[:order]] - return fmt.format(order, self.score(order=order), *bleup, - self.brevity(), self.stat.predlen/self.stat.reflen, - self.stat.predlen, self.stat.reflen) + return fmt.format( + order, + self.score(order=order), + *bleup, + self.brevity(), + self.stat.predlen / self.stat.reflen, + self.stat.predlen, + self.stat.reflen + ) diff --git a/fairseq/scoring/scoring_utils.py b/fairseq/scoring/scoring_utils.py new file mode 100644 index 0000000000..0b710d5bb8 --- /dev/null +++ b/fairseq/scoring/scoring_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from . import bleu, build_scoring + + +def build_scorer(args, tgt_dict): + if args.sacrebleu: + utils.deprecation_warning( + "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." + ) + args.scoring = "sacrebleu" + + if args.scoring == "bleu": + scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) + else: + return build_scoring(args) + + return scorer diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py new file mode 100644 index 0000000000..6f4521f6cd --- /dev/null +++ b/fairseq/scoring/wer.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import editdistance + +from fairseq.scoring import register_scoring + + +@register_scoring("wer") +class WerScorer(object): + def __init__(self, *unused): + self.reset() + + def reset(self): + self.distance = 0 + self.target_length = 0 + + def add_string(self, ref, pred): + pred_items = ref.split() + targ_items = pred.split() + self.distance += editdistance.eval(pred_items, targ_items) + self.target_length += len(targ_items) + + def result_string(self): + return f"WER: {self.score()}" + + def score(self): + return ( + 100.0 * self.distance / self.target_length if self.target_length > 0 else 0 + ) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ed65bd86b6..42012fbbb1 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -176,11 +176,17 @@ def _generate( ], ) net_input = sample["net_input"] - src_tokens = net_input["src_tokens"] - # length of the source text being the character length except EndOfSentence and pad - src_lengths = ( - (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) - ) + + if 'src_tokens' in net_input: + src_tokens = net_input['src_tokens'] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + elif 'source' in net_input: + src_tokens = net_input['source'] + src_lengths = net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor(src_tokens.size(-1)) + else: + raise Exception('expected src_tokens or source in net input') + # bsz: total number of sentences in beam input_size = src_tokens.size() bsz, src_len = input_size[0], input_size[1] diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index e161c224e9..f33637468f 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -1,15 +1,28 @@ -# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. # -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. import os +import sys -from fairseq.data import FileAudioDataset +from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset from . import FairseqTask, register_task -@register_task('audio_pretraining') +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +@register_task("audio_pretraining") class AudioPretrainingTask(FairseqTask): """ @@ -18,16 +31,53 @@ class AudioPretrainingTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='path to data directory') - parser.add_argument('--sample-rate', default=16000, type=int, - help='target sample rate. audio files will be up/down sampled to this rate') - parser.add_argument('--max-sample-size', default=None, type=int, - help='max sample size to crop to for batching. default = min sample length') - parser.add_argument('--min-sample-size', default=None, type=int, - help='min sample size to crop to for batching. default = same as --max-sample-size') - - def __init__(self, args): + parser.add_argument("data", help="path to data directory") + parser.add_argument( + "--sample-rate", + default=16000, + type=int, + help="target sample rate. audio files will be up/down sampled to this rate", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="if set, normalizes input to have 0 mean and unit variance", + ) + parser.add_argument( + "--max-sample-size", + default=None, + type=int, + help="max sample size to crop to for batching. default = min sample length", + ) + parser.add_argument( + "--min-sample-size", + default=None, + type=int, + help="min sample size to crop to for batching. default = same as --max-sample-size", + ) + + parser.add_argument( + "--enable-padding", + action="store_true", + help="pad shorter samples instead of cropping", + ) + + parser.add_argument( + "--no-min-cropping", action="store_true", help="always crop to max sample size or smallest length" + ) + + parser.add_argument( + "--labels", + type=str, + default=None, + help="extension of the label file to load, if any", + ) + + def __init__(self, args, source_dictionary=None): super().__init__(args) + self._target_dictionary = None + self._source_dictionary = source_dictionary + self.is_ctc = args.criterion == "ctc" @classmethod def setup_task(cls, args, **kwargs): @@ -44,15 +94,48 @@ def load_dataset(self, split, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ + manifest = os.path.join(self.args.data, "{}.tsv".format(split)) + self.datasets[split] = FileAudioDataset( + manifest, + sample_rate=self.args.sample_rate, + max_sample_size=self.args.max_sample_size, + min_sample_size=self.args.min_sample_size if not self.args.no_min_cropping else self.args.max_sample_size, + min_length=self.args.min_sample_size, + pad=self.args.labels is not None or self.args.enable_padding, + normalize=self.args.normalize, + ) + + if self.args.labels: + dict_path = os.path.join(self.args.data, f"dict.{self.args.labels}.txt") + self._target_dictionary = Dictionary.load(dict_path) + label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}") + labels = [] + with open(label_path, "r") as f: + for line in f: + labels.append(line) + + process_label = LabelEncoder(self.target_dictionary) - manifest = os.path.join(self.args.data, '{}.tsv'.format(split)) - self.datasets[split] = FileAudioDataset(manifest, - sample_rate=self.args.sample_rate, - max_sample_size=self.args.max_sample_size, - min_sample_size=self.args.min_sample_size) + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + add_to_input=not self.is_ctc, + ) + + @property + def source_dictionary(self): + return self._source_dictionary @property def target_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language model.""" - return None + return self._target_dictionary + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (sys.maxsize, sys.maxsize) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index fda3cdc9c0..cf472ff252 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -16,7 +16,7 @@ import torch -from fairseq import bleu, checkpoint_utils, options, tasks, utils +from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.data import encoders @@ -136,11 +136,8 @@ def decode_fn(x): x = tokenizer.decode(x) return x - # Generate and compute BLEU score - if args.sacrebleu: - scorer = bleu.SacrebleuScorer() - else: - scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) + scorer = scoring.scoring_utils.build_scorer(args, tgt_dict) + num_sentences = 0 has_target = True wps_meter = TimeMeter() @@ -162,7 +159,11 @@ def decode_fn(x): has_target = sample['target'] is not None # Remove padding - src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + if 'src_tokens' in sample['net_input']: + src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + else: + src_tokens = None + target_tokens = None if has_target: target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() @@ -255,7 +256,7 @@ def decode_fn(x): wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) - num_sentences += sample['nsentences'] + num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index f7c3dc42b9..59631c2d65 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -11,7 +11,7 @@ import os import sys -from fairseq import bleu +from fairseq.scoring import bleu from fairseq.data import dictionary diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 750a54bfee..806e4bc54b 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -252,10 +252,12 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc args.save_interval_updates > 0 and num_updates > 0 and num_updates % args.save_interval_updates == 0 + and num_updates >= args.validate_after_updates ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0) ) and not args.disable_validation # Validate diff --git a/setup.py b/setup.py index 7e2358c7a4..a309b90bd1 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,7 @@ def include_dirs(self, dirs): install_requires=[ 'cffi', 'cython', + 'editdistance', 'numpy', 'regex', 'sacrebleu', From aa62039d463d95767b69bd1b85e5694a7b2a2d40 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 4 Aug 2020 18:10:34 -0700 Subject: [PATCH 092/707] fix wav2vec docs (#1236) Summary: fixes wav2vec 2.0 and wav2vec docs (incl issue #2418) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1236 Reviewed By: ngoyal2707 Differential Revision: D22934517 Pulled By: alexeib fbshipit-source-id: aaffd05c5e6d22cf4b5d912ddaa170530b65b378 --- examples/wav2vec/README.md | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index d90dedf22e..e7f8633afb 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -6,15 +6,15 @@ wav2vec 2.0 learns speech representations on unlabeled data as described in [wav Model | Finetuning split | Dataset | Model |---|---|---|--- -Wav2Vec 2.0 Base | - | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) -Wav2Vec 2.0 Large | - | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) -Wav2Vec 2.0 Large (LV-60) | - | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) +Wav2Vec 2.0 Large (LV-60) | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) Wav2Vec 2.0 Large (LV-60) | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m.pt) Wav2Vec 2.0 Large (LV-60) | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h.pt) Wav2Vec 2.0 Large (LV-60) | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h.pt) @@ -25,7 +25,7 @@ Given a directory containing wav files to be used for pretraining (we recommend ### Prepare training data manifest: -$ext should be set t flac or wav, or whatever format your dataset happens to use that soundfile can read +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. $valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a @@ -37,6 +37,8 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ### Train a wav2vec 2.0 base model: +This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper + Note that this was tested with pytorch 1.4.0 and the input is expected to be single channel, sampled at 16 kHz ```shell script @@ -57,7 +59,7 @@ Note: you can simulate 64 GPUs by using k GPUs and setting --update-freq 64/k ### Train a wav2vec 2.0 large model: -This configuration was used for model trained on the Libri-light dataset in the paper wav2vec 2.0 paper +This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper ```shell script $ python train.py --distributed-world-size 128 --distributed-port $PORT /manifest/path \ @@ -79,8 +81,8 @@ Note: you can simulate 128 GPUs by using k GPUs and setting --update-freq 128/k ### Fine-tune a pre-trained model with CTC: Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. -A letter vocabulary is can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). -An example script that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: +A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: ```shell script split=train @@ -104,7 +106,7 @@ python train.py --distributed-world-size 24 --distributed-port $PORT /path/to/tr Note: you can simulate 24 GPUs by using k GPUs and setting --update-freq 24/k -Note that decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). +Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). Alternatively, simply omit the --wer-args flag. ### Evaluating a CTC model: @@ -116,13 +118,14 @@ Be sure to upper-case the language model vocab after downloading it. Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). -Next run the evaluation command: +Next, run the evaluation command: ```shell script $subset=dev_other python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ --nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ ---lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --post-process letter +--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ +--post-process letter ``` To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. @@ -170,7 +173,7 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \ ---max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test ``` ### Extract embeddings from the downstream task data: From 54baa2e72d1228ca7d40b78f676ed782285ae6c7 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 5 Aug 2020 02:25:33 -0700 Subject: [PATCH 093/707] fix validation bug (#1237) Summary: clone unmasked features so validation actually works Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1237 Reviewed By: HenryZhou7 Differential Revision: D22942085 Pulled By: alexeib fbshipit-source-id: aeb6519bf2df8b52fde315503043347e7252ebb0 --- fairseq/models/wav2vec/wav2vec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 226f035ba8..342bc5da3f 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -532,7 +532,7 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = features.transpose(1, 2) features = self.layer_norm(features) - unmasked_features = features + unmasked_features = features.clone() if padding_mask is not None: extra = padding_mask.size(1) % features.size(1) From 9bcf326093b94cf063aca8c05040f15c6f731341 Mon Sep 17 00:00:00 2001 From: Mandeep Baines Date: Wed, 5 Aug 2020 15:23:41 -0700 Subject: [PATCH 094/707] fix LegacyDDP to work with gpipe (#1213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Neither c10d or no_c10d work with pipeline parallelism and gradient checkpointing. Added minimal set of changes that get no_c10d working with gpipe. 1. use per-GPU batch buffers (this was not sufficient) 2. simplify when all_reduce happens via an explicit all_reduce call after backward Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1213 Test Plan: fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --ddp-backend no_c10d Before: 2020-07-23 20:22:54 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=263610, ups=1.01, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=66 2020-07-23 20:22:55 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=259423, ups=0.99, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=67 2020-07-23 20:22:56 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=256937, ups=0.98, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=68 2020-07-23 20:22:57 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403792, wps=260323, ups=0.99, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=69 2020-07-23 20:22:58 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=258095, ups=0.98, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=70 After: 2020-07-23 20:20:14 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=268872, ups=1.03, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=66 2020-07-23 20:20:15 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=263022, ups=1, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=67 2020-07-23 20:20:16 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=265523, ups=1.01, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=68 2020-07-23 20:20:17 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403792, wps=259082, ups=0.99, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=69 2020-07-23 20:20:18 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=250407, ups=0.95, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=70 # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22488572 Pulled By: msbaines fbshipit-source-id: a6ab5185fa69c5e0508cdca469e692d1caa8add1 --- fairseq/legacy_distributed_data_parallel.py | 94 +++++++++++---------- fairseq/trainer.py | 3 + 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py index 2e9c82539b..a8840dd4d1 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/legacy_distributed_data_parallel.py @@ -14,6 +14,7 @@ training with `--update-freq`. """ +from collections import OrderedDict from contextlib import contextmanager import copy @@ -60,11 +61,15 @@ def __init__(self, module, world_size, process_group=None, buffer_size=2**28): # all-reduce at some later time self.accumulate_grads = False - # For NCCL backend, since every single NCCL call is asynchoronous, we - # therefore directly enqueue all the NCCL reduction calls to the - # default CUDA stream without spawning up other reduction threads. - # This achieves the best performance. - self._register_grad_hook() + # make per-device lists of parameters + paramlists = OrderedDict() + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + self.per_device_params = list(paramlists.values()) + def __getstate__(self): attrs = copy.copy(self.__dict__) @@ -72,7 +77,6 @@ def __getstate__(self): def __setstate__(self, state): super().__setstate__(state) - self._register_grad_hook() @contextmanager def no_sync(self): @@ -83,17 +87,22 @@ def no_sync(self): self.accumulate_grads = old_accumulate_grads def forward(self, *inputs, **kwargs): + if self.need_reduction: + raise RuntimeError( + 'LegacyDistributedDataParallel requires explicit reduction, ' + 'must call LegacyDistributedDataParallel.all_reduce' + ) + if not self.accumulate_grads: + self.need_reduction = True return self.module(*inputs, **kwargs) - def _register_grad_hook(self): + def all_reduce(self): """ - This function registers the callback all-reduction function for the - NCCL backend. All gradients will be all reduced in one single step. - The NCCL reduction will directly be enqueued into the default CUDA - stream. Therefore, no synchronization is needed. + This function must be called explicitly after backward to reduce + gradients. There is no automatic hook like c10d. """ - def all_reduce(params): + def all_reduce_params(params): buffer = self.buffer nonzero_buffer = False if len(params) > 1: @@ -142,39 +151,32 @@ def reduction_fn(): if self.buffer is None: self.buffer = next(self.module.parameters()).new(self.buffer_size) - # All-reduce the gradients in buckets - offset = 0 - buffered_params = [] - for param in self.module.parameters(): - if not param.requires_grad: - continue - if param.grad is None: - param.grad = torch.zeros_like(param) - if param.grad.requires_grad: - raise RuntimeError("DistributedDataParallel only works " - "with gradients that don't require " - "grad") - sz = param.numel() - if sz > self.buffer.numel(): - # all-reduce big params directly - all_reduce([param]) - else: - if offset + sz > self.buffer.numel(): - all_reduce(buffered_params) - offset = 0 - buffered_params.clear() - buffered_params.append(param) - offset += sz - - if len(buffered_params) > 0: - all_reduce(buffered_params) - - # Now register the reduction hook on the parameters - for p in self.module.parameters(): + for params in self.per_device_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + if param.grad.requires_grad: + raise RuntimeError("DistributedDataParallel only works " + "with gradients that don't require " + "grad") + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz - def allreduce_hook(*unused): - self.need_reduction = True - Variable._execution_engine.queue_callback(reduction_fn) + if len(buffered_params) > 0: + all_reduce_params(buffered_params) - if p.requires_grad: - p.register_hook(allreduce_hook) + reduction_fn() diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a91d12fdc2..898edb6d6c 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -484,6 +484,9 @@ def maybe_no_sync(): ) self._cumulative_training_time = total_train_time / self.data_parallel_world_size + if hasattr(self.model, 'all_reduce'): + self.model.all_reduce() + overflow = False try: if self.tpu and self.data_parallel_world_size > 1: From 8aa06aa03b596de58d106d3f55ff43e2b9aa0b80 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Wed, 5 Aug 2020 18:16:08 -0700 Subject: [PATCH 095/707] Enable different directions to have different number of shards Summary: When different directions have very different sizes of data, they will be sharded into different number of shards while keeping the size of each shard roughly the same. This diff enables multilingual sharding mechanism to deal with it. * small datasets can have only one shard; and they will only appear in the first shard * larger dataset will continue to fill in the shards incrementally * the number of shard of the whole combined dataset is defined by the number of shards of the largest direction dataset * temperature-based sampling will be based on an approximation of data sizes: [size(d_i * num_shard_i) for i in range(num_data_sets)] Reviewed By: pipibjc Differential Revision: D22885256 fbshipit-source-id: 87e4f78191dd9f1ee56bc5aba8dbd94e078a9ada --- .../multilingual/multilingual_data_manager.py | 90 ++++++++++++++++--- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 098fc6fcba..35830fb78d 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -72,7 +72,7 @@ def __init__(self, args, lang_pairs, langs, dicts, sampling_method): self.sampling_method = sampling_method self.sampling_scheduler = None self._has_sharded_data = False - self._num_shards = {} + self._num_shards_dict = {} @classmethod def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): @@ -233,7 +233,7 @@ def check_langs(langs, pairs): dicts = OrderedDict() supported_langtok_specs = args.langtoks_specs for lang in langs_to_load_dicts: - paths = args.data.split(os.pathsep) + paths = utils.split_paths(args.data) assert len(paths) > 0 dicts[lang] = load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) if len(dicts) > 0: @@ -327,6 +327,29 @@ def mono_split_exists(cls, split, lang, data_path, dataset_impl): filename = os.path.join(data_path, '{}.{}'.format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + @classmethod + def bitext_split_exists(cls, split, src, tgt, data_path, dataset_impl): + src_exists = cls.split_exists(split, src, tgt, lang=src, data_path=data_path, dataset_impl=dataset_impl) \ + or cls.split_exists(split, tgt, src, lang=src, data_path=data_path, dataset_impl=dataset_impl) + + tgt_exists = cls.split_exists(split, src, tgt, lang=tgt, data_path=data_path, dataset_impl=dataset_impl) \ + or cls.split_exists(split, tgt, src, lang=tgt, data_path=data_path, dataset_impl=dataset_impl) + return src_exists and tgt_exists + + @classmethod + def get_split_num_shards(cls, split, src, tgt, data_paths, dataset_impl): + return sum( + 1 for path in data_paths + if cls.bitext_split_exists(split, src, tgt, path, dataset_impl) + ) + + @classmethod + def get_mono_split_num_shards(cls, split, lang, data_paths, dataset_impl): + return sum( + 1 for path in data_paths + if cls.mono_split_exists(split, lang, path, dataset_impl) + ) + def load_lang_dataset( self, data_path, split, @@ -607,13 +630,49 @@ def get_data_paths_and_lang_pairs(self, split): lang_pairs.update(extra_lang_pairs) return datapaths, lang_pairs + @classmethod + def get_dataset_key(cls, data_category, src, tgt): + return f'{data_category}:{src}-{tgt}' + + def get_split_num_data_shards(self, split): + if split in self._num_shards_dict: + return self._num_shards_dict[split] + num_shards_dict = {} + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + paths = utils.split_paths(paths) + lang_dirs = [lang_pair.split('-') for lang_pair in lang_pairs[data_category]] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + # monolingual data ruqires tgt only + assert src is not None or 'mono_' in data_category, (f'error: src={src}, ' + 'tgt={tgt} for data_category={data_category}') + key = self.get_dataset_key(data_category, src, tgt) + if 'mono_' in data_category: + num_shards_dict[key] = self.get_mono_split_num_shards( + split, tgt, paths, self.args.dataset_impl) + else: + num_shards_dict[key] = self.get_split_num_shards( + split, src, tgt, paths, self.args.dataset_impl) + self._num_shards_dict[split] = num_shards_dict + logger.info(f"[{split}] num of shards: {num_shards_dict}") + return num_shards_dict + + def get_split_data_path(self, paths, epoch, shard_epoch, num_shards): + shard = epoch if shard_epoch is None else shard_epoch + shard = (shard - 1) % num_shards + path = paths[shard] + return path + def get_split_data_param_list(self, split, epoch, shard_epoch=None): - def get_epoch(epoch, shard_epoch): - return epoch if shard_epoch is None else shard_epoch # TODO: to extend with extra datasets and keys and loop over different shard data paths param_list = [] data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) logger.info(f'langtoks settings: {self.args.langtoks}') + split_num_shards_dict = self.get_split_num_data_shards(split) for data_category, paths in data_paths.items(): if data_category not in lang_pairs: continue @@ -621,9 +680,7 @@ def get_epoch(epoch, shard_epoch): assert len(paths) > 0 if len(paths) > 1: self._has_sharded_data = True - self._num_shards[data_category] = len(paths) - # epoch starts with 1 now: - data_path = paths[(get_epoch(epoch, shard_epoch) - 1) % len(paths)] + if data_category in self.args.langtoks: lang_tok_spec = self.args.langtoks[data_category] else: @@ -637,9 +694,12 @@ def get_epoch(epoch, shard_epoch): assert src is not None or data_category == 'mono_dae', (f'error: src={src}, ' 'tgt={tgt} for data_category={data_category}') # logger.info(f"preparing param for {data_category}: {src} - {tgt}") + key = self.get_dataset_key(data_category, src, tgt) + data_path = self.get_split_data_path( + paths, epoch, shard_epoch, split_num_shards_dict[key]) param_list.append( { - 'key': f'{data_category}:{src}-{tgt}', + 'key': key, 'data_path': data_path, 'split': split, 'src': src, @@ -652,8 +712,15 @@ def get_epoch(epoch, shard_epoch): ) return param_list - def get_train_sampling_ratios(self, datasets, epoch=1): - data_sizes = [len(d) for _, d in datasets] + def get_train_dataset_sizes(self, data_param_list, datasets): + num_shards = [ + self.get_split_num_data_shards(param['split'])[param['key']] for param in data_param_list] + data_sizes = [(key, len(d) * num_shard) for (key, d), num_shard in zip(datasets, num_shards)] + logger.info(f'data sizes multiplied by num_shards used in sampling ratios: {data_sizes}') + return [s for _, s in data_sizes] + + def get_train_sampling_ratios(self, data_param_list, datasets, epoch=1): + data_sizes = self.get_train_dataset_sizes(data_param_list, datasets) sampling_func = self.sampling_method.sampling_method_selector() sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None return sample_ratios @@ -667,8 +734,7 @@ def get_sampling_ratios(self, data_param_list, datasets, epoch): elif self.args.sampling_weights: sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] else: - # TODO: modify to provide sampling function more information other than sizes - sample_ratios = self.get_train_sampling_ratios(datasets, epoch) + sample_ratios = self.get_train_sampling_ratios(data_param_list, datasets, epoch) if sample_ratios is not None: logger.info('| Upsample ratios: {}'.format( From 0bb7bc3777b880c282df794fb7edb56d7280449b Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 6 Aug 2020 10:18:52 -0700 Subject: [PATCH 096/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: add finetuning options Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext - Minor changes to - fairseq/checkpoint_utils.py to add finetuning option instead of using restore_file which will restore from original model when being requeued. Reviewed By: myleott Differential Revision: D22483494 fbshipit-source-id: 733300fd6a4d185e561c793ea668047c96f616c6 --- fairseq/checkpoint_utils.py | 45 ++++++++++++++++---- fairseq/options.py | 3 ++ tests/test_train.py | 83 ++++++++++++++++++++++++++++++++++--- 3 files changed, 117 insertions(+), 14 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 2d37e3fc31..af21db929f 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -122,32 +122,61 @@ def load_checkpoint(args, trainer, **passthrough_args): *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ + reset_optimizer = args.reset_optimizer + reset_lr_scheduler = args.reset_lr_scheduler + optimizer_overrides = eval(args.optimizer_overrides) + reset_meters = args.reset_meters + reset_dataloader = args.reset_dataloader + + if getattr(args, 'finetune_from_model', None) is not None \ + and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): + raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader") suffix = getattr(args, "checkpoint_suffix", "") - if args.restore_file == "checkpoint_last.pt": + if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) + first_launch = not PathManager.exists(checkpoint_path) + if getattr(args, 'finetune_from_model', None) is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(args.finetune_from_model): + checkpoint_path = args.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info(f'loading pretrained model from {checkpoint_path}: ' + 'optimizer, lr scheduler, meters, dataloader will be reset') + else: + raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist') elif getattr(args, "model_parallel_size", 1) > 1: checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = args.restore_file + if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None): + raise ValueError( + '--finetune-from-model and --restore-file (non-default value) ' + 'can not be specified together: ' + str(args)) + extra_state = trainer.load_checkpoint( checkpoint_path, - args.reset_optimizer, - args.reset_lr_scheduler, - eval(args.optimizer_overrides), - reset_meters=args.reset_meters, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, ) if ( extra_state is not None and "best" in extra_state - and not args.reset_optimizer - and not args.reset_meters + and not reset_optimizer + and not reset_meters ): save_checkpoint.best = extra_state["best"] - if extra_state is not None and not args.reset_dataloader: + if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( diff --git a/fairseq/options.py b/fairseq/options.py index 171c67966d..e889821ee6 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -489,6 +489,9 @@ def add_checkpoint_args(parser): group.add_argument('--restore-file', default='checkpoint_last.pt', help='filename from which to load checkpoint ' '(default: /checkpoint_last.pt') + group.add_argument('--finetune-from-model', default=None, type=str, + help='finetune from a pretrained model; ' + 'note that meters and lr scheduler will be reset') group.add_argument('--reset-dataloader', action='store_true', help='if set, does not reload dataloader state from the checkpoint') group.add_argument('--reset-lr-scheduler', action='store_true', diff --git a/tests/test_train.py b/tests/test_train.py index fb935461c8..5be74e415d 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -49,20 +49,28 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc return trainer, epoch_itr +def get_mock_args(finetune_from_model=None): + args_mock = MagicMock() + args_mock.optimizer_overrides = '{}' + args_mock.reset_dataloader = False + args_mock.reset_meters = False + args_mock.reset_optimizer = False + args_mock.reset_lr_scheduler = False + args_mock.finetune_from_model = finetune_from_model + args_mock.model_parallel_size = 1 + return args_mock + + class TestLoadCheckpoint(unittest.TestCase): def setUp(self): - self.args_mock = MagicMock() - self.args_mock.optimizer_overrides = '{}' - self.args_mock.reset_dataloader = False - self.args_mock.reset_meters = False - self.args_mock.reset_optimizer = False - self.args_mock.model_parallel_size = 1 + self.args_mock = get_mock_args() self.patches = { 'os.makedirs': MagicMock(), 'os.path.join': MagicMock(), 'os.path.isfile': MagicMock(return_value=True), 'os.path.isabs': MagicMock(return_value=False), + 'fairseq.file_io.PathManager.exists': MagicMock(return_value=False), } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] @@ -121,6 +129,69 @@ def test_load_no_checkpoint(self): self.assertEqual(epoch_itr.iterations_in_epoch, 0) self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) + def test_finetune_from_model_args_conflict(self): + with contextlib.redirect_stdout(StringIO()): + trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) + trainer.get_train_iterator = MagicMock(return_value=epoch_itr) + + for arg in ['reset_optimizer', 'reset_lr_scheduler', 'reset_meters', 'reset_dataloader']: + with self.subTest(arg=arg): + args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") + setattr(args_mock, arg, True) + with self.assertRaises(Exception) as context: + _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + + self.assertTrue( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" in str(context.exception) + ) + + def test_finetune_from_model(self): + with contextlib.redirect_stdout(StringIO()): + trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) + trainer.get_train_iterator = MagicMock(return_value=epoch_itr) + from_model_path = "/temp/checkpoint_pretrained.pt" + args_mock = get_mock_args(from_model_path) + args_mock.restore_file = "checkpoint_last.pt" + + def mock_finetune_exist(path): + if path == from_model_path: + return True + else: + return False + self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist + _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + checkpoint_path, reset_optimizer, reset_lr_scheduler, \ + optimizer_overrides = trainer.load_checkpoint.call_args[0] + reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters'] + self.assertTrue(reset_optimizer) + self.assertTrue(reset_lr_scheduler) + self.assertTrue(reset_meters) + + def test_finetune_from_model_resume(self): + with contextlib.redirect_stdout(StringIO()): + trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) + trainer.get_train_iterator = MagicMock(return_value=epoch_itr) + from_model_path = "/temp/checkpoint_pretrained.pt" + args_mock = get_mock_args(from_model_path) + args_mock.restore_file = "checkpoint_last.pt" + + # launch second time + # both restore_file=checkpoint_last.pt and finetune_from_model are set + def mock_finetune_exist(path): + if path == from_model_path or path.endsWith('checkpoint_last.pt'): + return True + else: + return False + self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist + _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + checkpoint_path, reset_optimizer, reset_lr_scheduler, \ + optimizer_overrides = trainer.load_checkpoint.call_args[0] + reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters'] + self.assertFalse(reset_optimizer) + self.assertFalse(reset_lr_scheduler) + self.assertFalse(reset_meters) + def tearDown(self): patch.stopall() From 631023200a8e0c8531ccdb9e5d7f54411149efdf Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 7 Aug 2020 14:17:45 -0700 Subject: [PATCH 097/707] Fix LegacyDistributedDataParallel (#2447) Summary: The "need_reduction" check was to help any legacy callers who might not be using the fairseq Trainer. Unfortunately the mechanism doesn't work during validation, since there we call forward multiple times without calling backward. Let's remove the check -- pretty much everyone should be using fairseq Trainer. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2447 Reviewed By: cndn, msbaines Differential Revision: D23005086 Pulled By: myleott fbshipit-source-id: 3a58342024096f2f3abc7bf1cdcce75e56a06aa0 --- fairseq/legacy_distributed_data_parallel.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py index a8840dd4d1..9832f2c97a 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/legacy_distributed_data_parallel.py @@ -53,10 +53,6 @@ def __init__(self, module, world_size, process_group=None, buffer_size=2**28): self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) self.buffer = None - # Flag used by the NCCL backend to make sure we only reduce gradients - # one time in the execution engine - self.need_reduction = False - # We can also forcibly accumulate grads locally and only do the # all-reduce at some later time self.accumulate_grads = False @@ -87,13 +83,6 @@ def no_sync(self): self.accumulate_grads = old_accumulate_grads def forward(self, *inputs, **kwargs): - if self.need_reduction: - raise RuntimeError( - 'LegacyDistributedDataParallel requires explicit reduction, ' - 'must call LegacyDistributedDataParallel.all_reduce' - ) - if not self.accumulate_grads: - self.need_reduction = True return self.module(*inputs, **kwargs) def all_reduce(self): @@ -144,9 +133,8 @@ def all_reduce_params(params): def reduction_fn(): # This function only needs to be called once - if not self.need_reduction or self.accumulate_grads: + if self.accumulate_grads: return - self.need_reduction = False if self.buffer is None: self.buffer = next(self.module.parameters()).new(self.buffer_size) From b479cd946faa304178c480ae8deda04a94ee54a7 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 7 Aug 2020 15:19:26 -0700 Subject: [PATCH 098/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: readme and example scripts Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext Readme and example scripts. Reviewed By: pipibjc Differential Revision: D22483502 fbshipit-source-id: 983a9949a088d9dbbdb24a5c07fa92044e1c65f1 --- examples/multilingual/README.md | 124 ++++++++++++++++++ .../finetune_multilingual_model.sh | 27 ++++ .../multilingual/multilingual_fairseq_gen.sh | 21 +++ .../multilingual/train_multilingual_model.sh | 23 ++++ 4 files changed, 195 insertions(+) create mode 100644 examples/multilingual/README.md create mode 100644 examples/multilingual/finetune_multilingual_model.sh create mode 100644 examples/multilingual/multilingual_fairseq_gen.sh create mode 100644 examples/multilingual/train_multilingual_model.sh diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md new file mode 100644 index 0000000000..392476f3b0 --- /dev/null +++ b/examples/multilingual/README.md @@ -0,0 +1,124 @@ +# Multilingual Translation + +[[Multilingual Translation with Extensible Multilingual Pretraining and Finetuning, https://arxiv.org/abs/2008.00401]](https://arxiv.org/abs/2008.00401) + +## Introduction + +This work is for training multilingual translation models with multiple bitext datasets. This multilingual translation framework supports (see [[training section]](#Training) and [[finetuning section]](#Finetuning) for examples) + +* temperature based sampling over unbalancing datasets of different translation directions + - --sampling-method' with + choices=['uniform', 'temperature', 'concat'] + - --sampling-temperature +* configurable to automatically add source and/or target language tokens to source/target sentences using data which are prepared in the same way as bilignual training + - --encoder-langtok with choices=['src', 'tgt', None] to specify whether to add source or target language tokens to the source sentences + - --decoder-langtok (binary option) to specify whether to add target language tokens to the target sentences or not +* finetuning mBART pretrained models for multilingual translation + - --finetune-from-model to specify the path from which to load the pretrained model + +## Preprocessing data +Multilingual training requires a joint BPE vocab. Please follow [mBART's preprocessing steps](https://github.com/pytorch/fairseq/tree/master/examples/mbart#bpe-data) to reuse our pretrained sentence-piece model. + +You can also train a joint BPE model on your own dataset and then follow the steps in [[link]](https://github.com/pytorch/fairseq/tree/master/examples/translation#multilingual-translation). + +## Training + + +```bash +lang_pairs= +path_2_data= +lang_list= + +fairseq-train $path_2_data \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 +``` + +## Finetuning +We can also finetune multilingual models from a monolingual pretrained models, e.g. [mMBART](https://github.com/pytorch/fairseq/tree/master/examples/mbart). +```bash +lang_pairs= +path_2_data= +lang_list= +pretrained_model= + +fairseq-train $path_2_data \ + --finetune-from-model $pretrained_model \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 +``` +## Generate +The following command uses the multilingual task (translation_multi_simple_epoch) to generate translation from $source_lang to $target_lang on the test dataset. During generaton, the source language tokens are added to source sentences and the target language tokens are added as the starting token to decode target sentences. Options --lang-dict and --lang-pairs are needed to tell the generation process the ordered list of languages and translation directions that the trained model are awared of; they will need to be consistent with the training. + +```bash +model= +source_lang= +target_lang= + +fairseq-generate $path_2_data \ + --path $model \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang $source_lang \ + --target-lang $target_lang + --sacrebleu --remove-bpe 'sentencepiece'\ + --max-sentences 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" > ${source_lang}_${target_lang}.txt +``` +Fairseq will generate translation into a file {source_lang}_${target_lang}.txt with sacreblue at the end. + +You can also use costomized tokenizer to compare the performance with the literature. For example, you get a tokenizer [here](https://github.com/rsennrich/wmt16-scripts) and do the following: +```bash +TOKENIZER= +TOK_CMD=<"$TOKENIZER $target_lang" or cat for sacrebleu> + +cat {source_lang}_${target_lang}.txt | grep -P "^H" |sort -V |cut -f 3- |$TOK_CMD > ${source_lang}_${target_lang}.hyp +cat {source_lang}_${target_lang}.txt | grep -P "^T" |sort -V |cut -f 2- |$TOK_CMD > ${source_lang}_${target_lang}.ref +sacrebleu -tok 'none' -s 'none' ${source_lang}_${target_lang}.ref < ${source_lang}_${target_lang}.hyp +``` + + + +## Citation + +```bibtex +@article{tang2020multilingual, + title={Multilingual Translation with Extensible Multilingual Pretraining and Finetuning}, + author={Yuqing Tang and Chau Tran and Xian Li and Peng-Jen Chen and Naman Goyal and Vishrav Chaudhary and Jiatao Gu and Angela Fan}, + year={2020}, + eprint={2008.00401}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/examples/multilingual/finetune_multilingual_model.sh b/examples/multilingual/finetune_multilingual_model.sh new file mode 100644 index 0000000000..cfa9a86113 --- /dev/null +++ b/examples/multilingual/finetune_multilingual_model.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +path_2_data=$1 # which contains binarized data for each directions +lang_list=$2 # +lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" +# pretrained can be an mBART pretrained model as well +pretrained_model=$4 # + + +fairseq-train "$path_2_data" \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --finetune-from-model "$pretrained_model" \ + --sampling-method "temperature" \ + --sampling-temperature "1.5" \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 diff --git a/examples/multilingual/multilingual_fairseq_gen.sh b/examples/multilingual/multilingual_fairseq_gen.sh new file mode 100644 index 0000000000..a7487975e4 --- /dev/null +++ b/examples/multilingual/multilingual_fairseq_gen.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +lang_pairs="en-fr,en-cs,fr-en,cs-en" +path_2_data=$1 # +lang_list=$2 # +model=$3 # +source_lang=cs +target_lang=en + +fairseq-generate "$path_2_data" \ + --path "$model" \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang "$source_lang" \ + --target-lang "$target_lang" \ + --sacrebleu --remove-bpe 'sentencepiece'\ + --max-sentences 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" diff --git a/examples/multilingual/train_multilingual_model.sh b/examples/multilingual/train_multilingual_model.sh new file mode 100644 index 0000000000..09014c8217 --- /dev/null +++ b/examples/multilingual/train_multilingual_model.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +path_2_data=$1 # which contains binarized data for each directions +lang_list=$2 # +lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" + +fairseq-train "$path_2_data" \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 From 587c179818d55a95d8b0d8bbd40f245a1b0c3692 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 7 Aug 2020 17:36:17 -0700 Subject: [PATCH 099/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: added denoising data to training Summary: A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext - Minor changes to - fairseq/data/denoising_dataset.py to (1) allow additional transformation; (2) allow a preset pad_to_length of batch; Reviewed By: myleott Differential Revision: D22483451 fbshipit-source-id: 2ffa9f95186dde2a42e0c356fea3f33c42711c82 --- fairseq/data/denoising_dataset.py | 29 +++++++++++++++---- .../multilingual/multilingual_data_manager.py | 19 +++++++++--- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index 9d7c318702..8dc240c1eb 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -18,22 +18,27 @@ def collate( left_pad_source=False, left_pad_target=False, input_feeding=True, + pad_to_length=None, ): assert input_feeding if len(samples) == 0: return {} - def merge(key, left_pad, move_eos_to_beginning=False): + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx=None, # use eos_idx of each sample instead of vocab.eos() left_pad=left_pad, move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, ) id = torch.LongTensor([s['id'] for s in samples]) - src_tokens = merge('source', left_pad=left_pad_source) + src_tokens = merge( + 'source', left_pad=left_pad_source, + pad_to_length=pad_to_length['source'] if pad_to_length is not None else None, + ) # sort by descending source length src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) @@ -43,7 +48,10 @@ def merge(key, left_pad, move_eos_to_beginning=False): prev_output_tokens = None target = None if samples[0].get('target', None) is not None: - target = merge('target', left_pad=left_pad_target) + target = merge( + 'target', left_pad=left_pad_target, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + ) target = target.index_select(0, sort_order) ntokens = sum(len(s['target']) for s in samples) @@ -54,6 +62,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): 'target', left_pad=left_pad_target, move_eos_to_beginning=True, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: @@ -68,6 +77,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): }, 'target': target, 'nsentences': samples[0]['source'].size(0), + 'sort_order': sort_order, } if prev_output_tokens is not None: batch['net_input']['prev_output_tokens'] = prev_output_tokens @@ -103,7 +113,8 @@ def __init__( shuffle, seed, args, - eos=None + eos=None, + item_transform_func=None, ): self.dataset = dataset @@ -120,6 +131,7 @@ def __init__( self.rotate_ratio = args.rotate self.permute_sentence_ratio = args.permute_sentences self.eos = (eos if eos is not None else vocab.eos()) + self.item_transform_func = item_transform_func if args.bpe != 'gpt2': self.full_stop_index = self.vocab.eos() @@ -174,6 +186,9 @@ def __getitem__(self, index): if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: source = self.add_rolling_noise(source) + # there can additional changes to make: + if self.item_transform_func is not None: + source, target = self.item_transform_func(source, target) assert (source >= 0).all() assert (source[1:-1] >= 1).all() @@ -348,14 +363,16 @@ def add_insertion_noise(self, tokens, p): assert (result >= 0).all() return result - def collater(self, samples): + def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch of data """ - return collate(samples, self.vocab.pad(), self.eos, self.vocab) + return collate( + samples, self.vocab.pad(), self.eos, self.vocab, + pad_to_length=pad_to_length) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 35830fb78d..92f6cc659c 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -82,9 +82,15 @@ def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): def add_args(parser): parser.add_argument('data', help='colon separated path to data directories list, \ will be iterated upon during epochs in round-robin manner') + parser.add_argument('--langs', default=None, type=csv_str_list, + help='a list of languages comma sperated languages which can appear in lang-pairs; ' + 'note that the ordering determines language token IDs', + ) parser.add_argument('--lang-dict', default=None, type=str, - help='language dictionary path with a list of ' - 'languages which can appear in lang-pairs') + help='an external file which contains a list of ' + 'languages which can appear in lang-pairs; ' + 'note that the ordering determines language token IDs; ' + '--langs and --lang-dict are two exclusive options') parser.add_argument('--lang-tok-style', default='multilingual', type=str, choices=['multilingual', 'mbart'], help='language token styles') @@ -157,7 +163,9 @@ def add_args(parser): @classmethod def load_langs(cls, args, **kwargs): - if args.lang_dict is None: + if args.lang_dict and args.langs: + raise ValueError('--langs and --lang-dict can not both be specified') + if args.lang_dict is None and args.langs is None: logger.warning( 'External language dictionary is not provided; ' 'use lang-pairs to infer the set of supported languages. ' @@ -167,10 +175,13 @@ def load_langs(cls, args, **kwargs): langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}) langs = sorted(langs) logger.info(f'inferred language list: {langs}') - else: + elif args.lang_dict: with PathManager.open(args.lang_dict, "r", encoding="utf-8") as f: langs = [lang.strip() for lang in f.readlines() if lang.strip()] logger.info(f'loaded language list from {args.lang_dict} as they are ordered in file') + elif args.langs: + langs = args.langs + logger.info(f'parsed the language list as they are ordered in the option: {langs}') return langs def has_sharded_data(self, split): From e882cfaaa38eafdba5126cad7cbdd399be677ab0 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 7 Aug 2020 19:54:18 -0700 Subject: [PATCH 100/707] Use PlasmaArray for itermediate indices in SampledMultilDataset to avoid copying during pickling Summary: Use PlasmaArray for itermediate indices in SampledMultilDataset to avoid copying during pickling Reviewed By: pipibjc Differential Revision: D22889904 fbshipit-source-id: 795be21ab3d34ca1e993883956ebd6622edab472 --- .../multilingual/sampled_multi_dataset.py | 15 +++++++- .../sampled_multi_epoch_dataset.py | 34 ++++++++++++++----- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index d6c104031e..74e432ea49 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -119,7 +119,15 @@ def __init__( self._size_cache = {} self.set_epoch(epoch) + def _clean_if_not_none(self, var_list): + for v in var_list: + if v is not None: + del v + def _reset_cached_properties(self): + self._clean_if_not_none([ + self._sizes, self._ordered_indices, self._cur_indices + ]) self._sizes = None self._ordered_indices = None self._cur_indices = None @@ -212,7 +220,8 @@ def _get_dataset_and_index(self, index): def __getitem__(self, index): ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - return (ds_idx, self.datasets[ds_idx][ds_sample_idx]) + ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx]) + return ret def num_tokens(self, index): ds_idx, ds_sample_idx = self._get_dataset_and_index(index) @@ -366,6 +375,10 @@ def _establish_virtual_datasets(self): ) indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( rng, self.datasets, self.sample_ratios, self.virtual_size) + + self._clean_if_not_none([ + self.cumulated_sizes, self.virtual_size_per_dataset + ]) self._cur_indices = plasma_utils.PlasmaArray(indices) self.cumulated_sizes = plasma_utils.PlasmaArray(cumulated_sizes) self.virtual_size_per_dataset = plasma_utils.PlasmaArray(virtual_size_per_dataset) diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index ff7f7fa18b..fdd47e5091 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -108,7 +108,7 @@ def num_tokens(self, index): def size(self, index): if self._epoch_sizes is not None: - return self._epoch_sizes[index] + return self._epoch_sizes.array[index] index = self._map_epoch_index_to_global(index) ds_idx, ds_sample_idx = self._get_dataset_and_index(index) return self.datasets[ds_idx].size(ds_sample_idx) @@ -123,7 +123,7 @@ def __len__(self): @property def sizes(self): if self._epoch_sizes is not None: - return self._epoch_sizes + return self._epoch_sizes.array start_time = time.time() size_cache = self._size_cache @@ -139,13 +139,13 @@ def sizes(self): s = (s, s) if not isinstance(s, tuple) else s size_cache[(ds_idx, ds_sample_idx)] = s ret.append(s) - logger.debug(f'sizes() calling time: {get_time_gap(start_time, time.time())}') - self._epoch_sizes = np.array(ret, np.int64) - return self._epoch_sizes + self._epoch_sizes = plasma_utils.PlasmaArray(np.array(ret, np.int64)) + logger.info(f'sizes() calling time: {get_time_gap(start_time, time.time())}') + return self._epoch_sizes.array def ordered_indices(self): if self._epoch_ordered_indices is not None: - return self._epoch_ordered_indices + return self._epoch_ordered_indices.array if self.batch_by_size: # No need to do shuffle as the data items are already randomized @@ -162,8 +162,8 @@ def ordered_indices(self): sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] else: sort_indices = np.arange(len(self)) - self._epoch_ordered_indices = sort_indices - return sort_indices + self._epoch_ordered_indices = plasma_utils.PlasmaArray(sort_indices) + return self._epoch_ordered_indices.array def prefetch(self, indices): prefetch_indices = [[] for _ in range(len(self.datasets))] @@ -195,6 +195,7 @@ def _next_global_indices(self, epoch): epoch, # epoch index, ] ) + del self._random_globa_indices self._random_globa_indices = plasma_utils.PlasmaArray( rng.choice(self.virtual_size, self.virtual_size, replace=False)) if self.load_next_shard is None: @@ -222,6 +223,19 @@ def _sync_shard_epoch(self, shard_epoch): ret = ret.numpy() return ret + def _sync_epoch(self, epoch): + # in case the ratios are not precisely the same across processes + # also to ensure every procresses update the ratios in the same pace + epoch = torch.DoubleTensor([epoch]) + if torch.distributed.is_initialized(): + if torch.cuda.is_available(): + distributed_utils.all_reduce(epoch.cuda()) + else: + distributed_utils.all_reduce(epoch) + ret = epoch.cpu() + ret = ret.numpy() + return ret + def _next_virtual_epoch(self, epoch): index = self._get_epoch_start_index(epoch) if index == 0 or self._random_globa_indices is None: @@ -234,6 +248,10 @@ def _next_virtual_epoch(self, epoch): else: self._cur_epoch = epoch # reset cache sizes and ordered_indices for the epoch after moving to a new epoch + + self._clean_if_not_none([ + self._epoch_sizes, self._epoch_ordered_indices, self._size_cache + ]) self._epoch_sizes = None self._epoch_ordered_indices = None self._current_epoch_start_index = index From 522c76ba1646cd5ec2cd4be29392f53d40aec50a Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 8 Aug 2020 21:24:27 -0700 Subject: [PATCH 101/707] no stochastic cropping (#1241) Summary: this matches what i had in my development branch for wav2vec 2.0 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1241 Reviewed By: HenryZhou7 Differential Revision: D23020339 Pulled By: alexeib fbshipit-source-id: 582203d627c6abce2baa4aaf91d4e5601b9a45d1 --- fairseq/models/wav2vec/wav2vec2.py | 2 +- fairseq/tasks/audio_pretraining.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 342bc5da3f..be6d10c7a2 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -993,7 +993,7 @@ def base_architecture(args): args.mask_min_space = getattr(args, "mask_min_space", 1) args.mask_channel_length = getattr(args, "mask_channel_length", 10) - args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.65) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0) args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") args.mask_channel_other = getattr(args, "mask_channel_other", 0) args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index f33637468f..46d164ba98 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -62,10 +62,6 @@ def add_args(parser): help="pad shorter samples instead of cropping", ) - parser.add_argument( - "--no-min-cropping", action="store_true", help="always crop to max sample size or smallest length" - ) - parser.add_argument( "--labels", type=str, @@ -99,7 +95,7 @@ def load_dataset(self, split, **kwargs): manifest, sample_rate=self.args.sample_rate, max_sample_size=self.args.max_sample_size, - min_sample_size=self.args.min_sample_size if not self.args.no_min_cropping else self.args.max_sample_size, + min_sample_size=self.args.max_sample_size, min_length=self.args.min_sample_size, pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, From c0a187d86e2cfa192f137656072050a0407c71f3 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 10 Aug 2020 09:54:45 -0700 Subject: [PATCH 102/707] Fixed a bug in sharing source and target data set for reversed directions in multilingual data manager Summary: We can reuse source and target datasets for reserved direction to save CPU memory. However, the old implementation has two bugs: (1) with the align_dataset present in lang_pair, the current logic will never allow reversed soure and target data sharing; (2) the key of cached datasets does not include direction, so different directions' source and target datasets can mess up with each other. This is a fix. Reviewed By: pipibjc Differential Revision: D22889908 fbshipit-source-id: e64ec36819ec32e711fd009228af5d4a569754d4 --- .../multilingual/multilingual_data_manager.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 92f6cc659c..b31028c191 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -449,10 +449,11 @@ def load_langpair_dataset( tgt_lang_id=None, langpairs_sharing_datasets=None, ): + norm_direction = "-".join(sorted([src, tgt])) if langpairs_sharing_datasets is not None: - src_dataset = langpairs_sharing_datasets.get((data_path, split, src), 'NotInCache') - tgt_dataset = langpairs_sharing_datasets.get((data_path, split, tgt), 'NotInCache') - align_dataset = langpairs_sharing_datasets.get((data_path, split, src, tgt), 'NotInCache') + src_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, src), 'NotInCache') + tgt_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, tgt), 'NotInCache') + align_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, src, tgt), 'NotInCache') # a hack: any one is not in cache, we need to reload them if ( @@ -476,9 +477,15 @@ def load_langpair_dataset( src_dataset = src_dataset_transform_func(src_dataset) tgt_dataset = tgt_dataset_transform_func(tgt_dataset) if langpairs_sharing_datasets is not None: - langpairs_sharing_datasets[(data_path, split, src)] = src_dataset - langpairs_sharing_datasets[(data_path, split, tgt)] = tgt_dataset - langpairs_sharing_datasets[(data_path, split, src, tgt)] = align_dataset + langpairs_sharing_datasets[(data_path, split, norm_direction, src)] = src_dataset + langpairs_sharing_datasets[(data_path, split, norm_direction, tgt)] = tgt_dataset + langpairs_sharing_datasets[(data_path, split, norm_direction, src, tgt)] = align_dataset + if align_dataset is None: + # no align data so flag the reverse direction as well in sharing + langpairs_sharing_datasets[(data_path, split, norm_direction, tgt, src)] = align_dataset + else: + logger.info(f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " + f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}") return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, From 8453a08cd18f3972fa62584cfc3a278bcc3af16c Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 10 Aug 2020 11:09:03 -0700 Subject: [PATCH 103/707] Fixed a bug in fairseq training loop where valid_losses is not evaluated before reference Summary: It is possible that valid_losses is not evaluated in train.py. Here is a fix. # Facebook: see f209151313 for an example. Reviewed By: myleott Differential Revision: D22894276 fbshipit-source-id: a3351d8d77e68ee64ea7c61a1e2c237b5fcfeeae --- fairseq_cli/train.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 806e4bc54b..90574ee8f7 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -215,18 +215,17 @@ def train(args, trainer, task, epoch_itr): "train_step-%d" % i ): log_output = trainer.train_step(samples) - if log_output is None: # OOM, overflow, ... - continue - - # log mid-epoch stats - num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: - stats = get_training_stats(metrics.get_smoothed_values("train_inner")) - progress.log(stats, tag="train_inner", step=num_updates) - - # reset mid-epoch stats after each log interval - # the end-of-epoch stats will still be preserved - metrics.reset_meters("train_inner") + + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % args.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( From 4c55744ec4cb26749cf2cf8dac89942f26ce4bd2 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 10 Aug 2020 19:27:07 -0700 Subject: [PATCH 104/707] Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: correct parameter names to be consistent with suggestioned revisions made to [1/x] Summary: In revision to [1/x], it was sugguested to change max_size into pad_to_length. Here let's correct parameter names in other datasets to be consistent with revisions made to [1/x]. Reviewed By: myleott Differential Revision: D22994888 fbshipit-source-id: cdc51af9eae7f122c644f5727eaf084013855ceb --- fairseq/data/concat_dataset.py | 6 +++--- fairseq/data/multilingual/sampled_multi_dataset.py | 8 ++++---- fairseq/data/transform_eos_lang_pair_dataset.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py index 0376b3ae1d..5ca80631f0 100644 --- a/fairseq/data/concat_dataset.py +++ b/fairseq/data/concat_dataset.py @@ -47,12 +47,12 @@ def _get_dataset_and_sample_index(self, idx: int): sample_idx = sample_idx % self.real_sizes[dataset_idx] return dataset_idx, sample_idx - def collater(self, samples): + def collater(self, samples, **extra_args): # For now only supports datasets with same underlying collater implementations if hasattr(self.datasets[0], 'collater'): - return self.datasets[0].collater(samples) + return self.datasets[0].collater(samples, **extra_args) else: - return default_collate(samples) + return default_collate(samples, **extra_args) def size(self, idx: int): """ diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 74e432ea49..95eab280f0 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -255,14 +255,14 @@ def collater(self, samples, **extra_args): ) else: samples_dict = defaultdict(list) - max_size = defaultdict(int) + pad_to_length = defaultdict(int) if 'pad_to_length' not in extra_args else extra_args['pad_to_length'] for ds_idx, s in samples: - max_size['source'] = max(max_size['source'], s['source'].size(0)) + pad_to_length['source'] = max(pad_to_length['source'], s['source'].size(0)) if s['target'] is not None: - max_size['target'] = max(max_size['target'], s['target'].size(0)) + pad_to_length['target'] = max(pad_to_length['target'], s['target'].size(0)) samples_dict[ds_idx].append(s) batches = [ - self.datasets[i].collater(samples_dict[i], max_size=max_size) + self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length) for i in range(len(self.datasets)) if len(samples_dict[i]) > 0 ] diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index 40f77c2916..55137ca55c 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -45,8 +45,8 @@ def __getitem__(self, index): def __len__(self): return len(self.dataset) - def collater(self, samples): - samples = self.dataset.collater(samples) + def collater(self, samples, **extra_args): + samples = self.dataset.collater(samples, **extra_args) if self.new_src_eos is not None: if self.dataset.left_pad_source: From 9438fb58ecdbfc7a816893e4c2ba15e03062f921 Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov Date: Wed, 12 Aug 2020 11:16:46 -0700 Subject: [PATCH 105/707] Optimize `filter_by_size` for `LanguagePairDataset` Summary: Speedup filtering by using numpy native filter methods. Summary from the test plan: Used inputs from f208190331 (as suggested in the task). Function runtime: Original run (nothing will be removed): ~10s Remove ~100K from ~200M: ~20s Remove all indices from ~200M samples: ~40s Details are in the test plan Reviewed By: myleott, akinh Differential Revision: D23045469 fbshipit-source-id: 8b755c0c203e51d057130a8df994ff8eb76c9b2b --- fairseq/data/data_utils.py | 3 +- fairseq/data/fairseq_dataset.py | 30 +++++++++++++++++ fairseq/data/language_pair_dataset.py | 32 ++++++++++++++++++ fairseq/tasks/fairseq_task.py | 47 +++++++++++++++++++++++---- 4 files changed, 104 insertions(+), 8 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 57991a8802..4f0234ff93 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -171,7 +171,8 @@ def check_size(idx): def filter_by_size(indices, dataset, max_positions, raise_exception=False): """ - Filter indices based on their size. + [deprecated] Filter indices based on their size. + Use `FairseqDataset::filter_indices_by_size` instead. Args: indices (List[int]): ordered list of dataset indices diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 5786d5c851..900bfaff10 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -6,6 +6,8 @@ import numpy as np import torch.utils.data +from fairseq.data import data_utils + class EpochListening: """Mixin for receiving updates whenever the epoch increments.""" @@ -122,6 +124,34 @@ def adjust_bsz(bsz, num_tokens): fixed_shapes=fixed_shapes, ) + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, 'sizes') and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif hasattr(self, 'sizes') and isinstance(self.sizes, list) and len(self.sizes) == 1: + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) + else: + indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) + return indices, ignored + class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """For datasets that need to be read sequentially, usually because the data diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 5c9e09edcf..7576e07d34 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -386,3 +386,35 @@ def prefetch(self, indices): self.tgt.prefetch(indices) if self.align_dataset is not None: self.align_dataset.prefetch(indices) + + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if self.tgt_sizes is None: + ignored = indices[self.src_sizes[indices] > max_src_size] + else: + ignored = indices[(self.src_sizes[indices] > max_src_size) | + (self.tgt_sizes[indices] > max_tgt_size)] + if len(ignored) > 0: + if self.tgt_sizes is None: + indices = indices[self.src_sizes[indices] <= max_src_size] + else: + indices = indices[(self.src_sizes[indices] <= max_src_size) & + (self.tgt_sizes[indices] <= max_tgt_size)] + return indices, ignored.tolist() diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 503a008b51..59663b531d 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -3,14 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings +import logging import os +import warnings + import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary +logger = logging.getLogger(__name__) + class FairseqTask(object): """ @@ -108,6 +112,37 @@ def dataset(self, split): raise TypeError("Datasets are expected to be of type FairseqDataset") return self.datasets[split] + def filter_indices_by_size(self, + indices, + dataset, + max_positions, + ignore_invalid_inputs): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + indices, ignored = dataset.filter_indices_by_size(indices, max_positions) + if len(ignored) > 0: + if not ignore_invalid_inputs: + raise Exception(( + 'Size of sample #{} is invalid (={}) since max_positions={}, ' + 'skip this example with --skip-invalid-size-inputs-valid-test' + ).format(ignored[0], dataset.size(ignored[0]), max_positions)) + logger.warning(( + '{} samples have invalid sizes and will be skipped, ' + 'max_positions={}, first few sample ids={}' + ).format(len(ignored), max_positions, ignored[:10])) + return indices + def get_batch_iterator( self, dataset, @@ -169,12 +204,10 @@ def get_batch_iterator( # filter examples that are too large if max_positions is not None: - indices = data_utils.filter_by_size( - indices, - dataset, - max_positions, - raise_exception=(not ignore_invalid_inputs), - ) + indices = self.filter_indices_by_size(indices, + dataset, + max_positions, + ignore_invalid_inputs) # create mini-batches with given size constraints batch_sampler = dataset.batch_by_size( From 7f3a5f6d5f6c571cce4fecd453a197c3dc521539 Mon Sep 17 00:00:00 2001 From: Kritika Singh Date: Wed, 12 Aug 2020 19:53:53 -0700 Subject: [PATCH 106/707] Use safe_round in ctc reduce_metrics Summary: Validation step was failing with this error: ``` File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/trainer.py", line 644, in valid_step logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/trainer.py", line 949, in _reduce_and_log_stats logging_output = agg.get_smoothed_values() File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/logging/meters.py", line 268, in get_smoothed_values for key in self.keys() File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/logging/meters.py", line 269, in if not key.startswith("_") File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/logging/meters.py", line 260, in get_smoothed_value return meter.fn(self) File "/mnt/xarfuse/uid-188222/f6ace8b6-seed-f166b132-0f4f-40ae-8261-c83ec5d2e63c-ns-4026533670/fairseq/criterions/ctc.py", line 234, in if meters["_w_total"].sum > 0 TypeError: type Tensor doesn't define __round__ method ``` This issue was also raised in T71334670. Quoting Myle from https://fb.workplace.com/groups/fairseq/permalink/694410594696062/: "Hmm, yeah, we now return tensors for multi-GPU since the stats are synced over GPU/NCCL. You can use utils.item(trainer.get_meter(...)) as a workaround." Reviewed By: alexeib Differential Revision: D23035154 fbshipit-source-id: 65537f3b5d9d1112d9e3a1964618ec524bc28378 --- fairseq/criterions/ctc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 3b8d974387..cbf712c69d 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -12,6 +12,7 @@ from fairseq import metrics, utils from fairseq.data.data_utils import post_process from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.logging.meters import safe_round @register_criterion("ctc") @@ -218,20 +219,20 @@ def reduce_metrics(logging_outputs) -> None: if c_total > 0: metrics.log_derived( "uer", - lambda meters: round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3) + lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3) if meters["_c_total"].sum > 0 else float("nan"), ) if w_total > 0: metrics.log_derived( "wer", - lambda meters: round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3) if meters["_w_total"].sum > 0 else float("nan"), ) metrics.log_derived( "raw_wer", - lambda meters: round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) if meters["_w_total"].sum > 0 else float("nan"), ) From fe3b63643c6c15cac70d0958242dbfb6bdf710e3 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 13 Aug 2020 15:26:49 -0700 Subject: [PATCH 107/707] delete windows build which always fails (#1245) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1245 Reviewed By: myleott Differential Revision: D23110691 Pulled By: alexeib fbshipit-source-id: bedf432f8710ba5c1a651410d27b39e282b80e11 --- .github/workflows/build_windows.yml | 44 ----------------------------- 1 file changed, 44 deletions(-) delete mode 100644 .github/workflows/build_windows.yml diff --git a/.github/workflows/build_windows.yml b/.github/workflows/build_windows.yml deleted file mode 100644 index 3161fd09c7..0000000000 --- a/.github/workflows/build_windows.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: build_windows - -on: - # Trigger the workflow on push to master or any pull request - push: - branches: - - master - pull_request: - -jobs: - build: - - strategy: - max-parallel: 4 - matrix: - platform: [windows-latest] - python-version: [3.6, 3.7] - - runs-on: ${{ matrix.platform }} - - steps: - - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Conditionally install pytorch - if: matrix.platform == 'windows-latest' - run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html - - name: Install locally - run: | - python -m pip install --upgrade pip - python setup.py build_ext --inplace - python -m pip install --editable . - - name: Lint with flake8 - run: | - pip install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Run tests - run: | - python setup.py test From 3a7b04fa098f6a94888c6700860afa44a22d9c9d Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 13 Aug 2020 16:06:11 -0700 Subject: [PATCH 108/707] fix wav2vec seq2seq training (#1244) Summary: fixes training wav2vec seq2seq models Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1244 Reviewed By: aconneau, jmp84 Differential Revision: D23110102 Pulled By: alexeib fbshipit-source-id: 7db38d4f59826000eac58fdd7e6cbb8e9cbb5b43 --- fairseq/data/add_target_dataset.py | 11 +++++++++-- fairseq/models/wav2vec/wav2vec2_asr.py | 14 ++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py index 91cf1a51c4..3a42dd7a2e 100644 --- a/fairseq/data/add_target_dataset.py +++ b/fairseq/data/add_target_dataset.py @@ -38,12 +38,19 @@ def collater(self, samples): return collated indices = set(collated["id"].tolist()) target = [s["label"] for s in samples if s["id"] in indices] + if self.batch_targets: collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + else: + collated["ntokens"] = sum([len(t) for t in target]) + collated["target"] = target + if self.add_to_input: eos = target.new_full((target.size(0), 1), self.eos) - collated["target"] = torch.cat([target, eos], dim=-1) - collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1) + collated["target"] = torch.cat([target, eos], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1).long() + collated["ntokens"] += target.size(0) return collated \ No newline at end of file diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index f50af255a5..e47e1f7009 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -358,6 +358,8 @@ def __init__(self, args, tgt_dict=None): if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) + elif getattr(args, 'decoder_embed_dim', d) != d: + self.proj = Linear(d, args.decoder_embed_dim) else: self.proj = None @@ -434,7 +436,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim - self.output_embed_dim = args.decoder_output_dim + self.output_embed_dim = args.decoder_embed_dim args.encoder_embed_dim = embed_dim self.layerdrop = args.decoder_layerdrop @@ -475,12 +477,6 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ] ) - self.project_out_dim = ( - Linear(embed_dim, self.output_embed_dim, bias=False) - if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights - else None - ) - if not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), self.output_embed_dim) @@ -583,9 +579,6 @@ def extract_features( # T x B x C -> B x T x C x = x.transpose(0, 1) - if self.project_out_dim is not None: - x = self.project_out_dim(x) - return x, {"attn": attn, "inner_states": inner_states} def output_layer(self, features, **kwargs): @@ -675,5 +668,6 @@ def seq2seq_architecture(args): args.decoder_dropout = getattr(args, "decoder_dropout", 0) args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) + args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", False) base_architecture(args) From 2497a9dc6ebcd58b410e7d8e265219cf0b0dbfbb Mon Sep 17 00:00:00 2001 From: Xu Song Date: Thu, 13 Aug 2020 21:19:09 -0700 Subject: [PATCH 109/707] Small fix for load_indexed_dataset parameter (#2430) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? `dictionary` is not a required parameter for `load_indexed_dataset`. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2430 Reviewed By: ngoyal2707 Differential Revision: D23006666 Pulled By: myleott fbshipit-source-id: 10f2850ecb294795575573f99efac7f12a118d5d --- fairseq/data/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 4f0234ff93..70d8997b19 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -54,7 +54,7 @@ def copy_tensor(src, dst): return res -def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, default='cached'): +def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False, default='cached'): """A helper function for loading indexed datasets. Args: From 96d767231dbcd99abaad54ce07f1d25577d41b9f Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 14 Aug 2020 02:45:33 -0700 Subject: [PATCH 110/707] use utterance level sampling strategy for stella giga data training Summary: Before, we would use batch level sampling with fairseq's MultiCorpusSampledDataset. This means we would sample a batch of portal data, a batch of video data, etc. The main disadvantage of this is that the implementation for this requires the sampling to happen **after** splitting up the data into batches. This means we don't know the underlying data of each batch until right before the forward call for that batch, which means when splitting up the data for batching fairseq overestimates the tokens per item by just taking the max across all datasets from which the sample could come from. This causes batching inefficiency because it overestimates the actual batch size. In D21887303 (https://github.com/pytorch/fairseq/commit/8570277f91d6bed03d71cc9c8326f096cd06b0d2) we added utterance level sampling, which samples **before** splitting up the data for batching on an utterance level. This is achieved through a new dataset introduced in that diff and refreshing the batch iterator every epoch. This enables accurate batch size calculation. Reviewed By: jay-mahadeokar Differential Revision: D22280839 fbshipit-source-id: 8d6c8b82267184558fe3ec9e1b77fe6f24e0376f --- fairseq/data/multi_corpus_dataset.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 02d269a17c..bf33cb23b5 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -33,10 +33,16 @@ class MultiCorpusDataset(FairseqDataset): datasets: a OrderedDict of FairseqDataset instances. distribution: a List containing the probability of getting an utterance from corresponding dataset + seed: random seed for sampling the datsets + sort_indices: if true, will sort the ordered indices by size """ def __init__( - self, datasets: Dict[str, FairseqDataset], distribution: List[float], seed: int + self, + datasets: Dict[str, FairseqDataset], + distribution: List[float], + seed: int, + sort_indices: bool = False, ): super().__init__() assert isinstance(datasets, OrderedDict) @@ -44,6 +50,7 @@ def __init__( self.datasets = datasets self.distribution = distribution self.seed = seed + self.sort_indices = sort_indices # Avoid repeated conversions to list later self.dataset_list = list(datasets.values()) @@ -68,13 +75,12 @@ def ordered_indices(self): # Keep track of which samples we've used for each dataset counters = [0 for _ in self.datasets] - return np.array( - [ - self._sample(indices, counters) - for _ in range(self.total_num_instances) - ], - dtype=np.int64, - ) + sampled_indices = [ + self._sample(indices, counters) for _ in range(self.total_num_instances) + ] + if self.sort_indices: + sampled_indices.sort(key=lambda i: self.num_tokens(i)) + return np.array(sampled_indices, dtype=np.int64) def _sample(self, indices, counters): # First pick dataset From 983163494663e24b611f1ba8d5d47a3edc00e2e5 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 14 Aug 2020 10:23:45 -0700 Subject: [PATCH 111/707] Misc fixes (#2448) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2448 Reviewed By: ngoyal2707 Differential Revision: D23011193 Pulled By: myleott fbshipit-source-id: 1a29481707108e4465aca78ec1581fb79f05efba --- README.md | 14 +++++---- docs/getting_started.rst | 19 ++++++++++++ docs/tutorial_simple_lstm.rst | 1 + examples/language_model/README.conv.md | 11 ++++--- examples/quant_noise/README.md | 2 +- examples/stories/README.md | 2 +- examples/translation/README.md | 14 +++++---- examples/wav2vec/README.md | 6 ++-- fairseq/checkpoint_utils.py | 8 +++-- fairseq/criterions/wav2vec_criterion.py | 3 +- fairseq/data/data_utils.py | 6 ++++ fairseq/data/fairseq_dataset.py | 12 ++++---- fairseq/data/iterators.py | 3 +- fairseq/nan_detector.py | 6 ++-- fairseq/optim/__init__.py | 2 +- fairseq/optim/adafactor.py | 10 +++---- fairseq/optim/fp16_optimizer.py | 1 + fairseq/optim/lr_scheduler/fixed_schedule.py | 7 +++++ .../lr_scheduler/reduce_lr_on_plateau.py | 2 ++ fairseq/options.py | 6 ++-- fairseq/registry.py | 3 ++ fairseq/scoring/__init__.py | 17 ++++++++++- fairseq/scoring/bleu.py | 4 +-- fairseq/scoring/scoring_utils.py | 22 -------------- fairseq/tasks/fairseq_task.py | 19 ++++++------ .../tasks/translation_multi_simple_epoch.py | 7 +++-- fairseq/trainer.py | 13 +++++++- fairseq/utils.py | 15 ++++++++++ fairseq_cli/generate.py | 2 +- fairseq_cli/train.py | 30 ++++--------------- tests/utils.py | 1 + 31 files changed, 166 insertions(+), 102 deletions(-) delete mode 100644 fairseq/scoring/scoring_utils.py diff --git a/README.md b/README.md index cea586d4f5..c6c37f0965 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,8 @@ We provide reference implementations of various sequence modeling papers: - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](examples/wav2vec/README.md) + - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -51,13 +52,14 @@ We provide reference implementations of various sequence modeling papers: ### What's New: - August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) - April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) - April 2020: [Quant-Noise code released](examples/quant_noise/README.md) - April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) -- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)

Previous updates

+- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) - February 2020: [mBART model and code released](examples/mbart/README.md) - February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) - December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) @@ -136,15 +138,17 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: -- [Training with Quantization Noise for Extreme Model Compression](examples/quant_noise/README.md) +- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +- [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) +- [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](examples/wav2vec/README.md) - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index c76534163a..416e29531d 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -182,3 +182,22 @@ replacing ``node_rank=0`` with ``node_rank=1`` on the second node: --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 3584 \ --fp16 --distributed-no-spawn + +Sharding very large datasets +---------------------------- + +It can be challenging to train over very large datasets, particularly if your +machine does not have much system RAM. Most tasks in fairseq support training +over "sharded" datasets, in which the original dataset has been preprocessed +into non-overlapping chunks (or "shards"). + +For example, instead of preprocessing all your data into a single "data-bin" +directory, you can split the data and create "data-bin1", "data-bin2", etc. +Then you can adapt your training command like so: + +.. code-block:: console + + > fairseq-train data-bin1:data-bin2:data-bin3 (...) + +Training will now iterate over each shard, one by one, with each shard +corresponding to an "epoch", thus reducing system memory usage. diff --git a/docs/tutorial_simple_lstm.rst b/docs/tutorial_simple_lstm.rst index 30bdc7213e..f52988507c 100644 --- a/docs/tutorial_simple_lstm.rst +++ b/docs/tutorial_simple_lstm.rst @@ -61,6 +61,7 @@ save the following in a new file named :file:`fairseq/models/simple_lstm.py`:: hidden_size=hidden_dim, num_layers=1, bidirectional=False, + batch_first=True, ) def forward(self, src_tokens, src_lengths): diff --git a/examples/language_model/README.conv.md b/examples/language_model/README.conv.md index 9fccfcc0ea..f0b6a3a921 100644 --- a/examples/language_model/README.conv.md +++ b/examples/language_model/README.conv.md @@ -11,11 +11,14 @@ fairseq-train --task language_modeling \ data-bin/wikitext-103 \ --save-dir checkpoints/fconv_wikitext-103 \ --arch fconv_lm_dauphin_wikitext103 \ - --max-epoch 35 \ --optimizer nag \ + --adaptive-softmax-cutoff 10000,20000,200000 \ + --dropout 0.2 \ + --criterion adaptive_loss \ + --optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \ --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ - --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 \ - --ddp-backend=no_c10d + --max-tokens 1024 --tokens-per-sample 1024 \ + --ddp-backend no_c10d \ + --max-epoch 35 ``` And evaluate with: diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 1dc7dc619c..bcf0c4c827 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -1,4 +1,4 @@ -# Training with Quantization Noise for Extreme Model Compression ({Fan\*, Stock\*} *et al.*, 2019) +# Training with Quantization Noise for Extreme Model Compression ({Fan\*, Stock\*} *et al.*, 2020) This page contains information for how to train and quantize models with Quantization Noise, for both scalar quantization like `int8` and Iterative Product Quantization. Check out our paper [here](https://arxiv.org/abs/2004.07320). diff --git a/examples/stories/README.md b/examples/stories/README.md index 28579de355..588941eddc 100644 --- a/examples/stories/README.md +++ b/examples/stories/README.md @@ -44,7 +44,7 @@ fairseq-preprocess --source-lang wp_source --target-lang wp_target \ --destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10 # Train the model: -fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False +fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --optimizer nag --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False # Train a fusion model: # add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint diff --git a/examples/translation/README.md b/examples/translation/README.md index e61d166c6e..3eb8e01310 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -175,9 +175,11 @@ mkdir -p checkpoints/fconv_wmt_en_de fairseq-train \ data-bin/wmt17_en_de \ --arch fconv_wmt_en_de \ - --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ + --dropout 0.2 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --lr-scheduler fixed --force-anneal 50 \ + --optimizer nag --clip-norm 0.1 \ + --lr 0.5 --lr-scheduler fixed --force-anneal 50 \ + --max-tokens 4000 \ --save-dir checkpoints/fconv_wmt_en_de # Evaluate @@ -205,10 +207,12 @@ fairseq-preprocess \ mkdir -p checkpoints/fconv_wmt_en_fr fairseq-train \ data-bin/wmt14_en_fr \ - --lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \ - --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --lr-scheduler fixed --force-anneal 50 \ --arch fconv_wmt_en_fr \ + --dropout 0.1 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer nag --clip-norm 0.1 \ + --lr 0.5 --lr-scheduler fixed --force-anneal 50 \ + --max-tokens 3000 \ --save-dir checkpoints/fconv_wmt_en_fr # Evaluate diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e7f8633afb..ca01e181cb 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -1,6 +1,6 @@ # wav2vec 2.0 -wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/2006.11477). +wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). ## Pre-trained models @@ -172,7 +172,7 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ ---skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ --max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test ``` @@ -234,7 +234,7 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ --log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ --combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ ---warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 --max-sample-size 150000 \ +--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ --max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test ``` diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index af21db929f..20891b5f30 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -263,10 +263,14 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(*args, **kwargs): +def torch_persistent_save(obj, f): + if isinstance(f, str): + with PathManager.open(f, "wb") as h: + torch_persistent_save(obj, h) + return for i in range(3): try: - return torch.save(*args, **kwargs) + return torch.save(obj, f) except Exception: if i == 2: logger.error(traceback.format_exc()) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 019db62249..ceb30458bf 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -10,6 +10,7 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.logging.meters import safe_round @register_criterion('wav2vec') @@ -132,7 +133,7 @@ def reduce_metrics(logging_outputs) -> None: if total > 0: metrics.log_derived( "accuracy", - lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) + lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5) if meters["_total"].sum > 0 else float("nan"), ) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 70d8997b19..5a00debbc3 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -11,6 +11,7 @@ import itertools import logging import os +import warnings from typing import Tuple, Optional @@ -182,6 +183,11 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): raise_exception (bool, optional): if ``True``, raise an exception if any elements are filtered (default: False). """ + warnings.warn( + 'data_utils.filter_by_size is deprecated. ' + 'Use `FairseqDataset::filter_indices_by_size` instead.', + stacklevel=2 + ) if isinstance(max_positions, float) or isinstance(max_positions, int): if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray): ignored = indices[dataset.sizes[indices] > max_positions].tolist() diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 900bfaff10..2c972127a7 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -125,10 +125,11 @@ def adjust_bsz(bsz, num_tokens): ) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer - than specified in max_sizes. + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. - WARNING: don't update, override method in child classes + WARNING: don't update, override method in child classes Args: indices (np.array): original array of sample indices @@ -154,8 +155,9 @@ def filter_indices_by_size(self, indices, max_sizes): class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): - """For datasets that need to be read sequentially, usually because the data - is being streamed or otherwise can't be manipulated on a single machine. + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. """ def __iter__(self): diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 5f4a616c65..fd701e11d3 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -19,7 +19,6 @@ logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) # Object used by _background_consumer to signal the source is exhausted # to the main thread. @@ -504,7 +503,7 @@ def __next__(self): if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): if time.time() - self.start_time > 5 * 60: if self.warning_time is None or time.time() - self.warning_time > 15 * 60: - logger.info( + logger.debug( "Data loading buffer is empty or nearly empty. This may " "indicate a data loading bottleneck, and increasing the " "number of workers (--num-workers) may help." diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index 89ea982f69..789169d2b0 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -44,8 +44,10 @@ def reset(self): def _detect(self, tensor, name, backward): err = None if ( - tensor.numel() >= 2 - ): # single value tensors (like the loss) will not provide much info + torch.is_floating_point(tensor) + # single value tensors (like the loss) will not provide much info + and tensor.numel() >= 2 + ): with torch.no_grad(): if torch.isnan(tensor).any(): err = "NaN" diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 2b8334d8c2..273aa5e8f6 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -22,7 +22,7 @@ build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( '--optimizer', base_class=FairseqOptimizer, - default='nag', + required=True, ) diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index f52ec0f139..0456f7d61d 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -204,8 +204,8 @@ def step(self, closure=None): exp_avg_sq_row = state['exp_avg_sq_row'] exp_avg_sq_col = state['exp_avg_sq_col'] - exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) - exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) @@ -213,7 +213,7 @@ def step(self, closure=None): else: exp_avg_sq = state['exp_avg_sq'] - exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) update = exp_avg_sq.rsqrt().mul_(grad) update.div_( @@ -223,11 +223,11 @@ def step(self, closure=None): if use_first_moment: exp_avg = state['exp_avg'] - exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update) + exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) update = exp_avg if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) p_data_fp32.add_(-update) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 37e94965bb..d6d4f34767 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -239,6 +239,7 @@ class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): # forward __init__ call to the next class in MRO (method resolution order) super().__init__(*args, **kwargs) + self._multiply_factor = 1. @property def has_flat_params(self): diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index cc10db1638..1c3edd0047 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -34,6 +34,13 @@ def add_args(parser): help='warmup the learning rate linearly for the first N updates') # fmt: on + def state_dict(self): + return {'lr': self.lr} + + def load_state_dict(self, state_dict): + if 'lr' in state_dict: + self.lr = state_dict['lr'] + def get_next_lr(self, epoch): lrs = self.args.lr if self.args.force_anneal is None or epoch < self.args.force_anneal: diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 8128cf0eb8..65ac2e3071 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -45,9 +45,11 @@ def __init__(self, args, optimizer): # linearly warmup for the first args.warmup_updates if args.warmup_updates > 0: self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + # this flag is either set from arg when no warm up, or set by # step_update() when warmup finishes self.warmup_end = True if args.warmup_updates <= 0 else False + # initial learning rate # this self.lr is used only during init and/or warm up period self.lr = args.warmup_init_lr diff --git a/fairseq/options.py b/fairseq/options.py index e889821ee6..2a93452de2 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -9,7 +9,7 @@ import torch -from fairseq import scoring, utils +from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl @@ -668,8 +668,8 @@ def add_model_args(parser): # 2) --arch argument # 3) --encoder/decoder-* arguments (highest priority) from fairseq.models import ARCH_MODEL_REGISTRY - group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', + group.add_argument('--arch', '-a', metavar='ARCH', choices=ARCH_MODEL_REGISTRY.keys(), - help='Model Architecture') + help='model architecture') # fmt: on return group diff --git a/fairseq/registry.py b/fairseq/registry.py index ed24258c57..3859872420 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -13,6 +13,7 @@ def setup_registry( registry_name: str, base_class=None, default=None, + required=False, ): assert registry_name.startswith('--') registry_name = registry_name[2:].replace('-', '_') @@ -31,6 +32,8 @@ def setup_registry( def build_x(args, *extra_args, **extra_kwargs): choice = getattr(args, registry_name, None) if choice is None: + if required: + raise ValueError('--{} is required!'.format(registry_name)) return None cls = REGISTRY[choice] if hasattr(cls, 'build_' + registry_name): diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 6e5cc287ba..c17aad368a 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -10,11 +10,26 @@ from fairseq import registry -build_scoring, register_scoring, SCORING_REGISTRY = registry.setup_registry( +_build_scoring, register_scoring, SCORING_REGISTRY = registry.setup_registry( "--scoring", default="bleu" ) +def build_scorer(args, tgt_dict): + from fairseq import utils + + if args.sacrebleu: + utils.deprecation_warning( + "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." + ) + args.scoring = "sacrebleu" + if args.scoring == "bleu": + from fairseq.scoring import bleu + return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) + else: + return _build_scoring(args) + + # automatically import any Python files in the current directory for file in os.listdir(os.path.dirname(__file__)): if file.endswith(".py") and not file.startswith("_"): diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 40f3440d82..476c0f0472 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -5,6 +5,8 @@ import ctypes import math +import sys + import torch from fairseq.scoring import register_scoring @@ -12,8 +14,6 @@ try: from fairseq import libbleu except ImportError as e: - import sys - sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") raise e diff --git a/fairseq/scoring/scoring_utils.py b/fairseq/scoring/scoring_utils.py deleted file mode 100644 index 0b710d5bb8..0000000000 --- a/fairseq/scoring/scoring_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fairseq import utils -from . import bleu, build_scoring - - -def build_scorer(args, tgt_dict): - if args.sacrebleu: - utils.deprecation_warning( - "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." - ) - args.scoring = "sacrebleu" - - if args.scoring == "bleu": - scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) - else: - return build_scoring(args) - - return scorer diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 59663b531d..a88b7600ec 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -112,11 +112,13 @@ def dataset(self, split): raise TypeError("Datasets are expected to be of type FairseqDataset") return self.datasets[split] - def filter_indices_by_size(self, - indices, - dataset, - max_positions, - ignore_invalid_inputs): + def filter_indices_by_size( + self, + indices, + dataset, + max_positions=None, + ignore_invalid_inputs=False, + ): """ Filter examples that are too large @@ -204,10 +206,9 @@ def get_batch_iterator( # filter examples that are too large if max_positions is not None: - indices = self.filter_indices_by_size(indices, - dataset, - max_positions, - ignore_invalid_inputs) + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) # create mini-batches with given size constraints batch_sampler = dataset.batch_by_size( diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index eba32f1759..b517d6f2b7 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -239,8 +239,11 @@ def construct_batch_sampler( # filter examples that are too large if max_positions is not None: my_time = time.time() - indices = data_utils.filter_by_size( - indices, dataset, max_positions, raise_exception=(not ignore_invalid_inputs), + indices = self.filter_indices_by_size( + indices, + dataset, + max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, ) logger.debug(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}') diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 898edb6d6c..36d558eba3 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -373,12 +373,20 @@ def get_valid_iterator( def begin_epoch(self, epoch): """Called at the beginning of each epoch.""" + logger.info("begin training epoch {}".format(epoch)) + if self.quantizer is not None: self.quantizer.begin_epoch(epoch) # task specific setup per epoch self.task.begin_epoch(epoch, self.get_model()) + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous('begin_epoch') # wait for all workers + xm.mark_step() + @metrics.aggregate("train") def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" @@ -896,7 +904,10 @@ def _check_grad_norms(self, grad_norm): def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) - return (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + return ( + not torch.isfinite(tensor).any() + or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) if not is_consistent(self._grad_norm_buf): pretty_detail = "\n".join( diff --git a/fairseq/utils.py b/fairseq/utils.py index f68860330c..2531896e57 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -18,6 +18,7 @@ import numpy as np import torch import torch.nn.functional as F +from fairseq.data import iterators from fairseq.logging.meters import safe_round from fairseq.modules import gelu, gelu_accurate from fairseq.modules.multihead_attention import MultiheadAttention @@ -560,6 +561,20 @@ def get_tpu_device(args): return xm.xla_device() +def tpu_data_loader(itr): + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + + xm.rendezvous("tpu_data_loader") # wait for all workers + xm.mark_step() + device = xm.xla_device() + return iterators.CountingIterator( + pl.ParallelLoader(itr, [device]).per_device_loader(device), + start=getattr(itr, "n", 0), + total=len(itr), + ) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index cf472ff252..03d2b7dfc0 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -136,7 +136,7 @@ def decode_fn(x): x = tokenizer.decode(x) return x - scorer = scoring.scoring_utils.build_scorer(args, tgt_dict) + scorer = scoring.build_scorer(args, tgt_dict) num_sentences = 0 has_target = True diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 90574ee8f7..6c8bd5d4df 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -44,6 +44,7 @@ def main(args): assert ( args.max_tokens is not None or args.max_sentences is not None ), "Must specify batch size either with --max-tokens or --max-sentences" + metrics.reset() np.random.seed(args.seed) @@ -66,8 +67,10 @@ def main(args): model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) + logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) + logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info( - "model {}, criterion {}".format(args.arch, criterion.__class__.__name__) + "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) ) logger.info( "num. model params: {} (num. trained: {})".format( @@ -104,11 +107,6 @@ def main(args): # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) - if args.tpu: - import torch_xla.core.xla_model as xm - - xm.rendezvous("load_checkpoint") # wait for all workers - xm.mark_step() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -162,25 +160,9 @@ def is_better(a, b): return False -def tpu_data_loader(args, itr): - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as pl - - xm.rendezvous("tpu_data_loader") # wait for all workers - xm.mark_step() - device = utils.get_tpu_device(args) - return iterators.CountingIterator( - pl.ParallelLoader(itr, [device]).per_device_loader(device), - start=getattr(itr, "n", 0), - total=len(itr), - ) - - @metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" - logger.info("begin training epoch {}".format(epoch_itr.epoch)) - # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, @@ -193,7 +175,7 @@ def train(args, trainer, task, epoch_itr): ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): - itr = tpu_data_loader(args, itr) + itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, @@ -302,7 +284,7 @@ def validate(args, trainer, task, epoch_itr, subsets): # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): - itr = tpu_data_loader(args, itr) + itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, diff --git a/tests/utils.py b/tests/utils.py index e207575d6f..869a70c5e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -210,6 +210,7 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' data_dir, '--save-dir', data_dir, '--arch', arch, + '--optimizer', 'nag', '--lr', '0.05', '--max-tokens', '500', '--max-epoch', '1', From 3217945c43aeb07754d74f87836b64e008504597 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 14 Aug 2020 11:01:01 -0700 Subject: [PATCH 112/707] Always use valid and test data from shard 0 to avoid the need to copy valid and test data to all shards Summary: Always use valid and test data from shard 0 to avoid the need to copy valid and test data to all shards Reviewed By: pipibjc Differential Revision: D23033446 fbshipit-source-id: dba4566321c7283484c94a6042b942da98d28605 --- fairseq/data/multilingual/multilingual_data_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index b31028c191..77731ae203 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -698,6 +698,9 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): assert len(paths) > 0 if len(paths) > 1: self._has_sharded_data = True + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] if data_category in self.args.langtoks: lang_tok_spec = self.args.langtoks[data_category] From f1ec983f2e01e54a710a1d2881fec27795d8d6c6 Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Fri, 14 Aug 2020 16:48:17 -0700 Subject: [PATCH 113/707] Set export=True in LayerNorm when under TorchScript execution Summary: When under TorchScript execution, set export=True as FusedLayerNorm doesn't work with JIT yet (See torch.jit.unused decorator). Reviewed By: myleott Differential Revision: D23088062 fbshipit-source-id: ba27ac8f598ddf80cf6ae460192ad8a3f83644ca --- fairseq/modules/layer_norm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairseq/modules/layer_norm.py b/fairseq/modules/layer_norm.py index 4fee32d4fc..7b1d241436 100644 --- a/fairseq/modules/layer_norm.py +++ b/fairseq/modules/layer_norm.py @@ -27,6 +27,8 @@ def forward(self, x): def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting(): + export = True if not export and torch.cuda.is_available() and has_fused_layernorm: return FusedLayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) From bd20dbda918cdec93ab6d1fe5bba0ce064a60103 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Fri, 14 Aug 2020 22:20:21 -0700 Subject: [PATCH 114/707] Fix typos in WER scorer Summary: Fix typos in WER scorer - The typos lead to using prediction length as the denominator in the formula, which is wrong. Reviewed By: alexeib Differential Revision: D23139261 fbshipit-source-id: d1bba0044365813603ce358388e880c1b3f9ec6b --- fairseq/scoring/wer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 6f4521f6cd..4e09e45614 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -15,18 +15,18 @@ def __init__(self, *unused): def reset(self): self.distance = 0 - self.target_length = 0 + self.ref_length = 0 def add_string(self, ref, pred): - pred_items = ref.split() - targ_items = pred.split() - self.distance += editdistance.eval(pred_items, targ_items) - self.target_length += len(targ_items) + ref_items = ref.split() + pred_items = pred.split() + self.distance += editdistance.eval(ref_items, pred_items) + self.ref_length += len(ref_items) def result_string(self): return f"WER: {self.score()}" def score(self): return ( - 100.0 * self.distance / self.target_length if self.target_length > 0 else 0 + 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 ) From 110f9f0cc781354eee358b28445d2096cdbd4a14 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 17 Aug 2020 16:21:48 -0700 Subject: [PATCH 115/707] Stricter boundary checks on CountingIterator (#2491) Summary: We were previously a bit too lenient with boundary conditions to support `CountingIterator.take`. Let's instead handle this more explicitly. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2491 Reviewed By: ngoyal2707 Differential Revision: D23172408 Pulled By: myleott fbshipit-source-id: 90d24b044812982f7d3eb4cdb39f3db3016a884d --- fairseq/data/iterators.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index fd701e11d3..3733f95cf0 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -53,13 +53,20 @@ def __init__(self, iterable, start=None, total=None): else: self.total = total + self.early_stop = self.total + def __len__(self): return self.total def __iter__(self): for x in self.iterable: if self.n >= self.total: - return + raise RuntimeError( + 'Mismatch between actual and expected iterable length. ' + 'Please report this to the fairseq developers.' + ) + elif self.n >= self.early_stop: + return # early stop based on take() self.n += 1 yield x @@ -79,7 +86,7 @@ def take(self, n): """ Truncates the iterator to n elements at most. """ - self.total = min(self.total, n) + self.early_stop = min(self.early_stop, n) # Propagate this change to the underlying iterator if hasattr(self.iterable, "take"): From d1d28793cdc05cf26d61a7b42c69ead046611ba8 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 17 Aug 2020 16:27:55 -0700 Subject: [PATCH 116/707] Stricter boundary checks for iterator sizes (#2490) Summary: There have been issues with some dynamic datasets where the iteration count stored in the checkpoint overflows the actual iterator size, but we've been unable to reproduce it in any reliable way. This overflow can apparently cause the epoch to advance when loading checkpoints, which is undesirable. This PR changes two things. First at the end of an epoch we advance the iterator to the next epoch directly in state_dict, so that we can distinguish this overflow corner case and the more typical end-of-epoch situation. We then raise an exception in the case of iterator overflow, which will hopefully help us (via the community) find a more reliable repro for the underlying issue. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2490 Reviewed By: ngoyal2707 Differential Revision: D23172070 Pulled By: myleott fbshipit-source-id: 6905cde8e83e56881d2583c74667717e08edf95e --- fairseq/data/iterators.py | 23 +++++++++++++++++++---- fairseq/tasks/fairseq_task.py | 4 ++-- fairseq/trainer.py | 4 ++-- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 3733f95cf0..b6ab54d4de 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -303,9 +303,16 @@ def iterations_in_epoch(self): def state_dict(self): """Returns a dictionary containing a whole state of the iterator.""" + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch return { - 'epoch': self.epoch, - 'iterations_in_epoch': self.iterations_in_epoch, + 'version': 2, + 'epoch': epoch, + 'iterations_in_epoch': iter_in_epoch, 'shuffle': self.shuffle, } @@ -313,6 +320,7 @@ def load_state_dict(self, state_dict): """Copies the state of the iterator from the given *state_dict*.""" self.epoch = state_dict['epoch'] itr_pos = state_dict.get('iterations_in_epoch', 0) + version = state_dict.get('version', 1) if itr_pos > 0: # fast-forward epoch iterator self._next_epoch_itr = self._get_iterator_for_epoch( @@ -321,8 +329,15 @@ def load_state_dict(self, state_dict): offset=itr_pos, ) if self._next_epoch_itr is None: - # we finished the epoch, increment epoch counter - self.epoch += 1 + if version == 1: + # legacy behavior: we finished the epoch, increment epoch counter + self.epoch += 1 + else: + raise RuntimeError( + 'Cannot resume training due to dataloader mismatch, please ' + 'report this to the fairseq developers. You can relaunch ' + 'training with `--reset-dataloader` and it should work.' + ) else: self._next_epoch_itr = None diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index a88b7600ec..5128d4b3f4 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -157,7 +157,7 @@ def get_batch_iterator( num_shards=1, shard_id=0, num_workers=0, - epoch=1 + epoch=1, ): """ Get an iterator that yields batches of data from the given dataset. @@ -228,7 +228,7 @@ def get_batch_iterator( shard_id=shard_id, num_workers=num_workers, epoch=epoch, - buffer_size=getattr(self.args, 'data_buffer_size', 0) + buffer_size=getattr(self.args, 'data_buffer_size', 0), ) self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 36d558eba3..263e2e0393 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -347,7 +347,7 @@ def get_train_iterator( num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, - epoch=epoch + epoch=epoch, ) def get_valid_iterator( @@ -368,7 +368,7 @@ def get_valid_iterator( seed=self.args.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers + num_workers=self.args.num_workers, ) def begin_epoch(self, epoch): From 5af2fdd32a4b3cce31d248c647d6f4152e72d525 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Tue, 18 Aug 2020 16:05:51 -0700 Subject: [PATCH 117/707] Have translation task always use valid/test in the the first shard to avoid copying valid/test data across all shards Summary: Have translation task always use valid/test in the the first shard to avoid copying valid/test data across all shards Reviewed By: pipibjc Differential Revision: D23180874 fbshipit-source-id: 20fe431438f2bb22fc955773397a7f4c08a1f014 --- fairseq/tasks/translation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 7077943c1e..ab1ff3cf34 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -159,7 +159,9 @@ def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') + will be iterated upon during epochs in round-robin manner; \ + however, valid and test data are always in the first directory to \ + avoid the need for repeating them in all directories') parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', @@ -246,6 +248,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode From 77983ee1a52c4e011e54cc6bfa5352b7811ec96d Mon Sep 17 00:00:00 2001 From: James Cross Date: Tue, 18 Aug 2020 17:36:24 -0700 Subject: [PATCH 118/707] fbtranslate: fairseq logging to manifold Summary: Manifold logging for Fairseq training in fbtranslate is currently broken because `distributed_train` dispatches separate processes for each training GPU. If these write to a local temporary file, we must copy to Manifold in the spawned process rather than from the dispatching function since they are run on different machines. (Previously this was coordinated by using Gluster, a shared filesystem available from all machines.) Reviewed By: akinh Differential Revision: D23200236 fbshipit-source-id: 8dbeef9ed3cc04d73321827200c16df48dddff61 --- fairseq/file_io.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index b57373f8b5..d667256922 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -104,3 +104,13 @@ def chmod(path: str, mode: int) -> None: def register_handler(handler) -> None: if FVCorePathManager: return FVCorePathManager.register_handler(handler=handler) + + @staticmethod + def copy_from_local( + local_path: str, dst_path: str, overwrite: bool = False, **kwargs + ) -> None: + if FVCorePathManager: + return FVCorePathManager.copy_from_local( + local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs + ) + return shutil.copyfile(local_path, dst_path) From 68c87f0abf95e84b2c9105911503f604611429d6 Mon Sep 17 00:00:00 2001 From: Jun Ru Anderson Date: Wed, 19 Aug 2020 16:03:14 -0700 Subject: [PATCH 119/707] optimize mixed precision (#1248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Implements the multiply_factor optimization used in memory efficient fp16 training to mixed precision training. The methods multiply_grads and clip_grad_norm do not touch each gradient, but rather a "multiply factor" that is then factored in when unscaling gradients. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1248 Reviewed By: myleott Differential Revision: D23201396 Pulled By: andersonic fbshipit-source-id: 6c6f64542893e0ecac72e132464bb334dcb9874d --- fairseq/optim/fp16_optimizer.py | 46 ++++++++++++-------- tests/test_fp16_optimizer.py | 77 +++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 19 deletions(-) create mode 100644 tests/test_fp16_optimizer.py diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index d6d4f34767..a815eeb085 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -75,12 +75,8 @@ def backward(self, loss): loss.backward() self._needs_sync = True - def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): + def _sync_fp16_grads_to_fp32(self): if self._needs_sync: - if self.scaler is not None: - # correct for dynamic loss scaler - multiply_grads /= self.scaler.loss_scale - # copy FP16 grads to FP32 if self.has_flat_params: offset = 0 @@ -91,20 +87,18 @@ def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): numel = grad_data.numel() self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1)) offset += numel - self.fp32_params.grad.data.mul_(multiply_grads) else: for p, p32 in zip(self.fp16_params, self.fp32_params): if not p.requires_grad: continue if p.grad is not None: p32.grad.data.copy_(p.grad.data) - p32.grad.data.mul_(multiply_grads) else: p32.grad = torch.zeros_like(p.data, dtype=torch.float) self._needs_sync = False - def _sync_fp32_grads_to_fp16(self): + def _sync_fp32_params_to_fp16(self): # copy FP32 params back into FP16 model if self.has_flat_params: offset = 0 @@ -120,36 +114,47 @@ def _sync_fp32_grads_to_fp16(self): continue p.data.copy_(p32.data) + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if self._multiply_factor != 1.: + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1. + def multiply_grads(self, c): """Multiplies grads by a constant ``c``.""" - if self._needs_sync: - self._sync_fp16_grads_to_fp32(c) - elif self.has_flat_params: - self.fp32_params.grad.data.mul_(c) - else: - for p32 in self.fp32_params: - p32.grad.data.mul_(c) + self._multiply_factor *= c def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() - grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn) - # detect overflow and adjust loss scale + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(0, aggregate_norm_fn) + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + self.scaler.check_overflow(grad_norm) + else: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef return grad_norm def step(self, closure=None): """Performs a single optimization step.""" self._sync_fp16_grads_to_fp32() - self.fp32_optimizer.step(closure) + + if self.supports_step_with_scale: + self.fp32_optimizer.step(closure, scale=(1. / self._multiply_factor)) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure) if self.scaler is not None: self.scaler.update() - self._sync_fp32_grads_to_fp16() + self._sync_fp32_params_to_fp16() def zero_grad(self): """Clears the gradients of all optimized parameters.""" @@ -162,6 +167,9 @@ def zero_grad(self): p32.grad.zero_() self._needs_sync = False + if self.scaler is not None: + self._multiply_factor = 1. / float(self.scaler.loss_scale) + class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): """ diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py new file mode 100644 index 0000000000..ae7b797ec8 --- /dev/null +++ b/tests/test_fp16_optimizer.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import unittest + +import torch + +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer + + +@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') +class TestGradientScaling(unittest.TestCase): + + def setUp(self): + self.x = torch.tensor([2.0]).cuda().half() + weight = 3.0 + bias = 5.0 + self.error = 1.0 + self.target = torch.tensor([self.x * weight + bias + self.error]).cuda().half() + self.loss_fn = torch.nn.L1Loss() + + self.model = torch.nn.Linear(1, 1) + self.model.weight.data = torch.tensor([[weight]]) + self.model.bias.data = torch.tensor([bias]) + self.model.cuda().half() + self.params = list(self.model.parameters()) + + self.namespace_dls = argparse.Namespace( + optimizer='adam', + lr=[0.1], + adam_betas='(0.9, 0.999)', + adam_eps=1e-8, + weight_decay=0.0, + fp16_init_scale=1, + fp16_scale_window=1, + fp16_scale_tolerance=1, + threshold_loss_scale=1, + min_loss_scale=1e-4 + ) + + def run_iter(self, model, params, optimizer): + optimizer.zero_grad() + y = model(self.x) + loss = self.loss_fn(y, self.target) + optimizer.backward(loss) + self.assertEqual(loss, torch.tensor(1., device='cuda:0', dtype=torch.float16)) + + grad_norm = optimizer.clip_grad_norm(0) + self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) + + optimizer.step() + self.assertEqual(model.weight, torch.tensor([[3.0996]], device='cuda:0', dtype=torch.float16, requires_grad=True)) + self.assertEqual(model.bias, torch.tensor([5.1016], device='cuda:0', dtype=torch.float16, requires_grad=True)) + self.assertEqual(optimizer.scaler.loss_scale, 2.) + + def test_mixed_precision(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) + + self.run_iter(model, params, optimizer) + self.assertTrue(torch.all(optimizer.fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True)))) + + def test_memory_efficient(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.namespace_dls, params) + + self.run_iter(model, params, optimizer) + + +if __name__ == '__main__': + unittest.main() From 54b934417d95baa1b0076089c61bde32728e34cf Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 19 Aug 2020 20:09:02 -0700 Subject: [PATCH 120/707] libri_labels fix + zero padding (#1249) Summary: fix libri_labels.py to output files without .txt extension zero-pad examples instead of initializing with random values and hoping they wont get used Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1249 Reviewed By: ngoyal2707 Differential Revision: D23227559 Pulled By: alexeib fbshipit-source-id: c30fa4b8242c6b52098b3f9d9c4ccb23902be2e6 --- examples/wav2vec/libri_labels.py | 4 ++-- fairseq/criterions/wav2vec_criterion.py | 2 +- fairseq/data/audio/raw_audio_dataset.py | 2 +- fairseq/models/wav2vec/wav2vec2.py | 3 +++ fairseq/trainer.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/wav2vec/libri_labels.py b/examples/wav2vec/libri_labels.py index 4feced0a02..812528732f 100644 --- a/examples/wav2vec/libri_labels.py +++ b/examples/wav2vec/libri_labels.py @@ -24,9 +24,9 @@ def main(): transcriptions = {} with open(args.tsv, "r") as tsv, open( - os.path.join(args.output_dir, args.output_name + ".ltr.txt"), "w" + os.path.join(args.output_dir, args.output_name + ".ltr"), "w" ) as ltr_out, open( - os.path.join(args.output_dir, args.output_name + ".wrd.txt"), "w" + os.path.join(args.output_dir, args.output_name + ".wrd"), "w" ) as wrd_out: root = next(tsv).strip() for line in tsv: diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index ceb30458bf..85403cb428 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -59,7 +59,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) sample_size = target.numel() if self.infonce else target.long().sum().item() - losses.append(loss) + losses.append(loss.detach().clone()) if self.loss_weights is not None: assert hasattr(model, "get_extra_losses") diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 675b095647..09838a54e0 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -88,7 +88,7 @@ def collater(self, samples): else: target_size = min(min(sizes), self.max_sample_size) - collated_sources = sources[0].new(len(sources), target_size) + collated_sources = sources[0].new_zeros(len(sources), target_size) padding_mask = ( torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None ) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index be6d10c7a2..ea6f901020 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -809,6 +809,9 @@ def forward(self, x, padding_mask=None): def extract_features(self, x, padding_mask=None): + if padding_mask is not None: + x[padding_mask] = 0 + x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x += x_conv diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 263e2e0393..de5df45105 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -592,7 +592,7 @@ def maybe_no_sync(): torch.cuda.empty_cache() if self.args.fp16: - metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) + metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4) metrics.log_stop_time("train_wall") From adbd89fd4be9e68100bf9a4ba9eed1e7fb2e4040 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 20 Aug 2020 06:40:45 -0700 Subject: [PATCH 121/707] Misc fixes (#2492) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2492 Reviewed By: ngoyal2707 Differential Revision: D23177728 Pulled By: myleott fbshipit-source-id: 32424f61cab57f759f87e16e8d5144d3eed5ae36 --- docs/getting_started.rst | 3 ++ examples/wav2vec/README.md | 17 ++++++--- fairseq/__init__.py | 1 + fairseq/clib/libnat_cuda/edit_dist.cu | 36 +++++++++--------- fairseq/criterions/legacy_masked_lm.py | 18 +++------ fairseq/data/data_utils_fast.pyx | 37 ++++++++++--------- fairseq/data/token_block_utils_fast.pyx | 4 +- fairseq/iterative_refinement_generator.py | 2 +- .../transformer_sentence_encoder_layer.py | 6 +-- fairseq/models/nat/levenshtein_utils.py | 2 +- fairseq/modules/adaptive_softmax.py | 2 +- fairseq/scoring/bleu.py | 24 ++++++------ fairseq/scoring/wer.py | 3 +- fairseq/search.py | 4 +- fairseq/sequence_generator.py | 17 +++++---- fairseq/tasks/masked_lm.py | 11 ++---- fairseq/tasks/sentence_prediction.py | 14 ++++--- fairseq/utils.py | 4 +- tests/test_inference_dropout.py | 5 +++ tests/test_train.py | 11 ++++-- 20 files changed, 121 insertions(+), 100 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 416e29531d..fa5971dd31 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -47,6 +47,9 @@ hypothesis along with an average log-likelihood; and *P* is the positional score per token position, including the end-of-sentence marker which is omitted from the text. +Other types of output lines you might see are *D*, the detokenized hypothesis, +*T*, the reference target, *A*, alignment info, *E* the history of generation steps. + See the `README `__ for a full list of pre-trained models available. diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index ca01e181cb..2e59798ead 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -25,16 +25,23 @@ Given a directory containing wav files to be used for pretraining (we recommend ### Prepare training data manifest: -$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. +First, install the `soundfile` library: +```shell script +pip install soundfile +``` -$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. -To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a -separately pre-processed manifest file. +Next, run: ```shell script $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid ``` +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + ### Train a wav2vec 2.0 base model: This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper @@ -43,7 +50,7 @@ Note that this was tested with pytorch 1.4.0 and the input is expected to be sin ```shell script $ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \ ---save-dir /model/path fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ +--save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \ --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \ diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 3dd29637af..1ba63fcaaa 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -20,6 +20,7 @@ import fairseq.optim # noqa import fairseq.optim.lr_scheduler # noqa import fairseq.pdb # noqa +import fairseq.scoring # noqa import fairseq.tasks # noqa import fairseq.benchmark # noqa diff --git a/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/clib/libnat_cuda/edit_dist.cu index b6486a8c22..22de16b270 100644 --- a/fairseq/clib/libnat_cuda/edit_dist.cu +++ b/fairseq/clib/libnat_cuda/edit_dist.cu @@ -253,11 +253,11 @@ torch::Tensor GenerateDeletionLabelCuda( AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { generate_deletion_label_kernel<<>>( - source.data(), + source.data_ptr(), source.size(1), operations.size(1), - operations.data(), - labels.data()); + operations.data_ptr(), + labels.data_ptr()); })); return labels; @@ -276,12 +276,12 @@ auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] { generate_insertion_label_kernel<<>>( - target.data(), + target.data_ptr(), target.size(1), operations.size(1), - operations.data(), - labels.data(), - masks.data()); + operations.data_ptr(), + labels.data_ptr(), + masks.data_ptr()); })); return std::make_pair(labels, masks); @@ -306,25 +306,25 @@ torch::Tensor LevenshteinDistanceCuda( auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { levenshtein_distance_kernel<<>>( - source.data(), - target.data(), - source_length.data(), - target_length.data(), + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), source.size(1), target.size(1), - operations.data(), - distances.data()); + operations.data_ptr(), + distances.data_ptr()); })); } else { AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] { faster_levenshtein_distance_kernel<<>>( - source.data(), - target.data(), - source_length.data(), - target_length.data(), + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), source.size(1), target.size(1), - operations.data()); + operations.data_ptr()); })); } diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py index 10dea76e4b..3dbfdfbe46 100644 --- a/fairseq/criterions/legacy_masked_lm.py +++ b/fairseq/criterions/legacy_masked_lm.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from fairseq import utils +from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -127,7 +127,7 @@ def forward(self, model, sample, reduce=True): return loss, sample_size, logging_output @staticmethod - def aggregate_logging_outputs(logging_outputs): + def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs) sentence_loss_sum = sum( @@ -137,16 +137,10 @@ def aggregate_logging_outputs(logging_outputs): sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) agg_loss = sum(log.get('loss', 0) for log in logging_outputs) - agg_output = { - 'loss': agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., - 'lm_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., - 'sentence_loss': sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., - 'nll_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, - } - return agg_output + metrics.log_scalar('loss', agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., sample_size, round=3) + metrics.log_scalar('lm_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) + metrics.log_scalar('sentence_loss', sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., nsentences, round=3) + metrics.log_scalar('nll_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index c1f97bf5b6..38b4aa67dd 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -9,11 +9,12 @@ import numpy as np cimport cython cimport numpy as np -DTYPE = np.int64 -ctypedef np.int64_t DTYPE_t +from libc.stdint cimport int32_t, int64_t +ctypedef int64_t DTYPE_t -cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences): + +cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences): if num_sentences == 0: return 0 if max_sentences > 0 and num_sentences == max_sentences: @@ -27,18 +28,18 @@ cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long m cpdef list batch_by_size_fast( np.ndarray[DTYPE_t, ndim=1] indices, num_tokens_fn, - long max_tokens, - long max_sentences, - int bsz_mult, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, ): - cdef long sample_len = 0 + cdef int64_t sample_len = 0 cdef list sample_lens = [] cdef list batch = [] cdef list batches = [] - cdef long mod_len - cdef long i - cdef long idx - cdef long num_tokens + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens cdef DTYPE_t[:] indices_view = indices for i in range(len(indices_view)): @@ -70,8 +71,8 @@ cpdef list batch_by_size_fast( cdef _find_valid_shape( DTYPE_t[:, :] shapes_view, - long num_sentences, - long num_tokens, + int64_t num_sentences, + int64_t num_tokens, ): """Return index of first valid shape of -1 if none is found.""" for i in range(shapes_view.shape[0]): @@ -86,14 +87,14 @@ cpdef list batch_fixed_shapes_fast( num_tokens_fn, np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, ): - cdef long sample_len = 0 + cdef int64_t sample_len = 0 cdef list sample_lens = [] cdef list batch = [] cdef list batches = [] - cdef long mod_len - cdef long i - cdef long idx - cdef long num_tokens + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens cdef DTYPE_t[:] indices_view = indices cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx index 5563b973e9..5a2f16ec34 100644 --- a/fairseq/data/token_block_utils_fast.pyx +++ b/fairseq/data/token_block_utils_fast.pyx @@ -12,8 +12,10 @@ from libc.math cimport ceil cimport cython cimport numpy as np +from libc.stdint cimport int32_t, int64_t + DTYPE = np.int64 -ctypedef np.int64_t DTYPE_t +ctypedef int64_t DTYPE_t @cython.boundscheck(False) diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index c7a267d258..97e66fabe9 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -266,7 +266,7 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): if decoder_out.history is not None else None, ) - encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze()) + encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()) sent_idxs = sent_idxs[not_terminated] prev_output_tokens = prev_decoder_out.output_tokens.clone() diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py index 0e1ea2b7d7..d09158b7f1 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -66,14 +66,14 @@ def forward( need_weights=False, attn_mask=self_attn_mask, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) - x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.activation_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x return x, None diff --git a/fairseq/models/nat/levenshtein_utils.py b/fairseq/models/nat/levenshtein_utils.py index e29b1fa27c..11fb29578b 100644 --- a/fairseq/models/nat/levenshtein_utils.py +++ b/fairseq/models/nat/levenshtein_utils.py @@ -250,7 +250,7 @@ def _skip_encoder_out(encoder, encoder_out, mask): if not mask.any(): return encoder_out else: - return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze()) + return encoder.reorder_encoder_out(encoder_out, mask.nonzero(as_tuple=False).squeeze()) def _fill(x, mask, y, padding_idx): diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py index 96f8b75ad3..8e47134a70 100644 --- a/fairseq/modules/adaptive_softmax.py +++ b/fairseq/modules/adaptive_softmax.py @@ -144,7 +144,7 @@ def adapt_target(self, target): new_target[0][mask] = self.cutoff[0] + i if mask.any(): - target_idxs.append(mask.nonzero().squeeze(1)) + target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) new_target.append(target[mask].add(-self.cutoff[i])) else: target_idxs.append(None) diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 476c0f0472..15275d94c9 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -11,15 +11,6 @@ from fairseq.scoring import register_scoring -try: - from fairseq import libbleu -except ImportError as e: - sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") - raise e - - -C = ctypes.cdll.LoadLibrary(libbleu.__file__) - class BleuStat(ctypes.Structure): _fields_ = [ @@ -70,13 +61,22 @@ def __init__(self, pad, eos, unk): self.pad = pad self.eos = eos self.unk = unk + + try: + from fairseq import libbleu + except ImportError as e: + sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") + raise e + + self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) + self.reset() def reset(self, one_init=False): if one_init: - C.bleu_one_init(ctypes.byref(self.stat)) + self.C.bleu_one_init(ctypes.byref(self.stat)) else: - C.bleu_zero_init(ctypes.byref(self.stat)) + self.C.bleu_zero_init(ctypes.byref(self.stat)) def add(self, ref, pred): if not isinstance(ref, torch.IntTensor): @@ -92,7 +92,7 @@ def add(self, ref, pred): rref = rref.contiguous().view(-1) pred = pred.contiguous().view(-1) - C.bleu_add( + self.C.bleu_add( ctypes.byref(self.stat), ctypes.c_size_t(rref.size(0)), ctypes.c_void_p(rref.data_ptr()), diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 4e09e45614..3aee5f69db 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,8 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import editdistance - from fairseq.scoring import register_scoring @@ -18,6 +16,7 @@ def reset(self): self.ref_length = 0 def add_string(self, ref, pred): + import editdistance ref_items = ref.split() pred_items = pred.split() self.distance += editdistance.eval(ref_items, pred_items) diff --git a/fairseq/search.py b/fairseq/search.py index 9e18581a97..8aa196a3cc 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -133,7 +133,9 @@ def step(self, step: int, lprobs, scores): # apply diversity penalty if g > 0: lprobs_g = torch.add( - lprobs_g, self.diversity_strength, diversity_buf.unsqueeze(1) + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, ) else: lprobs_g = lprobs_g.contiguous() diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 42012fbbb1..26e4c287b2 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -183,7 +183,11 @@ def _generate( src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) elif 'source' in net_input: src_tokens = net_input['source'] - src_lengths = net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor(src_tokens.size(-1)) + src_lengths = ( + net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) else: raise Exception('expected src_tokens or source in net input') @@ -372,11 +376,10 @@ def _generate( new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass - batch_mask = torch.ones(bsz).to(cand_indices) - batch_mask[ - torch.tensor(finalized_sents).to(cand_indices) - ] = torch.tensor(0).to(batch_mask) - batch_idxs = batch_mask.nonzero().squeeze(-1) + batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] @@ -665,7 +668,7 @@ def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx][ torch.tensor(banned_tokens[bbsz_idx]).long() - ] = torch.tensor(-math.inf, dtype=torch.float) + ] = torch.tensor(-math.inf).to(lprobs) return lprobs diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 4d7ea54b64..4a6e6a2d37 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -16,7 +16,7 @@ NestedDictionaryDataset, NumelDataset, NumSamplesDataset, - PadDataset, + RightPadDataset, PrependTokenDataset, SortDataset, TokenBlockDataset, @@ -150,17 +150,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): { 'id': IdDataset(), 'net_input': { - 'src_tokens': PadDataset( + 'src_tokens': RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), - left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, - 'target': PadDataset( + 'target': RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), - left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), @@ -174,7 +172,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): ) def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): - src_dataset = PadDataset( + src_dataset = RightPadDataset( TokenBlockDataset( src_tokens, src_lengths, @@ -184,7 +182,6 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): break_mode='eos', ), pad_idx=self.source_dictionary.pad(), - left_pad=False, ) src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index b50c9922cc..cf5eae38b1 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -192,16 +192,20 @@ def make_dataset(type, dictionary): else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): + def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] - dataset.update( - target=RawLabelDataset([ - parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) - ]) - ) + + with open(label_path) as h: + dataset.update( + target=RawLabelDataset([ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ]) + ) nested_dataset = NestedDictionaryDataset( dataset, diff --git a/fairseq/utils.py b/fairseq/utils.py index 2531896e57..d10ed2f28a 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -527,8 +527,8 @@ def get_token_to_word_mapping(tokens, exclude_list): def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): - tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1) - src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1) + tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) alignment = [] diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index 89e05473f5..4857bc7a87 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from tests.test_sequence_generator import get_dummy_task_and_parser @@ -17,6 +18,10 @@ def setUp(self): self.args = self.parser.parse_args([]) self.args.encoder_layers = 2 self.args.decoder_layers = 1 + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) def test_sets_inference_dropout_to_true(self): self.args.retain_dropout = True diff --git a/tests/test_train.py b/tests/test_train.py index 5be74e415d..048acaca54 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. import contextlib -from io import StringIO +import logging import unittest +from io import StringIO from unittest.mock import MagicMock, patch import torch @@ -74,6 +75,11 @@ def setUp(self): } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] + logging.disable(logging.CRITICAL) + + def tearDown(self): + patch.stopall() + logging.disable(logging.NOTSET) def test_load_partial_checkpoint(self): with contextlib.redirect_stdout(StringIO()): @@ -192,9 +198,6 @@ def mock_finetune_exist(path): self.assertFalse(reset_lr_scheduler) self.assertFalse(reset_meters) - def tearDown(self): - patch.stopall() - if __name__ == '__main__': unittest.main() From bd1b35d9b7cb21b2e7c17201d831c17560265b67 Mon Sep 17 00:00:00 2001 From: Matt Post Date: Thu, 20 Aug 2020 11:58:27 -0700 Subject: [PATCH 122/707] Added constrained decoding (#1536) (#2402) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This PR implements constrained decoding ([Hokamp & Liu, 2017](https://www.aclweb.org/anthology/P17-1141/); [Post & Vilar, 2018](https://www.aclweb.org/anthology/N18-1119/)) with vectorization for batching ([Hu et al., 2019](https://www.aclweb.org/anthology/N19-1090/)). In addition, it add *ordered constraints*, where the constraints are generated on the target side in order, with zero or more unconstrained tokens in between. This variant allows for optimizations that increase speed and BLEU scores (when testing with random scraps from the references). ### Usage and quick start It works with `fairseq-interactive` via a new command-line option: `fairseq-interactive --constraints [ordered,unordered]`, defaulting to `ordered` if nothing is provided. When active, it will split lines from STDIN on `\t`, with separate constraints each separated by a tab. For example (after downloading the [Fairseq WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md)): ```bash echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\tinfluence" \ | [normalize.py](https://gist.github.com/mjpost/4c54446b7030d7c64b57461d27090650) \ | [tok.py](https://gist.github.com/mjpost/ed7456f6a987c533102fc121678ed302) \ | PYTHONPATH=$HOME/code/fairseq-constraints fairseq-interactive $modeldir \ --bpe fastbpe \ --bpe-codes $modeldir/bpecodes \ --constraints \ --constraints-both -s de -t en \ --path $modeldir/model1.pt \ --max-tokens 1000 \ --beam 5 \ ``` Adding the `--constraints-both` option causes it to batch-decode the input sentence both with and without the constraints. When run with the Fairseq WMT19 German--English model, the following results are produced (here run on a CPU, don't be alarmed by the times!) ```text S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren . W-0 1.844 seconds C-0 hard C-0 influence H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence . D-0 -1.5333266258239746 Machine translation is hard to influence . P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511 S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren . W-0 1.844 seconds H-0 -0.3731671869754791 Mach@@ ine trans@@ lation is difficult to control . D-0 -0.3731671869754791 Machine translation is difficult to control . P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.1430 -0.1665 -0.8482 -0.1678 -0.1514 2020-07-31 12:17:55 | INFO | fairseq_cli.interactive | Total time: 12.803 seconds; translation time: 3.688 ``` Note the new tags present in the output: * `C-#` records active constraints (after applying preprocessing) for a sentence * `W-#` reports the sentence-level translation time (a useful unrelated feature I hope you'll accept) Some unit tests are written (`fairseq/test_constraints.py`) but not yet integrated. Advice here on where to place this is welcome. I also have not run this through lint; if someone can tell me the command to run, I'd appreciate it. ### Implementation notes This is largely self-contained, implemented in a new `LexicallyConstrainedBeamSearch` class in `search.py`. It does require a few minimal hooks from `_generate()` in `sequence_generator.py`, to ensure that constraints are updated at each timestep. (Edit: most changes in that file are documentation clarifications, corrections, and updates). Unconstrained sentences that are intermingled with constrained ones will not incur any time penalty, so long as they do not occur in the same batch. Addresses https://github.com/pytorch/fairseq/issues/1536. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2402 Reviewed By: alexeib Differential Revision: D23188945 Pulled By: myleott fbshipit-source-id: 9f5ed855f7a1dcf535b091c0ccf98b07fb9cbdd6 --- README.md | 2 + examples/constrained_decoding/README.md | 124 +++++ examples/constrained_decoding/normalize.py | 26 + examples/constrained_decoding/tok.py | 31 ++ .../translation_moe/src/translation_moe.py | 3 +- fairseq/__init__.py | 1 + fairseq/data/dictionary.py | 2 +- fairseq/data/language_pair_dataset.py | 16 + fairseq/iterative_refinement_generator.py | 4 +- fairseq/options.py | 2 + fairseq/search.py | 337 ++++++++++++ fairseq/sequence_generator.py | 85 ++- fairseq/tasks/fairseq_task.py | 7 +- fairseq/tasks/language_modeling.py | 5 +- fairseq/tasks/multilingual_translation.py | 8 +- fairseq/tasks/translation.py | 8 +- .../tasks/translation_from_pretrained_bart.py | 6 +- fairseq/tasks/translation_lev.py | 6 +- .../tasks/translation_multi_simple_epoch.py | 8 +- fairseq/token_generation_constraints.py | 500 ++++++++++++++++++ fairseq_cli/generate.py | 6 +- fairseq_cli/interactive.py | 85 ++- scripts/constraints/extract.py | 83 +++ scripts/constraints/validate.py | 33 ++ tests/test_constraints.py | 254 +++++++++ 25 files changed, 1598 insertions(+), 44 deletions(-) create mode 100644 examples/constrained_decoding/README.md create mode 100755 examples/constrained_decoding/normalize.py create mode 100755 examples/constrained_decoding/tok.py create mode 100644 fairseq/token_generation_constraints.py create mode 100755 scripts/constraints/extract.py create mode 100755 scripts/constraints/validate.py create mode 100755 tests/test_constraints.py diff --git a/README.md b/README.md index c6c37f0965..f77e56e75b 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ We provide reference implementations of various sequence modeling papers: - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md) + - [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) @@ -51,6 +52,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +- August 2020: [Lexically constrained decoding(examples/constrained_decoding/README.md) - August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) - July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) diff --git a/examples/constrained_decoding/README.md b/examples/constrained_decoding/README.md new file mode 100644 index 0000000000..d101c032da --- /dev/null +++ b/examples/constrained_decoding/README.md @@ -0,0 +1,124 @@ +# (Vectorized) Lexically constrained decoding with dynamic beam allocation + +This page provides instructions for how to use lexically constrained decoding in Fairseq. +Fairseq implements the code described in the following papers: + +* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) +* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) + +## Quick start + +Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`. +Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens) +is a separate field. + +The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md), +translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints +"hard" and "to influence". + + echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \ + | normalize.py | tok.py \ + | fairseq-interactive /path/to/model \ + --path /path/to/model/model1.pt \ + --bpe fastbpe \ + --bpe-codes /path/to/model/bpecodes \ + --constraints \ + -s de -t en \ + --beam 10 + +(tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing). +This will generate the following output: + + [snip] + S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren . + W-0 1.844 seconds + C-0 hard + C-0 influence + H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence . + D-0 -1.5333266258239746 Machine translation is hard to influence . + P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511 + +By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated +between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`. +Note that you may want to use a larger beam. + +## Implementation details + +The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance. +This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints +provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`: + +* OrderedConstraintState: assumes the C input constraints will be generated in the provided order +* UnorderedConstraintState: tries to apply C (phrasal) constraints in all C! orders + +## Differences from Sockeye + +There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints). + +* Generating constraints in the order supplied (the default option here) is not available in Sockeye. +* Due to an improved beam allocation method, there is no need to prune the beam. +* Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient. +* [The extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged + into the main branch. +* Sockeye 2, released in July 2020, no longer supports constrained decoding. + +## Citation + +The paper first describing lexical constraints for seq2seq decoding is: + +```bibtex +@inproceedings{hokamp-liu-2017-lexically, + title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search", + author = "Hokamp, Chris and + Liu, Qun", + booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + month = jul, + year = "2017", + address = "Vancouver, Canada", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/P17-1141", + doi = "10.18653/v1/P17-1141", + pages = "1535--1546", +} +``` + +The fairseq implementation uses the extensions described in + +```bibtex +@inproceedings{post-vilar-2018-fast, + title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation", + author = "Post, Matt and + Vilar, David", + booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)", + month = jun, + year = "2018", + address = "New Orleans, Louisiana", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/N18-1119", + doi = "10.18653/v1/N18-1119", + pages = "1314--1324", +} +``` + +and + +```bibtex +@inproceedings{hu-etal-2019-improved, + title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting", + author = "Hu, J. Edward and + Khayrallah, Huda and + Culkin, Ryan and + Xia, Patrick and + Chen, Tongfei and + Post, Matt and + Van Durme, Benjamin", + booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", + month = jun, + year = "2019", + address = "Minneapolis, Minnesota", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/N19-1090", + doi = "10.18653/v1/N19-1090", + pages = "839--850", +} +``` diff --git a/examples/constrained_decoding/normalize.py b/examples/constrained_decoding/normalize.py new file mode 100755 index 0000000000..2a7ae03102 --- /dev/null +++ b/examples/constrained_decoding/normalize.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from sacremoses.normalize import MosesPunctNormalizer + + +def main(args): + normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn) + for line in sys.stdin: + print(normalizer.normalize(line.rstrip()), flush=True) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--lang', '-l', default='en') + parser.add_argument('--penn', '-p', action='store_true') + args = parser.parse_args() + + main(args) diff --git a/examples/constrained_decoding/tok.py b/examples/constrained_decoding/tok.py new file mode 100755 index 0000000000..9215a66538 --- /dev/null +++ b/examples/constrained_decoding/tok.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import sacremoses + + +def main(args): + """Tokenizes, preserving tabs""" + mt = sacremoses.MosesTokenizer(lang=args.lang) + def tok(s): + return mt.tokenize(s, return_str=True) + + for line in sys.stdin: + parts = list(map(tok, line.split("\t"))) + print(*parts, sep="\t", flush=True) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--lang', '-l', default='en') + parser.add_argument('--penn', '-p', action='store_true') + parser.add_argument('--fields', '-f', help="fields to tokenize") + args = parser.parse_args() + + main(args) diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/src/translation_moe.py index 61e4bed809..b60175f093 100644 --- a/examples/translation_moe/src/translation_moe.py +++ b/examples/translation_moe/src/translation_moe.py @@ -201,13 +201,14 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = self._get_loss(sample, model, criterion) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None): + def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None): expert = expert or self.args.gen_expert with torch.no_grad(): return generator.generate( models, sample, prefix_tokens=prefix_tokens, + constraints=constraints, bos_token=self.expert_index(expert), ) diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 1ba63fcaaa..f7d7793349 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -22,6 +22,7 @@ import fairseq.pdb # noqa import fairseq.scoring # noqa import fairseq.tasks # noqa +import fairseq.token_generation_constraints # noqa import fairseq.benchmark # noqa import fairseq.model_parallel # noqa diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 01a6a81486..77ef9538a3 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -21,10 +21,10 @@ class Dictionary(object): def __init__( self, *, # begin keyword-only arguments + bos="", pad="", eos="", unk="", - bos="", extra_special_symbols=None, ): self.unk_word, self.pad_word, self.eos_word = unk, pad, eos diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 7576e07d34..aed54d618f 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -133,6 +133,16 @@ def compute_alignment_weights(alignments): batch['alignments'] = alignments batch['align_weights'] = align_weights + if samples[0].get("constraints", None) is not None: + # Collate the packed constraints across the samples, padding to + # the length of the longest sample. + lens = [sample.get("constraints").size(0) for sample in samples] + max_len = max(lens) + constraints = torch.zeros((len(samples), max(lens))).long() + for i, sample in enumerate(samples): + constraints[i, 0:lens[i]] = samples[i].get("constraints") + batch["constraints"] = constraints + return batch @@ -161,6 +171,8 @@ class LanguagePairDataset(FairseqDataset): target if it's absent (default: False). align_dataset (torch.utils.data.Dataset, optional): dataset containing alignments. + constraints (Tensor, optional): 2d tensor with a concatenated, zero- + delimited list of constraints for each sentence. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then @@ -180,6 +192,7 @@ def __init__( shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, + constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, @@ -206,6 +219,7 @@ def __init__( self.align_dataset = align_dataset if self.align_dataset is not None: assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" + self.constraints = constraints self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) self.src_lang_id = src_lang_id @@ -279,6 +293,8 @@ def __getitem__(self, index): } if self.align_dataset is not None: example['alignment'] = self.align_dataset[index] + if self.constraints is not None: + example["constraints"] = self.constraints[index] return example def __len__(self): diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index 97e66fabe9..6ac805988a 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -105,7 +105,9 @@ def generate_batched_itr( @torch.no_grad() - def generate(self, models, sample, prefix_tokens=None): + def generate(self, models, sample, prefix_tokens=None, constraints=None): + if constraints is not None: + raise NotImplementedError("Constrained decoding with the IterativeRefinementGenerator is not supported") # TODO: iterative refinement generator does not support ensemble for now. if not self.retain_dropout: diff --git a/fairseq/options.py b/fairseq/options.py index 2a93452de2..74c1499e06 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -609,6 +609,8 @@ def add_generation_args(parser): help='sample from top K likely next words instead of all words') group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS', help='sample from the smallest set whose cumulative probability mass exceeds p for next words') + group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"], + help='enables lexically constrained decoding') group.add_argument('--temperature', default=1., type=float, metavar='N', help='temperature for generation') group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', diff --git a/fairseq/search.py b/fairseq/search.py index 8aa196a3cc..ecb4764a82 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -10,6 +10,8 @@ import torch.nn as nn from torch import Tensor +from fairseq.token_generation_constraints import ConstraintState, UnorderedConstraintState, OrderedConstraintState + class Search(nn.Module): def __init__(self, tgt_dict): @@ -19,6 +21,7 @@ def __init__(self, tgt_dict): self.eos = tgt_dict.eos() self.vocab_size = len(tgt_dict) self.src_lengths = torch.tensor(-1) + self.supports_constraints = False def step(self, step, lprobs, scores): """Take a single search step. @@ -46,10 +49,49 @@ def step(self, step, lprobs, scores): def set_src_lengths(self, src_lengths): self.src_lengths = src_lengths + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + class BeamSearch(Search): def __init__(self, tgt_dict): super().__init__(tgt_dict) + self.constraint_states = None @torch.jit.export def step(self, step: int, lprobs, scores: Optional[Tensor]): @@ -75,11 +117,306 @@ def step(self, step: int, lprobs, scores: Optional[Tensor]): ) scores_buf = top_prediction[0] indices_buf = top_prediction[1] + # Project back into relative indices and beams beams_buf = indices_buf // vocab_size indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices return scores_buf, indices_buf, beams_buf +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + def __init__(self, tgt_dict, representation): + super().__init__(tgt_dict) + self.representation = representation + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == "ordered": + constraint_state = OrderedConstraintState.create(constraint_tensor) + elif self.representation == "unordered": + constraint_state = UnorderedConstraintState.create(constraint_tensor) + + self.constraint_states.append([constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [self.constraint_states[i] for i in batch_idxs.tolist()] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [self.constraint_states[sentid][i] for i in active_hypos[sentid]] + + @torch.jit.export + def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[not_finished_indices, self.eos] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence(step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone()) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence(self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = torch.tensor(beamno, device=device).repeat(next_tokens.size(0)).long() + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size)] + + banks = torch.tensor([state.bank for state in constraint_states], device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = (beams_buf * (self.vocab_size + 1) + indices_buf) + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[:self.num_cands] + indices_buf = indices_buf[:self.num_cands] + beams_buf = beams_buf[:self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + class LengthConstrainedBeamSearch(Search): def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): super().__init__(tgt_dict) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 26e4c287b2..27435986f4 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -61,6 +61,7 @@ def __init__( self.model = models else: self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() if eos is None else eos @@ -113,7 +114,7 @@ def forward( bos_token (int, optional): beginning of sentence token (default: self.eos) """ - return self._generate(sample, prefix_tokens, bos_token) + return self._generate(sample, prefix_tokens, bos_token=bos_token) # TODO(myleott): unused, deprecate after pytorch-translate migration def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): @@ -157,6 +158,8 @@ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints bos_token (int, optional): beginning of sentence token (default: self.eos) """ @@ -166,6 +169,7 @@ def _generate( self, sample: Dict[str, Dict[str, Tensor]], prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, ): incremental_states = torch.jit.annotate( @@ -192,10 +196,15 @@ def _generate( raise Exception('expected src_tokens or source in net input') # bsz: total number of sentences in beam - input_size = src_tokens.size() - bsz, src_len = input_size[0], input_size[1] + bsz, src_len = src_tokens.size() beam_size = self.beam_size + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them") + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + max_len: int = -1 if self.match_source_len: max_len = src_lengths.max().item() @@ -221,7 +230,7 @@ def _generate( # initialize buffers scores = ( torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() - ) # +1 for eos; pad is never choosed for scoring + ) # +1 for eos; pad is never chosen for scoring tokens = ( torch.zeros(bsz * beam_size, max_len + 2) .to(src_tokens) @@ -327,6 +336,7 @@ def _generate( if self.no_repeat_ngram_size > 0: lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step) + # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), @@ -339,10 +349,13 @@ def _generate( cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered eos_bbsz_idx = torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] ) @@ -352,6 +365,7 @@ def _generate( eos_scores = torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] ) + finalized_sents = self.finalize_hypos( step, eos_bbsz_idx, @@ -372,6 +386,8 @@ def _generate( break assert step < max_len + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) @@ -381,6 +397,9 @@ def _generate( # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask) + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) @@ -402,7 +421,8 @@ def _generate( bsz = new_bsz else: batch_idxs = None - # set active_mask so that values > cand_size indicate eos hypos + + # Set active_mask so that values > cand_size indicate eos hypos # and values < cand_size indicate candidate active hypos. # After, the min values per row are the top candidate active hypos @@ -414,16 +434,24 @@ def _generate( cand_offsets[: eos_mask.size(1)], ) - # get the top beam_size active hypotheses, which are just the hypos - # with the smallest values in active_mask + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) new_cands_to_ignore, active_hypos = torch.topk( active_mask, k=beam_size, dim=1, largest=False ) - # update cands_to_ignore to ignore any finalized hypos + # update cands_to_ignore to ignore any finalized hypos. cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. assert (~cands_to_ignore).any(dim=1).all() + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) @@ -431,9 +459,12 @@ def _generate( active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) tokens[:, : step + 1] = torch.index_select( tokens[:, : step + 1], dim=0, index=active_bbsz_idx ) + # Select the next token for each of them tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( cand_indices, dim=1, index=active_hypos ) @@ -445,6 +476,9 @@ def _generate( cand_scores, dim=1, index=active_hypos ) + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + # copy attention for active hypotheses if attn is not None: attn[:, :, : step + 2] = torch.index_select( @@ -517,13 +551,18 @@ def finalize_hypos( max_len: int, ): """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. - Returns number of sentences being finalized. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. Args: bbsz_idx (Tensor): """ assert bbsz_idx.numel() == eos_scores.numel() - # clone relevant token and attention tensors + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} tokens_clone = tokens.index_select(0, bbsz_idx)[ :, 1 : step + 2 ] # skip the first index, which is EOS @@ -545,6 +584,10 @@ def finalize_hypos( if self.normalize_scores: eos_scores /= (step + 1) ** self.len_penalty + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. cum_unfin: List[int] = [] prev = 0 for f in finished: @@ -554,12 +597,22 @@ def finalize_hypos( cum_unfin.append(prev) # set() is not supported in script export + + # The keys here are of the form "{sent}_{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch sents_seen: Dict[str, Optional[Tensor]] = {} + + # For every finished beam item for i in range(bbsz_idx.size()[0]): idx = bbsz_idx[i] score = eos_scores[i] + # sentence index in the current (possibly reduced) batch unfin_idx = idx // beam_size + # sentence index in the original (unreduced) batch sent = unfin_idx + cum_unfin[unfin_idx] + # print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}") # Cannot create dict for key type '(int, int)' in torchscript. # The workaround is to cast int to string seen = str(sent.item()) + "_" + str(unfin_idx.item()) @@ -569,12 +622,15 @@ def finalize_hypos( if self.match_source_len and step > src_lengths[unfin_idx]: score = torch.tensor(-math.inf).to(score) + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it if len(finalized[sent]) < beam_size: if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = torch.empty(0) + finalized[sent].append( { "tokens": tokens_clone[i], @@ -586,15 +642,18 @@ def finalize_hypos( ) newly_finished: List[int] = [] + for seen in sents_seen.keys(): # check termination conditions for this sentence sent: int = int(float(seen.split("_")[0])) unfin_idx: int = int(float(seen.split("_")[1])) + if not finished[sent] and self.is_finished( step, unfin_idx, max_len, len(finalized[sent]), beam_size ): finished[sent] = True newly_finished.append(unfin_idx) + return newly_finished def is_finished( @@ -606,9 +665,9 @@ def is_finished( beam_size: int, ): """ - Check whether we've finished generation for a given sentence, by - comparing the worst score among finalized hypotheses to the best - possible score among unfinalized hypotheses. + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. """ assert finalized_sent_len <= beam_size if finalized_sent_len == beam_size or step == max_len: diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 5128d4b3f4..ddc6760842 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -291,6 +291,7 @@ def build_generator( diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) if ( sum( int(cond) @@ -330,6 +331,8 @@ def build_generator( search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch(self.target_dictionary, args.constraints) else: search_strategy = search.BeamSearch(self.target_dictionary) @@ -395,9 +398,9 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None): + def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): with torch.no_grad(): - return generator.generate(models, sample, prefix_tokens=prefix_tokens) + return generator.generate(models, sample, prefix_tokens=prefix_tokens, constraints=constraints) def begin_epoch(self, epoch, model): """Hook function called before the start of each epoch.""" diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index a4a98e07bc..1916a1550c 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -258,7 +258,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): sizes=[np.array(src_lengths)], ) - def inference_step(self, generator, models, sample, prefix_tokens=None): + def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): with torch.no_grad(): # Generation will always be conditioned on bos_token if getattr(self.args, "add_bos_token", False): @@ -266,6 +266,9 @@ def inference_step(self, generator, models, sample, prefix_tokens=None): else: bos_token = self.source_dictionary.eos() + if constraints is not None: + raise NotImplementedError("Constrained decoding with the language_modeling task is not supported") + # SequenceGenerator doesn't use src_tokens directly, we need to # pass the `prefix_tokens` argument instead if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 59634131fc..272bcf1ae1 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -221,7 +221,10 @@ def language_pair_dataset(lang_pair): eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), ) - def build_dataset_for_inference(self, src_tokens, src_lengths): + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported") + lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) return RoundRobinZipDatasets( OrderedDict([( @@ -312,7 +315,7 @@ def valid_step(self, sample, model, criterion): agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] return agg_loss, agg_sample_size, agg_logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None): + def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): with torch.no_grad(): if self.args.decoder_langtok: bos_token = _lang_token_index(self.target_dictionary, self.args.target_lang) @@ -322,6 +325,7 @@ def inference_step(self, generator, models, sample, prefix_tokens=None): models, sample, prefix_tokens=prefix_tokens, + constraints=constraints, bos_token=bos_token, ) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index ab1ff3cf34..a01768ecb6 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -270,8 +270,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): shuffle=(split != 'test'), ) - def build_dataset_for_inference(self, src_tokens, src_lengths): - return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary) + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints) def build_model(self, args): model = super().build_model(args) @@ -377,7 +379,7 @@ def decode(toks, escape_unk=False): s = self.tokenizer.decode(s) return s - gen_out = self.inference_step(generator, [model], sample, None) + gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) hyps, refs = [], [] for i in range(len(gen_out)): hyps.append(decode(gen_out[i][0]['tokens'])) diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 2b7d589cee..b3c9f8e440 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -109,11 +109,13 @@ def build_generator(self, models, args): eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang)) ) - def build_dataset_for_inference(self, src_tokens, src_lengths): + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.source_lang)) source_tokens = [] for s_t in src_tokens: s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) source_tokens.append(s_t) - dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary) + dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints) return dataset diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 845dd81644..be362a1881 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -141,7 +141,11 @@ def build_generator(self, models, args): adaptive=not getattr(args, 'iter_decode_force_max_iter', False), retain_history=getattr(args, 'retain_iter_history', False)) - def build_dataset_for_inference(self, src_tokens, src_lengths): + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ + raise NotImplementedError("Constrained decoding with the translation_lev task is not supported") + return LanguagePairDataset( src_tokens, src_lengths, self.source_dictionary, append_bos=True ) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index b517d6f2b7..63a3d3ab86 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -126,7 +126,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs ) - def build_dataset_for_inference(self, src_tokens, src_lengths): + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported") + src_data = ListDataset(src_tokens, src_lengths) dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main'] @@ -173,7 +176,7 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None): + def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): with torch.no_grad(): _, tgt_langtok_spec = self.args.langtoks['main'] if not self.args.lang_tok_replacing_bos_eos: @@ -188,6 +191,7 @@ def inference_step(self, generator, models, sample, prefix_tokens=None): models, sample, prefix_tokens=prefix_tokens, + constraints=constraints, ) else: return generator.generate( diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py new file mode 100644 index 0000000000..7077199fd9 --- /dev/null +++ b/fairseq/token_generation_constraints.py @@ -0,0 +1,500 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +import torch + +from collections import Counter +from typing import Tuple, List, Optional, Set + +class ConstraintState: + def __init__(self): + pass + + +def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = 1 + sum([c.size(0) for c in sentence_constraints]) + len(sentence_constraints) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset:offset+this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f"[{self.token}].{term}#{self.num_constraints}" + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: "ConstraintNode"): + if len(node.children) == 0: + return str(node) + else: + s = f"({node}" + for child in node.children.values(): + s += " " + ConstraintNode.print_graph(child) + s += ")" + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + def __init__(self, + node: ConstraintNode, + copy_from: "ConstraintState" = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ",".join([str(node) for node in self.generated]) + return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return "ROOT" + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState(self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + def __init__(self, + sequence: ConstraintSequence, + state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f"{self.state}/{self.bank}x{self.num_completed}" + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len(list(filter(lambda x: x, self.sequence.endpoints[0:self.state+1]))) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return "ROOT" + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 03d2b7dfc0..f4c86f6479 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -150,8 +150,12 @@ def decode_fn(x): if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] + gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) + hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 2258f18326..032966051a 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -12,6 +12,7 @@ import logging import math import sys +import time import os import numpy as np @@ -20,9 +21,9 @@ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders +from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from .generate import get_symbols_to_strip_from_output - logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -32,7 +33,7 @@ logger = logging.getLogger('fairseq_cli.interactive') -Batch = namedtuple('Batch', 'ids src_tokens src_lengths') +Batch = namedtuple('Batch', 'ids src_tokens src_lengths constraints') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') @@ -50,28 +51,63 @@ def buffered_read(input, buffer_size): def make_batches(lines, args, task, max_positions, encode_fn): + def encode_fn_target(x): + return encode_fn(x) + + if args.constraints: + # Strip (tab-delimited) contraints, if present, from input lines, + # store them in batch_constraints + batch_constraints = [list() for _ in lines] + for i, line in enumerate(lines): + if "\t" in line: + lines[i], *batch_constraints[i] = line.split("\t") + + # Convert each List[str] to List[Tensor] + for i, constraint_list in enumerate(batch_constraints): + batch_constraints[i] = [task.target_dictionary.encode_line( + encode_fn_target(constraint), + append_eos=False, + add_if_not_exist=False, + ) for constraint in constraint_list] + tokens = [ task.source_dictionary.encode_line( encode_fn(src_str), add_if_not_exist=False ).long() for src_str in lines ] + + if args.constraints: + constraints_tensor = pack_constraints(batch_constraints) + else: + constraints_tensor = None + lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( - dataset=task.build_dataset_for_inference(tokens, lengths), + dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test ).next_epoch_itr(shuffle=False) for batch in itr: + ids = batch['id'] + src_tokens = batch['net_input']['src_tokens'] + src_lengths = batch['net_input']['src_lengths'] + constraints = batch.get("constraints", None) + yield Batch( - ids=batch['id'], - src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], + ids=ids, + src_tokens=src_tokens, + src_lengths=src_lengths, + constraints=constraints, ) def main(args): + start_time = time.time() + total_translate_time = 0 + utils.import_user_module(args) if args.buffer_size < 1: @@ -147,6 +183,9 @@ def decode_fn(x): *[model.max_positions() for model in models] ) + if args.constraints: + logger.warning("NOTE: Constrained decoding currently assumes a shared subword vocabulary.") + if args.buffer_size > 1: logger.info('Sentence buffer size: %s', args.buffer_size) logger.info('NOTE: hypothesis and token scores are output in base 2') @@ -155,11 +194,15 @@ def decode_fn(x): for inputs in buffered_read(args.input, args.buffer_size): results = [] for batch in make_batches(inputs, args, task, max_positions, encode_fn): + bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths + constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() + if constraints is not None: + constraints = constraints.cuda() sample = { 'net_input': { @@ -167,16 +210,29 @@ def decode_fn(x): 'src_lengths': src_lengths, }, } - translations = task.inference_step(generator, models, sample) + translate_start_time = time.time() + translations = task.inference_step(generator, models, sample, constraints=constraints) + translate_time = time.time() - translate_start_time + total_translate_time += translate_time + list_constraints = [[] for _ in range(bsz)] + if args.constraints: + list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) - results.append((start_id + id, src_tokens_i, hypos)) + constraints = list_constraints[i] + results.append((start_id + id, src_tokens_i, hypos, + { "constraints": constraints, + "time": translate_time / len(translations) } + )) # sort output to match input order - for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): + for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) - print('S-{}\t{}'.format(id, src_str)) + print('S-{}\t{}'.format(id_, src_str)) + print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) + for constraint in info["constraints"]: + print("C-{}\t{}".format(id_, tgt_dict.string(constraint, args.remove_bpe))) # Process top predictions for hypo in hypos[:min(len(hypos), args.nbest)]: @@ -192,11 +248,11 @@ def decode_fn(x): detok_hypo_str = decode_fn(hypo_str) score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) - print('H-{}\t{}\t{}'.format(id, score, hypo_str)) + print('H-{}\t{}\t{}'.format(id_, score, hypo_str)) # detokenized hypothesis - print('D-{}\t{}\t{}'.format(id, score, detok_hypo_str)) + print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str)) print('P-{}\t{}'.format( - id, + id_, ' '.join(map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 @@ -206,13 +262,14 @@ def decode_fn(x): if args.print_alignment: alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format( - id, + id_, alignment_str )) - # update running id counter + # update running id_ counter start_id += len(inputs) + logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(time.time() - start_time, total_translate_time)) def cli_main(): parser = options.get_interactive_generation_parser() diff --git a/scripts/constraints/extract.py b/scripts/constraints/extract.py new file mode 100755 index 0000000000..8f9bc4ad14 --- /dev/null +++ b/scripts/constraints/extract.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Extracts random constraints from reference files.""" + +import argparse +import random +import sys +from sacrebleu import extract_ngrams + + +def get_phrase(words, index, length): + assert(index < len(words) - length + 1) + phr = ' '.join(words[index:index+length]) + for i in range(index, index + length): + words.pop(index) + return phr + + +def main(args): + + if args.seed: + random.seed(args.seed) + + for line in sys.stdin: + constraints = [] + + def add_constraint(constraint): + constraints.append(constraint) + + source = line.rstrip() + if '\t' in line: + source, target = line.split('\t') + if args.add_sos: + target = f" {target}" + if args.add_eos: + target = f"{target} " + + if len(target.split()) >= args.len: + words = [target] + + num = args.number + + choices = {} + for i in range(num): + if len(words) == 0: + break + segmentno = random.choice(range(len(words))) + segment = words.pop(segmentno) + tokens = segment.split() + phrase_index = random.choice(range(len(tokens))) + choice = " ".join(tokens[phrase_index:min(len(tokens), phrase_index + args.len)]) + for j in range(phrase_index, min(len(tokens), phrase_index + args.len)): + tokens.pop(phrase_index) + if phrase_index > 0: + words.append(" ".join(tokens[0:phrase_index])) + if phrase_index + 1 < len(tokens): + words.append(" ".join(tokens[phrase_index:])) + choices[target.find(choice)] = choice + + # mask out with spaces + target = target.replace(choice, " " * len(choice), 1) + + for key in sorted(choices.keys()): + add_constraint(choices[key]) + + print(source, *constraints, sep="\t") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--number', '-n', type=int, default=1, help="number of phrases") + parser.add_argument('--len', '-l', type=int, default=1, help="phrase length") + parser.add_argument('--add-sos', default=False, action='store_true', help='add token') + parser.add_argument('--add-eos', default=False, action='store_true', help='add token') + parser.add_argument('--seed', "-s", default=0, type=int) + args = parser.parse_args() + + main(args) diff --git a/scripts/constraints/validate.py b/scripts/constraints/validate.py new file mode 100755 index 0000000000..6d1a4a0885 --- /dev/null +++ b/scripts/constraints/validate.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +"""Reads in a fairseq output file, and verifies that the constraints +(C- lines) are present in the output (the first H- line). Assumes that +constraints are listed prior to the first hypothesis. +""" + +constraints = [] +found = 0 +total = 0 +for line in sys.stdin: + if line.startswith("C-"): + constraints.append(line.rstrip().split("\t")[1]) + elif line.startswith("H-"): + text = line.split("\t")[2] + + for constraint in constraints: + total += 1 + if constraint in text: + found += 1 + else: + print(f"No {constraint} in {text}", file=sys.stderr) + + constraints = [] + +print(f"Found {found} / {total} = {100 * found / total:.1f}%") diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100755 index 0000000000..3f63c8ace5 --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import torch +import unittest + +from fairseq.token_generation_constraints import * + + +def tensorize(constraints: List[List[int]]) -> torch.Tensor: + return [torch.tensor(x) for x in constraints] + + +class TestHelperRoutines(unittest.TestCase): + def setUp(self): + self.examples = [ + ( + [[]], + torch.tensor([[0]]) + ), + ( + [[], []], + torch.tensor([[0], [0]]) + ), + ( + [[torch.tensor([1, 2])], []], + torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]]) + ), + ( + [[torch.tensor([3, 1, 2]), torch.tensor([3]), torch.tensor([4, 5, 6, 7])], + [], + [ torch.tensor([1, 8, 9, 10, 1, 4, 11, 12]) ]], + torch.tensor([[3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0]]) + ) + ] + + def test_packing(self): + """Ensures the list of lists of tensors gets packed correctly.""" + for batch_constraints, expected_tensor in self.examples: + packed = pack_constraints(batch_constraints) + assert torch.equal(packed, expected_tensor) + + +class TestUnorderedConstraintState(unittest.TestCase): + def setUp(self): + # Tuples of (contraint set, expected printed graph, token counts per node) + self.examples = [ + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + "([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))", + { 1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1 } + ), + ( [], "[None].False#0", {} ), + ( tensorize([[0]]), "([None].False#1 [0].True#1)", { 0: 1 } ), + ( tensorize([[100000, 1, 2, 3, 4, 5]]), "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", { 100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1 } ), + ( + tensorize([[1, 2], [1, 2]]), + "([None].False#2 ([1].False#2 [2].True#2))", + { 1: 2, 2: 2 }, + ), + ( + tensorize([[1, 2], [3, 4]]), + "([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))", + { 1: 1, 2: 1, 3: 1, 4: 1}, + ), + ] + + self.sequences = [ + ( + self.examples[0][0], + [], + { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + ), + ( + self.examples[0][0], + [1, 2], + { "bank": 2, "num_completed": 0, "finished": False, "is_root": False }, + ), + ( + self.examples[0][0], + [1, 2, 94], + { "bank": 1, "num_completed": 1, "finished": False, "is_root": True }, + ), + ( + self.examples[0][0], + [1, 3, 999, 1, 4], + { "bank": 4, "num_completed": 2, "finished": False, "is_root": False }, + ), + ( + self.examples[0][0], + [1, 3, 999, 1, 4, 999], + { "bank": 4, "num_completed": 2, "finished": False, "is_root": True }, + ), + ( + self.examples[0][0], + [4, 5, 6, 8], + { "bank": 2, "num_completed": 1, "finished": False, "is_root": True }, + ), + ( + self.examples[0][0], + # Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5] + # [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]], + [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], + { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + ), + ( + self.examples[0][0], + [1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], + { "bank": 14, "num_completed": 6, "finished": True, "is_root": True }, + ), + ( + tensorize([[1], [2, 3]]), + # Should not be able to get credit for entering 1 a second time + [1, 1], + { "bank": 1, "num_completed": 1, "finished": False, "is_root": True }, + ), + ( + self.examples[4][0], + [1, 2, 1, 2], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + ), + ( + self.examples[4][0], + [1, 2, 1, 2, 1], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": True }, + ), + ( + self.examples[5][0], + [1, 2, 3, 4, 5], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": True }, + ), + ] + + def test_graphs(self): + """ + Test whether unordered graph systems are created correctly. + """ + for example in self.examples: + constraints, expected, gold_counts = example + c = ConstraintNode.create(constraints) + assert ConstraintNode.print_graph(c) == expected, f"got {ConstraintNode.print_graph(c)}, expected {expected}" + assert c.token_counts() == gold_counts, f"{c} got {c.token_counts()} wanted {gold_counts}" + + def test_next_tokens(self): + """ + Tests that the set of next tokens is correct. + """ + for example in self.examples: + constraints, expected, gold_counts = example + root = ConstraintNode.create(constraints) + + root_tokens = set(root.children.keys()) + for sequence in constraints: + state = UnorderedConstraintState(root) + for token in sequence: + all_tokens = root_tokens.union(state.node.children.keys()) + assert all_tokens == state.next_tokens(), f"ALL {all_tokens} NEXT {state.next_tokens()}" + state = state.advance(token) + + def test_sequences(self): + for constraints, tokens, expected in self.sequences: + state = UnorderedConstraintState.create(pack_constraints([constraints])[0]) + for token in tokens: + state = state.advance(token) + result = {} + for attr in expected.keys(): + result[attr] = getattr(state, attr) + + assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}" + + +class TestOrderedConstraintState(unittest.TestCase): + def setUp(self): + self.sequences = [ + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [], + { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2], + { "bank": 2, "num_completed": 0, "finished": False, "is_root": False }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 94], + { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 3, 999, 1, 4], + { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 999, 999], + { "bank": 3, "num_completed": 1, "finished": False, "is_root": False }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 77, 1, 3, 1], + { "bank": 6, "num_completed": 2, "finished": False, "is_root": False }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], + { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + ), + ( + tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), + [1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], + { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + ), + ( + tensorize([[1], [2, 3]]), + [1, 1], + { "bank": 1, "num_completed": 1, "finished": False, "is_root": False }, + ), + ( + tensorize([[1, 2], [1, 2]]), + [1, 2, 1, 2], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + ), + ( + tensorize([[1, 2], [1, 2]]), + [1, 2, 1, 2, 1], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + ), + ( + tensorize([[1, 2], [3, 4]]), + [1, 2, 3, 4, 5], + { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + ), + ] + + def test_sequences(self): + for i, (constraints, tokens, expected) in enumerate(self.sequences): + state = OrderedConstraintState.create(pack_constraints([constraints])[0]) + for token in tokens: + state = state.advance(token) + result = {} + for attr in expected.keys(): + result[attr] = getattr(state, attr) + assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}" + +if __name__ == "__main__": + unittest.main() + From 703fd48bb1468ed2df22a681511a03f1937ed60f Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 20 Aug 2020 15:42:59 -0700 Subject: [PATCH 123/707] Fix README and #2496 (#2505) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2505 Reviewed By: shruti-bh Differential Revision: D23247882 Pulled By: myleott fbshipit-source-id: 1cfc9e0128e1aa55a1aca31d8dd30f231558e70f --- README.md | 3 ++- fairseq/criterions/masked_lm.py | 2 +- fairseq_cli/eval_lm.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f77e56e75b..3743571b91 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: -- August 2020: [Lexically constrained decoding(examples/constrained_decoding/README.md) +- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) - August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) - July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) @@ -84,6 +84,7 @@ We provide reference implementations of various sequence modeling papers: - beam search - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - sampling (unconstrained, top-k and top-p/nucleus) + - lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) - large mini-batch training even on a single GPU via delayed updates - mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index 80864693ec..f62ed805f2 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -18,7 +18,7 @@ class MaskedLmLoss(FairseqCriterion): Implementation for the loss used in masked language model (MLM) training. """ - def __init__(self, task, tpu): + def __init__(self, task, tpu=False): super().__init__(task) self.tpu = tpu diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index cf5b62c366..1c12ad6091 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -176,7 +176,7 @@ def main(parsed_args, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() - if args.add_bos_token: + if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] From 83d701ac10a328e54422a9244337d8e569bcf96d Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Thu, 20 Aug 2020 17:03:48 -0700 Subject: [PATCH 124/707] log latest value on loss_scale Summary: the log_scalar defaults to logging average, which isn't very useful. Reviewed By: myleott Differential Revision: D23226068 fbshipit-source-id: 78a67366344608d452141523fb900f99706a8a17 --- fairseq/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index de5df45105..6cd73a631a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -592,7 +592,13 @@ def maybe_no_sync(): torch.cuda.empty_cache() if self.args.fp16: - metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4) + metrics.log_scalar( + "loss_scale", + self.optimizer.scaler.loss_scale, + priority=700, + round=4, + weight=0, + ) metrics.log_stop_time("train_wall") From 49940c8d25d61a251e290d96fe3bbbc9f210408f Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Thu, 20 Aug 2020 20:07:45 -0700 Subject: [PATCH 125/707] fix mismatch length of counting iterator when truncated Summary: PySpeech integration training tests have recently been stuck at end of epoch. Digging into it, it looks like this is because the end of epoch check relies on this (https://fburl.com/diffusion/xt09z6n9): ``` def end_of_epoch(self) -> bool: """Returns whether the most recent epoch iterator has been exhausted""" return not self._cur_epoch_itr.has_next() ``` which is implemented like this in CountingIterator: def has_next(self): """Whether the iterator has been exhausted.""" return self.n < len(self) It seems like D23172408 (https://github.com/pytorch/fairseq/commit/110f9f0cc781354eee358b28445d2096cdbd4a14) modified CountingIterator such that `len(self) > len(iter(self))` when `take()` is used. This mismatch causes `has_next` to return `True` for some PySpeech processes even when all elements in `iter(self))` have been consumed, causing training to get stuck. My proposed fix is to remove the `self.early_stop` variable and just directly modify `self.total` and `self.iterable`, ensuring `len(self) == len(iter(self))` Reviewed By: myleott Differential Revision: D23250734 fbshipit-source-id: efb5a38216783bded67f501135b2f68b9246b9dd --- fairseq/data/iterators.py | 8 +++----- tests/test_iterators.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index b6ab54d4de..e902b2fb47 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -53,8 +53,6 @@ def __init__(self, iterable, start=None, total=None): else: self.total = total - self.early_stop = self.total - def __len__(self): return self.total @@ -65,8 +63,6 @@ def __iter__(self): 'Mismatch between actual and expected iterable length. ' 'Please report this to the fairseq developers.' ) - elif self.n >= self.early_stop: - return # early stop based on take() self.n += 1 yield x @@ -86,11 +82,13 @@ def take(self, n): """ Truncates the iterator to n elements at most. """ - self.early_stop = min(self.early_stop, n) + self.total = min(self.total, n) # Propagate this change to the underlying iterator if hasattr(self.iterable, "take"): self.iterable.take(n) + else: + self.iterable = itertools.islice(self.iterable, n) class EpochBatchIterating(object): diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 6e935ffc55..7ceef124f5 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -70,6 +70,21 @@ def test_sharded_iterator(self): itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0) self.test_counting_iterator(ref, itr) + def test_counting_iterator_take(self): + ref = list(range(10)) + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(len(itr), len(list(iter(itr)))) + self.assertEqual(len(itr), 5) + + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(next(itr), ref[1]) + itr.skip(2) + self.assertEqual(next(itr), ref[4]) + self.assertFalse(itr.has_next()) + if __name__ == '__main__': unittest.main() From 39f0911ac6ea9d4fe9f4714d6343a2b1a1343609 Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 21 Aug 2020 20:35:54 -0700 Subject: [PATCH 126/707] fix validation interval updates with loss scaling and remove filtering by size from audio pretraining (#1252) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1252 Reviewed By: xuqiantong Differential Revision: D23277400 Pulled By: alexeib fbshipit-source-id: 169fa4d2cb91e9a089fe9d58e8085213e4b800e1 --- fairseq/tasks/audio_pretraining.py | 10 ++++++++++ fairseq_cli/train.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 46d164ba98..2a51279ebc 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -135,3 +135,13 @@ def target_dictionary(self): def max_positions(self): """Maximum input length supported by the encoder.""" return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size( + self, + indices, + dataset, + max_positions=None, + ignore_invalid_inputs=False, + ): + # we do not need to filter by size in this task as dataloaders take care of this + return indices diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 6c8bd5d4df..72e95c917f 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -238,7 +238,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - or (args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0) + or (args.validate_interval_updates > 0 and num_updates > 0 and num_updates % args.validate_interval_updates == 0) ) and not args.disable_validation # Validate From 17fd14d3fb8dc01bac2be44cc72c64e66a41feb2 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 21 Aug 2020 20:56:28 -0700 Subject: [PATCH 127/707] fix convtransformer test Summary: in D23188945 (https://github.com/pytorch/fairseq/commit/bd1b35d9b7cb21b2e7c17201d831c17560265b67) looks like the sequence generator was modified to only support 2d src tokens, but for audio seems like we need 3d Reviewed By: zhengwy888 Differential Revision: D23274482 fbshipit-source-id: 2e66c897c9d2c929158c2c9f858eac607af302bf --- fairseq/sequence_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 27435986f4..6cc4a13e20 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -196,7 +196,8 @@ def _generate( raise Exception('expected src_tokens or source in net input') # bsz: total number of sentences in beam - bsz, src_len = src_tokens.size() + # Note that src_tokens may have more than 2 dimenions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] beam_size = self.beam_size if constraints is not None and not self.search.supports_constraints: From 585c330ea35711e2d6ac7c56bdea44c82669b243 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 24 Aug 2020 10:27:29 -0700 Subject: [PATCH 128/707] Enable no target data inference for multilingual tasks Summary: Multilingual data manager always check target dataset which is not necessary for inference time. Here is a fix. Reviewed By: pipibjc Differential Revision: D23282118 fbshipit-source-id: cbfd6f17919694fbb69a3ae85cda8e9c96df6764 --- .../multilingual/multilingual_data_manager.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 77731ae203..0d02ac1e0a 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -342,10 +342,10 @@ def mono_split_exists(cls, split, lang, data_path, dataset_impl): def bitext_split_exists(cls, split, src, tgt, data_path, dataset_impl): src_exists = cls.split_exists(split, src, tgt, lang=src, data_path=data_path, dataset_impl=dataset_impl) \ or cls.split_exists(split, tgt, src, lang=src, data_path=data_path, dataset_impl=dataset_impl) - - tgt_exists = cls.split_exists(split, src, tgt, lang=tgt, data_path=data_path, dataset_impl=dataset_impl) \ - or cls.split_exists(split, tgt, src, lang=tgt, data_path=data_path, dataset_impl=dataset_impl) - return src_exists and tgt_exists + # check source exists to determine shard number + # also note that during inference time target is not required + # so checking target will fail inference time data loading + return src_exists @classmethod def get_split_num_shards(cls, split, src, tgt, data_paths, dataset_impl): @@ -488,8 +488,10 @@ def load_langpair_dataset( f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}") return LanguagePairDataset( - src_dataset, src_dataset.sizes, src_dict, - tgt_dataset, tgt_dataset.sizes, tgt_dict, + src_dataset, src_dataset.sizes, + src_dict, + tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, @@ -510,6 +512,9 @@ def src_dataset_tranform_func(self, src_lang, tgt_lang, dataset, spec=None): return dataset def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None): + if dataset is None: + # note that target dataset can be None during inference time + return None if self.args.lang_tok_replacing_bos_eos: # TODO: Unifiy with alter_dataset_langtok # It is handled by self.alter_dataset_langtok. From 2239cdadfa5ecf09e08887c77c05f5d7a8d532bd Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 24 Aug 2020 12:16:56 -0700 Subject: [PATCH 129/707] fix generate bugs introduced by constraint decoding (#1254) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1254 Reviewed By: ngoyal2707 Differential Revision: D23287443 Pulled By: alexeib fbshipit-source-id: 6fd358019a72abd0863308571d7cb63b47603964 --- examples/speech_recognition/infer.py | 2 +- examples/speech_recognition/w2l_decoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index d22acc9c3b..e40d37d390 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -219,7 +219,7 @@ def __init__(self, decoder, emissions): self.decoder = decoder self.emissions = emissions - def generate(self, models, sample, prefix_tokens=None): + def generate(self, models, sample, **unused): ids = sample["id"].cpu().numpy() try: emissions = np.stack(self.emissions[ids]) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 149bec0c49..020aac5593 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -66,7 +66,7 @@ def __init__(self, args, tgt_dict): else: raise RuntimeError(f"unknown criterion: {args.criterion}") - def generate(self, models, sample, prefix_tokens=None): + def generate(self, models, sample, **unused): """Generate a batch of inferences.""" # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder From 226f0e45391ac1dcacb72ff68b523bf7b2ebceda Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 24 Aug 2020 17:34:31 -0700 Subject: [PATCH 130/707] fix #2483 by explicitly specifying a cross-platform dtype. (#2521) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2483 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2521 Reviewed By: michaelauli, lematt1991 Differential Revision: D23303214 Pulled By: alexeib fbshipit-source-id: 4b60afc7b902c07bc36fad2eeb116a8ed7f5ffe2 --- fairseq/data/fairseq_dataset.py | 2 +- fairseq/data/language_pair_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 2c972127a7..f196aff14f 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -50,7 +50,7 @@ def size(self, index): def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" - return np.arange(len(self)) + return np.arange(len(self), dtype=np.int64) @property def supports_prefetch(self): diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index aed54d618f..fba3d37bc5 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -372,9 +372,9 @@ def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: - indices = np.random.permutation(len(self)) + indices = np.random.permutation(len(self)).astype(np.int64) else: - indices = np.arange(len(self)) + indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: From fc27170a9e70c6485331d8c84d56142a98de8a84 Mon Sep 17 00:00:00 2001 From: Jun Ru Anderson Date: Wed, 26 Aug 2020 09:26:38 -0700 Subject: [PATCH 131/707] change supports_step_with_scale check (#1255) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Changes the supports_step_with_scale check in fp16_optimizer.py to use fp32_optimizer.supports_step_with_scale / wrapped_optimizer.supports_step_with-scale rather than self.supports_step_with_scale. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1255 Reviewed By: myleott, stanvp Differential Revision: D23298512 Pulled By: andersonic fbshipit-source-id: a346375a3d8b0c2fec33322e2adfaee268c25423 --- fairseq/optim/fp16_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index a815eeb085..30d486c393 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -145,7 +145,7 @@ def step(self, closure=None): """Performs a single optimization step.""" self._sync_fp16_grads_to_fp32() - if self.supports_step_with_scale: + if getattr(self, 'supports_step_with_scale', False): self.fp32_optimizer.step(closure, scale=(1. / self._multiply_factor)) else: self._unscale_grads() @@ -332,7 +332,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): def step(self, closure=None): """Performs a single optimization step.""" - if self.supports_step_with_scale: + if getattr(self, 'supports_step_with_scale', False): # NOTE(msb) optimizer divides by scale factor self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor)) else: From 59ac9c0c12c68afdf1f03fdfc4437b70182406db Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Thu, 27 Aug 2020 17:23:13 -0700 Subject: [PATCH 132/707] Add a FastaDataset (#1187) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Adds some fantastic work done by Zeming Lin ebetica. FASTA is the predominant format used by biologists for DNA, RNA and proteins. It looks like something like this: ``` >name of your protein1 MSHFAHSDFAHSDFHWEHJW FHDSJFASJDAHASFASDFIAA >name of your protein2 MAHASDFMASFJADSFMSMSM MASDFJASDJ ``` There's no need for BPE or other fancy preprocessing, so we can read the FASTA file directly in fairseq with no speed hit compared to binarized data. Building the index is important, but we can just cache that, similar to the other cached indexed datasets. We hope this reduces the barrier for biologists to use fairseq, making this great framework even more accessible to the computational biology community! This dataset is used internally in proteinseq, [see here for an example](https://github.com/fairinternal/proteinseq/pull/178/files). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1187 Reviewed By: myleott Differential Revision: D22020223 Pulled By: ebetica fbshipit-source-id: 372ebc199c0c9200645c79fa7722aded931e9038 --- fairseq/data/__init__.py | 4 ++ fairseq/data/fasta_dataset.py | 107 ++++++++++++++++++++++++++++++++ fairseq/data/indexed_dataset.py | 10 ++- 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 fairseq/data/fasta_dataset.py diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index a99d9280fa..d195e59493 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -49,6 +49,8 @@ from .shorten_dataset import TruncateDataset, RandomCropDataset from .multilingual.sampled_multi_dataset import SampledMultiDataset from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset +from .fasta_dataset import FastaDataset, EncodedFastaDataset + from .iterators import ( CountingIterator, EpochBatchIterator, @@ -68,9 +70,11 @@ 'CountingIterator', 'DenoisingDataset', 'Dictionary', + 'EncodedFastaDataset', 'EpochBatchIterator', 'FairseqDataset', 'FairseqIterableDataset', + 'FastaDataset', 'GroupedIterator', 'IdDataset', 'IndexedCachedDataset', diff --git a/fairseq/data/fasta_dataset.py b/fairseq/data/fasta_dataset.py new file mode 100644 index 0000000000..007011974a --- /dev/null +++ b/fairseq/data/fasta_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import threading +from pathlib import Path + +import numpy as np +import torch + + +def fasta_file_path(prefix_path): + return prefix_path + ".fasta" + + +class FastaDataset(torch.utils.data.Dataset): + """ + For loading protein sequence datasets in the common FASTA data format + """ + + def __init__(self, path: str, cache_indices=False): + self.fn = fasta_file_path(path) + self.threadlocal = threading.local() + self.cache = Path(f"{path}.fasta.idx.npy") + if cache_indices: + if self.cache.exists(): + self.offsets, self.sizes = np.load(self.cache) + else: + self.offsets, self.sizes = self._build_index(path) + np.save(self.cache, np.stack([self.offsets, self.sizes])) + else: + self.offsets, self.sizes = self._build_index(path) + + def _get_file(self): + if not hasattr(self.threadlocal, "f"): + self.threadlocal.f = open(self.fn, "r") + return self.threadlocal.f + + def __getitem__(self, idx): + f = self._get_file() + f.seek(self.offsets[idx]) + desc = f.readline().strip() + line = f.readline() + seq = "" + while line != "" and line[0] != ">": + seq += line.strip() + line = f.readline() + return desc, seq + + def __len__(self): + return self.offsets.size + + def _build_index(self, path: str): + # Use grep and awk to get 100M/s on local SSD. + # Should process your enormous 100G fasta in ~10 min single core... + path = fasta_file_path(path) + bytes_offsets = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| grep --byte-offset '^>' -o | cut -d: -f1", + shell=True, + ) + fasta_lengths = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", + shell=True, + ) + bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") + sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") + return bytes_np, sizes_np + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "f"): + self.threadlocal.f.close() + del self.threadlocal.f + + @staticmethod + def exists(path): + return os.path.exists(fasta_file_path(path)) + + +class EncodedFastaDataset(FastaDataset): + """ + The FastaDataset returns raw sequences - this allows us to return + indices with a dictionary instead. + """ + + def __init__(self, path, dictionary): + super().__init__(path, cache_indices=True) + self.dictionary = dictionary + + def __getitem__(self, idx): + desc, seq = super().__getitem__(idx) + return self.dictionary.encode_line(seq, line_tokenizer=list).long() diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 12497989bb..5b6155a2df 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -12,6 +12,7 @@ import torch from . import FairseqDataset +from fairseq.data.fasta_dataset import FastaDataset def __best_fitting_dtype(vocab_size=None): @@ -22,7 +23,7 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['raw', 'lazy', 'cached', 'mmap'] + return ['raw', 'lazy', 'cached', 'mmap', 'fasta'] def infer_dataset_impl(path): @@ -37,6 +38,8 @@ def infer_dataset_impl(path): return 'mmap' else: return None + elif FastaDataset.exists(path): + return 'fasta' else: return None @@ -44,6 +47,8 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == 'mmap': return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + elif impl == 'fasta': + raise NotImplementedError else: return IndexedDatasetBuilder(out_file) @@ -58,6 +63,9 @@ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == 'mmap' and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path) + elif impl == 'fasta' and FastaDataset.exists(path): + from fairseq.data.fasta_dataset import EncodedFastaDataset + return EncodedFastaDataset(path, dictionary) return None From 0cde6b4e508ef82f7c7a4df01ab6231ce1d257ee Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 28 Aug 2020 10:09:43 -0700 Subject: [PATCH 133/707] Added shared dictionary check for translation_multi_simple_epoch task. Summary: translation_multi_simple_epoch task only supports shared dictionary across all languages, so add the check in the task setup. Reviewed By: pipibjc Differential Revision: D23288388 fbshipit-source-id: 4236a096bcb75429b486ef8a9244e3ef0d5095f0 --- .../tasks/translation_multi_simple_epoch.py | 19 +++++++++++-------- tests/test_binaries.py | 5 ++++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 63a3d3ab86..b10e696f9b 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -98,6 +98,15 @@ def setup_task(cls, args, **kwargs): langs, dicts, training = MultilingualDatasetManager.prepare( cls.load_dictionary, args, **kwargs ) + dict0 = None + for _, lang_dict in dicts.items(): + if dict0 is None: + dict0 = lang_dict + else: + assert ( + dict0 == lang_dict + ), "Diffrent dictionary are specified for different languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all languages" return cls(args, langs, dicts, training) def has_sharded_data(self, split): @@ -211,17 +220,11 @@ def max_positions(self): @property def source_dictionary(self): - if self.training: - return next(iter(self.dicts.values())) - else: - return self.dicts[self.args.source_lang] + return next(iter(self.dicts.values())) @property def target_dictionary(self): - if self.training: - return next(iter(self.dicts.values())) - else: - return self.dicts[self.args.target_lang] + return next(iter(self.dicts.values())) def create_batch_sampler_func( self, max_positions, ignore_invalid_inputs, diff --git a/tests/test_binaries.py b/tests/test_binaries.py index a6133b1b41..16887aacbd 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -217,7 +217,10 @@ def test_translation_multi_simple_epoch(self): dec_ltok_flag = decoder_langtok_flags[j] with tempfile.TemporaryDirectory(f'test_translation_multi_simple_epoch_{i}_{j}') as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir) + preprocess_translation_data( + data_dir, + extra_flags=['--joined-dictionary'] + ) train_translation_model( data_dir, arch='transformer', From 4bfd70d400c1b4c474ce617b1359a0250329d6f0 Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 28 Aug 2020 16:34:51 -0700 Subject: [PATCH 134/707] fix transformer lm defaults (#1258) Summary: these changes allow loading older checkpoints created before these flags were added Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1258 Reviewed By: edunov Differential Revision: D23388753 Pulled By: alexeib fbshipit-source-id: 12b48ebf1a36bd4b24034d8ed68398f6e1526052 --- fairseq/models/transformer_lm.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index b59363900e..e24452ff8a 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -111,16 +111,16 @@ def add_args(parser): parser.add_argument('--no-scale-embedding', action='store_true', help='if True, dont scale embeddings') # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) - parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, + parser.add_argument('--decoder-layerdrop', type=float, metavar='D', help='LayerDrop probability for decoder') - parser.add_argument('--decoder-layers-to-keep', default=None, + parser.add_argument('--decoder-layers-to-keep', help='which layers to *keep* when pruning as a comma-separated list') # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) - parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, + parser.add_argument('--quant-noise-pq', type=float, metavar='D', help='iterative PQ quantization noise at training time') - parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, + parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', help='block size of quantization noise at training time') - parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, + parser.add_argument('--quant-noise-scalar', type=float, metavar='D', help='scalar quantization noise and scalar quantization at training time') # fmt: on @@ -196,6 +196,12 @@ def base_lm_architecture(args): args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.activation_fn = getattr(args, 'activation_fn', 'relu') + args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0) + args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None) + args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0) + args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size', 8) + args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0) + args.add_bos_token = getattr(args, 'add_bos_token', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) From 0989eca746ada5a2439010ffb60b17efdc378270 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 28 Aug 2020 17:49:48 -0700 Subject: [PATCH 135/707] add more detailed logging for fp16 diverging Summary: We often get a generic "minimum loss scale reached" when fp16 training diverges. Would be useful to have a breakdown on where exactly the gradient norm becomes too big. Reviewed By: myleott Differential Revision: D23297774 fbshipit-source-id: 69da1cca1be22f15af633f8efe4e7b491cf4f6f9 --- fairseq/nan_detector.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index 789169d2b0..df4e28ec89 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -19,6 +19,7 @@ def __init__(self, model, forward=True, backward=True): self.fhooks = [] self.forward = forward self.backward = backward + self.model = model self.reset() for name, mod in model.named_modules(): @@ -29,6 +30,19 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): + # Dump out all model gnorms to enable better debugging + norm = {} + gradients = {} + for name, param in self.model.named_parameters(): + grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) + norm[name] = grad_norm.item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + self.close() def add_hooks(self, module): From fe1b1bbe17a5fa6c2b3505735deb3c61fe5b68cc Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 31 Aug 2020 11:28:14 -0700 Subject: [PATCH 136/707] Misc fixes (#2524) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2524 Reviewed By: ngoyal2707 Differential Revision: D23318746 Pulled By: myleott fbshipit-source-id: 6db6a87aac178847bd0da26db09b1a63632a724f --- fairseq/benchmark/dummy_mt.py | 24 ++++++++-------- fairseq/data/denoising_dataset.py | 4 +-- fairseq/distributed_utils.py | 4 +-- .../model_parallel/models/roberta/model.py | 7 ++++- fairseq/models/lstm.py | 2 +- fairseq/models/roberta/model.py | 6 +++- fairseq/optim/adafactor.py | 6 ++-- fairseq/optim/fp16_optimizer.py | 3 +- fairseq_cli/eval_lm.py | 2 +- fairseq_cli/generate.py | 2 +- fairseq_cli/interactive.py | 2 +- fairseq_cli/preprocess.py | 2 +- fairseq_cli/train.py | 28 +++++++++++++------ fairseq_cli/validate.py | 3 +- 14 files changed, 58 insertions(+), 37 deletions(-) diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 09f2f0c119..9fba9bb520 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -23,9 +23,8 @@ def add_args(parser): """Add task-specific arguments to the parser.""" parser.add_argument('--dict-size', default=49996, type=int) parser.add_argument('--dataset-size', default=100000, type=int) - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments ' - 'per sample for BERT dataset') + parser.add_argument('--src-len', default=30, type=int) + parser.add_argument('--tgt-len', default=30, type=int) def __init__(self, args, dictionary): super().__init__(args) @@ -34,10 +33,8 @@ def __init__(self, args, dictionary): dictionary.pad_to_multiple_(8) # often faster if divisible by 8 - seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 - - self.dummy_src = seq[:-1] - self.dummy_tgt = seq[1:] + self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1 + self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1 @classmethod def setup_task(cls, args, **kwargs): @@ -46,6 +43,10 @@ def setup_task(cls, args, **kwargs): for i in range(args.dict_size): dictionary.add_symbol('word{}'.format(i)) logger.info('dictionary: {} types'.format(len(dictionary))) + + args.max_source_positions = args.src_len + dictionary.pad() + 2 + args.max_target_positions = args.tgt_len + dictionary.pad() + 2 + return cls(args, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): @@ -53,10 +54,11 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ + item_size = max(self.args.src_len, self.args.tgt_len) if self.args.max_sentences is not None: bsz = self.args.max_sentences else: - bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + bsz = max(1, self.args.max_tokens // item_size) tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) self.datasets[split] = DummyDataset( { @@ -64,16 +66,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): 'net_input': { 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), 'src_lengths': torch.full( - (bsz, ), self.args.tokens_per_sample, dtype=torch.long + (bsz, ), self.args.src_len, dtype=torch.long ), 'prev_output_tokens': tgt.clone(), }, 'target': tgt, 'nsentences': bsz, - 'ntokens': bsz * self.args.tokens_per_sample, + 'ntokens': bsz * self.args.tgt_len, }, num_items=self.args.dataset_size, - item_size=self.args.tokens_per_sample, + item_size=item_size, ) @property diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index 8dc240c1eb..c55ce1ba49 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -210,7 +210,7 @@ def permute_sentences(self, source, p=1.0): full_stops[-2] = 1 # Tokens that are full stops, where the previous token is not - sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2 + sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 result = source.clone() num_sentences = sentence_ends.size(0) @@ -271,7 +271,7 @@ def add_whole_word_mask(self, source, p): else: lengths = torch.ones((num_to_mask,)).long() assert is_word_start[-1] == 0 - word_starts = is_word_start.nonzero() + word_starts = is_word_start.nonzero(as_tuple=False) indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1) mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 7ee89adce9..16f3edaeef 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -115,9 +115,7 @@ def distributed_init(args): xm.rendezvous('distributed_init') # wait for all workers xm.mark_step() - if is_master(args): - logging.getLogger().setLevel(logging.INFO) - else: + if not is_master(args): logging.getLogger().setLevel(logging.WARNING) if args.model_parallel_size > 1: diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index e0ae4a2c8f..6ba097b14d 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -69,6 +69,11 @@ def build_model(cls, args, task): if not hasattr(args, 'max_positions'): args.max_positions = args.tokens_per_sample + if getattr(args, 'untie_weights_roberta', False): + raise NotImplementedError( + '--untie-weights-roberta is not supported in model parallel mode' + ) + encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) return cls(args, encoder) @@ -127,7 +132,7 @@ def forward(self, features, masked_tokens=None, **kwargs): x = self.activation_fn(x) x = self.layer_norm(x) - features = copy_to_model_parallel_region(features) + x = copy_to_model_parallel_region(x) # project back to size of vocabulary with bias x = F.linear(x, self.weight) x = gather_from_model_parallel_region(x).contiguous() diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 850428a32d..72bd815bcc 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -268,7 +268,7 @@ def forward( # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence( - x, src_lengths.data, enforce_sorted=enforce_sorted + x, src_lengths.cpu(), enforce_sorted=enforce_sorted ) # apply LSTM diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 2303fbe26e..e9008076b3 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -314,7 +314,11 @@ def __init__(self, args, dictionary): embed_dim=args.encoder_embed_dim, output_dim=len(dictionary), activation_fn=args.activation_fn, - weight=self.sentence_encoder.embed_tokens.weight if not args.untie_weights_roberta else None, + weight=( + self.sentence_encoder.embed_tokens.weight + if not args.untie_weights_roberta + else None + ), ) def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 0456f7d61d..b0fb3a9f5e 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -138,9 +138,9 @@ def _rms(self, tensor): def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = ( exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) - ).rsqrt_() - c_factor = exp_avg_sq_col.rsqrt() - return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) + ).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) def step(self, closure=None): """Performs a single optimization step. diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 30d486c393..960c3a67eb 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -17,6 +17,7 @@ class _FP16OptimizerMixin(object): def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) + self._multiply_factor = 1. @property def has_flat_params(self): @@ -135,7 +136,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): self._multiply_factor *= max_norm / grad_norm self.scaler.check_overflow(grad_norm) - else: + elif max_norm > 0.0: clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) self._multiply_factor *= clip_coef diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 1c12ad6091..e59674b530 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -25,7 +25,7 @@ logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, + level=os.environ.get('LOGLEVEL', 'INFO').upper(), ) logger = logging.getLogger('fairseq_cli.eval_lm') diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index f4c86f6479..a6e48f927e 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -49,7 +49,7 @@ def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, + level=os.environ.get('LOGLEVEL', 'INFO').upper(), stream=output_file, ) logger = logging.getLogger('fairseq_cli.generate') diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 032966051a..f8ee0197dd 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -27,7 +27,7 @@ logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, + level=os.environ.get('LOGLEVEL', 'INFO').upper(), stream=sys.stdout, ) logger = logging.getLogger('fairseq_cli.interactive') diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index b107b9fa18..3fe5131324 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -23,7 +23,7 @@ logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, + level=os.environ.get('LOGLEVEL', 'INFO').upper(), stream=sys.stdout, ) logger = logging.getLogger('fairseq_cli.preprocess') diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 72e95c917f..05cffd5a7e 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -10,6 +10,7 @@ import argparse import logging import math +import os import random import sys @@ -32,7 +33,7 @@ logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger("fairseq_cli.train") @@ -229,16 +230,26 @@ def train(args, trainer, task, epoch_itr): def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): num_updates = trainer.get_num_updates() + max_update = args.max_update or math.inf do_save = ( - args.save_interval_updates > 0 - and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates - ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + or num_updates >= max_update + or ( + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates >= args.validate_after_updates + ) + ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - or (args.validate_interval_updates > 0 and num_updates > 0 and num_updates % args.validate_interval_updates == 0) + or num_updates >= max_update + or ( + args.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % args.validate_interval_updates == 0 + ) ) and not args.disable_validation # Validate @@ -247,10 +258,9 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) # Stopping conditions - max_update = args.max_update or math.inf should_stop = ( should_stop_early(args, valid_losses[0]) - or trainer.get_num_updates() >= max_update + or num_updates >= max_update or ( args.stop_time_hours > 0 and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index b339a056a0..304aecee9e 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -7,6 +7,7 @@ from itertools import chain import logging +import os import sys import torch @@ -18,7 +19,7 @@ logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, + level=os.environ.get('LOGLEVEL', 'INFO').upper(), stream=sys.stdout, ) logger = logging.getLogger('fairseq_cli.validate') From 251c86940600b932f9925d630e1a77b55cc24d02 Mon Sep 17 00:00:00 2001 From: Mu Tian Date: Mon, 31 Aug 2020 23:02:05 -0700 Subject: [PATCH 137/707] hydra fairseq - add yaml files Summary: hydra fairseq - add yaml files Reviewed By: alexeib Differential Revision: D22403786 fbshipit-source-id: 81fb5902c1fbcf7b03d111037327ab0f8bfb57f2 --- config/config.yaml | 7 ++ config/config_eval_lm.yaml | 7 ++ config/criterion/adaptive_loss.yaml | 3 + config/criterion/cross_entropy.yaml | 3 + config/lr_scheduler/cosine.yaml | 7 ++ config/lr_scheduler/inverse_sqrt.yaml | 3 + config/model/transformer_lm.yaml | 36 ++++++ config/model/transformer_lm_baevski_gbw.yaml | 36 ++++++ .../model/transformer_lm_baevski_wiki103.yaml | 36 ++++++ config/model/transformer_lm_big.yaml | 36 ++++++ config/model/transformer_lm_gbw.yaml | 36 ++++++ config/model/transformer_lm_gpt.yaml | 36 ++++++ config/model/transformer_lm_gpt2_big.yaml | 36 ++++++ config/model/transformer_lm_gpt2_medium.yaml | 36 ++++++ config/model/transformer_lm_gpt2_small.yaml | 36 ++++++ config/model/transformer_lm_wiki103.yaml | 36 ++++++ config/optimizer/adam.yaml | 5 + config/optimizer/nag.yaml | 3 + config/params/eval_lm_params.yaml | 106 ++++++++++++++++++ config/params/training_params.yaml | 96 ++++++++++++++++ config/task/language_modeling.yaml | 10 ++ 21 files changed, 610 insertions(+) create mode 100644 config/config.yaml create mode 100644 config/config_eval_lm.yaml create mode 100644 config/criterion/adaptive_loss.yaml create mode 100644 config/criterion/cross_entropy.yaml create mode 100644 config/lr_scheduler/cosine.yaml create mode 100644 config/lr_scheduler/inverse_sqrt.yaml create mode 100644 config/model/transformer_lm.yaml create mode 100644 config/model/transformer_lm_baevski_gbw.yaml create mode 100644 config/model/transformer_lm_baevski_wiki103.yaml create mode 100644 config/model/transformer_lm_big.yaml create mode 100644 config/model/transformer_lm_gbw.yaml create mode 100644 config/model/transformer_lm_gpt.yaml create mode 100644 config/model/transformer_lm_gpt2_big.yaml create mode 100644 config/model/transformer_lm_gpt2_medium.yaml create mode 100644 config/model/transformer_lm_gpt2_small.yaml create mode 100644 config/model/transformer_lm_wiki103.yaml create mode 100644 config/optimizer/adam.yaml create mode 100644 config/optimizer/nag.yaml create mode 100644 config/params/eval_lm_params.yaml create mode 100644 config/params/training_params.yaml create mode 100644 config/task/language_modeling.yaml diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000000..66723e706c --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,7 @@ +defaults: + - params: training_params + - task: language_modeling + - model: transformer_lm + - criterion: cross_entropy + - optimizer: adam + - lr_scheduler: inverse_sqrt diff --git a/config/config_eval_lm.yaml b/config/config_eval_lm.yaml new file mode 100644 index 0000000000..5a93cb5d92 --- /dev/null +++ b/config/config_eval_lm.yaml @@ -0,0 +1,7 @@ +defaults: + - params: eval_lm_params + - task: language_modeling + - model: transformer_lm + - criterion: cross_entropy + - optimizer: adam + - lr_scheduler: inverse_sqrt diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml new file mode 100644 index 0000000000..a85a7eed1c --- /dev/null +++ b/config/criterion/adaptive_loss.yaml @@ -0,0 +1,3 @@ +# @package _group_ +sentence_avg: ${params.optimization.sentence_avg} +ddp_backend: ${params.distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml new file mode 100644 index 0000000000..a85a7eed1c --- /dev/null +++ b/config/criterion/cross_entropy.yaml @@ -0,0 +1,3 @@ +# @package _group_ +sentence_avg: ${params.optimization.sentence_avg} +ddp_backend: ${params.distributed_training.ddp_backend} diff --git a/config/lr_scheduler/cosine.yaml b/config/lr_scheduler/cosine.yaml new file mode 100644 index 0000000000..0f91e0d240 --- /dev/null +++ b/config/lr_scheduler/cosine.yaml @@ -0,0 +1,7 @@ +# @package _group_ +warmup_updates: 0 +warmup_init_lr: -1 +max_lr: 1.0 +t_mult: 1.0 +lr_period_updates: -1 +lr_shrink: 0.1 diff --git a/config/lr_scheduler/inverse_sqrt.yaml b/config/lr_scheduler/inverse_sqrt.yaml new file mode 100644 index 0000000000..0eac7d88eb --- /dev/null +++ b/config/lr_scheduler/inverse_sqrt.yaml @@ -0,0 +1,3 @@ +# @package _group_ +warmup_updates: 4000 +warmup_init_lr: -1 diff --git a/config/model/transformer_lm.yaml b/config/model/transformer_lm.yaml new file mode 100644 index 0000000000..3837ea54e1 --- /dev/null +++ b/config/model/transformer_lm.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 2048 +decoder_layers: 6 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_baevski_gbw.yaml b/config/model/transformer_lm_baevski_gbw.yaml new file mode 100644 index 0000000000..30b1a4f1e0 --- /dev/null +++ b/config/model/transformer_lm_baevski_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_baevski_wiki103.yaml b/config/model/transformer_lm_baevski_wiki103.yaml new file mode 100644 index 0000000000..1154cfa660 --- /dev/null +++ b/config/model/transformer_lm_baevski_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_big.yaml b/config/model/transformer_lm_big.yaml new file mode 100644 index 0000000000..309575310b --- /dev/null +++ b/config/model/transformer_lm_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gbw.yaml b/config/model/transformer_lm_gbw.yaml new file mode 100644 index 0000000000..30b1a4f1e0 --- /dev/null +++ b/config/model/transformer_lm_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt.yaml b/config/model/transformer_lm_gpt.yaml new file mode 100644 index 0000000000..2c6cb7be38 --- /dev/null +++ b/config/model/transformer_lm_gpt.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 768 +decoder_output_dim: 768 +decoder_input_dim: 768 +decoder_ffn_embed_dim: 3072 +decoder_layers: 12 +decoder_attention_heads: 12 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_big.yaml b/config/model/transformer_lm_gpt2_big.yaml new file mode 100644 index 0000000000..a08769a178 --- /dev/null +++ b/config/model/transformer_lm_gpt2_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1600 +decoder_output_dim: 1600 +decoder_input_dim: 1600 +decoder_ffn_embed_dim: 6400 +decoder_layers: 48 +decoder_attention_heads: 25 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_medium.yaml b/config/model/transformer_lm_gpt2_medium.yaml new file mode 100644 index 0000000000..64261d793c --- /dev/null +++ b/config/model/transformer_lm_gpt2_medium.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1280 +decoder_output_dim: 1280 +decoder_input_dim: 1280 +decoder_ffn_embed_dim: 5120 +decoder_layers: 36 +decoder_attention_heads: 20 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_small.yaml b/config/model/transformer_lm_gpt2_small.yaml new file mode 100644 index 0000000000..702e81f466 --- /dev/null +++ b/config/model/transformer_lm_gpt2_small.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 24 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_wiki103.yaml b/config/model/transformer_lm_wiki103.yaml new file mode 100644 index 0000000000..1154cfa660 --- /dev/null +++ b/config/model/transformer_lm_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/optimizer/adam.yaml b/config/optimizer/adam.yaml new file mode 100644 index 0000000000..e5264f895e --- /dev/null +++ b/config/optimizer/adam.yaml @@ -0,0 +1,5 @@ +# @package _group_ +adam_betas: "(0.9, 0.999)" +adam_eps: 1.0e-8 +weight_decay: 0 +use_old_adam: false diff --git a/config/optimizer/nag.yaml b/config/optimizer/nag.yaml new file mode 100644 index 0000000000..4ab2745686 --- /dev/null +++ b/config/optimizer/nag.yaml @@ -0,0 +1,3 @@ +# @package _group_ +momentum: 0.99 +weight_decay: 0.0 diff --git a/config/params/eval_lm_params.yaml b/config/params/eval_lm_params.yaml new file mode 100644 index 0000000000..4a0259bca6 --- /dev/null +++ b/config/params/eval_lm_params.yaml @@ -0,0 +1,106 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + fp16: false + memory_efficient_fp16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + checkpoint_suffix: "" + quantization_config_path: null +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + max_sentences: null + batch_size: ${params.dataset.max_sentences} + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${params.dataset.max_tokens} + max_sentences_valid: ${params.dataset.max_sentences} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [1] + lr: [0.25] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 +common_eval: + path: null + remove_bpe: null + quiet: false + model_overrides: "{}" + results_path: null +eval_lm: + output_word_probs: false + output_word_stats: false + context_window: 0 +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false diff --git a/config/params/training_params.yaml b/config/params/training_params.yaml new file mode 100644 index 0000000000..3d52a82ac4 --- /dev/null +++ b/config/params/training_params.yaml @@ -0,0 +1,96 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + fp16: false + memory_efficient_fp16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + checkpoint_suffix: "" + quantization_config_path: null +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + max_sentences: null + batch_size: ${params.dataset.max_sentences} + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${params.dataset.max_tokens} + max_sentences_valid: ${params.dataset.max_sentences} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [1] + lr: [0.25] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false diff --git a/config/task/language_modeling.yaml b/config/task/language_modeling.yaml new file mode 100644 index 0000000000..58a2ad1358 --- /dev/null +++ b/config/task/language_modeling.yaml @@ -0,0 +1,10 @@ +# @package _group_ +data: ??? +sample_break_mode: "none" +tokens_per_sample: 1024 +output_dictionary_size: -1 +self_target: false +future_target: false +past_target: false +add_bos_token: false +max_target_positions: null From 5d7ed6ab4f92d20ad10f8f792b8703e260a938ac Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines Date: Tue, 1 Sep 2020 18:16:15 -0700 Subject: [PATCH 138/707] Initial support for ZeRO optimizer state sharding (#1259) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: FairseqOSS will work with any optimizer and dtype. TODO(future PR): * support reduce instead of all_reduce * support gradient sharding * support parameter sharding Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1259 Test Plan: Verified that checkpoint save and restore work. Verified that grad_norm, loss, and ppl are identical with and without sharding enable. Before: $ fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 ... 2020-08-27 22:24:51 | INFO | train_inner | epoch 001: 49 / 394 loss=18.84, ppl=469411, wps=269226, ups=1.03, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68 2020-08-27 22:24:52 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=256992, ups=0.98, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69 2020-08-27 22:24:53 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=259178, ups=0.99, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70 2020-08-27 22:24:54 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=257710, ups=0.98, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71 2020-08-27 22:24:55 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403794, wps=269279, ups=1.03, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72 2020-08-27 22:24:56 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=264616, ups=1.01, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73 2020-08-27 22:24:56 | INFO | fairseq_cli.train | begin save checkpoint 2020-08-27 22:24:56 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) 2020-08-27 22:24:56 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 264825 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73 2020-08-27 22:24:56 | INFO | fairseq_cli.train | done training in 72.2 seconds After: $ fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --zero-sharding os ... 2020-08-27 22:22:55 | INFO | train_inner | epoch 001: 49 / 394 loss=18.84, ppl=469411, wps=267663, ups=1.02, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68 2020-08-27 22:22:56 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=252797, ups=0.96, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69 2020-08-27 22:22:57 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=267692, ups=1.02, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70 2020-08-27 22:22:58 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=267507, ups=1.02, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71 2020-08-27 22:22:59 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403794, wps=254410, ups=0.97, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72 2020-08-27 22:23:00 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=268234, ups=1.02, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73 2020-08-27 22:23:00 | INFO | fairseq_cli.train | begin save checkpoint 2020-08-27 22:23:00 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) 2020-08-27 22:23:00 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 263570 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73 2020-08-27 22:23:00 | INFO | fairseq_cli.train | done training in 72.3 seconds # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D23432082 Pulled By: msbaines fbshipit-source-id: 6a020b25e36a3d9283582b7d89a6a53038e5b181 --- fairseq/checkpoint_utils.py | 7 +++++- fairseq/optim/__init__.py | 2 ++ fairseq/optim/fairseq_optimizer.py | 9 ++++++++ fairseq/optim/fp16_optimizer.py | 35 +++++++++++++++++++----------- fairseq/optim/shard.py | 33 ++++++++++++++++++++++++++++ fairseq/options.py | 4 ++++ fairseq/trainer.py | 17 +++++++++++++++ 7 files changed, 93 insertions(+), 14 deletions(-) create mode 100644 fairseq/optim/shard.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 20891b5f30..3b9e6bfd27 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -32,7 +32,12 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) - if args.no_save or not trainer.is_data_parallel_master: + if args.no_save: + return + + trainer.consolidate_optimizer() + + if not trainer.is_data_parallel_master: return def is_better(a, b): diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 273aa5e8f6..dff140d580 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -10,12 +10,14 @@ from fairseq.optim.fairseq_optimizer import FairseqOptimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.shard import shard_ __all__ = [ 'FairseqOptimizer', 'FP16Optimizer', 'MemoryEfficientFP16Optimizer', + 'shard_', ] diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index b1b9c76edb..e00a04dd1b 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -28,6 +28,15 @@ def optimizer(self): raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') return self._optimizer + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, '_optimizer'): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') + self._optimizer = optimizer + @property def optimizer_config(self): """ diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 960c3a67eb..777d43a713 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -232,6 +232,10 @@ def build_optimizer(cls, args, params): def optimizer(self): return self.fp32_optimizer.optimizer + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + @property def optimizer_config(self): return self.fp32_optimizer.optimizer_config @@ -279,19 +283,20 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): # params are FP16 while the optimizer state is FP32 and we don't want # to cast. A workaround is to manually copy back the original state # after the optimizer has been loaded. - groups = self.optimizer.param_groups - saved_groups = state_dict['param_groups'] - id_map = { - old_id: p - for old_id, p in zip( - chain(*(g['params'] for g in saved_groups)), - chain(*(g['params'] for g in groups)) - ) - } - for k, v in state_dict['state'].items(): - if k in id_map: - param = id_map[k] - self.optimizer.state[param] = v + if not getattr(self.optimizer, 'disable_mem_eff_fp16_loading_hack', False): + groups = self.optimizer.param_groups + saved_groups = state_dict['param_groups'] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g['params'] for g in saved_groups)), + chain(*(g['params'] for g in groups)) + ) + } + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v def backward(self, loss): """Computes the sum of gradients of the given tensor w.r.t. graph leaves. @@ -412,6 +417,10 @@ def build_optimizer(cls, args, params): def optimizer(self): return self.wrapped_optimizer.optimizer + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + @property def optimizer_config(self): return self.wrapped_optimizer.optimizer_config diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py new file mode 100644 index 0000000000..4f35dbda47 --- /dev/null +++ b/fairseq/optim/shard.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +try: + from fairscale.optim import OSS + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(args, optimizer): + if not _has_fairscale: + raise ImportError( + '\n\nPlease install the fairscale package:' + '\n\n pip install fairscale' + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name)) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, **optimizer.optimizer_config) diff --git a/fairseq/options.py b/fairseq/options.py index 74c1499e06..01150bda4a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -448,6 +448,10 @@ def add_distributed_training_args(parser, default_world_size=None): help='number of GPUs in each node. An allreduce operation across GPUs in ' 'a node is very fast. Hence, we do allreduce across GPUs in a node, ' 'and gossip across different nodes') + # Add argument for ZeRO sharding of OptimizerState(os), gradients(g) and parameters(p) + group.add_argument('--zero-sharding', default='none', type=str, + choices=['none', 'os'], + help='ZeRO sharding') # fmt: on return group diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 6cd73a631a..5022ceea2d 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -214,11 +214,28 @@ def _build_optimizer(self): if self.args.use_bmuf: self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) + if self.args.zero_sharding == 'os': + if (self.args.fp16 + and not self.args.memory_efficient_fp16 + and not self.args.memory_efficient_bf16 + ) and not self.args.fp16_no_flatten_grads: + raise ValueError( + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" + ) + else: + optim.shard_(self.args, self._optimizer) + # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self._lr_scheduler.step_update(0) + def consolidate_optimizer(self): + """For OSS, we need to consolidate the state dict.""" + if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): + self.optimizer.optimizer.consolidate_state_dict() + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint From d8e2cf28e22584a3e7606783e7abf38135a05eff Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 3 Sep 2020 20:36:09 -0700 Subject: [PATCH 139/707] wav2vec 2.0 fixes (#1266) Summary: Fixes #2563 (input quantizer) Fixes #2538 (documentation for fine-tuning) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1266 Reviewed By: aconneau, vineelpratap Differential Revision: D23521820 Pulled By: alexeib fbshipit-source-id: 61f2c9baf126554dcf5b7a315e4f4f54577c24bf --- examples/wav2vec/README.md | 5 ++++ fairseq/models/wav2vec/wav2vec2.py | 47 ++++++++++++++++++------------ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 2e59798ead..d849dde85d 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -116,6 +116,11 @@ Note: you can simulate 24 GPUs by using k GPUs and setting --update-freq 24/k Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). Alternatively, simply omit the --wer-args flag. +For hyper-parameters to fine-tune other Librispeech splits (10 minutes, 1 hour, etc) please refer to the table in Appendix B in the wav2vec 2.0 paper. +The main changes to make are adjusting --max-update, and then adjusting --warmup-steps, --hold-steps, and --decay steps so that they use 0.1/0.4/0.5 of max-update respectively. You then need to adjust --mask-prob and --mask-channel-prob. This should be set to the mask-length * x where x is the number in the table and mask-length is what you use for --mask-length (10 in this example. Use --mask-channel-length value for --mask-channel-prob). + +For example, for 10 hours, we see in the paper that timestep mask prob should be 0.065, so we set --mask-prob to 10* 0.065 = 0.65. channel mask prob is 0.004, so we set it to 64 * 0.004 = 0.256. then we set --max-updates to 20000 and change --warmup-steps to 20000 * 0.1 = 2000, --hold-steps to 8000 and --decay-steps to 10000. + ### Evaluating a CTC model: Evaluating a CTC model with a language model requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings) to be installed. diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index ea6f901020..4f1ab2277f 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -131,6 +131,12 @@ def add_args(parser): "--quantize-input", action="store_true", help="use quantized inputs" ) + parser.add_argument( + "--same-quantizer", + action="store_true", + help="use same quantizer for inputs and targets", + ) + parser.add_argument( "--feature-grad-mult", type=float, @@ -342,23 +348,6 @@ def __init__(self, args): self.logit_temp = args.logit_temp - if args.quantize_input: - vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim - self.input_quantizer = ( - GumbelVectorQuantizer( - dim=args.encoder_embed_dim, - num_vars=args.latent_vars, - temp=eval(args.latent_temp), - groups=args.latent_groups, - combine_groups=False, - vq_dim=vq_dim, - time_first=True, - ) - if not args.same_quantizer - else self.quantizer - ) - self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) - final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim if args.quantize_targets: @@ -376,6 +365,25 @@ def __init__(self, args): else: self.project_q = nn.Linear(self.embed, final_dim) + if args.quantize_input: + if args.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = ( + args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim + ) + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + self.mask_emb = nn.Parameter( torch.FloatTensor(args.encoder_embed_dim).uniform_() ) @@ -564,7 +572,9 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): if mask: x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: - y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) else: y = unmasked_features else: @@ -981,6 +991,7 @@ def base_architecture(args): args.quantize_targets = getattr(args, "quantize_targets", False) args.quantize_input = getattr(args, "quantize_input", False) + args.same_quantizer = getattr(args, "same_quantizer", False) args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) From ac730aec61523031016afc8922c04bfc7a7ac42b Mon Sep 17 00:00:00 2001 From: Juan Miguel Pino Date: Thu, 3 Sep 2020 21:17:41 -0700 Subject: [PATCH 140/707] Remove `BeamContainer` from sequence generator (#2567) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2567 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1263 The PyTorch Mobile liter interpreter does not support object creation (CREATE_OBJECT). I tried using `sorted` and `lambda` to replace the sorting code but TorchScript does not support lambda functions. This change enables to save a scripted model to the lite interpreter. I can make as much testing as possible and push this as a core change, that would benefit on-device MT and on-device sequence modeling in general. Or I'm happy to make this an experimental change with an `fb_` file. Reviewed By: myleott, jhcross Differential Revision: D23440771 fbshipit-source-id: c6e5381159857f613b143935b806c5f89464a33b --- fairseq/sequence_generator.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 6cc4a13e20..6cfcc90baf 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -491,16 +491,10 @@ def _generate( # sort by score descending for sent in range(len(finalized)): - # make into beam container - BCList = [ - BeamContainer(elem["score"].item(), elem) for elem in finalized[sent] - ] - BCList.sort() - BCList.reverse() - finalized[sent] = torch.jit.annotate( - List[Dict[str, Tensor]], [x.elem for x in BCList] - ) - + scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]]) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent]) return finalized def _prefix_tokens( @@ -966,17 +960,3 @@ def forward_align(self, src_tokens, src_lengths, prev_output_tokens): if len(self.models) > 1: avg_attn.div_(len(self.models)) return avg_attn - - -@torch.jit.script -class BeamContainer(object): - def __init__(self, score: float, elem: Dict[str, Tensor]): - self.score = score - self.elem = elem - - def __lt__(self, other): - # type: (BeamContainer) -> bool - # Due to https://github.com/pytorch/pytorch/issues/20388, - # this has to use old style type annotations - # Match original behavior of sorted function when two scores are equal. - return self.score <= other.score From 7f4c7481a54a9d1da0a3c65084cb79c70196ddd2 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Fri, 4 Sep 2020 04:45:03 -0700 Subject: [PATCH 141/707] Apply Black auto-formatter to multilingual_data_manager.py Reviewed By: jmp84 Differential Revision: D23496278 fbshipit-source-id: 60b77a17a227e3a6e547600fabc0cd49c4f8543d --- .../multilingual/multilingual_data_manager.py | 814 +++++++++++------- 1 file changed, 509 insertions(+), 305 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 0d02ac1e0a..01ba6ece8f 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -4,53 +4,50 @@ # LICENSE file in the root directory of this source tree. import itertools +import json import logging import os - -import numpy as np from collections import OrderedDict -import json +import numpy as np from fairseq import options, utils -from fairseq.options import eval_str_dict, csv_str_list - from fairseq.data import ( - Dictionary, AppendTokenDataset, ConcatDataset, - data_utils, - indexed_dataset, + Dictionary, LanguagePairDataset, PrependTokenDataset, - StripTokenDataset, - TruncateDataset, SampledMultiDataset, - TransformEosLangPairDataset, SampledMultiEpochDataset, + StripTokenDataset, + TransformEosLangPairDataset, + TruncateDataset, + data_utils, + indexed_dataset, ) from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat from fairseq.file_io import PathManager +from fairseq.options import csv_str_list, eval_str_dict + logger = logging.getLogger(__name__) -def _lang_token(lang: str, style='__{}__'): +def _lang_token(lang: str, style="__{}__"): return style.format(lang) -def _lang_token_index(dic: Dictionary, lang: str, style='__{}__'): +def _lang_token_index(dic: Dictionary, lang: str, style="__{}__"): """Return language token index.""" idx = dic.index(_lang_token(lang, style)) - assert idx != dic.unk_index, \ - 'cannot find language token for lang {}'.format(lang) + assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) return idx def _lang_id(dic: Dictionary, lang: str): """Return language ID index.""" idx = dic.index(lang) - assert idx != dic.unk_index, \ - 'cannot find language ID for lang {}'.format(lang) + assert idx != dic.unk_index, "cannot find language ID for lang {}".format(lang) return idx @@ -76,121 +73,215 @@ def __init__(self, args, lang_pairs, langs, dicts, sampling_method): @classmethod def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): - return MultilingualDatasetManager(args, lang_pairs, langs, dicts, sampling_method) + return MultilingualDatasetManager( + args, lang_pairs, langs, dicts, sampling_method + ) @staticmethod def add_args(parser): - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') - parser.add_argument('--langs', default=None, type=csv_str_list, - help='a list of languages comma sperated languages which can appear in lang-pairs; ' - 'note that the ordering determines language token IDs', - ) - parser.add_argument('--lang-dict', default=None, type=str, - help='an external file which contains a list of ' - 'languages which can appear in lang-pairs; ' - 'note that the ordering determines language token IDs; ' - '--langs and --lang-dict are two exclusive options') - parser.add_argument('--lang-tok-style', default='multilingual', - type=str, choices=['multilingual', 'mbart'], - help='language token styles') - - parser.add_argument('--load-alignments', action='store_true', - help='load the binarized alignments') - parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', - help='pad the source on the left') - parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', - help='pad the target on the left') - parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') - parser.add_argument('--upsample-primary', default=1, type=int, - help='amount to upsample primary dataset') - parser.add_argument('--truncate-source', action='store_true', default=False, - help='truncate source to max-source-positions') - parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], - metavar='SRCTGT', - help='prepend to the beginning of source sentence the source or target ' - 'language token. (src/tgt)') - parser.add_argument('--decoder-langtok', action='store_true', - help='prepend to the beginning of target sentence the target language token') - parser.add_argument('--lang-tok-replacing-bos-eos', action='store_true', default=False) - parser.add_argument('--enable-lang-ids', default=False, action='store_true', - help='whether to include language IDs in samples') - parser.add_argument('--enable-reservsed-directions-shared-datasets', default=False, action='store_true', - help='whether to allow datasets be used in reversed directions') - - parser.add_argument('--extra-data', help='a dictionary of data name to this path, \ + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--langs", + default=None, + type=csv_str_list, + help="a list of languages comma sperated languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs", + ) + parser.add_argument( + "--lang-dict", + default=None, + type=str, + help="an external file which contains a list of " + "languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs; " + "--langs and --lang-dict are two exclusive options", + ) + parser.add_argument( + "--lang-tok-style", + default="multilingual", + type=str, + choices=["multilingual", "mbart"], + help="language token styles", + ) + + parser.add_argument( + "--load-alignments", + action="store_true", + help="load the binarized alignments", + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument( + "--upsample-primary", + default=1, + type=int, + help="amount to upsample primary dataset", + ) + parser.add_argument( + "--truncate-source", + action="store_true", + default=False, + help="truncate source to max-source-positions", + ) + parser.add_argument( + "--encoder-langtok", + default=None, + type=str, + choices=["src", "tgt"], + metavar="SRCTGT", + help="prepend to the beginning of source sentence the source or target " + "language token. (src/tgt)", + ) + parser.add_argument( + "--decoder-langtok", + action="store_true", + help="prepend to the beginning of target sentence the target language token", + ) + parser.add_argument( + "--lang-tok-replacing-bos-eos", action="store_true", default=False + ) + parser.add_argument( + "--enable-lang-ids", + default=False, + action="store_true", + help="whether to include language IDs in samples", + ) + parser.add_argument( + "--enable-reservsed-directions-shared-datasets", + default=False, + action="store_true", + help="whether to allow datasets be used in reversed directions", + ) + + parser.add_argument( + "--extra-data", + help='a dictionary of data name to this path, \ e.g. {"mined", path_to_mined_data, "denoised": path_to_denoised_data}', - type=lambda uf: eval_str_dict(uf, type=str), - default=None) - parser.add_argument('--extra-lang-pairs', help='a dictionary of data name to the language pairs they serve, \ + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--extra-lang-pairs", + help='a dictionary of data name to the language pairs they serve, \ e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}', - type=lambda uf: eval_str_dict(uf, type=str), - default=None) - parser.add_argument('--langtoks-specs', - help='a list of comma separated data types that a set of language tokens to be specialized for, \ + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--langtoks-specs", + help='a list of comma separated data types that a set of language tokens to be specialized for, \ e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \ distinguish languages in different training data types. If not specified, default language \ tokens per languages will be added', - default='main', - type=csv_str_list, - ) - parser.add_argument('--langtoks', help='a dictionary of how to add language tokens, \ + default="main", + type=csv_str_list, + ) + parser.add_argument( + "--langtoks", + help='a dictionary of how to add language tokens, \ e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \ ("src", "tgt")}, or {"mined": ("src.mined", "tgt")}', - default=None, - type=lambda uf: eval_str_dict(uf, type=str), - ) - parser.add_argument('--sampling-weights-from-file', - help='a file contain a python dictionary of how to sample data sets, \ + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--sampling-weights-from-file", + help='a file contain a python dictionary of how to sample data sets, \ e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', - default=None, type=str, - ) - parser.add_argument('--sampling-weights', help='a dictionary of how to sample data sets, \ + default=None, + type=str, + ) + parser.add_argument( + "--sampling-weights", + help='a dictionary of how to sample data sets, \ e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', - default=None, - type=lambda uf: eval_str_dict(uf, type=str), - ) - parser.add_argument('--virtual-epoch-size', default=1000000, type=int, - help='virtual epoch size to speed up data loading') - parser.add_argument('--virtual-data-size', default=None, type=int, - help='virtual data size of the whole joint dataset to speed' - 'up data loading and have specific dynamic sampling strategy interval') + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--virtual-epoch-size", + default=1000000, + type=int, + help="virtual epoch size to speed up data loading", + ) + parser.add_argument( + "--virtual-data-size", + default=None, + type=int, + help="virtual data size of the whole joint dataset to speed" + "up data loading and have specific dynamic sampling strategy interval", + ) @classmethod def load_langs(cls, args, **kwargs): if args.lang_dict and args.langs: - raise ValueError('--langs and --lang-dict can not both be specified') + raise ValueError("--langs and --lang-dict can not both be specified") if args.lang_dict is None and args.langs is None: logger.warning( - 'External language dictionary is not provided; ' - 'use lang-pairs to infer the set of supported languages. ' - 'The language ordering is not stable which might cause ' - 'misalignment in pretraining and finetuning.') + "External language dictionary is not provided; " + "use lang-pairs to infer the set of supported languages. " + "The language ordering is not stable which might cause " + "misalignment in pretraining and finetuning." + ) # infer from lang_pairs as it is - langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}) + langs = list( + {x for lang_pair in args.lang_pairs for x in lang_pair.split("-")} + ) langs = sorted(langs) - logger.info(f'inferred language list: {langs}') + logger.info(f"inferred language list: {langs}") elif args.lang_dict: with PathManager.open(args.lang_dict, "r", encoding="utf-8") as f: langs = [lang.strip() for lang in f.readlines() if lang.strip()] - logger.info(f'loaded language list from {args.lang_dict} as they are ordered in file') + logger.info( + f"loaded language list from {args.lang_dict} as they are ordered in file" + ) elif args.langs: langs = args.langs - logger.info(f'parsed the language list as they are ordered in the option: {langs}') + logger.info( + f"parsed the language list as they are ordered in the option: {langs}" + ) return langs def has_sharded_data(self, split): - return self._has_sharded_data and split == getattr(self.args, "train_subset", None) + return self._has_sharded_data and split == getattr( + self.args, "train_subset", None + ) def _shared_collater(self): - return ( - not (self.args.extra_data and 'mono_dae' in self.args.extra_data) - and (not self.args.lang_tok_replacing_bos_eos) + return not (self.args.extra_data and "mono_dae" in self.args.extra_data) and ( + not self.args.lang_tok_replacing_bos_eos ) @classmethod @@ -198,28 +289,32 @@ def prepare(cls, load_dictionary, args, **kargs): args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) - if not hasattr(args, 'shuffle_instance'): + if not hasattr(args, "shuffle_instance"): args.shuffle_instance = False if args.langtoks is None: args.langtoks = {} - if 'main' not in args.langtoks: + if "main" not in args.langtoks: src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None - tgt_langtok_spec = 'tgt' if args.decoder_langtok else None - args.langtoks['main'] = (src_langtok_spec, tgt_langtok_spec) + tgt_langtok_spec = "tgt" if args.decoder_langtok else None + args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec) def check_langs(langs, pairs): messages = [] for src, tgt in pairs: if src not in langs or tgt not in langs: - messages.append(f'language pair {src}-{tgt} contains languages ' - 'that are not in the language dictionary') + messages.append( + f"language pair {src}-{tgt} contains languages " + "that are not in the language dictionary" + ) if len(messages) > 0: - raise ValueError(' '.join(messages) + f"; langs: {langs}") + raise ValueError(" ".join(messages) + f"; langs: {langs}") if args.lang_pairs is None: - raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.') + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) if isinstance(args.lang_pairs, str): - args.lang_pairs = args.lang_pairs.split(',') + args.lang_pairs = args.lang_pairs.split(",") if args.source_lang is not None or args.target_lang is not None: training = False else: @@ -227,17 +322,25 @@ def check_langs(langs, pairs): sorted_langs = cls.load_langs(args, **kargs) check_langs( sorted_langs, - ([p.split('-') for p in args.lang_pairs] if training - else [(args.source_lang, args.target_lang)]) + ( + [p.split("-") for p in args.lang_pairs] + if training + else [(args.source_lang, args.target_lang)] + ), ) # load dictionaries if training: extra_lang_pairs = ( - list({p for _, v in args.extra_lang_pairs.items() for p in v.split(',')}) - if args.extra_lang_pairs else [] + list( + {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} + ) + if args.extra_lang_pairs + else [] + ) + langs_to_load_dicts = sorted( + {x for p in args.lang_pairs + extra_lang_pairs for x in p.split("-")} ) - langs_to_load_dicts = sorted({x for p in args.lang_pairs + extra_lang_pairs for x in p.split('-')}) else: langs_to_load_dicts = sorted([args.source_lang, args.target_lang]) @@ -246,7 +349,9 @@ def check_langs(langs, pairs): for lang in langs_to_load_dicts: paths = utils.split_paths(args.data) assert len(paths) > 0 - dicts[lang] = load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) + dicts[lang] = load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) if len(dicts) > 0: assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() @@ -259,26 +364,20 @@ def check_langs(langs, pairs): dicts[lang].add_symbol( MultilingualDatasetManager.get_lang_tok(lang_to_add, args, spec) ) - if args.lang_tok_style == 'mbart' or (args.extra_data and 'mono_dae' in args.extra_data): - dicts[lang].add_symbol('') - logger.info('[{}] dictionary: {} types'.format(lang, len(dicts[lang]))) + if args.lang_tok_style == "mbart" or ( + args.extra_data and "mono_dae" in args.extra_data + ): + dicts[lang].add_symbol("") + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) return sorted_langs, dicts, training - TOKEN_STYLES = { - 'mbart': '[{}]', - 'multilingual': '__{}__' - } + TOKEN_STYLES = {"mbart": "[{}]", "multilingual": "__{}__"} @classmethod def create_lang_dictionary(cls, langs): - unk = '' + unk = "" # hack to remove symbols other than unk as they are not needed by lang dict - lang_dict = Dictionary( - pad=unk, - eos=unk, - unk=unk, - bos=unk, - ) + lang_dict = Dictionary(pad=unk, eos=unk, unk=unk, bos=unk) for lang in langs: lang_dict.add_symbol(lang) return lang_dict @@ -288,26 +387,27 @@ def get_lang_tok_style(cls, args): return cls.TOKEN_STYLES[args.lang_tok_style] @classmethod - def get_lang_tok(cls, lang, args, spec=''): + def get_lang_tok(cls, lang, args, spec=""): if spec is None: return None - if spec.endswith('dae'): - lang = f'{lang}_dae' - elif spec.endswith('mined'): - lang = f'{lang}_mined' + if spec.endswith("dae"): + lang = f"{lang}_dae" + elif spec.endswith("mined"): + lang = f"{lang}_mined" return _lang_token(lang, cls.get_lang_tok_style(args)) @classmethod def get_langtok_index(cls, lang_tok, dic): idx = dic.index(lang_tok) - assert idx != dic.unk_index, \ - 'cannot find language token {} in the dictionary'.format(lang_tok) + assert ( + idx != dic.unk_index + ), "cannot find language token {} in the dictionary".format(lang_tok) return idx def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): if spec is None: return None - if spec and spec.startswith('src'): + if spec and spec.startswith("src"): if src_lang is None: return None langtok = self.get_lang_tok(src_lang, self.args, spec) @@ -315,7 +415,9 @@ def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): if tgt_lang is None: return None langtok = self.get_lang_tok(tgt_lang, self.args, spec) - return self.get_langtok_index(langtok, self.dicts[src_lang if src_lang else tgt_lang]) + return self.get_langtok_index( + langtok, self.dicts[src_lang if src_lang else tgt_lang] + ) def get_decoder_langtok(self, tgt_lang, spec=None): if spec is None: @@ -330,18 +432,21 @@ def load_data(cls, path, vdict, impl): @classmethod def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl): - filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) @classmethod def mono_split_exists(cls, split, lang, data_path, dataset_impl): - filename = os.path.join(data_path, '{}.{}'.format(split, lang)) + filename = os.path.join(data_path, "{}.{}".format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) @classmethod def bitext_split_exists(cls, split, src, tgt, data_path, dataset_impl): - src_exists = cls.split_exists(split, src, tgt, lang=src, data_path=data_path, dataset_impl=dataset_impl) \ - or cls.split_exists(split, tgt, src, lang=src, data_path=data_path, dataset_impl=dataset_impl) + src_exists = cls.split_exists( + split, src, tgt, lang=src, data_path=data_path, dataset_impl=dataset_impl + ) or cls.split_exists( + split, tgt, src, lang=src, data_path=data_path, dataset_impl=dataset_impl + ) # check source exists to determine shard number # also note that during inference time target is not required # so checking target will fail inference time data loading @@ -350,45 +455,57 @@ def bitext_split_exists(cls, split, src, tgt, data_path, dataset_impl): @classmethod def get_split_num_shards(cls, split, src, tgt, data_paths, dataset_impl): return sum( - 1 for path in data_paths + 1 + for path in data_paths if cls.bitext_split_exists(split, src, tgt, path, dataset_impl) ) @classmethod def get_mono_split_num_shards(cls, split, lang, data_paths, dataset_impl): return sum( - 1 for path in data_paths + 1 + for path in data_paths if cls.mono_split_exists(split, lang, path, dataset_impl) ) def load_lang_dataset( - self, - data_path, split, - src, src_dict, - tgt, tgt_dict, - combine, dataset_impl, upsample_primary, - max_source_positions, - prepend_bos=False, load_alignments=False, - truncate_source=False, + self, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + max_source_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, ): src_datasets = [] tgt_datasets = [] for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + split_k = split + (str(k) if k > 0 else "") # infer langcode if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: - logger.error(f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}") - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + logger.error( + f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" + ) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) if truncate_source: @@ -400,13 +517,13 @@ def load_lang_dataset( src_dict.eos(), ) src_datasets.append(src_dataset) - tgt_datasets.append( - self.load_data(prefix + tgt, tgt_dict, dataset_impl) - ) + tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) - logger.info('{} {} {}-{} {} examples'.format( - data_path, split_k, src, tgt, len(src_datasets[-1]) - )) + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) if not combine: break @@ -428,20 +545,33 @@ def load_lang_dataset( align_dataset = None if load_alignments: - align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) + align_path = os.path.join( + data_path, "{}.align.{}-{}".format(split, src, tgt) + ) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): - align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) return src_dataset, tgt_dataset, align_dataset def load_langpair_dataset( self, - data_path, split, - src, src_dict, - tgt, tgt_dict, - combine, dataset_impl, upsample_primary, - left_pad_source, left_pad_target, max_source_positions, - max_target_positions, prepend_bos=False, load_alignments=False, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, truncate_source=False, src_dataset_transform_func=lambda dataset: dataset, tgt_dataset_transform_func=lambda dataset: dataset, @@ -451,46 +581,70 @@ def load_langpair_dataset( ): norm_direction = "-".join(sorted([src, tgt])) if langpairs_sharing_datasets is not None: - src_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, src), 'NotInCache') - tgt_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, tgt), 'NotInCache') - align_dataset = langpairs_sharing_datasets.get((data_path, split, norm_direction, src, tgt), 'NotInCache') + src_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src), "NotInCache" + ) + tgt_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, tgt), "NotInCache" + ) + align_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src, tgt), "NotInCache" + ) # a hack: any one is not in cache, we need to reload them if ( langpairs_sharing_datasets is None - or src_dataset == 'NotInCache' - or tgt_dataset == 'NotInCache' - or align_dataset == 'NotInCache' + or src_dataset == "NotInCache" + or tgt_dataset == "NotInCache" + or align_dataset == "NotInCache" or split != getattr(self.args, "train_subset", None) ): # source and target datasets can be reused in reversed directions to save memory # reversed directions of valid and test data will not share source and target datasets src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( - data_path, split, - src, src_dict, - tgt, tgt_dict, - combine, dataset_impl, upsample_primary, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, max_source_positions=max_source_positions, - prepend_bos=prepend_bos, load_alignments=load_alignments, + prepend_bos=prepend_bos, + load_alignments=load_alignments, truncate_source=truncate_source, ) src_dataset = src_dataset_transform_func(src_dataset) tgt_dataset = tgt_dataset_transform_func(tgt_dataset) if langpairs_sharing_datasets is not None: - langpairs_sharing_datasets[(data_path, split, norm_direction, src)] = src_dataset - langpairs_sharing_datasets[(data_path, split, norm_direction, tgt)] = tgt_dataset - langpairs_sharing_datasets[(data_path, split, norm_direction, src, tgt)] = align_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src) + ] = src_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt) + ] = tgt_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src, tgt) + ] = align_dataset if align_dataset is None: # no align data so flag the reverse direction as well in sharing - langpairs_sharing_datasets[(data_path, split, norm_direction, tgt, src)] = align_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt, src) + ] = align_dataset else: - logger.info(f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " - f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}") + logger.info( + f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " + f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}" + ) return LanguagePairDataset( - src_dataset, src_dataset.sizes, + src_dataset, + src_dataset.sizes, src_dict, - tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dataset, + tgt_dataset.sizes if tgt_dataset is not None else None, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, @@ -529,17 +683,25 @@ def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None return PrependTokenDataset(dataset, tok) return dataset - def alter_dataset_langtok(self, lang_pair_dataset, - src_eos=None, src_lang=None, - tgt_eos=None, tgt_lang=None, - src_langtok_spec=None, tgt_langtok_spec=None, - ): + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + src_langtok_spec=None, + tgt_langtok_spec=None, + ): if src_langtok_spec is None and tgt_langtok_spec is None: return lang_pair_dataset new_src_eos = None - if src_langtok_spec is not None and src_eos is not None \ - and (src_lang is not None or tgt_lang is not None): + if ( + src_langtok_spec is not None + and src_eos is not None + and (src_lang is not None or tgt_lang is not None) + ): new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec) else: src_eos = None @@ -559,16 +721,18 @@ def alter_dataset_langtok(self, lang_pair_dataset, ) def load_a_dataset( - self, - split, - data_path, - src, src_dict, - tgt, tgt_dict, - combine, - prepend_bos=False, - langpairs_sharing_datasets=None, - data_category=None, - **extra_kwargs, + self, + split, + data_path, + src, + src_dict, + tgt, + tgt_dict, + combine, + prepend_bos=False, + langpairs_sharing_datasets=None, + data_category=None, + **extra_kwargs, ): dataset_impl = self.args.dataset_impl upsample_primary = self.args.upsample_primary @@ -582,24 +746,43 @@ def load_a_dataset( tgt_dataset_transform_func = self.tgt_dataset_tranform_func enable_lang_ids = self.args.enable_lang_ids lang_dictionary = self.lang_dict - src_langtok_spec, tgt_langtok_spec = extra_kwargs['langtok_spec'] + src_langtok_spec, tgt_langtok_spec = extra_kwargs["langtok_spec"] src_langtok = self.get_encoder_langtok(src, tgt, src_langtok_spec) tgt_langtok = self.get_decoder_langtok(tgt, tgt_langtok_spec) - logger.info(f'{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}') + logger.info( + f"{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}" + ) langpair_ds = self.load_langpair_dataset( - data_path, split, - src, src_dict, - tgt, tgt_dict, - combine, dataset_impl, upsample_primary, - left_pad_source, left_pad_target, max_source_positions, - max_target_positions, prepend_bos, load_alignments, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos, + load_alignments, truncate_source, - src_dataset_transform_func=lambda dataset: src_dataset_transform_func(src, tgt, dataset, src_langtok_spec), - tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func(src, tgt, dataset, tgt_langtok_spec), - src_lang_id=_lang_id(lang_dictionary, src) if enable_lang_ids and lang_dictionary is not None else None, - tgt_lang_id=_lang_id(lang_dictionary, tgt) if enable_lang_ids and lang_dictionary is not None else None, + src_dataset_transform_func=lambda dataset: src_dataset_transform_func( + src, tgt, dataset, src_langtok_spec + ), + tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func( + src, tgt, dataset, tgt_langtok_spec + ), + src_lang_id=_lang_id(lang_dictionary, src) + if enable_lang_ids and lang_dictionary is not None + else None, + tgt_lang_id=_lang_id(lang_dictionary, tgt) + if enable_lang_ids and lang_dictionary is not None + else None, langpairs_sharing_datasets=langpairs_sharing_datasets, ) if langpair_ds.tgt_sizes is None: @@ -607,7 +790,9 @@ def load_a_dataset( langpair_ds.sizes = langpair_ds.src_sizes else: # use the max of two sides to define the size to help max positions filtering - langpair_ds.sizes = np.vstack([langpair_ds.src_sizes, langpair_ds.tgt_sizes]).max(axis=0) + langpair_ds.sizes = np.vstack( + [langpair_ds.src_sizes, langpair_ds.tgt_sizes] + ).max(axis=0) assert langpair_ds.sizes.shape == langpair_ds.src_sizes.shape # TODO: handle modified lang toks for mined data and dae data if self.args.lang_tok_replacing_bos_eos: @@ -624,38 +809,38 @@ def load_a_dataset( ds = langpair_ds return ds - def load_split_langpair_datasets( - self, - split, - data_param_list, - ): + def load_split_langpair_datasets(self, split, data_param_list): datasets = [] - langpairs_sharing_datasets = {} if self.args.enable_reservsed_directions_shared_datasets else None + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None + ) for param in data_param_list: - ds = self.load_a_dataset(split=split, langpairs_sharing_datasets=langpairs_sharing_datasets, **param) + ds = self.load_a_dataset( + split=split, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param, + ) datasets.append(ds) return datasets def get_data_paths_and_lang_pairs(self, split): - datapaths = { - 'main': self.args.data, - } - lang_pairs = { - 'main': self.lang_pairs - } + datapaths = {"main": self.args.data} + lang_pairs = {"main": self.lang_pairs} if split == getattr(self.args, "train_subset", None): # only training data can have extra data and extra language pairs if self.args.extra_data: extra_datapaths = self.args.extra_data datapaths.update(extra_datapaths) if self.args.extra_lang_pairs: - extra_lang_pairs = {k: v.split(',') for k, v in self.args.extra_lang_pairs.items()} + extra_lang_pairs = { + k: v.split(",") for k, v in self.args.extra_lang_pairs.items() + } lang_pairs.update(extra_lang_pairs) return datapaths, lang_pairs @classmethod def get_dataset_key(cls, data_category, src, tgt): - return f'{data_category}:{src}-{tgt}' + return f"{data_category}:{src}-{tgt}" def get_split_num_data_shards(self, split): if split in self._num_shards_dict: @@ -667,19 +852,24 @@ def get_split_num_data_shards(self, split): if data_category not in lang_pairs: continue paths = utils.split_paths(paths) - lang_dirs = [lang_pair.split('-') for lang_pair in lang_pairs[data_category]] + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] for src, tgt in lang_dirs: # monolingual data ruqires tgt only - assert src is not None or 'mono_' in data_category, (f'error: src={src}, ' - 'tgt={tgt} for data_category={data_category}') + assert src is not None or "mono_" in data_category, ( + f"error: src={src}, " "tgt={tgt} for data_category={data_category}" + ) key = self.get_dataset_key(data_category, src, tgt) - if 'mono_' in data_category: + if "mono_" in data_category: num_shards_dict[key] = self.get_mono_split_num_shards( - split, tgt, paths, self.args.dataset_impl) + split, tgt, paths, self.args.dataset_impl + ) else: num_shards_dict[key] = self.get_split_num_shards( - split, src, tgt, paths, self.args.dataset_impl) + split, src, tgt, paths, self.args.dataset_impl + ) self._num_shards_dict[split] = num_shards_dict logger.info(f"[{split}] num of shards: {num_shards_dict}") return num_shards_dict @@ -694,7 +884,7 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): # TODO: to extend with extra datasets and keys and loop over different shard data paths param_list = [] data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) - logger.info(f'langtoks settings: {self.args.langtoks}') + logger.info(f"langtoks settings: {self.args.langtoks}") split_num_shards_dict = self.get_split_num_data_shards(split) for data_category, paths in data_paths.items(): if data_category not in lang_pairs: @@ -714,35 +904,48 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): lang_tok_spec = (None, None) # infer langcode - lang_dirs = [lang_pair.split('-') for lang_pair in lang_pairs[data_category]] + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] for src, tgt in lang_dirs: - assert src is not None or data_category == 'mono_dae', (f'error: src={src}, ' - 'tgt={tgt} for data_category={data_category}') + assert src is not None or data_category == "mono_dae", ( + f"error: src={src}, " "tgt={tgt} for data_category={data_category}" + ) # logger.info(f"preparing param for {data_category}: {src} - {tgt}") key = self.get_dataset_key(data_category, src, tgt) data_path = self.get_split_data_path( - paths, epoch, shard_epoch, split_num_shards_dict[key]) + paths, epoch, shard_epoch, split_num_shards_dict[key] + ) param_list.append( - { - 'key': key, - 'data_path': data_path, - 'split': split, - 'src': src, - 'src_dict': self.dicts[src] if src and data_category != 'mono_dae' else None, - 'tgt': tgt, - 'tgt_dict': self.dicts[tgt], - 'data_category': data_category, - 'langtok_spec': lang_tok_spec, - } + { + "key": key, + "data_path": data_path, + "split": split, + "src": src, + "src_dict": self.dicts[src] + if src and data_category != "mono_dae" + else None, + "tgt": tgt, + "tgt_dict": self.dicts[tgt], + "data_category": data_category, + "langtok_spec": lang_tok_spec, + } ) return param_list def get_train_dataset_sizes(self, data_param_list, datasets): num_shards = [ - self.get_split_num_data_shards(param['split'])[param['key']] for param in data_param_list] - data_sizes = [(key, len(d) * num_shard) for (key, d), num_shard in zip(datasets, num_shards)] - logger.info(f'data sizes multiplied by num_shards used in sampling ratios: {data_sizes}') + self.get_split_num_data_shards(param["split"])[param["key"]] + for param in data_param_list + ] + data_sizes = [ + (key, len(d) * num_shard) + for (key, d), num_shard in zip(datasets, num_shards) + ] + logger.info( + f"data sizes multiplied by num_shards used in sampling ratios: {data_sizes}" + ) return [s for _, s in data_sizes] def get_train_sampling_ratios(self, data_param_list, datasets, epoch=1): @@ -755,37 +958,42 @@ def get_sampling_ratios(self, data_param_list, datasets, epoch): if self.args.sampling_weights_from_file: weights = load_sampling_weights(self.args.sampling_weights_from_file) sample_ratios = [weights[k] for k, _ in datasets] - logger.info('| ignoring --sampling-weights when loadding sampling weights ' - f'from file {self.args.sampling_weights_from_file}') + logger.info( + "| ignoring --sampling-weights when loadding sampling weights " + f"from file {self.args.sampling_weights_from_file}" + ) elif self.args.sampling_weights: sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] else: - sample_ratios = self.get_train_sampling_ratios(data_param_list, datasets, epoch) + sample_ratios = self.get_train_sampling_ratios( + data_param_list, datasets, epoch + ) if sample_ratios is not None: - logger.info('| Upsample ratios: {}'.format( - list(zip(map(lambda x: x['key'], data_param_list), sample_ratios)) - )) + logger.info( + "| Upsample ratios: {}".format( + list(zip(map(lambda x: x["key"], data_param_list), sample_ratios)) + ) + ) assert len(sample_ratios) == len(datasets) return sample_ratios def load_split_datasets( - self, - split, - training, - epoch=1, combine=False, shard_epoch=None, **kwargs, + self, split, training, epoch=1, combine=False, shard_epoch=None, **kwargs ): data_param_list = self.get_split_data_param_list( - split, epoch, shard_epoch=shard_epoch, + split, epoch, shard_epoch=shard_epoch + ) + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None ) - langpairs_sharing_datasets = {} if self.args.enable_reservsed_directions_shared_datasets else None datasets = [ ( - param['key'], + param["key"], self.load_a_dataset( combine=combine, langpairs_sharing_datasets=langpairs_sharing_datasets, - **param + **param, ), ) for param in data_param_list @@ -793,52 +1001,48 @@ def load_split_datasets( return datasets, data_param_list def load_into_sampled_multi_epoch_dataset( - self, split, datasets, data_param_list, - epoch, shard_epoch=None + self, split, datasets, data_param_list, epoch, shard_epoch=None ): sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) return SampledMultiEpochDataset( - OrderedDict(datasets), - epoch=epoch, - shard_epoch=shard_epoch, - # valid and test datasets will be degerate to concating datasets: - sampling_ratios=sample_ratios, - eval_key=None, - batch_by_size=True, - collate_format=CollateFormat.single, - virtual_size=self.args.virtual_data_size, - split=split, - virtual_epoch_size=self.args.virtual_epoch_size, - # if not using lang_tok altering, simplified to use the same collater - shared_collater=self._shared_collater(), + OrderedDict(datasets), + epoch=epoch, + shard_epoch=shard_epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + batch_by_size=True, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + virtual_epoch_size=self.args.virtual_epoch_size, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), ) def load_into_concat_dataset(self, split, datasets, data_param_list): if self.args.lang_tok_replacing_bos_eos: # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset return SampledMultiDataset( - OrderedDict(datasets), - sampling_ratios=None, - eval_key=None, - batch_by_size=True, - collate_format=CollateFormat.single, - virtual_size=None, - split=split, - ) + OrderedDict(datasets), + sampling_ratios=None, + eval_key=None, + batch_by_size=True, + collate_format=CollateFormat.single, + virtual_size=None, + split=split, + ) return ConcatDataset([d for _, d in datasets]) def load_sampled_multi_epoch_dataset( - self, - split, - training, - epoch=0, combine=False, shard_epoch=None, **kwargs + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs ): datasets, data_param_list = self.load_split_datasets( - split, training, - epoch, combine, shard_epoch=shard_epoch, **kwargs - ) + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) if training and split == getattr(self.args, "train_subset", None): return self.load_into_sampled_multi_epoch_dataset( - split, datasets, data_param_list, epoch, shard_epoch=shard_epoch) + split, datasets, data_param_list, epoch, shard_epoch=shard_epoch + ) else: return self.load_into_concat_dataset(split, datasets, data_param_list) From e171c8d86a939cf4ebc483cd649bee1935379771 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 4 Sep 2020 14:25:14 -0700 Subject: [PATCH 142/707] Account for checkpoint updates when calling take on CountingIterator Summary: Recently some of our runs are getting: "RuntimeError: Mismatch between actual and expected iterable length. Please report this to the fairseq developers." f214567466 We never ran into this before because this is a new check by fairseq to be more strict with iterators. Fix is to: 1. Account for the offset (i.e. load from checkpoint mid epoch) when propagating `take`. This fixes the issue of `next` returning too many things, which is what causes the error. 2. Update the underlying iterator when calling `take` on `BufferedIterator` and the length of the `BufferedIterator`. Although this doesn't cause the error, it is necessary to maintain consistency. Reviewed By: myleott Differential Revision: D23443012 fbshipit-source-id: 73c26db8392e5508a61acfda7ca40a24df89fabb --- fairseq/data/iterators.py | 30 +++++++++++++++++++++--------- tests/test_iterators.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index e902b2fb47..20c76f8398 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -85,10 +85,17 @@ def take(self, n): self.total = min(self.total, n) # Propagate this change to the underlying iterator + # Only take after what we have already consumed (i.e. after restarting + # from checkpoint mid epoch, we have to subtract self.n which is the + # starting point) + # + # This to maintain the invariant self.total = self.n + len(iterable), + # before calling __next__ or __iter__ + propagated_take = max(n - self.n, 0) if hasattr(self.iterable, "take"): - self.iterable.take(n) + self.iterable.take(propagated_take) else: - self.iterable = itertools.islice(self.iterable, n) + self.iterable = itertools.islice(self.iterable, propagated_take) class EpochBatchIterating(object): @@ -468,9 +475,7 @@ def __init__(self, queue, source, max_len): def run(self): try: - self._source_iter = iter(self._source) - for _ in range(len(self._source)): - item = next(self._source_iter) + for item in self._source: self._queue.put(item) # Stop if we reached the maximum length @@ -490,17 +495,18 @@ class BufferedIterator(object): def __init__(self, size, iterable): self._queue = queue.Queue(size) self._iterable = iterable - self.max_len = None self._consumer = None self.start_time = time.time() self.warning_time = None + self.total = len(iterable) + def _create_consumer(self): self._consumer = BackgroundConsumer( self._queue, self._iterable, - self.max_len + self.total, ) self._consumer.daemon = True self._consumer.start() @@ -509,10 +515,16 @@ def __iter__(self): return self def __len__(self): - return len(self._iterable) + return self.total def take(self, n): - self.max_len = n + self.total = min(self.total, n) + + # Propagate this change to the underlying iterator + if hasattr(self._iterable, "take"): + self._iterable.take(n) + else: + self._iterable = itertools.islice(self._iterable, n) def __next__(self): # Create consumer if not created yet diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 7ceef124f5..9e444d154b 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -85,6 +85,40 @@ def test_counting_iterator_take(self): self.assertEqual(next(itr), ref[4]) self.assertFalse(itr.has_next()) + def test_counting_iterator_buffered_iterator_take(self): + ref = list(range(10)) + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(len(itr), len(list(iter(itr)))) + self.assertEqual(len(itr), 5) + + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(len(buffered_itr), 5) + self.assertEqual(len(list(iter(buffered_itr))), 5) + + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr) + itr.take(5) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(next(itr), ref[1]) + itr.skip(2) + self.assertEqual(next(itr), ref[4]) + self.assertFalse(itr.has_next()) + self.assertRaises(StopIteration, next, buffered_itr) + + ref = list(range(4,10)) + buffered_itr = iterators.BufferedIterator(2, ref) + itr = iterators.CountingIterator(buffered_itr, start=4) + itr.take(5) + self.assertEqual(len(itr), 5) + self.assertEqual(len(buffered_itr), 1) + self.assertEqual(next(itr), ref[0]) + self.assertFalse(itr.has_next()) + self.assertRaises(StopIteration, next, buffered_itr) + if __name__ == '__main__': unittest.main() From e7f76c44817e9766acca3aacf0d4e8807a6a2d03 Mon Sep 17 00:00:00 2001 From: Mu Tian Date: Fri, 4 Sep 2020 17:05:34 -0700 Subject: [PATCH 143/707] hydra-fairseq - add dataclass Summary: hydra fairseq - add main common dataclasses as structured config Reviewed By: alexeib Differential Revision: D23375458 fbshipit-source-id: 4cb2802e523990d4e2b1a87e3cf1bc4dc852bc5b --- fairseq/dataclass/__init__.py | 0 fairseq/dataclass/data_class.py | 689 ++++++++++++++++++++++++++++++++ fairseq/dataclass/utils.py | 194 +++++++++ setup.py | 2 + 4 files changed, 885 insertions(+) create mode 100644 fairseq/dataclass/__init__.py create mode 100644 fairseq/dataclass/data_class.py create mode 100644 fairseq/dataclass/utils.py diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py new file mode 100644 index 0000000000..5c8a25dfa2 --- /dev/null +++ b/fairseq/dataclass/data_class.py @@ -0,0 +1,689 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Any, List, Optional, Dict, Type, Tuple +import torch +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.utils import FairseqDataclass, ChoiceEnum +import sys +from fairseq.tasks import TASK_DATACLASS_REGISTRY +from fairseq.models import ARCH_DATACLASS_REGISTRY +from fairseq.criterions import CRITERION_DATACLASS_REGISTRY +from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY +from fairseq.optim.bmuf import FairseqBMUFConfig +from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY +from hydra.core.config_store import ConfigStore +from argparse import Namespace + + +@dataclass +class CommonParams(FairseqDataclass): + # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were + # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. + no_progress_bar: bool = field( + default=False, metadata={"help": "disable progress bar"} + ) + log_interval: int = field( + default=100, + metadata={ + "help": "log progress every N batches (when progress bar is disabled)" + }, + ) + log_format: Optional[ChoiceEnum(["json", "none", "simple", "tqdm"])] = field( + default=None, metadata={"help": "log format to use"} + ) + tensorboard_logdir: Optional[str] = field( + default=None, + metadata={ + "help": "path to save logs for tensorboard, should match --logdir " + "of running tensorboard (default: no tensorboard logging)" + }, + ) + seed: int = field( + default=1, metadata={"help": "pseudo random number generator seed"} + ) + cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"}) + tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"}) + bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"}) + memory_efficient_bf16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of BF16 training; implies --bf16" + }, + ) + fp16: bool = field(default=False, metadata={"help": "use FP16"}) + memory_efficient_fp16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of FP16 training; implies --fp16" + }, + ) + fp16_no_flatten_grads: bool = field( + default=False, metadata={"help": "don't flatten FP16 grads tensor"} + ) + fp16_init_scale: int = field( + default=2 ** 7, metadata={"help": "default FP16 loss scale"} + ) + fp16_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing loss scale"}, + ) + fp16_scale_tolerance: float = field( + default=0.0, + metadata={ + "help": "pct of updates that can overflow before decreasing the loss scale" + }, + ) + min_loss_scale: float = field( + default=1e-4, + metadata={"help": "minimum FP16 loss scale, after which training is stopped"}, + ) + threshold_loss_scale: Optional[float] = field( + default=None, metadata={"help": "threshold FP16 loss scale from below"} + ) + user_dir: Optional[str] = field( + default=None, + metadata={ + "help": "path to a python module containing custom extensions (tasks and/or architectures)" + }, + ) + empty_cache_freq: int = field( + default=0, + metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"}, + ) + all_gather_list_size: int = field( + default=16384, + metadata={"help": "number of bytes reserved for gathering stats from workers"}, + ) + model_parallel_size: int = field( + default=1, metadata={"help": "total number of GPUs to parallelize model over"} + ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + quantization_config_path: Optional[str] = field( + default=None, metadata={"help": "path to quantization config file"} + ) + profile: bool = field( + default=False, metadata={"help": "enable autograd profiler emit_nvtx"} + ) + + +@dataclass +class DistributedTrainingParams(FairseqDataclass): + distributed_world_size: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of GPUs across all nodes (default: all visible GPUs)" + }, + ) + distributed_rank: Optional[int] = field( + default=0, metadata={"help": "rank of the current worker"} + ) + distributed_backend: str = field( + default="nccl", metadata={"help": "distributed backend"} + ) + distributed_init_method: Optional[str] = field( + default=None, + metadata={ + "help": "typically tcp://hostname:port that will be used to " + "establish initial connetion" + }, + ) + distributed_port: int = field( + default=-1, + metadata={ + "help": "port number (not required if using --distributed-init-method)" + }, + ) + device_id: int = field( + default=0, + metadata={"help": "which GPU to use (usually configured automatically)"}, + ) + local_rank: int = field( + default=0, + metadata={"help": "which GPU to use (usually configured automatically)"}, + ) + distributed_no_spawn: bool = field( + default=False, + metadata={ + "help": "do not spawn multiple processes even if multiple GPUs are visible" + }, + ) + ddp_backend: ChoiceEnum(["c10d", "no_c10d"]) = field( + default="c10d", metadata={"help": "DistributedDataParallel backend"} + ) + bucket_cap_mb: int = field( + default=25, metadata={"help": "bucket size for reduction"} + ) + fix_batches_to_gpus: bool = field( + default=False, + metadata={ + "help": "don't shuffle batches between GPUs; this reduces overall " + "randomness and may affect precision but avoids the cost of re-reading the data" + }, + ) + find_unused_parameters: bool = field( + default=False, + metadata={ + "help": "disable unused parameter detection (not applicable to " + "no_c10d ddp-backend" + }, + ) + fast_stat_sync: bool = field( + default=False, + metadata={"help": "[deprecated] this is now defined per Criterion"}, + ) + broadcast_buffers: bool = field( + default=False, + metadata={ + "help": "Copy non-trainable parameters between GPUs, such as " + "batchnorm population statistics" + }, + ) + distributed_wrapper: ChoiceEnum(["DDP", "SlowMo"]) = field( + default="DDP", metadata={"help": "DistributedDataParallel backend"} + ) + slowmo_momentum: Optional[float] = field( + default=None, + metadata={ + "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, " + "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" + }, + ) + slowmo_algorithm: str = field( + default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"} + ) + localsgd_frequency: int = field( + default=3, metadata={"help": "Local SGD allreduce frequency"} + ) + nprocs_per_node: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "number of GPUs in each node. An allreduce operation across GPUs in " + "a node is very fast. Hence, we do allreduce across GPUs in a node, " + "and gossip across different nodes" + }, + ) + + +@dataclass +class DatasetParams(FairseqDataclass): + num_workers: int = field( + default=1, metadata={"help": "how many subprocesses to use for data loading"} + ) + skip_invalid_size_inputs_valid_test: bool = field( + default=False, + metadata={"help": "ignore too long or too short lines in valid and test set"}, + ) + max_tokens: Optional[int] = field( + default=None, metadata={"help": "maximum number of tokens in a batch"} + ) + max_sentences: Optional[int] = field( + default=None, metadata={"help": "maximum number of sentences in a batch"} + ) + batch_size: Optional[int] = field( + default=None, metadata={"help": "maximum number of sentences in a batch"} + ) + required_batch_size_multiple: int = field( + default=8, metadata={"help": "batch size will be a multiplier of this value"} + ) + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field( + default=None, metadata={"help": "output dataset implementation"} + ) + data_buffer_size: int = field( + default=10, metadata={"help": "Number of batches to preload"} + ) + train_subset: str = field( + default="train", + metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, + ) + valid_subset: str = field( + default="valid", + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)" + }, + ) + validate_interval: int = field( + default=1, metadata={"help": "validate every N epochs"} + ) + validate_interval_updates: int = field( + default=0, metadata={"help": "validate every N updates"} + ) + validate_after_updates: int = field( + default=0, metadata={"help": "dont validate until reaching this many updates"} + ) + fixed_validation_seed: Optional[int] = field( + default=None, metadata={"help": "specified random seed for validation"} + ) + disable_validation: bool = field( + default=False, metadata={"help": "disable validation"} + ) + max_tokens_valid: Optional[int] = field( + default=None, + metadata={ + "help": "maximum number of tokens in a validation batch" + " (defaults to --max-tokens)" + }, + ) + max_sentences_valid: Optional[int] = field( + default=None, + metadata={ + "help": "maximum number of sentences in a validation batch" + " (defaults to --max-sentences)" + }, + ) + curriculum: int = field( + default=0, metadata={"help": "don't shuffle batches for first N epochs"} + ) + gen_subset: str = field( + default="test", + metadata={"help": "data subset to generate (train, valid, test)"}, + ) + num_shards: int = field( + default=1, metadata={"help": "shard generation over N shards"} + ) + shard_id: int = field( + default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} + ) + + +@dataclass +class OptimizationParams(FairseqDataclass): + max_epoch: int = field( + default=0, metadata={"help": "force stop training at specified epoch"} + ) + max_update: int = field( + default=0, metadata={"help": "force stop training at specified update"} + ) + stop_time_hours: float = field( + default=0, + metadata={ + "help": "force stop training after specified cumulative time (if >0)" + }, + ) + clip_norm: float = field( + default=25.0, metadata={"help": "clip threshold of gradients"} + ) + sentence_avg: bool = field( + default=False, + metadata={ + "help": "normalize gradients by the number of sentences in a batch" + " (default is to normalize by number of tokens)" + }, + ) + update_freq: List[int] = field( + default_factory=lambda: [1], + metadata={"help": "update parameters every N_i batches, when in epoch i"}, + ) + lr: List[float] = field( + default_factory=lambda: [0.25], + metadata={ + "help": "learning rate for the first N epochs; all epochs >N using LR_N" + " (note: this may be interpreted differently depending on --lr-scheduler)" + }, + ) + min_lr: float = field( + default=-1.0, + metadata={"help": "stop training when the learning rate reaches this minimum"}, + ) + use_bmuf: bool = field( + default=False, + metadata={ + "help": "specify global optimizer for syncing models on different GPUs/shards" + }, + ) + + +@dataclass +class CheckpointParams(FairseqDataclass): + save_dir: str = field( + default="checkpoints", metadata={"help": "path to save checkpoints"} + ) + restore_file: str = field( + default="checkpoint_last.pt", + metadata={ + "help": "filename from which to load checkpoint " + "(default: /checkpoint_last.pt" + }, + ) + finetune_from_model: Optional[str] = field( + default=None, + metadata={ + "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset" + }, + ) + reset_dataloader: bool = field( + default=False, + metadata={ + "help": "if set, does not reload dataloader state from the checkpoint" + }, + ) + reset_lr_scheduler: bool = field( + default=False, + metadata={ + "help": "if set, does not load lr scheduler state from the checkpoint" + }, + ) + reset_meters: bool = field( + default=False, + metadata={"help": "if set, does not load meters from the checkpoint"}, + ) + reset_optimizer: bool = field( + default=False, + metadata={"help": "if set, does not load optimizer state from the checkpoint"}, + ) + optimizer_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override optimizer args when loading a checkpoint" + }, + ) + save_interval: int = field( + default=1, metadata={"help": "save a checkpoint every N epochs"} + ) + save_interval_updates: int = field( + default=0, metadata={"help": "save a checkpoint (and validate) every N updates"} + ) + keep_interval_updates: int = field( + default=-1, + metadata={ + "help": "keep the last N checkpoints saved with --save-interval-updates" + }, + ) + keep_last_epochs: int = field( + default=-1, metadata={"help": "keep last N epoch checkpoints"} + ) + keep_best_checkpoints: int = field( + default=-1, metadata={"help": "keep best N checkpoints based on scores"} + ) + no_save: bool = field( + default=False, metadata={"help": "don't save models or checkpoints"} + ) + no_epoch_checkpoints: bool = field( + default=False, metadata={"help": "only store last and best checkpoints"} + ) + no_last_checkpoints: bool = field( + default=False, metadata={"help": "don't store last checkpoints"} + ) + no_save_optimizer_state: bool = field( + default=False, + metadata={"help": "don't save optimizer-state as part of checkpoint"}, + ) + best_checkpoint_metric: str = field( + default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} + ) + maximize_best_checkpoint_metric: bool = field( + default=False, + metadata={ + "help": 'select the largest metric value for saving "best" checkpoints' + }, + ) + patience: int = field( + default=-1, + metadata={ + "help": ( + "early stop training if valid performance doesn't " + "improve for N consecutive validation runs; note " + "that this is influenced by --validate-interval" + ) + }, + ) + + +@dataclass +class CommonEvalParams(FairseqDataclass): + path: Optional[str] = field( + default=None, metadata={"help": "path(s) to model file(s), colon separated"} + ) + remove_bpe: Optional[str] = field( + default=None, + metadata={ + "help": "remove BPE tokens before scoring (can be set to sentencepiece)" + }, + ) + quiet: bool = field(default=False, metadata={"help": "only print final scores"}) + model_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override model args at generation that were used during model training" + }, + ) + results_path: Optional[str] = field( + default=None, metadata={"help": "path to save eval results (optional)"} + ) + + +@dataclass +class EvalLMParams(FairseqDataclass): + output_word_probs: bool = field( + default=False, + metadata={ + "help": "if set, outputs words and their predicted log probabilities to standard output" + }, + ) + output_word_stats: bool = field( + default=False, + metadata={ + "help": "if set, outputs word statistics such as word count, average probability, etc" + }, + ) + context_window: int = field( + default=0, + metadata={ + "help": "ensures that every evaluated token has access to a context of at least this size, if possible" + }, + ) + softmax_batch: int = field( + default=sys.maxsize, + metadata={ + "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" + }, + ) + + +@dataclass +class TrainingConfig(FairseqDataclass): + """Config for training, a composition of training params""" + + common: CommonParams = CommonParams() + distributed_training: DistributedTrainingParams = DistributedTrainingParams() + dataset: DatasetParams = DatasetParams() + optimization: OptimizationParams = OptimizationParams() + checkpoint: CheckpointParams = CheckpointParams() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + + +@dataclass +class EvalLMConfig(FairseqDataclass): + """Config for eval lm, a composition of eval_lm params""" + + common: CommonParams = CommonParams() + distributed_training: DistributedTrainingParams = DistributedTrainingParams() + dataset: DatasetParams = DatasetParams() + optimization: OptimizationParams = OptimizationParams() + checkpoint: CheckpointParams = CheckpointParams() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + common_eval: CommonEvalParams = CommonEvalParams() + eval_lm: EvalLMParams = EvalLMParams() + + +def register_params_dataclass( + cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass] +) -> None: + """register params dataclass in config store""" + node_ = data_class(_name=data_class.name()) + cs.store(name=name, group=group, node=node_) + + +def register_module_dataclass( + cs: ConfigStore, registry: Dict[str, Any], group: str +) -> None: + """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" + # note that if `group == model`, we register all model archs, not the model name. + for k, v in registry.items(): + if v is not None: + node_ = v(_name=k) + cs.store(name=k, group=group, node=node_) + + +def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: + """cs: config store instance, register common training configs""" + + register_params_dataclass( + cs, name="training_params", group="params", data_class=TrainingConfig + ) + + register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") + register_module_dataclass(cs, ARCH_DATACLASS_REGISTRY, "model") + register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") + register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") + register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") + + +def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: + """cs: config store instance, register common training configs""" + + register_params_dataclass( + cs, name="eval_lm_params", group="params", data_class=EvalLMConfig + ) + + register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") + register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") + register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") + register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") + + +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + for k in data_class.__dataclass_fields__.keys(): + if k == "_name": + # private member, skip + continue + if not hasattr(args, k): + # print(f"cannot override {sub_node}.{k} since args does not have attribute {k}") + continue + if getattr(args, k) is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif getattr(args, k) == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(getattr(args, k), str): + if ( + getattr(args, k).startswith("[") + or getattr(args, k).startswith("(") + or getattr(args, k).startswith("{") + or ("," in getattr(args, k)) + ): + overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k))) + else: + overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + else: + overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + return overrides + + +def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]: + overrides = [] + + overrides.extend(_override_attr("params.common", CommonParams, args)) + overrides.extend(_override_attr("params.dataset", DatasetParams, args)) + overrides.extend( + _override_attr("params.distributed_training", DistributedTrainingParams, args) + ) + overrides.extend(_override_attr("params.optimization", OptimizationParams, args)) + overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args)) + overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) + module_overrides, module_deletes = override_module_args(args) + overrides.extend(module_overrides) + + return overrides, module_deletes + + +def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]: + overrides = [] + + overrides.extend(_override_attr("params.common", CommonParams, args)) + overrides.extend(_override_attr("params.dataset", DatasetParams, args)) + overrides.extend( + _override_attr("params.distributed_training", DistributedTrainingParams, args) + ) + overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args)) + overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args)) + overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) + module_overrides, module_deletes = override_module_args(args) + overrides.extend(module_overrides) + + return overrides, module_deletes + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + if args is not None: + assert ( + hasattr(args, "task") + and hasattr(args, "criterion") + and hasattr(args, "optimizer") + and hasattr(args, "lr_scheduler") + ) + if args.task in TASK_DATACLASS_REGISTRY: + overrides.append("task={}".format(args.task)) + overrides.append("task._name={}".format(args.task)) + overrides.extend( + _override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args) + ) + else: + deletes.append("task") + if args.criterion in CRITERION_DATACLASS_REGISTRY: + overrides.append("criterion={}".format(args.criterion)) + overrides.append("criterion._name={}".format(args.criterion)) + overrides.extend( + _override_attr( + "criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args + ) + ) + else: + deletes.append("criterion") + if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY: + overrides.append("optimizer={}".format(args.optimizer)) + overrides.append("optimizer._name={}".format(args.optimizer)) + overrides.extend( + _override_attr( + "optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args + ) + ) + else: + deletes.append("optimizer") + if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY: + overrides.append("lr_scheduler={}".format(args.lr_scheduler)) + overrides.append("lr_scheduler._name={}".format(args.lr_scheduler)) + overrides.extend( + _override_attr( + "lr_scheduler", + LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler], + args, + ) + ) + else: + deletes.append("lr_scheduler") + + if hasattr(args, "arch"): + if args.arch in ARCH_DATACLASS_REGISTRY: + overrides.append("model={}".format(args.arch)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend( + _override_attr("model", ARCH_DATACLASS_REGISTRY[args.arch], args) + ) + else: + deletes.append("model") + + return overrides, deletes diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py new file mode 100644 index 0000000000..b910d8353d --- /dev/null +++ b/fairseq/dataclass/utils.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, MISSING +from typing import Any, List, Optional, Dict +from enum import Enum +from argparse import ArgumentParser + + +def eval_str_list(x, x_type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(x_type, x)) + except TypeError: + return [x_type(x)] + + +class StrEnum(Enum): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +@dataclass +class FairseqDataclass: + """fairseq base dataclass that supported fetching attributes and metas""" + + _name: Optional[str] = None + + @staticmethod + def name(): + return None + + def _get_all_attributes(self) -> List[str]: + return [k for k in self.__dataclass_fields__.keys()] + + def _get_meta( + self, attribute_name: str, meta: str, default: Optional[Any] = None + ) -> Any: + return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) + + def _get_name(self, attribute_name: str) -> str: + return self.__dataclass_fields__[attribute_name].name + + def _get_default(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default + ): + return getattr(self, attribute_name) + return self.__dataclass_fields__[attribute_name].default + + def _get_default_factory(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default_factory() + ): + return getattr(self, attribute_name) + return self.__dataclass_fields__[attribute_name].default_factory() + + def _get_type(self, attribute_name: str) -> Any: + return self.__dataclass_fields__[attribute_name].type + + def _get_help(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "help") + + def _get_choices(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "choices") + + +def gen_parser_from_dataclass( + parser: ArgumentParser, + dataclass_instance: FairseqDataclass, + delete_default: bool = False, +) -> None: + """convert a dataclass instance to tailing parser arguments""" + import re + + def argparse_name(name: str): + if name == "data": + # normally data is positional args + return name + if name == "_name": + # private member, skip + return None + return "--" + name.replace("_", "-") + + def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError() + typestring = str(field_type) + if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): + return field_type.__args__[0] + return field_type + + def get_kwargs_from_dc( + dataclass_instance: FairseqDataclass, k: str + ) -> Dict[str, Any]: + """k: dataclass attributes""" + field_type = dataclass_instance._get_type(k) + inter_type = interpret_dc_type(field_type) + if isinstance(inter_type, type) and issubclass(inter_type, List): + field_default = dataclass_instance._get_default_factory(k) + else: + field_default = dataclass_instance._get_default(k) + + if isinstance(inter_type, type) and issubclass(inter_type, Enum): + field_choices = [t.value for t in list(inter_type)] + else: + field_choices = None + + field_help = dataclass_instance._get_help(k) + kwargs = {} + if isinstance(field_default, str) and field_default.startswith("${"): + kwargs["default"] = field_default + else: + if field_default is MISSING: + kwargs["required"] = True + if field_choices is not None: + kwargs["choices"] = field_choices + if (isinstance(inter_type, type) and issubclass(inter_type, List)) or ( + "List" in str(inter_type) + ): + if "int" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, int) + elif "float" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, float) + elif "str" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, str) + else: + raise NotImplementedError() + if field_default is not MISSING: + kwargs["default"] = ",".join(map(str, field_default)) + elif (isinstance(inter_type, type) and issubclass(inter_type, Enum)) or ( + "Enum" in str(inter_type) + ): + kwargs["type"] = str + if field_default is not MISSING: + if isinstance(field_default, Enum): + kwargs["default"] = field_default.value + else: + kwargs["default"] = field_default + elif inter_type is bool: + kwargs["action"] = ( + "store_false" if field_default is True else "store_true" + ) + kwargs["default"] = field_default + else: + kwargs["type"] = inter_type + if field_default is not MISSING: + kwargs["default"] = field_default + + kwargs["help"] = field_help + return kwargs + + for k in dataclass_instance._get_all_attributes(): + field_name = argparse_name(dataclass_instance._get_name(k)) + if field_name is None: + continue + kwargs = get_kwargs_from_dc(dataclass_instance, k) + if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): + continue + if delete_default: + del kwargs["default"] + parser.add_argument(field_name, **kwargs) diff --git a/setup.py b/setup.py index a309b90bd1..fa59acef78 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,8 @@ def include_dirs(self, dirs): install_requires=[ 'cffi', 'cython', + 'hydra-core', + 'dataclasses', 'editdistance', 'numpy', 'regex', From 0ffb94151f597ecb677551289e7046a21fb5ebaf Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 4 Sep 2020 22:11:06 -0700 Subject: [PATCH 144/707] fix deleting source_iter Summary: D23443012 (https://github.com/pytorch/fairseq/commit/e171c8d86a939cf4ebc483cd649bee1935379771) removed this iterator, so we need to remove this line otherwise jobs will fail. I guess we never ran into this during testing since we never actually finished an entire epoch and consumed everything since we call `take` Reviewed By: tangyuq Differential Revision: D23554132 fbshipit-source-id: 232e950a0a436419f6c5139e35caa81e3594fe38 --- fairseq/data/iterators.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 20c76f8398..19add56afa 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -488,8 +488,6 @@ def run(self): except Exception as e: self._queue.put(e) - del self._source_iter - class BufferedIterator(object): def __init__(self, size, iterable): From 87350759e399d73d13596c7c26539ba18a4145ea Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 9 Sep 2020 06:18:51 -0700 Subject: [PATCH 145/707] Add FairseqDataset.can_reuse_epoch_itr_across_epochs (#2525) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2525 Reviewed By: ngoyal2707 Differential Revision: D23318762 Pulled By: myleott fbshipit-source-id: c9c7236a2c9dc127716f5078d92d60df1fe5f716 --- fairseq/data/base_wrapper_dataset.py | 10 ++++++++++ fairseq/data/concat_dataset.py | 4 ++++ fairseq/data/denoising_dataset.py | 4 ++++ fairseq/data/fairseq_dataset.py | 13 +++++++++++++ fairseq/data/mask_tokens_dataset.py | 4 ++++ fairseq/data/multi_corpus_dataset.py | 4 ++++ .../multilingual/sampled_multi_dataset.py | 4 ++++ .../sampled_multi_epoch_dataset.py | 4 ++++ fairseq/data/nested_dictionary_dataset.py | 4 ++++ fairseq/data/resampling_dataset.py | 4 ++++ fairseq/data/shorten_dataset.py | 4 ++++ fairseq/tasks/fairseq_task.py | 19 ++++++++++++++----- fairseq/tasks/multilingual_masked_lm.py | 16 ---------------- 13 files changed, 73 insertions(+), 21 deletions(-) diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py index 8b5326a635..680dcce9ae 100644 --- a/fairseq/data/base_wrapper_dataset.py +++ b/fairseq/data/base_wrapper_dataset.py @@ -43,6 +43,9 @@ def ordered_indices(self): def supports_prefetch(self): return getattr(self.dataset, 'supports_prefetch', False) + def attr(self, attr: str, index: int): + return self.dataset.attr(attr, index) + def prefetch(self, indices): self.dataset.prefetch(indices) @@ -63,6 +66,13 @@ def batch_by_size( required_batch_size_multiple=required_batch_size_multiple, ) + def filter_indices_by_size(self, indices, max_sizes): + return self.dataset.filter_indices_by_size(indices, max_sizes) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return self.dataset.can_reuse_epoch_itr_across_epochs + def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.dataset, 'set_epoch'): diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py index 5ca80631f0..2c3306d6f5 100644 --- a/fairseq/data/concat_dataset.py +++ b/fairseq/data/concat_dataset.py @@ -98,6 +98,10 @@ def prefetch(self, indices): ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) frm = to + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + def set_epoch(self, epoch): super().set_epoch(epoch) for ds in self.datasets: diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index c55ce1ba49..4fe560b0a7 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -166,6 +166,10 @@ def __init__( self.epoch = 0 + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + def set_epoch(self, epoch, **unused): self.epoch = epoch diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index f196aff14f..a4a0985210 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -12,6 +12,19 @@ class EpochListening: """Mixin for receiving updates whenever the epoch increments.""" + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + def set_epoch(self, epoch): """Will receive the updated epoch number at the beginning of the epoch.""" pass diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 28bc3bc9cf..31f5459307 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -91,6 +91,10 @@ def __init__( self.epoch = 0 + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index bf33cb23b5..d2457666d6 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -146,6 +146,10 @@ def size(self, index: int): index, key = self._map_index(index) return self.datasets[key].size(index) + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 95eab280f0..14090ac8c5 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -346,6 +346,10 @@ def prefetch(self, indices): for i in range(len(prefetch_indices)): self.datasets[i].prefetch(prefetch_indices[i]) + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + def set_epoch(self, epoch): super().set_epoch(epoch) if epoch == self._cur_epoch: diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index fdd47e5091..289a117a00 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -174,6 +174,10 @@ def prefetch(self, indices): for i in range(len(prefetch_indices)): self.datasets[i].prefetch(prefetch_indices[i]) + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + def set_epoch(self, epoch): if self._current_epoch_start_index is None: self._setup(epoch) diff --git a/fairseq/data/nested_dictionary_dataset.py b/fairseq/data/nested_dictionary_dataset.py index 2795f895dd..ebc56303b9 100644 --- a/fairseq/data/nested_dictionary_dataset.py +++ b/fairseq/data/nested_dictionary_dataset.py @@ -110,6 +110,10 @@ def prefetch(self, indices): if getattr(ds, 'supports_prefetch', False): ds.prefetch(indices) + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) + def set_epoch(self, epoch): super().set_epoch(epoch) for ds in self.defn.values(): diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py index a2c9b31d79..ffb25ac668 100644 --- a/fairseq/data/resampling_dataset.py +++ b/fairseq/data/resampling_dataset.py @@ -107,6 +107,10 @@ def ordered_indices(self): def prefetch(self, indices): self.dataset.prefetch(self._cur_indices.array[indices]) + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + def set_epoch(self, epoch): logger.debug('ResamplingDataset.set_epoch: {}'.format(epoch)) super().set_epoch(epoch) diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py index 9c84219dc7..85659d101e 100644 --- a/fairseq/data/shorten_dataset.py +++ b/fairseq/data/shorten_dataset.py @@ -43,6 +43,10 @@ def __init__(self, dataset, truncation_length, seed=1): self.seed = seed self.epoch = 0 + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the crop changes, not item sizes + def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index ddc6760842..7067734ba5 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -145,6 +145,13 @@ def filter_indices_by_size( ).format(len(ignored), max_positions, ignored[:10])) return indices + def can_reuse_epoch_itr(self, dataset): + # We can reuse the epoch iterator across epochs as long as the dataset + # hasn't disabled it. We default to ``False`` here, although in practice + # this will be ``True`` for most datasets that inherit from + # ``FairseqDataset`` due to the base implementation there. + return getattr(dataset, 'can_reuse_epoch_itr_across_epochs', False) + def get_batch_iterator( self, dataset, @@ -189,10 +196,9 @@ def get_batch_iterator( ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ - # For default fairseq task, return same iterator across epochs - # as datasets are not dynamic, can be overridden in task specific - # setting. - if dataset in self.dataset_to_epoch_iter: + can_reuse_epoch_itr = self.can_reuse_epoch_itr(dataset) + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.debug('reusing EpochBatchIterator for epoch {}'.format(epoch)) return self.dataset_to_epoch_iter[dataset] assert isinstance(dataset, FairseqDataset) @@ -230,7 +236,10 @@ def get_batch_iterator( epoch=epoch, buffer_size=getattr(self.args, 'data_buffer_size', 0), ) - self.dataset_to_epoch_iter[dataset] = epoch_iter + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + return epoch_iter def build_model(self, args): diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py index 248724bd56..5d96a608b5 100644 --- a/fairseq/tasks/multilingual_masked_lm.py +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -292,22 +292,6 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) return src_dataset - def get_batch_iterator( - self, dataset, max_tokens=None, max_sentences=None, max_positions=None, - ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, - ): - # Recreate epoch iterator every epoch cause the underlying - # datasets are dynamic due to sampling. - self.dataset_to_epoch_iter = {} - epoch_iter = super().get_batch_iterator( - dataset, max_tokens, max_sentences, max_positions, - ignore_invalid_inputs, required_batch_size_multiple, - seed, num_shards, shard_id, num_workers, epoch, - ) - self.dataset_to_epoch_iter = {} - return epoch_iter - @property def source_dictionary(self): return self.dictionary From 1cc8e95cece54152b6960e7880a65da98d8ac58a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 9 Sep 2020 06:18:51 -0700 Subject: [PATCH 146/707] Don't cache epoch iterators when using sharded datasets (#1268) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1268 We previously had a memory leak when using sharded datasets. In particular, each sharded dataset is a new FairseqDataset instance, and the cache is keyed by the `dataset` instance. Since we never clear the cache, this would eventually cause the system to run out of CPU RAM. This diff disables caching when using sharded datasets. Note that we also change the signature to `get_batch_iterator`, which needs to propagate to many places. We previously avoided this update when adding `data_buffer_size`, so I'm also adding that everywhere. Reviewed By: ngoyal2707 Differential Revision: D23319135 fbshipit-source-id: 6bcd6aee141ad9cc234448c49106a8dbf8ea1800 --- docs/tutorial_classifying_names.rst | 3 ++- examples/speech_recognition/infer.py | 1 + fairseq/hub_utils.py | 1 + fairseq/tasks/fairseq_task.py | 14 ++++++++++-- .../tasks/translation_multi_simple_epoch.py | 22 ++++++++++++++++--- fairseq/trainer.py | 6 +++++ fairseq_cli/eval_lm.py | 1 + fairseq_cli/generate.py | 1 + fairseq_cli/train.py | 9 +++++++- fairseq_cli/validate.py | 1 + 10 files changed, 52 insertions(+), 7 deletions(-) diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index e2b5a67168..40a3cb6f25 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -313,7 +313,8 @@ following contents:: # def get_batch_iterator( # self, dataset, max_tokens=None, max_sentences=None, max_positions=None, # ignore_invalid_inputs=False, required_batch_size_multiple=1, - # seed=1, num_shards=1, shard_id=0, + # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, + # data_buffer_size=0, disable_iterator_cache=False, # ): # (...) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index e40d37d390..19f2c2ed03 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -102,6 +102,7 @@ def get_dataset_itr(args, task, models): num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index e040a8c3f3..4e499e141d 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -233,6 +233,7 @@ def _build_batches( max_sentences=self.args.max_sentences, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, + disable_iterator_cache=True, ).next_epoch_itr(shuffle=False) return batch_iterator diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 7067734ba5..c7b39f5b62 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -165,6 +165,8 @@ def get_batch_iterator( shard_id=0, num_workers=0, epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -192,11 +194,19 @@ def get_batch_iterator( (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 1). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ - can_reuse_epoch_itr = self.can_reuse_epoch_itr(dataset) + can_reuse_epoch_itr = ( + not disable_iterator_cache + and self.can_reuse_epoch_itr(dataset) + ) if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: logger.debug('reusing EpochBatchIterator for epoch {}'.format(epoch)) return self.dataset_to_epoch_iter[dataset] @@ -234,7 +244,7 @@ def get_batch_iterator( shard_id=shard_id, num_workers=num_workers, epoch=epoch, - buffer_size=getattr(self.args, 'data_buffer_size', 0), + buffer_size=data_buffer_size, ) if can_reuse_epoch_itr: diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index b10e696f9b..e13c9fd88b 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -269,6 +269,7 @@ def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, + data_buffer_size=0, disable_iterator_cache=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -296,6 +297,11 @@ def get_batch_iterator( (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 0). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split @@ -308,9 +314,19 @@ def get_batch_iterator( self.args.sampling_method == 'RoundRobin' ): batch_iter = super().get_batch_iterator( - dataset, max_tokens=max_tokens, max_sentences=max_sentences, max_positions=max_positions, - ignore_invalid_inputs=ignore_invalid_inputs, required_batch_size_multiple=required_batch_size_multiple, - seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, + dataset, + max_tokens=max_tokens, + max_sentences=max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + data_buffer_size=data_buffer_size, + disable_iterator_cache=disable_iterator_cache, ) self.dataset_to_epoch_iter[dataset] = batch_iter return batch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5022ceea2d..64ad44c1df 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -339,6 +339,7 @@ def get_train_iterator( load_dataset=True, data_selector=None, shard_batch_itr=True, + disable_iterator_cache=False, ): """Return an EpochBatchIterator over the training set for a given epoch.""" if load_dataset: @@ -365,11 +366,14 @@ def get_train_iterator( shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, epoch=epoch, + data_buffer_size=self.args.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, ) def get_valid_iterator( self, subset, + disable_iterator_cache=False, ): """Return an EpochBatchIterator over given validation subset for a given epoch.""" return self.task.get_batch_iterator( @@ -386,6 +390,8 @@ def get_valid_iterator( num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, num_workers=self.args.num_workers, + data_buffer_size=self.args.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, ) def begin_epoch(self, epoch): diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index e59674b530..c5dd7fe4ce 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -125,6 +125,7 @@ def main(parsed_args, **unused_kwargs): num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index a6e48f927e..90d270f460 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -113,6 +113,7 @@ def _main(args, output_file): num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 05cffd5a7e..f7dd527166 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -107,7 +107,12 @@ def main(args): # Load the latest checkpoint if one is available and restore the # corresponding train iterator - extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) + extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + args, + trainer, + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -128,6 +133,8 @@ def main(args): epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 304aecee9e..510560d968 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -88,6 +88,7 @@ def main(args, override_args=None): num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, From df45f42efdaab751e698726bc1922f7643aa9276 Mon Sep 17 00:00:00 2001 From: lematt1991 Date: Wed, 9 Sep 2020 13:13:47 -0700 Subject: [PATCH 147/707] Fix `ChoiceEnum` Lint Error (#2596) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: For some reason, this fixes flake8 errors, no idea why... # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2596 Reviewed By: myleott Differential Revision: D23602637 Pulled By: lematt1991 fbshipit-source-id: b6070a8693eda79f0598fdce92b94c4de569c4fa --- fairseq/dataclass/data_class.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index 5c8a25dfa2..19e22f4c58 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -19,6 +19,11 @@ from argparse import Namespace +LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) +DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) +DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) + + @dataclass class CommonParams(FairseqDataclass): # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were @@ -32,7 +37,7 @@ class CommonParams(FairseqDataclass): "help": "log progress every N batches (when progress bar is disabled)" }, ) - log_format: Optional[ChoiceEnum(["json", "none", "simple", "tqdm"])] = field( + log_format: Optional[LOG_FORMAT_CHOICES] = field( default=None, metadata={"help": "log format to use"} ) tensorboard_logdir: Optional[str] = field( @@ -153,7 +158,7 @@ class DistributedTrainingParams(FairseqDataclass): "help": "do not spawn multiple processes even if multiple GPUs are visible" }, ) - ddp_backend: ChoiceEnum(["c10d", "no_c10d"]) = field( + ddp_backend: DDP_BACKEND_CHOICES = field( default="c10d", metadata={"help": "DistributedDataParallel backend"} ) bucket_cap_mb: int = field( @@ -184,7 +189,7 @@ class DistributedTrainingParams(FairseqDataclass): "batchnorm population statistics" }, ) - distributed_wrapper: ChoiceEnum(["DDP", "SlowMo"]) = field( + distributed_wrapper: DISTRIBUTED_WRAPPER_CHOICES = field( default="DDP", metadata={"help": "DistributedDataParallel backend"} ) slowmo_momentum: Optional[float] = field( From 42c5dcbd18c85dcdc9424886f3880c184d589f0d Mon Sep 17 00:00:00 2001 From: Mu Tian Date: Wed, 9 Sep 2020 17:00:56 -0700 Subject: [PATCH 148/707] hydra fairseq 3 - inherit from legacy for fairseq classes Summary: hydra fairseq 3 - inherit from legacy for fairseq classes Reviewed By: alexeib Differential Revision: D23375457 fbshipit-source-id: ef9d19f2d02f2326eea44a70f1f6e1668b420840 --- docs/hydra_integration.md | 113 ++++++++++++++++++ .../commonsense_qa/commonsense_qa_task.py | 4 +- examples/roberta/wsc/wsc_task.py | 4 +- .../tasks/speech_recognition.py | 4 +- fairseq/benchmark/dummy_lm.py | 4 +- fairseq/benchmark/dummy_masked_lm.py | 4 +- fairseq/benchmark/dummy_mt.py | 4 +- fairseq/optim/__init__.py | 2 +- fairseq/optim/adadelta.py | 4 +- fairseq/optim/adafactor.py | 4 +- fairseq/optim/adagrad.py | 4 +- fairseq/optim/adamax.py | 4 +- fairseq/optim/fairseq_optimizer.py | 6 + fairseq/optim/fused_lamb.py | 4 +- fairseq/optim/lr_scheduler/__init__.py | 2 +- .../lr_scheduler/fairseq_lr_scheduler.py | 11 ++ fairseq/optim/lr_scheduler/fixed_schedule.py | 4 +- .../lr_scheduler/polynomial_decay_schedule.py | 4 +- .../lr_scheduler/reduce_lr_on_plateau.py | 4 +- .../lr_scheduler/tri_stage_lr_scheduler.py | 4 +- .../lr_scheduler/triangular_lr_scheduler.py | 4 +- fairseq/optim/sgd.py | 4 +- fairseq/tasks/__init__.py | 2 +- fairseq/tasks/audio_pretraining.py | 4 +- fairseq/tasks/cross_lingual_lm.py | 4 +- fairseq/tasks/denoising.py | 4 +- fairseq/tasks/fairseq_task.py | 54 +++++++++ fairseq/tasks/legacy_masked_lm.py | 4 +- fairseq/tasks/masked_lm.py | 4 +- fairseq/tasks/multilingual_masked_lm.py | 4 +- fairseq/tasks/multilingual_translation.py | 9 +- fairseq/tasks/sentence_prediction.py | 4 +- fairseq/tasks/sentence_ranking.py | 4 +- fairseq/tasks/translation.py | 8 +- .../tasks/translation_multi_simple_epoch.py | 4 +- tests/speech_recognition/asr_test_base.py | 4 +- tests/test_export.py | 4 +- tests/test_lstm_jitable.py | 4 +- tests/test_sequence_generator.py | 4 +- tests/utils.py | 3 +- 40 files changed, 257 insertions(+), 73 deletions(-) create mode 100644 docs/hydra_integration.md diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md new file mode 100644 index 0000000000..9b77dd8351 --- /dev/null +++ b/docs/hydra_integration.md @@ -0,0 +1,113 @@ + + +## Hydra + +Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads. + +## Train models with hydra interface + +#### Provide parameters in `.yaml` files +For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below). + +- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example: + +``` +defaults: + - params: training_params + - task: language_modeling + - model: transformer_lm + - criterion: cross_entropy + - optimizer: adam + - lr_scheduler: inverse_sqrt +``` + +- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` +- Provide task parameters: `config/task/language_modeling.yaml` +- Provide model parameters: `config/model/transformer_lm.yaml` +- Provide criterion parameters: `config/criterion/cross_entropy.yaml` +- Provide optimizer parameters: `config/optimizer/adam.yaml` +- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml` + +#### Command line overriding +`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command: + +``` +# task.data is requested field marked by `???` in yaml +python fairseq_cli/train_hydra.py \ +task.data=/private/home/abaevski/data/wiki103 \ +``` + +Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits) + +``` +python fairseq_cli/train_hydra.py +params=training_params \ +task=language_modeling \ +task.data=/private/home/abaevski/data/wiki103 \ +task.tokens_per_sample=512 \ +task.sample_break_mode=none \ +model=transformer_lm \ +model.share_decoder_input_output_embed=true \ +model.dropout=0.1 \ +optimizer=adam \ +optimizer.adam_betas="'(0.9, 0.98)'" \ +optimizer.weight_decay=0.01 \ +lr_scheduler=inverse_sqrt \ +lr_scheduler.warmup_updates=4000 \ +lr_scheduler.warmup_init_lr=1e-07 \ +criterion=cross_entropy \ +params.common.fp16=true \ +params.common.log_format=json \ +params.common.log_interval=1 \ +params.dataset.max_tokens=1024 \ +params.dataset.num_workers=4 \ +params.optimization.update_freq=[16] \ +params.optimization.max_update=50000 \ +params.optimization.clip_norm=0.0 \ +params.optimization.lr=[0.0005] \ +params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +params.checkpoint.save_interval_updates=10 +``` + +## Migrate existing/Creating new modules to hydra interface + +In each of the modules we want to migrated/create with hydra interface, fundamentally we need to + +- Provide a dataclass that layouts the parameters used in the module. + +- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility. + +- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below. + +#### Migrated examples: + +- Task: `fairseq/tasks/language_modeling.py` + +- Model: `fairseq/models/transformer_lm.py` + +- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py` + +- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py` + +- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py` + + +## Interpolate parameters across different places + +## Support of legacy interface +If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned. + +``` +python fairseq_cli/train.py --task language_modeling \ +/private/home/abaevski/data/wiki103 \ +--save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +--arch transformer_lm --share-decoder-input-output-embed \ +--dropout 0.1 \ +--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \ +--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ +--tokens-per-sample 512 --sample-break-mode none \ +--max-tokens 1024 --update-freq 16 \ +--fp16 \ +--max-update 50000 --log-format json --log-interval 1 --num-workers 4 \ +--save-interval-updates 10 +``` diff --git a/examples/roberta/commonsense_qa/commonsense_qa_task.py b/examples/roberta/commonsense_qa/commonsense_qa_task.py index 274e8d39aa..7ed2bc36a4 100644 --- a/examples/roberta/commonsense_qa/commonsense_qa_task.py +++ b/examples/roberta/commonsense_qa/commonsense_qa_task.py @@ -22,11 +22,11 @@ RightPadDataset, SortDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask @register_task('commonsense_qa') -class CommonsenseQATask(FairseqTask): +class CommonsenseQATask(LegacyFairseqTask): """Task to finetune RoBERTa for Commonsense QA.""" @staticmethod diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py index fbba0d8964..058e3eea23 100644 --- a/examples/roberta/wsc/wsc_task.py +++ b/examples/roberta/wsc/wsc_task.py @@ -24,13 +24,13 @@ PadDataset, SortDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from . import wsc_utils @register_task('wsc') -class WSCTask(FairseqTask): +class WSCTask(LegacyFairseqTask): """Task to finetune RoBERTa for Winograd Schemas.""" @staticmethod diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index e5717c0ef8..dde0b12577 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -10,7 +10,7 @@ import torch from fairseq.data import Dictionary -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from examples.speech_recognition.data import AsrDataset from examples.speech_recognition.data.replabels import replabel_symbol @@ -66,7 +66,7 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict): @register_task("speech_recognition") -class SpeechRecognitionTask(FairseqTask): +class SpeechRecognitionTask(LegacyFairseqTask): """ Task for training speech recognition model. """ diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 92e9dc8df5..f33a1adcf6 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -9,14 +9,14 @@ import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_lm') -class DummyLMTask(FairseqTask): +class DummyLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index f2e459caa2..3b0bdc51f5 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -9,14 +9,14 @@ import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_masked_lm') -class DummyMaskedLMTask(FairseqTask): +class DummyMaskedLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 9fba9bb520..0371b3e754 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -9,14 +9,14 @@ import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_mt') -class DummyMTTask(FairseqTask): +class DummyMTTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index dff140d580..b172b270a7 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -7,7 +7,7 @@ import os from fairseq import registry -from fairseq.optim.fairseq_optimizer import FairseqOptimizer +from fairseq.optim.fairseq_optimizer import FairseqOptimizer, LegacyFairseqOptimizer # noqa from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.bmuf import FairseqBMUF # noqa from fairseq.optim.shard import shard_ diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py index 0a76e27fe4..9b311ae38a 100644 --- a/fairseq/optim/adadelta.py +++ b/fairseq/optim/adadelta.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adadelta') -class Adadelta(FairseqOptimizer): +class Adadelta(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index b0fb3a9f5e..ab69e0e58d 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -7,11 +7,11 @@ import torch import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adafactor') -class FairseqAdafactor(FairseqOptimizer): +class FairseqAdafactor(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = Adafactor(params, **self.optimizer_config) diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index 57f83258cf..5056752776 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adagrad') -class Adagrad(FairseqOptimizer): +class Adagrad(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py index 856215a3ba..195e7a90d8 100644 --- a/fairseq/optim/adamax.py +++ b/fairseq/optim/adamax.py @@ -6,11 +6,11 @@ import torch import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adamax') -class FairseqAdamax(FairseqOptimizer): +class FairseqAdamax(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = Adamax(params, **self.optimizer_config) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index e00a04dd1b..18c26a3a39 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -140,3 +140,9 @@ def supports_flat_params(self): def average_params(self): pass + + +class LegacyFairseqOptimizer(FairseqOptimizer): + + def __init__(self, args): + self.args = args diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py index f9b0409c53..d48ecbc8e0 100644 --- a/fairseq/optim/fused_lamb.py +++ b/fairseq/optim/fused_lamb.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.optim import FairseqOptimizer, register_optimizer +from fairseq.optim import register_optimizer, LegacyFairseqOptimizer @register_optimizer('lamb') -class FairseqLAMB(FairseqOptimizer): +class FairseqLAMB(LegacyFairseqOptimizer): """LAMB optimizer.""" def __init__(self, args, params): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index edd0a6a13e..76c5357189 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -7,7 +7,7 @@ import os from fairseq import registry -from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 8b7884829a..5569de3db8 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .. import FairseqOptimizer +from argparse import Namespace class FairseqLRScheduler(object): @@ -40,3 +41,13 @@ def step(self, epoch, val_loss=None): def step_update(self, num_updates): """Update the learning rate after each update.""" return self.optimizer.get_lr() + + +class LegacyFairseqLRScheduler(FairseqLRScheduler): + + def __init__(self, args: Namespace, optimizer): + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError('optimizer must be an instance of FairseqOptimizer') + self.args = args + self.optimizer = optimizer + self.best = None diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index 1c3edd0047..9a30195fab 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('fixed') -class FixedSchedule(FairseqLRScheduler): +class FixedSchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" def __init__(self, args, optimizer): diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index aff57f9b93..73e8b170bc 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('polynomial_decay') -class PolynomialDecaySchedule(FairseqLRScheduler): +class PolynomialDecaySchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" def __init__(self, args, optimizer): diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 65ac2e3071..5199b09a3e 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -5,11 +5,11 @@ import torch.optim.lr_scheduler -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('reduce_lr_on_plateau') -class ReduceLROnPlateau(FairseqLRScheduler): +class ReduceLROnPlateau(LegacyFairseqLRScheduler): """ Decay the LR by a factor every time the validation loss plateaus. Also comes with optional warmup phase, where we linearly increase diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index 3460fa1226..95c5576f20 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -3,12 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler import math @register_lr_scheduler('tri_stage') -class TriStageLRSchedule(FairseqLRScheduler): +class TriStageLRSchedule(LegacyFairseqLRScheduler): """Tristage learning rate schedulr Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index fed0cf7ef1..67e1df65e1 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -5,11 +5,11 @@ import math -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('triangular') -class TriangularSchedule(FairseqLRScheduler): +class TriangularSchedule(LegacyFairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. See https://arxiv.org/pdf/1506.01186.pdf for details. diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py index 8c4e3e0a80..b558f41ab0 100644 --- a/fairseq/optim/sgd.py +++ b/fairseq/optim/sgd.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('sgd') -class SGD(FairseqOptimizer): +class SGD(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.SGD(params, **self.optimizer_config) diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index b1bb404f1c..69231a8522 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -7,7 +7,7 @@ import importlib import os -from .fairseq_task import FairseqTask +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa TASK_REGISTRY = {} TASK_CLASS_NAMES = set() diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 2a51279ebc..75bcfaa8db 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -9,7 +9,7 @@ import sys from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset -from . import FairseqTask, register_task +from . import LegacyFairseqTask, register_task class LabelEncoder(object): @@ -23,7 +23,7 @@ def __call__(self, label): @register_task("audio_pretraining") -class AudioPretrainingTask(FairseqTask): +class AudioPretrainingTask(LegacyFairseqTask): """ """ diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index 3589492f11..a7ce1f1ad5 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -21,14 +21,14 @@ ) from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils logger = logging.getLogger(__name__) @register_task('cross_lingual_lm') -class CrossLingualLMTask(FairseqTask): +class CrossLingualLMTask(LegacyFairseqTask): """ Task for training cross-lingual language models. diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index 28beb517f2..ea6db45c75 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -16,7 +16,7 @@ TokenBlockDataset, ) from fairseq.data.encoders.utils import get_whole_word_mask -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -24,7 +24,7 @@ @register_task('denoising') -class DenoisingTask(FairseqTask): +class DenoisingTask(LegacyFairseqTask): """ Denoising task for applying sequence to sequence denoising. (ie. BART) """ diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index c7b39f5b62..8da07bf8bb 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -12,6 +12,7 @@ from fairseq import metrics, search, tokenizer, utils from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary +from argparse import Namespace logger = logging.getLogger(__name__) @@ -486,3 +487,56 @@ def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary` (if applicable for this task).""" raise NotImplementedError + + +class LegacyFairseqTask(FairseqTask): + + def __init__(self, args: Namespace): + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return (os.pathsep in getattr(self.args, 'data', '')) + + def build_model(self, args: Namespace): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + model = models.build_model(args, self) + if getattr(args, 'tpu', False): + model.prepare_for_tpu_() + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py index 40e2724953..4e0390cdca 100644 --- a/fairseq/tasks/legacy_masked_lm.py +++ b/fairseq/tasks/legacy_masked_lm.py @@ -20,7 +20,7 @@ from fairseq.data.legacy.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dictionary import BertDictionary -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -28,7 +28,7 @@ @register_task('legacy_masked_lm') -class LegacyMaskedLMTask(FairseqTask): +class LegacyMaskedLMTask(LegacyFairseqTask): """ Task for training Masked LM (BERT) model. Args: diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 4a6e6a2d37..10b234a96b 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -21,8 +21,8 @@ SortDataset, TokenBlockDataset, ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task from fairseq.data.encoders.utils import get_whole_word_mask from fairseq import utils @@ -31,7 +31,7 @@ @register_task('masked_lm') -class MaskedLMTask(FairseqTask): +class MaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py index 5d96a608b5..110e580a73 100644 --- a/fairseq/tasks/multilingual_masked_lm.py +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -26,7 +26,7 @@ SortDataset, TokenBlockDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -34,7 +34,7 @@ @register_task('multilingual_masked_lm') -class MultiLingualMaskedLMTask(FairseqTask): +class MultiLingualMaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 272bcf1ae1..784b438ca9 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -6,11 +6,11 @@ from collections import OrderedDict import logging import os - +from fairseq import options import contextlib import torch -from fairseq import metrics, options +from fairseq import metrics, utils from fairseq.data import ( Dictionary, LanguagePairDataset, @@ -20,8 +20,7 @@ from fairseq.models import FairseqMultiModel from fairseq.tasks.translation import load_langpair_dataset -from . import FairseqTask, register_task -from fairseq import utils +from . import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @@ -39,7 +38,7 @@ def _lang_token_index(dic: Dictionary, lang: str): @register_task('multilingual_translation') -class MultilingualTranslationTask(FairseqTask): +class MultilingualTranslationTask(LegacyFairseqTask): """A task for training multiple translation models simultaneously. We iterate round-robin over batches from multiple language pairs, ordered diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index cf5eae38b1..fec19e0a75 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -25,15 +25,15 @@ SortDataset, StripTokenDataset, ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task logger = logging.getLogger(__name__) @register_task('sentence_prediction') -class SentencePredictionTask(FairseqTask): +class SentencePredictionTask(LegacyFairseqTask): """ Sentence (or sentence pair) prediction (classification or regression) task. diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index ea4b50a294..a1d332a3ca 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -23,15 +23,15 @@ SortDataset, TruncateDataset ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task logger = logging.getLogger(__name__) @register_task('sentence_ranking') -class SentenceRankingTask(FairseqTask): +class SentenceRankingTask(LegacyFairseqTask): """ Ranking task on multiple sentences. diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index a01768ecb6..6eac293659 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -8,10 +8,10 @@ import itertools import logging import os - +from fairseq import options import numpy as np -from fairseq import metrics, options, utils +from fairseq import metrics, utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, @@ -24,7 +24,7 @@ TruncateDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask EVAL_BLEU_ORDER = 4 @@ -133,7 +133,7 @@ def split_exists(split, src, tgt, lang, data_path): @register_task('translation') -class TranslationTask(FairseqTask): +class TranslationTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index e13c9fd88b..94f1fd32af 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -16,7 +16,7 @@ ListDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.multilingual.sampling_method import SamplingMethod from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager @@ -31,7 +31,7 @@ def get_time_gap(s, e): @register_task('translation_multi_simple_epoch') -class TranslationMultiSimpleEpochTask(FairseqTask): +class TranslationMultiSimpleEpochTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 7482858ffc..4f3d3fceb7 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -17,7 +17,7 @@ FairseqEncoderModel, FairseqModel, ) -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask @@ -37,7 +37,7 @@ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): return dummy_dict -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_export.py b/tests/test_export.py index 7b0e7fcf1d..87e52bd7c1 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -12,13 +12,13 @@ from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.modules import multihead_attention, sinusoidal_positional_embedding -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_lstm_jitable.py b/tests/test_lstm_jitable.py index d0d812ceac..d97652fb77 100644 --- a/tests/test_lstm_jitable.py +++ b/tests/test_lstm_jitable.py @@ -10,13 +10,13 @@ import torch from fairseq.data.dictionary import Dictionary from fairseq.models.lstm import LSTMModel -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 36560bcca6..517aa77d59 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -14,13 +14,13 @@ from fairseq.models.transformer import TransformerModel from fairseq.sequence_generator import SequenceGenerator, EnsembleModel -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/utils.py b/tests/utils.py index 869a70c5e9..ef546fa58a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,6 +20,7 @@ FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.tasks import LegacyFairseqTask from fairseq.tasks import FairseqTask from fairseq_cli import ( generate, @@ -284,7 +285,7 @@ def __len__(self): return len(self.data) -class TestTranslationTask(FairseqTask): +class TestTranslationTask(LegacyFairseqTask): def __init__(self, args, src_dict, tgt_dict, model): super().__init__(args) From d47067937abacfe87f2963adca8daeada3c631fe Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Thu, 10 Sep 2020 21:20:51 -0700 Subject: [PATCH 149/707] Pipeline parallel transformer in fairseq (#1257) Summary: ## What does this PR do? * Add pipeline parallelism (inter-layer model parallelism) in fairseq for Transformer models * This involves creating a modular version of the Transformer model that can be expressed as an `nn.Sequential` * This uses `fairscale`'s `Pipe` for pipeline parallel execution * Some other relevant code changes include - * updating fp16_optimizer buffers to be per-device so that we can scale to a model size that does not fit on a single GPU * ability to convert the state dict of a regular Transformer checkpoint into a pipeline parallel Transformer ## Testing ### Regular Transformer ``` 2020-08-25 20:42:40 | INFO | train_inner | epoch 001: 105 / 884 loss=10.123, nll_loss=9.657, ppl=807.41, wps=30505.2, ups=6.77, wpb=4503.2, bsz=173.9, num_updates=100, lr=0.00012, gnorm=2.382, loss_scale=4, train_wall=14, wall=0 2020-08-25 20:42:53 | INFO | train_inner | epoch 001: 205 / 884 loss=9.8, nll_loss=9.305, ppl=632.39, wps=34882.4, ups=7.66, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.29, loss_scale=4, train_wall=13, wall=0 2020-08-25 20:43:06 | INFO | train_inner | epoch 001: 305 / 884 loss=9.91, nll_loss=9.434, ppl=691.51, wps=33456.9, ups=7.62, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.117, loss_scale=4, train_wall=13, wall=0 2020-08-25 20:43:19 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.75, wps=34137.3, ups=7.66, wpb=4457.2, bsz=191.2, num_updates=400, lr=6e-05, gnorm=1.347, loss_scale=4, train_wall=13, wall=0 ``` ### 1xMP (single-gpu) Pipeline Parallel Transformer ``` 2020-08-25 20:44:16 | INFO | train_inner | epoch 001: 105 / 884 loss=10.123, nll_loss=9.657, ppl=807.41, wps=31227.6, ups=6.93, wpb=4503.2, bsz=173.9, num_updates=100, lr=0.00012, gnorm=2.382, loss_scale=4, train_wall=13, wall=0 2020-08-25 20:44:29 | INFO | train_inner | epoch 001: 205 / 884 loss=9.8, nll_loss=9.305, ppl=632.39, wps=35378.9, ups=7.77, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.29, loss_scale=4, train_wall=13, wall=0 2020-08-25 20:44:42 | INFO | train_inner | epoch 001: 305 / 884 loss=9.91, nll_loss=9.434, ppl=691.51, wps=34017.8, ups=7.75, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.117, loss_scale=4, train_wall=13, wall=0 2020-08-25 20:44:55 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.75, wps=34661.1, ``` ### 2xMP Pipeline Parallel Transformer ``` 2020-08-26 12:10:13 | INFO | train_inner | epoch 001: 105 / 884 loss=10.185, nll_loss=9.728, ppl=848.2, wps=29247.7, ups=6.5, wpb=4488.8, bsz=172.4, num_updates=100, lr=0.00012, gnorm=2.539, loss_scale=4, train_wall=15, wall=17 2020-08-26 12:10:28 | INFO | train_inner | epoch 001: 205 / 884 loss=9.798, nll_loss=9.303, ppl=631.54, wps=30663.7, ups=6.73, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.263, loss_scale=4, train_wall=15, wall=32 2020-08-26 12:10:42 | INFO | train_inner | epoch 001: 305 / 884 loss=9.908, nll_loss=9.432, ppl=690.89, wps=30600.4, ups=6.97, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.091, loss_scale=4, train_wall=14, wall=46 2020-08-26 12:10:57 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.57, wps=29345.2, ups=6.58, wpb=4457.2, bsz=191.2, num_updates=400, lr=6e-05, gnorm=1.344, loss_scale=4, train_wall=15, wall=62 ``` ### 4xMP Pipeline Parallel Transformer ``` 2020-08-26 13:27:25 | INFO | train_inner | epoch 001: 105 / 884 loss=10.185, nll_loss=9.728, ppl=848.1, wps=11158.8, ups=2.48, wpb=4488.8, bsz=172.4, num_updates=100, lr=0.00012, gnorm=2.538, loss_scale=4, train_wall=41, wall=44 2020-08-26 13:28:03 | INFO | train_inner | epoch 001: 205 / 884 loss=9.798, nll_loss=9.303, ppl=631.51, wps=12078, ups=2.65, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.263, loss_scale=4, train_wall=38, wall=82 2020-08-26 13:28:40 | INFO | train_inner | epoch 001: 305 / 884 loss=9.908, nll_loss=9.432, ppl=690.88, wps=11738.5, ups=2.67, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.091, loss_scale=4, train_wall=37, wall=119 2020-08-26 13:29:18 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.58, wps=11810.4, ups=2.65, wpb=4457.2, bsz=191.2, num_updates=400, lr=6e-05, gnorm=1.344, loss_scale=4, train_wall=38, wall=157 ``` ### 8xMP Pipeline Parallel Transformer ``` 2020-08-26 13:37:57 | INFO | train_inner | epoch 001: 105 / 884 loss=10.185, nll_loss=9.728, ppl=848.13, wps=5077.9, ups=1.13, wpb=4488.8, bsz=172.4, num_updates=100, lr=0.00012, gnorm=2.539, loss_scale=4, train_wall=90, wall=96 2020-08-26 13:39:20 | INFO | train_inner | epoch 001: 205 / 884 loss=9.798, nll_loss=9.303, ppl=631.53, wps=5472.2, ups=1.2, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.263, loss_scale=4, train_wall=83, wall=179 2020-08-26 13:40:43 | INFO | train_inner | epoch 001: 305 / 884 loss=9.908, nll_loss=9.432, ppl=690.88, wps=5313.2, ups=1.21, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.091, loss_scale=4, train_wall=83, wall=262 2020-08-26 13:42:05 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.57, wps=5409.3, ups=1.21, wpb=4457.2, bsz=191.2, num_updates=400, lr=6e-05, gnorm=1.344, loss_scale=4, train_wall=82, wall=344 ``` ### 8xMP Pipeline Parallel Transformer + Checkpointing ``` 2020-08-26 13:56:58 | INFO | train_inner | epoch 001: 105 / 884 loss=10.185, nll_loss=9.728, ppl=848.1, wps=3908, ups=0.87, wpb=4488.8, bsz=172.4, num_updates=100, lr=0.00012, gnorm=2.539, loss_scale=4, train_wall=115, wall=120 2020-08-26 13:58:48 | INFO | train_inner | epoch 001: 205 / 884 loss=9.798, nll_loss=9.303, ppl=631.52, wps=4152.5, ups=0.91, wpb=4553, bsz=194.6, num_updates=200, lr=8.48528e-05, gnorm=1.263, loss_scale=4, train_wall=110, wall=230 2020-08-26 14:00:37 | INFO | train_inner | epoch 001: 305 / 884 loss=9.908, nll_loss=9.432, ppl=690.88, wps=4026.4, ups=0.92, wpb=4391.7, bsz=148.7, num_updates=300, lr=6.9282e-05, gnorm=1.091, loss_scale=4, train_wall=109, wall=339 2020-08-26 14:02:26 | INFO | train_inner | epoch 001: 405 / 884 loss=9.81, nll_loss=9.319, ppl=638.56, wps=4101.6, ups=0.92, wpb=4457.2, bsz=191.2, num_updates=400, lr=6e-05, gnorm=1.343, loss_scale=4, train_wall=109, wall=448 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1257 Reviewed By: myleott Differential Revision: D23593137 Pulled By: shruti-bh fbshipit-source-id: 840227be6d8bec438e8360b30fe1fdbc4d97dd9c --- fairseq/distributed_utils.py | 79 ++- .../pipeline_parallel_transformer/__init__.py | 6 + .../pipeline_parallel_transformer/layers.py | 547 +++++++++++++++++ .../pipeline_parallel_transformer/model.py | 572 ++++++++++++++++++ fairseq/optim/__init__.py | 9 +- fairseq/optim/fp16_optimizer.py | 83 ++- fairseq/options.py | 21 + fairseq/trainer.py | 15 +- fairseq/utils.py | 12 +- 9 files changed, 1305 insertions(+), 39 deletions(-) create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/model.py diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 16f3edaeef..b25232f386 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -31,6 +31,19 @@ def infer_init_method(args, force_distributed=False): if args.distributed_init_method is not None or getattr(args, 'tpu', False): return + if args.pipeline_model_parallel: + if args.pipeline_balance is None: + raise ValueError('--pipeline-balance is currently required for pipeline model parallelism') + if args.pipeline_devices is None: + raise ValueError('--pipeline-devices is currently required for pipeline model parallelism') + gpus_per_node = torch.cuda.device_count() + num_pipeline_devices = len(set(args.pipeline_devices)) + assert gpus_per_node >= num_pipeline_devices and gpus_per_node % num_pipeline_devices == 0, ( + 'the number of unique device IDs in --pipeline-devices must evenly divide ' + 'the number of GPUs per node (multi-node pipelining is not yet supported)' + ) + num_pipelines_per_node = gpus_per_node // num_pipeline_devices + # support torch.distributed.launch if all(key in os.environ for key in [ 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK' @@ -63,10 +76,28 @@ def infer_init_method(args, force_distributed=False): assert ntasks % nnodes == 0 ntasks_per_node = int(ntasks / nnodes) if ntasks_per_node == 1: - assert args.distributed_world_size % nnodes == 0 - gpus_per_node = args.distributed_world_size // nnodes + gpus_per_node = torch.cuda.device_count() node_id = int(os.environ.get('SLURM_NODEID')) args.distributed_rank = node_id * gpus_per_node + args.distributed_world_size = nnodes * gpus_per_node + elif args.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + 'SLURM --ntasks-per-node must match number of pipelines per ' + 'node (={})'.format(num_pipelines_per_node) + ) + args.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get('SLURM_NODEID')) + local_id = int(os.environ.get('SLURM_LOCALID')) + args.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, device_id will always be in [0, 1], + # which also matches torch.distributed.launch. + args.device_id = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + args.distributed_world_size = nnodes * num_pipelines_per_node else: assert ntasks_per_node == args.distributed_world_size // nnodes args.distributed_no_spawn = True @@ -83,6 +114,45 @@ def infer_init_method(args, force_distributed=False): port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + if args.pipeline_model_parallel: + if not args.distributed_no_spawn: + # When distributed_no_spawn is False, we expect distributed_rank and + # distributed_world_size to be based on the total number of GPUs, so + # we need to correct them to be based on the number of pipelines. + assert args.distributed_world_size % num_pipeline_devices == 0 + args.distributed_world_size = args.distributed_world_size // num_pipeline_devices + # In the case of 4-way MP on nodes with 8 GPUs, we want + # distributed_rank to be the starting GPU index for each pipeline + # i.e., 0, 2, ... + assert args.distributed_rank % gpus_per_node == 0 + assert args.distributed_rank % num_pipeline_devices == 0 + args.distributed_rank = args.distributed_rank // num_pipeline_devices + # launch one process per pipeline + args.distributed_num_procs = num_pipelines_per_node + + # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 + # and 4, indicating the starting device IDs for each pipeline + args.device_id *= num_pipeline_devices + + if args.device_id > 0: + # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 + # GPU node), we need to adjust pipeline_devices accordingly + logger.debug( + "setting CUDA device={} on rank {}" + .format(args.device_id, args.distributed_rank) + ) + torch.cuda.set_device(args.device_id) + args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] + logger.info( + "setting pipeline_devices={} on rank {}" + .format(args.pipeline_devices, args.distributed_rank), + ) + elif not args.distributed_no_spawn: + args.distributed_num_procs = min( + torch.cuda.device_count(), + args.distributed_world_size, + ) + def distributed_init(args): if not getattr(args, 'tpu', False): @@ -167,10 +237,7 @@ def call_main(args, main, **kwargs): torch.multiprocessing.spawn( fn=distributed_main, args=(main, args, kwargs), - nprocs=min( - torch.cuda.device_count(), - args.distributed_world_size, - ), + nprocs=args.distributed_num_procs, ) else: distributed_main(args.device_id, main, args, kwargs) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py new file mode 100644 index 0000000000..117827c3e9 --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py new file mode 100644 index 0000000000..70551ca900 --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -0,0 +1,547 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils +from fairseq.modules import ( + AdaptiveSoftmax, + LayerNorm, + PositionalEmbedding, + MultiheadAttention, +) + +EncoderOut = namedtuple('TransformerEncoderOut', [ + 'encoder_out', # T x B x C + 'encoder_padding_mask', # B x T + 'encoder_embedding', # B x T x C + 'encoder_states', # List[T x B x C] +]) + + +class TransformerEncoderEmbedding(nn.Module): + """ Encoder Embedding + Positional Embedding """ + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.max_source_positions = args.max_source_positions + self.embed_tokens = embed_tokens + if isinstance(embed_tokens, nn.ModuleList): + self.padding_idx = embed_tokens[0].padding_idx + embed_dim = sum(e.embedding_dim for e in embed_tokens) + else: + self.padding_idx = embed_tokens.padding_idx + embed_dim = embed_tokens.embedding_dim + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, embed_dim, self.padding_idx, + learned=args.encoder_learned_pos, + ) if not args.no_token_positional_embeddings else None + if getattr(args, 'layernorm_embedding', False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + def forward(self, input): + # embed tokens and positions + src_tokens = input[0] + prev_output_tokens = input[2] + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(src_tokens)) + + embedded = torch.cat(x_embed_list, dim=-1) + else: + embedded = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * embedded + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding: + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerEncoderLayerNorm(nn.Module): + """ + Layer norm at the the end of all encoder layers if + args.encoder_enormalize_before = True + """ + def __init__(self, args, embed_dim): + super().__init__() + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input): + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + if self.layer_norm: + x = self.layer_norm(x) + # keeping track of the incremental_state is not supported yet + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerDecoderEmbedding(nn.Module): + """ Decoder Embedding + Positional Embedding """ + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + input_embed_dim = sum(e.embedding_dim for e in embed_tokens) \ + if isinstance(embed_tokens, nn.ModuleList) \ + else embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_output_dim + + padding_idx = embed_tokens[0].padding_idx \ + if isinstance(embed_tokens, nn.ModuleList) \ + else embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None + + self.embed_positions = PositionalEmbedding( + args.max_target_positions, embed_dim, padding_idx, + learned=args.decoder_learned_pos, + ) if not args.no_token_positional_embeddings else None + + def forward(self, input): + mt_task = False + if isinstance(input, tuple): + if len(input) == 3: + encoder_out = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + incremental_state = None # Hardcoding to avoid passing of None objects + mt_task = True + else: + # HACK for now, need to fix (TODO sidgoyal) + prev_output_tokens = input[0] + # discard "src_lengths" + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + else: + prev_output_tokens = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + positions = self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) if self.embed_positions is not None else None + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(prev_output_tokens)) + + x = self.embed_scale * torch.cat(x_embed_list, dim=-1) + else: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + +class TransformerDecoderOutputLayer(nn.Module): + def __init__(self, args, embed_tokens, dictionary): + super().__init__() + self.share_input_output_embed = args.share_decoder_input_output_embed + self.output_embed_dim = args.decoder_output_dim + embed_dim = args.decoder_embed_dim + + self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None + self.adaptive_softmax = None + if args.adaptive_softmax_cutoff is not None: + assert not isinstance(embed_tokens, nn.ModuleList) + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + embed_tokens.weight.shape[1], + embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = embed_tokens.weight + elif not self.share_input_output_embed: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 + ) + + if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input, apply_final_proj=True): + if isinstance(input, tuple): + x = input[0] + else: + x = input + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + if apply_final_proj: + x = self.output_layer(x) + return x + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer block. + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.encoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, args): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.self_attn = MultiheadAttention( + self.embed_dim, args.encoder_attention_heads, + dropout=args.attention_dropout, self_attention=True + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') + ) + self.activation_dropout = getattr(args, 'activation_dropout', 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, 'relu_dropout', 0) + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = { + '0': 'self_attn_layer_norm', + '1': 'final_layer_norm' + } + for old, new in layer_norm_map.items(): + for m in ('weight', 'bias'): + k = '{}.layer_norms.{}.{}'.format(name, old, m) + if k in state_dict: + state_dict[ + '{}.{}.{}'.format(name, new, m) + ] = state_dict[k] + del state_dict[k] + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + input[2] (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing) + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return (x, encoder_padding_mask, prev_output_tokens) + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.self_attn = MultiheadAttention( + embed_dim=self.embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True + ) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') + ) + self.activation_dropout = getattr(args, 'activation_dropout', 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, 'relu_dropout', 0) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, 'char_inputs', False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, 'encoder_embed_dim', None), + vdim=getattr(args, 'encoder_embed_dim', None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` + input[2] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + # Note: incremental state is not yet supported + mt_task = False + if isinstance(input, tuple): + x = input[0] + encoder_out = input[1] + encoder_padding_mask = input[2] + incremental_state = None + mt_task = True + else: + x = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + # TODO: add back prev_self_attn_state, prev_attn_state, + # self_attn_padding_mask + prev_self_attn_state = None + prev_attn_state = None + self_attn_padding_mask = None + + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + if prev_self_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_self_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.self_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + if self.encoder_attn is not None: + residual = x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) + if prev_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: + self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + return self._future_mask[:dim, :dim] + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py new file mode 100644 index 0000000000..ca1c2698fb --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -0,0 +1,572 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + BaseFairseqModel, + FairseqDecoder, + FairseqEncoder, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + base_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de_big, +) +from fairseq.modules import SinusoidalPositionalEmbedding +from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( + Embedding, + TransformerEncoderLayer, + TransformerDecoderLayer, + TransformerEncoderEmbedding, + TransformerEncoderLayerNorm, + TransformerDecoderEmbedding, + TransformerDecoderOutputLayer, +) +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model('pipeline_parallel_transformer') +class PipelineParallelTransformerModel(BaseFairseqModel): + def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError('Please install fairscale with: pip install fairscale') + super().__init__() + assert isinstance(encoder, FairseqEncoder) + assert isinstance(decoder, FairseqDecoder) + module_list = nn.Sequential( + encoder.embedding_layer, + *list(encoder.encoder_layers), + encoder.final_layer_norm, + decoder.embedding_layer, + *list(decoder.decoder_layers), + decoder.decoder_output_layer, + ) + self.devices = devices + self.model = Pipe( + module_list, + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) + self.encoder_max_positions = self.max_positions_helper( + encoder.embedding_layer, + 'max_source_positions' + ) + self.decoder_max_positions = self.max_positions_helper( + decoder.embedding_layer, + 'max_target_positions' + ) + self.adaptive_softmax = getattr(decoder, 'adaptive_softmax', None) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + input_lst = [src_tokens, src_lengths, prev_output_tokens] + input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) + return self.model(input) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + # fmt: on + + @classmethod + def build_model_base(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, 'max_source_positions'): + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if not hasattr(args, 'max_target_positions'): + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError('--share-all-embeddings requires a joined dictionary') + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path): + raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path, + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = build_embedding( + tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return (encoder, decoder) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + @classmethod + def build_model(cls, args, task): + encoder, decoder = cls.build_model_base(args, task) + return PipelineParallelTransformerModel( + encoder=encoder, + decoder=decoder, + balance=args.pipeline_balance, + devices=args.pipeline_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder_max_positions, self.decoder_max_positions) + + def max_positions_helper(self, embedding_layer, + max_positions_field='max_source_positions'): + """Maximum input length supported by the encoder or decoder.""" + if embedding_layer.embed_positions is None: + return getattr(embedding_layer, max_positions_field) + return min(getattr(embedding_layer, max_positions_field), + embedding_layer.embed_positions.max_positions) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: + if sample is not None: + assert 'target' in sample + target = sample['target'] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output, target=target) + return out.exp_() if not log_probs else out + + logits = net_output + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=False) + else: + return utils.softmax(logits, dim=-1, onnx_trace=False) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def load_state_dict(self, state_dict, strict=True, args=None): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + self.upgrade_state_dict(state_dict) + is_regular_transformer = not any('model.partitions' in k for k in state_dict) + if is_regular_transformer: + state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) + return super().load_state_dict(state_dict, strict) + + def convert_to_pipeline_parallel_state_dict(self, state_dict): + new_state_dict = self.state_dict() + encoder_layer_idx = 0 + decoder_layer_idx = 0 + encoder_key_suffixes = [ + 'self_attn.k_proj.weight', 'self_attn.k_proj.bias', + 'self_attn.v_proj.weight', 'self_attn.v_proj.bias', + 'self_attn.q_proj.weight', 'self_attn.q_proj.bias', + 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', + 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias', 'fc1.weight', + 'fc1.bias', 'fc2.weight', 'fc2.bias', 'final_layer_norm.weight', + 'final_layer_norm.bias', + ] + decoder_key_suffixes = [ + 'self_attn.k_proj.weight', 'self_attn.k_proj.bias', + 'self_attn.v_proj.weight', 'self_attn.v_proj.bias', + 'self_attn.q_proj.weight', 'self_attn.q_proj.bias', + 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', + 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias', + 'encoder_attn.k_proj.weight', 'encoder_attn.k_proj.bias', + 'encoder_attn.v_proj.weight', 'encoder_attn.v_proj.bias', + 'encoder_attn.q_proj.weight', 'encoder_attn.q_proj.bias', + 'encoder_attn.out_proj.weight', 'encoder_attn.out_proj.bias', + 'encoder_attn_layer_norm.weight', 'encoder_attn_layer_norm.bias', + 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', + 'final_layer_norm.weight', 'final_layer_norm.bias' + ] + for pid, partition in enumerate(self.model.partitions): + print(f"Begin Partition {pid}") + for mid, module in enumerate(partition): + # fmt: off + if isinstance(module, TransformerEncoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor'] + if isinstance(module, TransformerEncoderLayer): + for suffix in encoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] + encoder_layer_idx += 1 + if isinstance(module, TransformerDecoderLayer): + for suffix in decoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] + decoder_layer_idx += 1 + if isinstance(module, TransformerEncoderLayerNorm): + if 'encoder.layer_norm.weight' in state_dict: + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] + if isinstance(module, TransformerDecoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor'] + if isinstance(module, TransformerDecoderOutputLayer): + new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] + # fmt: on + return new_state_dict + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(dictionary) + self.register_buffer('version', torch.Tensor([3])) + self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + layers = [ + TransformerEncoderLayer(args) for i in range(args.encoder_layers) + ] + # Note: layer drop not supported yet + # Note: layer wise attention not supported yet + self.encoder_layers = nn.Sequential(*layers) + if isinstance(embed_tokens, nn.ModuleList): + emb_dim = sum(e.embedding_dim for e in embed_tokens) + else: + emb_dim = embed_tokens.embedding_dim + self.final_layer_norm = \ + TransformerEncoderLayerNorm(args, emb_dim) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + """ + Args: + input_tuple( + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + prev_output_tokens + ) + + Returns: + output_tuple( + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - prev_output_tokens + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + ) + """ + input_tuple = (src_tokens, src_lengths, prev_output_tokens) + encoder_embed_output_tuple = self.embedding_layer(input_tuple) + encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) + return self.final_layer_norm(encoder_layers_output) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out.encoder_out is not None: + encoder_out = encoder_out._replace( + encoder_out=encoder_out.encoder_out.index_select(1, new_order) + ) + if encoder_out.encoder_padding_mask is not None: + encoder_out = encoder_out._replace( + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order) + ) + if encoder_out.encoder_embedding is not None: + encoder_out = encoder_out._replace( + encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order) + ) + if encoder_out.encoder_states is not None: + for idx, state in enumerate(encoder_out.encoder_states): + encoder_out.encoder_states[idx] = state.index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_source_positions + return min(self.embedding_layer.max_source_positions, + self.embedding_layer.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: + self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = '{}.embed_positions.weights'.format(name) + if weights_key in state_dict: + print('deleting {0}'.format(weights_key)) + del state_dict[weights_key] + state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) + for i in range(len(self.layers)): + # update layer norms + self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) + + version_key = '{}.version'.format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict + + +class TransformerDecoder(FairseqDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(dictionary) + self.register_buffer('version', torch.Tensor([3])) + self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + layers = [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + self.decoder_layers = nn.Sequential(*layers) + self.decoder_output_layer = TransformerDecoderOutputLayer(args, embed_tokens, dictionary) + + def forward(self, prev_output_tokens, encoder_out=None,): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + input = (prev_output_tokens, encoder_out) + embed_layer_output = self.embedding_layer(input) + state = self.decoder_layers(embed_layer_output) + return self.decoder_output_layer(state) + + def extract_features(self, prev_output_tokens, encoder_out=None,): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + input = (prev_output_tokens, encoder_out) + embed_layer_output = self.embedding_layer(input) + state = self.decoder_layers(embed_layer_output) + return self.decoder_output_layer(state, apply_final_proj=False) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_target_positions + return min(self.embedding_layer.max_target_positions, + self.embedding_layer.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, '_future_mask') + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = '{}.embed_positions.weights'.format(name) + if weights_key in state_dict: + del state_dict[weights_key] + state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) + + for i in range(len(self.layers)): + # update layer norms + layer_norm_map = { + '0': 'self_attn_layer_norm', + '1': 'encoder_attn_layer_norm', + '2': 'final_layer_norm' + } + for old, new in layer_norm_map.items(): + for m in ('weight', 'bias'): + k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m) + if k in state_dict: + state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k] + del state_dict[k] + + version_key = '{}.version'.format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +@register_model_architecture('pipeline_parallel_transformer', + 'transformer_iwslt_de_en_pipeline_parallel') +def transformer_iwslt_de_en_dist(args): + transformer_iwslt_de_en(args) + + +@register_model_architecture('pipeline_parallel_transformer', + 'transformer_wmt_en_de_big_pipeline_parallel') +def transformer_wmt_en_de_big_dist(args): + transformer_wmt_en_de_big(args) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index b172b270a7..2f723866dc 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -21,13 +21,20 @@ ] -build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( +_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( '--optimizer', base_class=FairseqOptimizer, required=True, ) +def build_optimizer(args, params, *extra_args, **extra_kwargs): + if all(isinstance(p, dict) for p in params): + params = [t for p in params for t in p.values()] + params = list(filter(lambda p: p.requires_grad, params)) + return _build_optimizer(args, params, *extra_args, **extra_kwargs) + + # automatically import any Python files in the optim/ directory for file in os.listdir(os.path.dirname(__file__)): if file.endswith('.py') and not file.startswith('_'): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 777d43a713..593519eb7f 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from itertools import chain +from collections import defaultdict import torch @@ -21,21 +22,38 @@ def __init__(self, *args, **kwargs): @property def has_flat_params(self): - return torch.is_tensor(self.fp32_params) + return ( + torch.is_tensor(self.fp32_params) or + ( + isinstance(self.fp32_params, dict) and + all(torch.is_tensor(t) for t in self.fp32_params.values()) + ) + ) @classmethod - def build_fp32_params(cls, params, flatten=True): + def build_fp32_params(cls, args, params, flatten=True): # create FP32 copy of parameters and grads if flatten: total_param_size = sum(p.data.numel() for p in params) - fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=params[0].device) - offset = 0 - for p in params: - numel = p.data.numel() - fp32_params[offset:offset+numel].copy_(p.data.view(-1)) - offset += numel - fp32_params = torch.nn.Parameter(fp32_params) - fp32_params.grad = fp32_params.data.new(total_param_size) + devices = [torch.cuda.current_device()] + if args.pipeline_model_parallel and args.distributed_no_spawn: + devices = list(set(args.pipeline_devices)) + fp32_params = {} + for device in devices: + if args.pipeline_model_parallel and args.distributed_no_spawn: + device_param_size = sum(p.data.numel() for p in params if p.device.index == device) + device_params = [p for p in params if p.device.index == device] + else: + device_param_size = total_param_size + device_params = params + fp32_params[device] = device_params[0].new(0).float().new(device_param_size) + offset = 0 + for p in device_params: + numel = p.data.numel() + fp32_params[device][offset:offset+numel].copy_(p.data.view(-1)) + offset += numel + fp32_params[device] = torch.nn.Parameter(fp32_params[device]) + fp32_params[device].grad = fp32_params[device].data.new(device_param_size) return fp32_params else: fp32_params = [] @@ -80,14 +98,19 @@ def _sync_fp16_grads_to_fp32(self): if self._needs_sync: # copy FP16 grads to FP32 if self.has_flat_params: - offset = 0 + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) for p in self.fp16_params: - if not p.requires_grad: - continue - grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) - numel = grad_data.numel() - self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1)) - offset += numel + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) + numel = grad_data.numel() + self.fp32_params[device].grad.data[offset:offset+numel].copy_(grad_data.view(-1)) + offset += numel else: for p, p32 in zip(self.fp16_params, self.fp32_params): if not p.requires_grad: @@ -102,13 +125,17 @@ def _sync_fp16_grads_to_fp32(self): def _sync_fp32_params_to_fp16(self): # copy FP32 params back into FP16 model if self.has_flat_params: - offset = 0 + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) for p in self.fp16_params: - if not p.requires_grad: - continue - numel = p.data.numel() - p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) - offset += numel + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + numel = p.data.numel() + p.data.copy_(self.fp32_params[device].data[offset:offset+numel].view_as(p.data)) + offset += numel else: for p, p32 in zip(self.fp16_params, self.fp32_params): if not p.requires_grad: @@ -162,7 +189,13 @@ def zero_grad(self): for p in self.fp16_params: p.grad = None if self.has_flat_params: - self.fp32_params.grad.zero_() + if torch.is_tensor(self.fp32_params): + self.fp32_params.grad.zero_() + elif isinstance(self.fp32_params, dict): + for fp32_params in self.fp32_params.values(): + fp32_params.grad.zero_() + else: + raise("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: p32.grad.zero_() @@ -216,7 +249,7 @@ def build_optimizer(cls, args, params): flatten = not getattr(args, 'fp16_no_flatten_grads', False) if getattr(args, 'bf16', False): flatten = False # mixed precision is faster on TPUs without flat grads - fp32_params = cls.build_fp32_params(params, flatten=flatten) + fp32_params = cls.build_fp32_params(args, params, flatten=flatten) if flatten: fp32_optimizer = optim.build_optimizer(args, [fp32_params]) else: diff --git a/fairseq/options.py b/fairseq/options.py index 01150bda4a..f74d12073e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -409,6 +409,8 @@ def add_distributed_training_args(parser, default_world_size=None): help='which GPU to use (usually configured automatically)') group.add_argument('--distributed-no-spawn', action='store_true', help='do not spawn multiple processes even if multiple GPUs are visible') + group.add_argument('--distributed-num-procs', default=None, type=int, + help='number of processes to spawn (usually configured automatically)') # "c10d" is PyTorch's DDP implementation and provides the fastest # training. "no_c10d" is a more robust, but slightly slower DDP # implementation. Try this if you get warning messages about @@ -448,6 +450,25 @@ def add_distributed_training_args(parser, default_world_size=None): help='number of GPUs in each node. An allreduce operation across GPUs in ' 'a node is very fast. Hence, we do allreduce across GPUs in a node, ' 'and gossip across different nodes') + # Pipeline Parallel Arguments + group.add_argument('--pipeline-model-parallel', default=False, action='store_true', + help='if set, use pipeline model parallelism across GPUs') + group.add_argument('--pipeline-balance', metavar='N1,N2,...,N_K', + type=lambda x: eval_str_list(x, type=int), + help='partition the model into N_K pieces, where each piece ' + 'contains N_i layers. The sum(args.pipeline_balance) ' + 'should equal the total number of layers in the model') + group.add_argument('--pipeline-devices', metavar='N1,N2,...,N_K', + type=lambda x: eval_str_list(x, type=int), + help='a list of device indices indicating which device to place ' + 'each of the N_K partitions. The length of this list should ' + 'equal the length of the --pipeline-balance argument') + group.add_argument('--pipeline-chunks', type=int, metavar='N', + help='microbatch count for pipeline model parallelism') + group.add_argument('--pipeline-checkpoint', type=str, metavar='STR', + choices=['always', 'never', 'except_last'], + default='never', + help='checkpointing mode for pipeline model parallelism') # Add argument for ZeRO sharding of OptimizerState(os), gradients(g) and parameters(p) group.add_argument('--zero-sharding', default='none', type=str, choices=['none', 'os'], diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 64ad44c1df..60fa161d02 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -64,8 +64,13 @@ def __init__(self, args, task, model, criterion, quantizer=None): elif args.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - self._criterion = self._criterion.to(device=self.device) - self._model = self._model.to(device=self.device) + if not args.pipeline_model_parallel: + self._criterion = self._criterion.to(device=self.device) + self._model = self._model.to(device=self.device) + self.pipeline_model_parallel = args.pipeline_model_parallel + self.last_device = None + if self.cuda and self.pipeline_model_parallel: + self.last_device = torch.device(args.pipeline_devices[-1]) # check that shared parameters are preserved after device transfer for shared_param in shared_params: @@ -791,7 +796,11 @@ def _prepare_sample(self, sample): return None if self.cuda: - sample = utils.move_to_cuda(sample) + if self.pipeline_model_parallel: + if 'target' in sample: + sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device) + else: + sample = utils.move_to_cuda(sample) def apply_half(t): if t.dtype is torch.float32: diff --git a/fairseq/utils.py b/fairseq/utils.py index d10ed2f28a..b8914a78dc 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -74,9 +74,13 @@ def _apply(x): return _apply(sample) -def move_to_cuda(sample): +def move_to_cuda(sample, device=None): + device = device or torch.cuda.current_device() + def _move_to_cuda(tensor): - return tensor.cuda() + # non_blocking is ignored if tensor is not pinned, so we can always set + # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) + return tensor.cuda(device=device, non_blocking=True) return apply_to_sample(_move_to_cuda, sample) @@ -277,7 +281,7 @@ def multi_tensor_total_norm(grads, chunk_size=2048*32) -> torch.Tensor: has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) with torch.cuda.device(device): norm = multi_tensor_l2norm(chunk_size, has_inf, [cur_device_grads], False) - norms.append(norm[0]) + norms.append(norm[0].to(torch.cuda.current_device())) else: norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] total_norm = torch.norm(torch.stack(norms)) @@ -307,7 +311,7 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: "you may get better performance by installing NVIDIA's apex library" ) total_norm = torch.norm( - torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in grads]) + torch.stack([torch.norm(g, p=2, dtype=torch.float32).cuda(torch.cuda.current_device()) for g in grads]) ) if aggregate_norm_fn is not None: From fd080b308e1e3361d6c498b235496080fa6599e5 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Thu, 10 Sep 2020 22:58:57 -0700 Subject: [PATCH 150/707] hydra fairseq 4 - migrated several common utils from options to utils Summary: hydra fairseq 4 - migrated several common utils from options to utils Reviewed By: myleott Differential Revision: D23413473 fbshipit-source-id: 01e377de0fdca77a321924ab3768b7fafe3da32e --- .../translation_moe/src/translation_moe.py | 2 +- .../multilingual/multilingual_data_manager.py | 4 +-- fairseq/data/token_block_dataset.py | 2 +- fairseq/models/fconv_lm.py | 4 +-- fairseq/models/lightconv.py | 14 ++++---- fairseq/models/lightconv_lm.py | 10 +++--- fairseq/models/lstm.py | 6 ++-- fairseq/models/lstm_lm.py | 4 +-- fairseq/models/roberta/hub_interface.py | 4 +-- fairseq/models/transformer.py | 5 ++- fairseq/modules/transformer_layer.py | 4 +-- fairseq/options.py | 35 ++----------------- fairseq/tasks/multilingual_translation.py | 4 +-- fairseq/tasks/translation.py | 4 +-- fairseq/utils.py | 34 +++++++++++++++++- 15 files changed, 68 insertions(+), 68 deletions(-) diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/src/translation_moe.py index b60175f093..5455dd6681 100644 --- a/examples/translation_moe/src/translation_moe.py +++ b/examples/translation_moe/src/translation_moe.py @@ -162,7 +162,7 @@ def get_lprob_yz(winners=None): return lprob_yz # compute responsibilities without dropout - with utils.eval(model): # disable dropout + with utils.model_eval(model): # disable dropout with torch.no_grad(): # disable autograd lprob_yz = get_lprob_yz() # B x K prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 01ba6ece8f..361d41b438 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -286,8 +286,8 @@ def _shared_collater(self): @classmethod def prepare(cls, load_dictionary, args, **kargs): - args.left_pad_source = options.eval_bool(args.left_pad_source) - args.left_pad_target = options.eval_bool(args.left_pad_target) + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) if not hasattr(args, "shuffle_instance"): args.shuffle_instance = False diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 4e2f5cc482..cae872c310 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -75,7 +75,7 @@ def __init__( if break_mode == "eos" and block_size is None: block_size = 0 - slice_indices = _get_slice_indices_fast(sizes, break_mode, block_size, document_sep_len) + slice_indices = _get_slice_indices_fast(sizes, str(break_mode), block_size, document_sep_len) self._sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices diff --git a/fairseq/models/fconv_lm.py b/fairseq/models/fconv_lm.py index f2320b1700..4c3c5c66dd 100644 --- a/fairseq/models/fconv_lm.py +++ b/fairseq/models/fconv_lm.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq import options +from fairseq import utils from fairseq.models import ( FairseqLanguageModel, register_model, @@ -56,7 +56,7 @@ def build_model(cls, args, task): share_embed=False, positional_embeddings=False, adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), adaptive_softmax_dropout=args.adaptive_softmax_dropout, diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py index 05939e1c75..09d4d0be2e 100644 --- a/fairseq/models/lightconv.py +++ b/fairseq/models/lightconv.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from fairseq import options, utils +from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqIncrementalDecoder, @@ -133,13 +133,13 @@ def add_args(parser): help='sets adaptive softmax dropout for the tail projections') """LightConv and DynamicConv arguments""" - parser.add_argument('--encoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + parser.add_argument('--encoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), help='list of kernel size (default: "[3,7,15,31,31,31,31]")') - parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), help='list of kernel size (default: "[3,7,15,31,31,31]")') - parser.add_argument('--encoder-glu', type=options.eval_bool, + parser.add_argument('--encoder-glu', type=utils.eval_bool, help='glu after in proj') - parser.add_argument('--decoder-glu', type=options.eval_bool, + parser.add_argument('--decoder-glu', type=utils.eval_bool, help='glu after in proj') parser.add_argument('--encoder-conv-type', default='dynamic', type=str, choices=['dynamic', 'lightweight'], @@ -147,7 +147,7 @@ def add_args(parser): parser.add_argument('--decoder-conv-type', default='dynamic', type=str, choices=['dynamic', 'lightweight'], help='type of convolution') - parser.add_argument('--weight-softmax', default=True, type=options.eval_bool) + parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool) parser.add_argument('--weight-dropout', type=float, metavar='D', help='dropout probability for conv weights') @@ -353,7 +353,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_ self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), output_embed_dim, - options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, diff --git a/fairseq/models/lightconv_lm.py b/fairseq/models/lightconv_lm.py index a268ddd859..861f6430e9 100644 --- a/fairseq/models/lightconv_lm.py +++ b/fairseq/models/lightconv_lm.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq import options +from fairseq import utils from fairseq.models import ( FairseqLanguageModel, register_model, @@ -83,14 +83,14 @@ def add_args(parser): help='use learned positional embeddings in the decoder') """LightConv and DynamicConv arguments""" - parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), + parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), help='list of kernel size (default: "[3,7,15,31,31,31]")') - parser.add_argument('--decoder-glu', type=options.eval_bool, + parser.add_argument('--decoder-glu', type=utils.eval_bool, help='glu after in proj') parser.add_argument('--decoder-conv-type', default='dynamic', type=str, choices=['dynamic', 'lightweight'], help='type of convolution') - parser.add_argument('--weight-softmax', default=True, type=options.eval_bool) + parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool) parser.add_argument('--weight-dropout', type=float, metavar='D', help='dropout probability for conv weights') @@ -115,7 +115,7 @@ def build_model(cls, args, task): elif args.adaptive_input: embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, args.adaptive_input_factor, args.decoder_embed_dim, - options.eval_str_list(args.adaptive_input_cutoff, type=int)) + utils.eval_str_list(args.adaptive_input_cutoff, type=int)) else: embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 72bd815bcc..8404cafe1d 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F -from fairseq import options, utils +from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqIncrementalDecoder, @@ -168,12 +168,12 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_layers=args.decoder_layers, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, - attention=options.eval_bool(args.decoder_attention), + attention=utils.eval_bool(args.decoder_attention), encoder_output_units=encoder.output_units, pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), max_target_positions=max_target_positions, diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py index 9f6758a4bc..82bd02f6f7 100644 --- a/fairseq/models/lstm_lm.py +++ b/fairseq/models/lstm_lm.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq import options, utils +from fairseq import utils from fairseq.models import ( FairseqLanguageModel, register_model, register_model_architecture ) @@ -104,7 +104,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), max_target_positions=max_target_positions, diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index baf0bf28b9..20456b3f5c 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -155,7 +155,7 @@ def fill_mask(self, masked_input: str, topk: int = 5): if tokens.dim() == 1: tokens = tokens.unsqueeze(0) - with utils.eval(self.model): + with utils.model_eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, @@ -200,5 +200,5 @@ def disambiguate_pronoun(self, sentence: str) -> bool: """ assert hasattr(self.task, 'disambiguate_pronoun'), \ 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.' - with utils.eval(self.model): + with utils.model_eval(self.model): return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda') diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 6fd5c2bd05..a55c47c155 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from fairseq import options, utils +from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -30,7 +30,6 @@ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor - DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -612,7 +611,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, - options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 037d8e88ae..ced8d933f5 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -13,7 +13,6 @@ from fairseq.modules.fairseq_dropout import FairseqDropout from torch import Tensor - class TransformerEncoderLayer(nn.Module): """Encoder layer block. @@ -174,8 +173,9 @@ def __init__( add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) + self.activation_fn = utils.get_activation_fn( - activation=getattr(args, "activation_fn", "relu") + activation=str(args.activation_fn) if getattr(args, "activation_fn", None) is not None else "relu" ) activation_dropout_p = getattr(args, "activation_dropout", 0) if activation_dropout_p == 0: diff --git a/fairseq/options.py b/fairseq/options.py index f74d12073e..8e50064d76 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -8,7 +8,8 @@ from typing import Callable, List, Optional import torch - +# this import is for backward compatibility +from fairseq.utils import csv_str_list, eval_str_list, eval_str_dict, eval_bool # noqa from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl @@ -60,38 +61,6 @@ def get_validation_parser(default_task=None): return parser -def csv_str_list(x): - return x.split(',') - - -def eval_str_list(x, type=float): - if x is None: - return None - if isinstance(x, str): - x = eval(x) - try: - return list(map(type, x)) - except TypeError: - return [type(x)] - - -def eval_str_dict(x, type=dict): - if x is None: - return None - if isinstance(x, str): - x = eval(x) - return x - - -def eval_bool(x, default=False): - if x is None: - return default - try: - return bool(eval(x)) - except TypeError: - return default - - def parse_args_and_arch( parser: argparse.ArgumentParser, input_args: List[str] = None, diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 784b438ca9..7c7e18ec87 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -119,8 +119,8 @@ def setup_task(cls, args, **kwargs): @classmethod def prepare(cls, args, **kargs): - args.left_pad_source = options.eval_bool(args.left_pad_source) - args.left_pad_target = options.eval_bool(args.left_pad_target) + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) if args.lang_pairs is None: raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.') diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 6eac293659..8181c1a650 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -218,8 +218,8 @@ def setup_task(cls, args, **kwargs): Args: args (argparse.Namespace): parsed command-line arguments """ - args.left_pad_source = options.eval_bool(args.left_pad_source) - args.left_pad_target = options.eval_bool(args.left_pad_target) + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) paths = utils.split_paths(args.data) assert len(paths) > 0 diff --git a/fairseq/utils.py b/fairseq/utils.py index b8914a78dc..e95e76ca03 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -467,7 +467,7 @@ def get_available_activation_fns() -> List: @contextlib.contextmanager -def eval(model): +def model_eval(model): is_training = model.training model.eval() yield @@ -606,3 +606,35 @@ def pretty_print_cuda_env_list(cuda_env_list): + "name = {:40s}".format(env.name) ) logger.info(first_line) + + +def csv_str_list(x): + return x.split(',') + + +def eval_str_list(x, type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(type, x)) + except TypeError: + return [type(x)] + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def eval_bool(x, default=False): + if x is None: + return default + try: + return bool(eval(x)) + except TypeError: + return default From 23d8502bdde88a3e58e0910e2ee49834f8478b39 Mon Sep 17 00:00:00 2001 From: lematt1991 Date: Fri, 11 Sep 2020 13:27:19 -0700 Subject: [PATCH 151/707] Only move to CUDA device if available in clip_grad_norm_ (#2607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fix tests from d47067937abacfe87f2963adca8daeada3c631fe ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2607 Reviewed By: shruti-bh Differential Revision: D23650531 Pulled By: lematt1991 fbshipit-source-id: ac2e8b07968535e7454836db9cc0b659a0cdf887 --- fairseq/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/utils.py b/fairseq/utils.py index e95e76ca03..888e4d95e4 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -305,13 +305,15 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if multi_tensor_l2norm_available: total_norm = multi_tensor_total_norm(grads) else: + device = torch.device('cpu') if torch.cuda.is_available(): warnings.warn( "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " "you may get better performance by installing NVIDIA's apex library" ) + device = torch.cuda.current_device() total_norm = torch.norm( - torch.stack([torch.norm(g, p=2, dtype=torch.float32).cuda(torch.cuda.current_device()) for g in grads]) + torch.stack([torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]) ) if aggregate_norm_fn is not None: From 380227cc477f731f2315d62aa3d867720231538c Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 17 Sep 2020 20:42:59 -0700 Subject: [PATCH 152/707] Optimize sampled_multi_dataset.py and sampled_multi_epoch_dataset.py Summary: # Facebook: The research implementation is not stable for multilingual data at scale of 1.5 - 5 billions due to * large inter-progress memory consumption in PlasmaArray * caching of a few indices and sizes arrays For example, the following runs failed after training for a while: f217270435 and f216525055 This diff simplifies the implementation and optimizes it to have more stable training at scale: * removed PlasmaArray * introduced the lang_pair_dataset.latest_filter_by_size implementation to sampled_multi_dataset * optimized the sizes() implementation Reviewed By: akinh Differential Revision: D23634750 fbshipit-source-id: c3ded98ef2c84e3d4512f1693c0cbf493b91c1df --- fairseq/data/concat_dataset.py | 15 +- fairseq/data/data_utils.py | 42 +++++ fairseq/data/language_pair_dataset.py | 25 +-- .../multilingual/multilingual_data_manager.py | 46 ++---- .../multilingual/sampled_multi_dataset.py | 154 +++++++++-------- .../sampled_multi_epoch_dataset.py | 156 ++++-------------- .../tasks/translation_multi_simple_epoch.py | 50 ++++-- 7 files changed, 224 insertions(+), 264 deletions(-) diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py index 2c3306d6f5..0091a28e47 100644 --- a/fairseq/data/concat_dataset.py +++ b/fairseq/data/concat_dataset.py @@ -88,7 +88,20 @@ def ordered_indices(self): """ Returns indices sorted by length. So less padding is needed. """ - return np.argsort(self.sizes) + if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: + # special handling for concatenating lang_pair_datasets + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[ + np.argsort(tgt_sizes[indices], kind='mergesort') + ] + return indices[np.argsort(src_sizes[indices], kind='mergesort')] + else: + return np.argsort(self.sizes) def prefetch(self, indices): frm = 0 diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 5a00debbc3..11217020f2 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -213,6 +213,41 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): return indices +def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if tgt_sizes is None: + ignored = indices[src_sizes[indices] > max_src_size] + else: + ignored = indices[ + (src_sizes[indices] > max_src_size) | + (tgt_sizes[indices] > max_tgt_size)] + if len(ignored) > 0: + if tgt_sizes is None: + indices = indices[src_sizes[indices] <= max_src_size] + else: + indices = indices[ + (src_sizes[indices] <= max_src_size) & + (tgt_sizes[indices] <= max_tgt_size)] + return indices, ignored.tolist() + + def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, fixed_shapes=None, @@ -400,3 +435,10 @@ def arrange(s, e, length, keep_length): mask[i, mask_idc] = True return mask + + +def get_mem_usage(): + # for debug + import psutil + mb = 1024 * 1024 + return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb' diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index fba3d37bc5..5cc5087b2d 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -208,6 +208,7 @@ def __init__( self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source @@ -416,21 +417,9 @@ def filter_indices_by_size(self, indices, max_sizes): np.array: filtered sample array list: list of removed indices """ - if max_sizes is None: - return indices, [] - if type(max_sizes) in (int, float): - max_src_size, max_tgt_size = max_sizes, max_sizes - else: - max_src_size, max_tgt_size = max_sizes - if self.tgt_sizes is None: - ignored = indices[self.src_sizes[indices] > max_src_size] - else: - ignored = indices[(self.src_sizes[indices] > max_src_size) | - (self.tgt_sizes[indices] > max_tgt_size)] - if len(ignored) > 0: - if self.tgt_sizes is None: - indices = indices[self.src_sizes[indices] <= max_src_size] - else: - indices = indices[(self.src_sizes[indices] <= max_src_size) & - (self.tgt_sizes[indices] <= max_tgt_size)] - return indices, ignored.tolist() + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 361d41b438..18b5b96b28 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -785,15 +785,6 @@ def load_a_dataset( else None, langpairs_sharing_datasets=langpairs_sharing_datasets, ) - if langpair_ds.tgt_sizes is None: - # hack to use src_sizes as the sizes for the whole pair dataset for ConcatDataset - langpair_ds.sizes = langpair_ds.src_sizes - else: - # use the max of two sides to define the size to help max positions filtering - langpair_ds.sizes = np.vstack( - [langpair_ds.src_sizes, langpair_ds.tgt_sizes] - ).max(axis=0) - assert langpair_ds.sizes.shape == langpair_ds.src_sizes.shape # TODO: handle modified lang toks for mined data and dae data if self.args.lang_tok_replacing_bos_eos: ds = self.alter_dataset_langtok( @@ -1000,26 +991,6 @@ def load_split_datasets( ] return datasets, data_param_list - def load_into_sampled_multi_epoch_dataset( - self, split, datasets, data_param_list, epoch, shard_epoch=None - ): - sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) - return SampledMultiEpochDataset( - OrderedDict(datasets), - epoch=epoch, - shard_epoch=shard_epoch, - # valid and test datasets will be degerate to concating datasets: - sampling_ratios=sample_ratios, - eval_key=None, - batch_by_size=True, - collate_format=CollateFormat.single, - virtual_size=self.args.virtual_data_size, - split=split, - virtual_epoch_size=self.args.virtual_epoch_size, - # if not using lang_tok altering, simplified to use the same collater - shared_collater=self._shared_collater(), - ) - def load_into_concat_dataset(self, split, datasets, data_param_list): if self.args.lang_tok_replacing_bos_eos: # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset @@ -1027,7 +998,6 @@ def load_into_concat_dataset(self, split, datasets, data_param_list): OrderedDict(datasets), sampling_ratios=None, eval_key=None, - batch_by_size=True, collate_format=CollateFormat.single, virtual_size=None, split=split, @@ -1041,8 +1011,20 @@ def load_sampled_multi_epoch_dataset( split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs ) if training and split == getattr(self.args, "train_subset", None): - return self.load_into_sampled_multi_epoch_dataset( - split, datasets, data_param_list, epoch, shard_epoch=shard_epoch + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiEpochDataset( + OrderedDict(datasets), + epoch=epoch, + shard_epoch=shard_epoch, + # valid and test datasets will be degenerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + virtual_epoch_size=self.args.virtual_epoch_size, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), ) else: return self.load_into_concat_dataset(split, datasets, data_param_list) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 14090ac8c5..af8643c042 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -17,7 +17,7 @@ import torch from fairseq import distributed_utils -from fairseq.data import plasma_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils def get_time_gap(s, e): @@ -54,9 +54,7 @@ class SampledMultiDataset(FairseqDataset): or OrderedDict[str, ~torch.utils.data.Dataset] ): datasets sampling_ratios (List[float]): list of probability of each dataset to be sampled - (default: None, which corresponds to concating all dataset together). - batch_by_size (bool): whether or not to batch by sequence length - (default: True). + (default: None, which corresponds to concatenating all dataset together). seed (int): RNG seed to use (default: 2). epoch (int): starting epoch number (default: 1). eval_key (str, optional): a key used at evaluation time that causes @@ -70,24 +68,25 @@ class SampledMultiDataset(FairseqDataset): virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). split (str): the split of the data, e.g. 'train', 'valid' or 'test'. shared_collater (bool): whether or not to all sub-datasets have the same collater. + shuffle (bool): whether or not to shuffle data (default: True). """ def __init__( - self, - datasets, - sampling_ratios=None, - batch_by_size=False, - seed=2, - epoch=1, - eval_key=None, - collate_format=CollateFormat.single, - virtual_size=default_virtual_size_func, - split='', - shared_collater=False, + self, + datasets, + sampling_ratios=None, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split='', + shared_collater=False, + shuffle=True, ): super().__init__() - self.batch_by_size = batch_by_size self.shared_collater = shared_collater + self.shuffle = shuffle if isinstance(datasets, OrderedDict): self.keys = list(datasets.keys()) @@ -107,16 +106,18 @@ def __init__( self.seed = seed self._cur_epoch = None + + self.cumulated_sizes = None + # self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset + # namely, data item i is sampled from the kth sub-dataset self.datasets[k] + # where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k] self._cur_indices = None + self._sizes = None - self._ordered_indices = None self.virtual_size_per_dataset = None # caching properties self._reset_cached_properties() self.setup_sampling(sampling_ratios, virtual_size) - self.cumulated_sizes = None - self.virtual_size_per_dataset = None - self._size_cache = {} self.set_epoch(epoch) def _clean_if_not_none(self, var_list): @@ -126,10 +127,9 @@ def _clean_if_not_none(self, var_list): def _reset_cached_properties(self): self._clean_if_not_none([ - self._sizes, self._ordered_indices, self._cur_indices + self._sizes, self._cur_indices ]) self._sizes = None - self._ordered_indices = None self._cur_indices = None def setup_sampling(self, sample_ratios, virtual_size): @@ -141,10 +141,10 @@ def setup_sampling(self, sample_ratios, virtual_size): else: if not isinstance(sample_ratios, np.ndarray): sample_ratios = np.array(sample_ratios) - self.sample_ratios = plasma_utils.PlasmaArray(sample_ratios) + self.sample_ratios = sample_ratios virtual_size = default_virtual_size_func if virtual_size is None else virtual_size self.virtual_size = ( - virtual_size(self.datasets, self.sample_ratios.array) if callable(virtual_size) + virtual_size(self.datasets, self.sample_ratios) if callable(virtual_size) else virtual_size) def adjust_sampling(self, epoch, sampling_ratios, virtual_size): @@ -198,7 +198,6 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios): in_dataset_indices = [list(range(s)) for s in sizes] virtual_sizes_per_dataset = sizes else: - sample_ratios = sample_ratios.array ratios = sample_ratios / sample_ratios.sum() in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios) virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices] @@ -215,23 +214,21 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios): return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset def _get_dataset_and_index(self, index): - i = bisect_right(self.cumulated_sizes.array, index) - return i, self._cur_indices.array[index] + i = bisect_right(self.cumulated_sizes, index) + return i, self._cur_indices[index] def __getitem__(self, index): + # self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]] + # where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k] ds_idx, ds_sample_idx = self._get_dataset_and_index(index) ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx]) return ret def num_tokens(self, index): - ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - return self.datasets[ds_idx].num_tokens(ds_sample_idx) + return self.sizes[index].max() def size(self, index): - if self._sizes is not None: - return self._sizes[index] - ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - return self.datasets[ds_idx].size(ds_sample_idx) + return self.sizes[index] def __len__(self): return self.virtual_size @@ -244,13 +241,13 @@ def collater(self, samples, **extra_args): collect_samples = [[] for _ in range(len(self.datasets))] for (i, sample) in samples: collect_samples[i].append(sample) - return OrderedDict([ + batch = OrderedDict([ (self.keys[i], dataset.collater(collect_samples[i])) for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) if len(collect_samples[i]) > 0 ]) elif self.shared_collater: - return self.datasets[0].collater( + batch = self.datasets[0].collater( [s for _, s in samples] ) else: @@ -295,47 +292,41 @@ def straight_order(tensors): batch['net_input']['src_lang_id'] = straight_order([b['net_input']['src_lang_id'] for b in batches]) if 'tgt_lang_id' in batches[0]: batch['tgt_lang_id'] = straight_order([b['tgt_lang_id'] for b in batches]) - return batch + return batch @property def sizes(self): if self._sizes is not None: return self._sizes start_time = time.time() - size_cache = self._size_cache - ret = [] - for i in range(len(self)): - ds_idx, ds_sample_idx = self._get_dataset_and_index(i) - if (ds_idx, ds_sample_idx) in size_cache: - ret.append(size_cache[(ds_idx, ds_sample_idx)]) - else: - s = self.datasets[ds_idx].size(ds_sample_idx) - size_cache[(ds_idx, ds_sample_idx)] = s - ret.append(s) - logger.debug(f'sizes() calling time: {get_time_gap(start_time, time.time())}') - self._sizes = np.array(ret, np.int64) + in_sub_dataset_indices = [ + self._cur_indices[0 if i == 0 else self.cumulated_sizes[i-1]:self.cumulated_sizes[i]] + for i in range(len(self.datasets)) + ] + sub_dataset_sizes = [ + d.sizes[indices] + for d, indices in zip(self.datasets, in_sub_dataset_indices) + ] + self._sizes = np.vstack(sub_dataset_sizes) + logger.info(f'sizes() calling time: {get_time_gap(start_time, time.time())}') return self._sizes def ordered_indices(self): - if self._ordered_indices is not None: - return self._ordered_indices - - if self.batch_by_size: - # No need to do shuffle as the data items are already randomized - indices = np.arange(len(self)) - sizes = self.sizes - tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None - src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes - - # sort by target length, then source length - if tgt_sizes is not None: - indices = indices[ - np.argsort(tgt_sizes[indices], kind='mergesort') - ] - sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] + if self.shuffle: + indices = np.random.permutation(len(self)) else: - sort_indices = np.arange(len(self)) - self._ordered_indices = sort_indices + indices = np.arange(len(self)) + + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[ + np.argsort(tgt_sizes[indices], kind='mergesort') + ] + sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] return sort_indices def prefetch(self, indices): @@ -383,18 +374,37 @@ def _establish_virtual_datasets(self): self._clean_if_not_none([ self.cumulated_sizes, self.virtual_size_per_dataset ]) - self._cur_indices = plasma_utils.PlasmaArray(indices) - self.cumulated_sizes = plasma_utils.PlasmaArray(cumulated_sizes) - self.virtual_size_per_dataset = plasma_utils.PlasmaArray(virtual_size_per_dataset) + self._cur_indices = indices + self.cumulated_sizes = cumulated_sizes + self.virtual_size_per_dataset = virtual_size_per_dataset raw_sizes = [len(d) for d in self.datasets] - sampled_sizes = self.virtual_size_per_dataset.array + sampled_sizes = self.virtual_size_per_dataset logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; ' f'raw total size: {sum(raw_sizes)}') logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; ' f'resampled total size: {sum(sampled_sizes)}') if self.sample_ratios is not None: - logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios.array)))}') + logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}') else: logger.info(f'[{self.split}] A concat dataset') - logger.debug(f'[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}') + logger.info(f'[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}') + + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + + return data_utils.filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes) diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index 289a117a00..9442ed460e 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -6,13 +6,10 @@ import hashlib import math import logging -import time import numpy as np -import torch -from fairseq import distributed_utils -from fairseq.data import plasma_utils, SampledMultiDataset -from .sampled_multi_dataset import default_virtual_size_func, get_time_gap, CollateFormat +from fairseq.data import SampledMultiDataset +from .sampled_multi_dataset import default_virtual_size_func, CollateFormat logger = logging.getLogger(__name__) @@ -28,8 +25,6 @@ class SampledMultiEpochDataset(SampledMultiDataset): ): datasets sampling_ratios (List[float]): list of probability of each dataset to be sampled (default: None, which corresponds to concating all dataset together). - batch_by_size (bool): whether or not to batch by sequence length - (default: True). seed (int): RNG seed to use (default: 2). epoch (int): starting epoch number (default: 1). eval_key (str, optional): a key used at evaluation time that causes @@ -47,12 +42,12 @@ class SampledMultiEpochDataset(SampledMultiDataset): can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. shared_collater (bool): whether or not to all sub-datasets have the same collater. shard_epoch (int): the real epoch number for shard selection. + shuffle (bool): whether or not to shuffle data (default: True). """ def __init__( self, datasets, sampling_ratios=None, - batch_by_size=False, seed=2, epoch=1, eval_key=None, @@ -62,18 +57,16 @@ def __init__( virtual_epoch_size=None, shared_collater=False, shard_epoch=1, + shuffle=True, ): self.virtual_epoch_size = virtual_epoch_size self._current_epoch_start_index = None - self._epoch_sizes = None - self._epoch_ordered_indices = None - self._random_globa_indices = None + self._random_global_indices = None self.shard_epoch = shard_epoch if shard_epoch is not None else 1 self.load_next_shard = None super().__init__( datasets=datasets, sampling_ratios=sampling_ratios, - batch_by_size=batch_by_size, seed=seed, epoch=epoch, eval_key=eval_key, @@ -81,6 +74,7 @@ def __init__( virtual_size=virtual_size, split=split, shared_collater=shared_collater, + shuffle=shuffle, ) def _setup(self, epoch): @@ -96,22 +90,19 @@ def _setup(self, epoch): def _map_epoch_index_to_global(self, index): index = self._current_epoch_start_index + index # add randomness - return self._random_globa_indices.array[index] + return self._random_global_indices[index] - def __getitem__(self, index): - i = self._map_epoch_index_to_global(index) - return super().__getitem__(i) + @property + def sizes(self): + _sizes = super().sizes + indices = self._random_global_indices[ + self._current_epoch_start_index:self._current_epoch_start_index + len(self) + ] + return _sizes[indices] - def num_tokens(self, index): + def _get_dataset_and_index(self, index): i = self._map_epoch_index_to_global(index) - return super().num_tokens(i) - - def size(self, index): - if self._epoch_sizes is not None: - return self._epoch_sizes.array[index] - index = self._map_epoch_index_to_global(index) - ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - return self.datasets[ds_idx].size(ds_sample_idx) + return super()._get_dataset_and_index(i) def __len__(self): return ( @@ -120,72 +111,17 @@ def __len__(self): else self.virtual_size - self._current_epoch_start_index ) - @property - def sizes(self): - if self._epoch_sizes is not None: - return self._epoch_sizes.array - start_time = time.time() - - size_cache = self._size_cache - ret = [] - for i in range(len(self)): - index = self._map_epoch_index_to_global(i) - ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - - if (ds_idx, ds_sample_idx) in size_cache: - ret.append(size_cache[(ds_idx, ds_sample_idx)]) - else: - s = self.datasets[ds_idx].size(ds_sample_idx) - s = (s, s) if not isinstance(s, tuple) else s - size_cache[(ds_idx, ds_sample_idx)] = s - ret.append(s) - self._epoch_sizes = plasma_utils.PlasmaArray(np.array(ret, np.int64)) - logger.info(f'sizes() calling time: {get_time_gap(start_time, time.time())}') - return self._epoch_sizes.array - - def ordered_indices(self): - if self._epoch_ordered_indices is not None: - return self._epoch_ordered_indices.array - - if self.batch_by_size: - # No need to do shuffle as the data items are already randomized - indices = np.arange(len(self)) - sizes = self.sizes - tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None - src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes - - # sort by target length, then source length - if tgt_sizes is not None: - indices = indices[ - np.argsort(tgt_sizes[indices], kind='mergesort') - ] - sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] - else: - sort_indices = np.arange(len(self)) - self._epoch_ordered_indices = plasma_utils.PlasmaArray(sort_indices) - return self._epoch_ordered_indices.array - - def prefetch(self, indices): - prefetch_indices = [[] for _ in range(len(self.datasets))] - for i in indices: - index = self._map_epoch_index_to_global(i) - ds_idx, ds_sample_idx = self._get_dataset_and_index(index) - prefetch_indices[ds_idx].append(ds_sample_idx) - for i in range(len(prefetch_indices)): - self.datasets[i].prefetch(prefetch_indices[i]) - - @property - def can_reuse_epoch_itr_across_epochs(self): - return False - def set_epoch(self, epoch): if self._current_epoch_start_index is None: + # initializing epoch idnices of a virtual dataset self._setup(epoch) self._next_virtual_epoch(epoch) - if epoch == self._cur_epoch: - # re-enter so return - return - self._next_virtual_epoch(epoch) + else: + # working on already intialized epoch indices + if epoch == self._cur_epoch: + # re-enter so return + return + self._next_virtual_epoch(epoch) def _get_epoch_start_index(self, epoch): assert epoch >= 1 # fairseq is using 1-based epoch everywhere @@ -199,50 +135,20 @@ def _next_global_indices(self, epoch): epoch, # epoch index, ] ) - del self._random_globa_indices - self._random_globa_indices = plasma_utils.PlasmaArray( - rng.choice(self.virtual_size, self.virtual_size, replace=False)) + del self._random_global_indices + self._random_global_indices = rng.choice(self.virtual_size, self.virtual_size, replace=False) if self.load_next_shard is None: self.load_next_shard = False else: # increase shard epoch for next loading self.shard_epoch += 1 self.load_next_shard = True - # a hack to avoid possible out of sync of shard epoch number - # TODO: to confirm whether this is needed; without it, CUDA event error is occassionally observed - synced_shard_epoch = self._sync_shard_epoch(self.shard_epoch) logger.info('to load next epoch/shard in next load_dataset: ' - f'epoch={epoch}/shard_epoch={self.shard_epoch}[synced={synced_shard_epoch}]') - - def _sync_shard_epoch(self, shard_epoch): - # in case the ratios are not precisely the same across processes - # also to ensure every procresses update the ratios in the same pace - shard_epoch = torch.DoubleTensor([shard_epoch]) - if torch.distributed.is_initialized(): - if torch.cuda.is_available(): - distributed_utils.all_reduce(shard_epoch.cuda()) - else: - distributed_utils.all_reduce(shard_epoch) - ret = shard_epoch.cpu() - ret = ret.numpy() - return ret - - def _sync_epoch(self, epoch): - # in case the ratios are not precisely the same across processes - # also to ensure every procresses update the ratios in the same pace - epoch = torch.DoubleTensor([epoch]) - if torch.distributed.is_initialized(): - if torch.cuda.is_available(): - distributed_utils.all_reduce(epoch.cuda()) - else: - distributed_utils.all_reduce(epoch) - ret = epoch.cpu() - ret = ret.numpy() - return ret + f'epoch={epoch}/shard_epoch={self.shard_epoch}') def _next_virtual_epoch(self, epoch): index = self._get_epoch_start_index(epoch) - if index == 0 or self._random_globa_indices is None: + if index == 0 or self._random_global_indices is None: # need to start from the beginning, # so call super().set_epoch(epoch) to establish the global virtual indices logger.info('establishing a new set of global virtual indices for ' @@ -251,12 +157,10 @@ def _next_virtual_epoch(self, epoch): self._next_global_indices(epoch) else: self._cur_epoch = epoch - # reset cache sizes and ordered_indices for the epoch after moving to a new epoch + # reset cache sizes and ordered_indices for the epoch after moving to a new epoch self._clean_if_not_none([ - self._epoch_sizes, self._epoch_ordered_indices, self._size_cache + self._sizes, ]) - self._epoch_sizes = None - self._epoch_ordered_indices = None + self._sizes = None self._current_epoch_start_index = index - self._size_cache = {} diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 94f1fd32af..2858d0ad7b 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -129,6 +129,11 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): else: shard_epoch = None logger.info(f'loading data for {split} epoch={epoch}/{shard_epoch}') + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + if split in self.datasets: + del self.datasets[split] + logger.info('old dataset deleted manually') + logger.info(f"mem usage: {data_utils.get_mem_usage()}") self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( split, self.training, @@ -228,39 +233,51 @@ def target_dictionary(self): def create_batch_sampler_func( self, max_positions, ignore_invalid_inputs, - max_tokens, max_sentences + max_tokens, max_sentences, + required_batch_size_multiple=1, + seed=1, ): def construct_batch_sampler( dataset, epoch ): splits = [s for s, _ in self.datasets.items() if self.datasets[s] == dataset] split = splits[0] if len(splits) > 0 else None - + # NEW implementation if epoch is not None: + # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) - start_time = time.time() + # get indices ordered by example size - indices = dataset.ordered_indices() - logger.debug(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}') + start_time = time.time() + logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") + + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + logger.info(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}') + logger.info(f"mem usage: {data_utils.get_mem_usage()}") # filter examples that are too large if max_positions is not None: my_time = time.time() indices = self.filter_indices_by_size( - indices, - dataset, - max_positions, - ignore_invalid_inputs=ignore_invalid_inputs, + indices, dataset, max_positions, ignore_invalid_inputs ) - logger.debug(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}') + logger.info(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}') + logger.info(f"mem usage: {data_utils.get_mem_usage()}") # create mini-batches with given size constraints my_time = time.time() - batch_sampler = data_utils.batch_by_size( - indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, ) - logger.debug(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}') - logger.debug(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}') + + logger.info(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}') + logger.info(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}') + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + return batch_sampler return construct_batch_sampler @@ -333,7 +350,10 @@ def get_batch_iterator( construct_batch_sampler = self.create_batch_sampler_func( max_positions, ignore_invalid_inputs, - max_tokens, max_sentences) + max_tokens, max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + ) epoch_iter = iterators.EpochBatchIterator( dataset=dataset, From 4bd52fb55b28ac99b987a755c729ac1a79e3b5d3 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 18 Sep 2020 14:40:39 -0700 Subject: [PATCH 153/707] Bring back cached sizes() of sampled_multi_epoch_dataset Summary: # Facebook: D23634750 (https://github.com/pytorch/fairseq/commit/380227cc477f731f2315d62aa3d867720231538c) removed the cache of SampledMultiEpochDataset. This results in repeating calls of sizes() function for num_tokens making large dataset impossible to do batch_by_size. This diff brings it back. Reviewed By: pipibjc Differential Revision: D23781668 fbshipit-source-id: b0eabe639500ff87d47abe53a9256644d7ee4d1e --- fairseq/data/multilingual/sampled_multi_dataset.py | 9 +++++---- .../multilingual/sampled_multi_epoch_dataset.py | 13 ++++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index af8643c042..5270675124 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -368,12 +368,13 @@ def _establish_virtual_datasets(self): self._cur_epoch, # epoch index, ] ) - indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( - rng, self.datasets, self.sample_ratios, self.virtual_size) - self._clean_if_not_none([ - self.cumulated_sizes, self.virtual_size_per_dataset + self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes ]) + self._sizes = None + + indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( + rng, self.datasets, self.sample_ratios, self.virtual_size) self._cur_indices = indices self.cumulated_sizes = cumulated_sizes self.virtual_size_per_dataset = virtual_size_per_dataset diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index 9442ed460e..81ff78f705 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -64,6 +64,7 @@ def __init__( self._random_global_indices = None self.shard_epoch = shard_epoch if shard_epoch is not None else 1 self.load_next_shard = None + self._epoch_sizes = None super().__init__( datasets=datasets, sampling_ratios=sampling_ratios, @@ -94,11 +95,17 @@ def _map_epoch_index_to_global(self, index): @property def sizes(self): + if self._epoch_sizes is not None: + return self._epoch_sizes _sizes = super().sizes indices = self._random_global_indices[ self._current_epoch_start_index:self._current_epoch_start_index + len(self) ] - return _sizes[indices] + self._epoch_sizes = _sizes[indices] + # del super()._sizes to save memory + del self._sizes + self._sizes = None + return self._epoch_sizes def _get_dataset_and_index(self, index): i = self._map_epoch_index_to_global(index) @@ -160,7 +167,7 @@ def _next_virtual_epoch(self, epoch): # reset cache sizes and ordered_indices for the epoch after moving to a new epoch self._clean_if_not_none([ - self._sizes, + self._epoch_sizes, ]) - self._sizes = None + self._epoch_sizes = None self._current_epoch_start_index = index From 732ddcc5ae0244ca95fdb2edbd8c4843441abb50 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 18 Sep 2020 14:56:48 -0700 Subject: [PATCH 154/707] Better estimation of shard data sizes for sampling ratios computation in multilingual training Summary: Previously we use current_data_size * num_shards as the total data size of a sharding dataset. It can be a bad estimation for small shards. Here as the training goes one, multilingual_data_manager will cache the shard size so that it can have exact total data sizes after a pass over the data. Differential Revision: D23683717 fbshipit-source-id: fcd7fd170327091bfd0bed599f1d930c71f62570 --- .../multilingual/multilingual_data_manager.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 18b5b96b28..f3527c1f59 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -7,7 +7,7 @@ import json import logging import os -from collections import OrderedDict +from collections import OrderedDict, defaultdict import numpy as np from fairseq import options, utils @@ -70,6 +70,7 @@ def __init__(self, args, lang_pairs, langs, dicts, sampling_method): self.sampling_scheduler = None self._has_sharded_data = False self._num_shards_dict = {} + self._training_data_sizes = defaultdict(lambda: {}) @classmethod def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): @@ -865,10 +866,14 @@ def get_split_num_data_shards(self, split): logger.info(f"[{split}] num of shards: {num_shards_dict}") return num_shards_dict - def get_split_data_path(self, paths, epoch, shard_epoch, num_shards): + @classmethod + def get_shard_id(cls, num_shards, epoch, shard_epoch=None): shard = epoch if shard_epoch is None else shard_epoch shard = (shard - 1) % num_shards - path = paths[shard] + return shard + + def get_split_data_path(self, paths, epoch, shard_epoch, num_shards): + path = paths[self.get_shard_id(num_shards, epoch, shard_epoch)] return path def get_split_data_param_list(self, split, epoch, shard_epoch=None): @@ -925,27 +930,41 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): ) return param_list - def get_train_dataset_sizes(self, data_param_list, datasets): + def get_train_dataset_sizes(self, data_param_list, datasets, epoch, shard_epoch=None): num_shards = [ self.get_split_num_data_shards(param["split"])[param["key"]] for param in data_param_list ] - data_sizes = [ - (key, len(d) * num_shard) - for (key, d), num_shard in zip(datasets, num_shards) - ] + data_sizes = [] + for (key, d), num_shard in zip(datasets, num_shards): + my_data_sizes = self._training_data_sizes[key] + shard_ind = self.get_shard_id(num_shard, epoch, shard_epoch) + if shard_ind not in my_data_sizes: + my_data_sizes[shard_ind] = len(d) + known_size = max(my_data_sizes.values()) + data_sizes.append( + # If we don't know the data size of the shard yet, + # use the the max known data size to approximate. + # Note that we preprocess shards by a designated shard size + # and put any remaining data at the end into the last shard so + # the max shard size approximation is almost correct before loading + # the last shard; after loading the last shard, it will have the + # exact data sizes of the whole data size. + (key, sum(my_data_sizes.get(i, known_size) for i in range(num_shard))) + ) logger.info( - f"data sizes multiplied by num_shards used in sampling ratios: {data_sizes}" + f"estimated total data sizes of all shards used in sampling ratios: {data_sizes}. " + "Note that if the data a shard has not been loaded yet, use the max known data size to approximate" ) return [s for _, s in data_sizes] - def get_train_sampling_ratios(self, data_param_list, datasets, epoch=1): - data_sizes = self.get_train_dataset_sizes(data_param_list, datasets) + def get_train_sampling_ratios(self, data_param_list, datasets, epoch=1, shard_epoch=None): + data_sizes = self.get_train_dataset_sizes(data_param_list, datasets, epoch, shard_epoch) sampling_func = self.sampling_method.sampling_method_selector() sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None return sample_ratios - def get_sampling_ratios(self, data_param_list, datasets, epoch): + def get_sampling_ratios(self, data_param_list, datasets, epoch, shard_epoch=None): if self.args.sampling_weights_from_file: weights = load_sampling_weights(self.args.sampling_weights_from_file) sample_ratios = [weights[k] for k, _ in datasets] @@ -957,7 +976,7 @@ def get_sampling_ratios(self, data_param_list, datasets, epoch): sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] else: sample_ratios = self.get_train_sampling_ratios( - data_param_list, datasets, epoch + data_param_list, datasets, epoch, shard_epoch ) if sample_ratios is not None: From 30fe5f5a6b715cf583a9482fc7fc20eb1b7f7bd6 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Fri, 18 Sep 2020 18:01:35 -0700 Subject: [PATCH 155/707] Add to-many decoding support to FBTranslate TorchScript model Summary: Add support for our TorchScript model to decode to multiple output languages by specifying decoder language tokens Reviewed By: cndn Differential Revision: D23188553 fbshipit-source-id: 2a985dc3aa0b24e7297f23e449b3479c155e6611 --- .../multilingual/multilingual_data_manager.py | 80 +++++++------------ .../data/multilingual/multilingual_utils.py | 63 +++++++++++++++ 2 files changed, 93 insertions(+), 50 deletions(-) create mode 100644 fairseq/data/multilingual/multilingual_utils.py diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index f3527c1f59..1ddb19eb7b 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -25,6 +25,13 @@ data_utils, indexed_dataset, ) +from fairseq.data.multilingual.multilingual_utils import ( + EncoderLangtok, + LangTokSpec, + LangTokStyle, + augment_dictionary, + get_lang_tok, +) from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat from fairseq.file_io import PathManager from fairseq.options import csv_str_list, eval_str_dict @@ -33,17 +40,6 @@ logger = logging.getLogger(__name__) -def _lang_token(lang: str, style="__{}__"): - return style.format(lang) - - -def _lang_token_index(dic: Dictionary, lang: str, style="__{}__"): - """Return language token index.""" - idx = dic.index(_lang_token(lang, style)) - assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) - return idx - - def _lang_id(dic: Dictionary, lang: str): """Return language ID index.""" idx = dic.index(lang) @@ -103,9 +99,9 @@ def add_args(parser): ) parser.add_argument( "--lang-tok-style", - default="multilingual", + default=LangTokStyle.multilingual.value, type=str, - choices=["multilingual", "mbart"], + choices=[LangTokStyle.multilingual.value, LangTokStyle.mbart.value], help="language token styles", ) @@ -158,7 +154,7 @@ def add_args(parser): "--encoder-langtok", default=None, type=str, - choices=["src", "tgt"], + choices=[EncoderLangtok.src.value, EncoderLangtok.tgt.value], metavar="SRCTGT", help="prepend to the beginning of source sentence the source or target " "language token. (src/tgt)", @@ -204,7 +200,7 @@ def add_args(parser): e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \ distinguish languages in different training data types. If not specified, default language \ tokens per languages will be added', - default="main", + default=LangTokSpec.main.value, type=csv_str_list, ) parser.add_argument( @@ -320,9 +316,9 @@ def check_langs(langs, pairs): training = False else: training = True - sorted_langs = cls.load_langs(args, **kargs) + language_list = cls.load_langs(args, **kargs) check_langs( - sorted_langs, + language_list, ( [p.split("-") for p in args.lang_pairs] if training @@ -346,33 +342,25 @@ def check_langs(langs, pairs): langs_to_load_dicts = sorted([args.source_lang, args.target_lang]) dicts = OrderedDict() - supported_langtok_specs = args.langtoks_specs for lang in langs_to_load_dicts: paths = utils.split_paths(args.data) assert len(paths) > 0 dicts[lang] = load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(lang)) ) + augment_dictionary( + dictionary=dicts[lang], + language_list=language_list, + lang_tok_style=args.lang_tok_style, + langtoks_specs=args.langtoks_specs, + extra_data=args.extra_data, + ) if len(dicts) > 0: assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() assert dicts[lang].unk() == dicts[langs_to_load_dicts[0]].unk() - - # keep the langs consistent for all experiments with the same lang dict - # for finetuning regardless of whether lang_tok is required or not just add the tokens to the dicts - for spec in supported_langtok_specs: - for lang_to_add in sorted_langs: - dicts[lang].add_symbol( - MultilingualDatasetManager.get_lang_tok(lang_to_add, args, spec) - ) - if args.lang_tok_style == "mbart" or ( - args.extra_data and "mono_dae" in args.extra_data - ): - dicts[lang].add_symbol("") logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) - return sorted_langs, dicts, training - - TOKEN_STYLES = {"mbart": "[{}]", "multilingual": "__{}__"} + return language_list, dicts, training @classmethod def create_lang_dictionary(cls, langs): @@ -383,20 +371,6 @@ def create_lang_dictionary(cls, langs): lang_dict.add_symbol(lang) return lang_dict - @classmethod - def get_lang_tok_style(cls, args): - return cls.TOKEN_STYLES[args.lang_tok_style] - - @classmethod - def get_lang_tok(cls, lang, args, spec=""): - if spec is None: - return None - if spec.endswith("dae"): - lang = f"{lang}_dae" - elif spec.endswith("mined"): - lang = f"{lang}_mined" - return _lang_token(lang, cls.get_lang_tok_style(args)) - @classmethod def get_langtok_index(cls, lang_tok, dic): idx = dic.index(lang_tok) @@ -411,11 +385,15 @@ def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): if spec and spec.startswith("src"): if src_lang is None: return None - langtok = self.get_lang_tok(src_lang, self.args, spec) + langtok = get_lang_tok( + lang=src_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) else: if tgt_lang is None: return None - langtok = self.get_lang_tok(tgt_lang, self.args, spec) + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) return self.get_langtok_index( langtok, self.dicts[src_lang if src_lang else tgt_lang] ) @@ -423,7 +401,9 @@ def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): def get_decoder_langtok(self, tgt_lang, spec=None): if spec is None: return None - langtok = self.get_lang_tok(tgt_lang, self.args, spec) + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) return self.get_langtok_index(langtok, self.dicts[tgt_lang]) @classmethod diff --git a/fairseq/data/multilingual/multilingual_utils.py b/fairseq/data/multilingual/multilingual_utils.py new file mode 100644 index 0000000000..b4e0f9828c --- /dev/null +++ b/fairseq/data/multilingual/multilingual_utils.py @@ -0,0 +1,63 @@ +from enum import Enum +from typing import Dict, List, Optional, Sequence + +import torch +from fairseq.data import Dictionary + + +class EncoderLangtok(Enum): + """ + Prepend to the beginning of source sentence either the + source or target language token. (src/tgt). + """ + + src = "src" + tgt = "tgt" + + +class LangTokSpec(Enum): + main = "main" + mono_dae = "mono_dae" + + +class LangTokStyle(Enum): + multilingual = "multilingual" + mbart = "mbart" + + +@torch.jit.export +def get_lang_tok( + lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value +) -> str: + # TOKEN_STYLES can't be defined outside this fn since it needs to be + # TorchScriptable. + TOKEN_STYLES: Dict[str, str] = { + LangTokStyle.mbart.value: "[{}]", + LangTokStyle.multilingual.value: "__{}__", + } + + if spec.endswith("dae"): + lang = f"{lang}_dae" + elif spec.endswith("mined"): + lang = f"{lang}_mined" + style = TOKEN_STYLES[lang_tok_style] + return style.format(lang) + + +def augment_dictionary( + dictionary: Dictionary, + language_list: List[str], + lang_tok_style: str, + langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), + extra_data: Optional[Dict[str, str]] = None, +) -> None: + for spec in langtoks_specs: + for language in language_list: + dictionary.add_symbol( + get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) + ) + + if lang_tok_style == LangTokStyle.mbart.value or ( + extra_data is not None and LangTokSpec.mono_dae.value in extra_data + ): + dictionary.add_symbol("") From 1bd84232571da80c6ca7fea2b260091258d8c718 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 18 Sep 2020 21:40:35 -0700 Subject: [PATCH 156/707] Lazy manifold download for translation and multilingual translation Summary: # Facebook: Downloading all shards to local is formidable for 5 billion pairs of bitext. This diff enables manifold lazy download to avoid downloading all shards before training starts * Cache the downloading of individual files as needed * For small files, directly use PathManager.open Changes in the following files will serve most translation tasks: * deeplearning/projects/fairseq-py/fairseq/data/data_utils.py * deeplearning/projects/fairseq-py/fairseq/data/indexed_dataset.py Reviewed By: theweiho Differential Revision: D23670815 fbshipit-source-id: 4f90c184d8832edd14f3c30cdc8e1bfad59946a9 --- fairseq/data/data_utils.py | 3 +- fairseq/data/indexed_dataset.py | 20 ++++- .../multilingual/multilingual_data_manager.py | 73 +++++++------------ 3 files changed, 45 insertions(+), 51 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 11217020f2..c8cc90fb08 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -71,15 +71,14 @@ def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False """ from fairseq.data.concat_dataset import ConcatDataset import fairseq.data.indexed_dataset as indexed_dataset - datasets = [] for k in itertools.count(): path_k = path + (str(k) if k > 0 else '') + path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) dataset_impl_k = dataset_impl if dataset_impl_k is None: dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) - dataset = indexed_dataset.make_dataset( path_k, impl=dataset_impl_k or default, diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 5b6155a2df..55bf0ca585 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -13,6 +13,7 @@ from . import FairseqDataset from fairseq.data.fasta_dataset import FastaDataset +from fairseq.file_io import PathManager def __best_fitting_dtype(vocab_size=None): @@ -179,7 +180,7 @@ def size(self, index): @staticmethod def exists(path): return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + PathManager.exists(index_file_path(path)) and PathManager.exists(data_file_path(path)) ) @property @@ -287,7 +288,7 @@ def size(self, index): @staticmethod def exists(path): - return os.path.exists(path) + return PathManager.exists(path) class IndexedDatasetBuilder(object): @@ -497,10 +498,23 @@ def supports_prefetch(self): @staticmethod def exists(path): return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + PathManager.exists(index_file_path(path)) and PathManager.exists(data_file_path(path)) ) +def get_indexed_dataset_to_local(path): + local_index_path = PathManager.get_local_path(index_file_path(path)) + local_data_path = PathManager.get_local_path(data_file_path(path)) + + assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), \ + "PathManager.get_local_path does not return files with expected patterns: " \ + f"{local_index_path} and {local_data_path}" + + local_path = local_data_path[:-4] # stripping surfix ".bin" + assert local_path == local_index_path[:-4] # stripping surfix ".idx" + return local_path + + class MMapIndexedDatasetBuilder(object): def __init__(self, out_file, dtype=np.int64): self._data_file = open(out_file, 'wb') diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 1ddb19eb7b..6240cf76d5 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -9,8 +9,7 @@ import os from collections import OrderedDict, defaultdict -import numpy as np -from fairseq import options, utils +from fairseq import utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, @@ -342,9 +341,9 @@ def check_langs(langs, pairs): langs_to_load_dicts = sorted([args.source_lang, args.target_lang]) dicts = OrderedDict() + paths = utils.split_paths(args.data) + assert len(paths) > 0 for lang in langs_to_load_dicts: - paths = utils.split_paths(args.data) - assert len(paths) > 0 dicts[lang] = load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(lang)) ) @@ -416,39 +415,6 @@ def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) - @classmethod - def mono_split_exists(cls, split, lang, data_path, dataset_impl): - filename = os.path.join(data_path, "{}.{}".format(split, lang)) - return indexed_dataset.dataset_exists(filename, impl=dataset_impl) - - @classmethod - def bitext_split_exists(cls, split, src, tgt, data_path, dataset_impl): - src_exists = cls.split_exists( - split, src, tgt, lang=src, data_path=data_path, dataset_impl=dataset_impl - ) or cls.split_exists( - split, tgt, src, lang=src, data_path=data_path, dataset_impl=dataset_impl - ) - # check source exists to determine shard number - # also note that during inference time target is not required - # so checking target will fail inference time data loading - return src_exists - - @classmethod - def get_split_num_shards(cls, split, src, tgt, data_paths, dataset_impl): - return sum( - 1 - for path in data_paths - if cls.bitext_split_exists(split, src, tgt, path, dataset_impl) - ) - - @classmethod - def get_mono_split_num_shards(cls, split, lang, data_paths, dataset_impl): - return sum( - 1 - for path in data_paths - if cls.mono_split_exists(split, lang, path, dataset_impl) - ) - def load_lang_dataset( self, data_path, @@ -814,6 +780,20 @@ def get_data_paths_and_lang_pairs(self, split): def get_dataset_key(cls, data_category, src, tgt): return f"{data_category}:{src}-{tgt}" + @classmethod + def _get_shard_num_dict(cls, split, paths): + shards = defaultdict(int) + for path in paths: + files = PathManager.ls(path) + for f in files: + if f.startswith(split) and f.endswith('.idx'): + # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" + direction = f.split('.')[-3] + shards[direction] += 1 + # each direction has two '.idx' files + # one for source language and one for target language, so: + return {k: v // 2 for k, v in shards.items()} + def get_split_num_data_shards(self, split): if split in self._num_shards_dict: return self._num_shards_dict[split] @@ -824,24 +804,25 @@ def get_split_num_data_shards(self, split): if data_category not in lang_pairs: continue paths = utils.split_paths(paths) + shards_dict = self._get_shard_num_dict(split, paths) lang_dirs = [ lang_pair.split("-") for lang_pair in lang_pairs[data_category] ] lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] for src, tgt in lang_dirs: - # monolingual data ruqires tgt only - assert src is not None or "mono_" in data_category, ( - f"error: src={src}, " "tgt={tgt} for data_category={data_category}" - ) key = self.get_dataset_key(data_category, src, tgt) if "mono_" in data_category: - num_shards_dict[key] = self.get_mono_split_num_shards( - split, tgt, paths, self.args.dataset_impl + # monolingual data requires tgt only + assert src is None or src == tgt, ( + f"error: src={src}, " "tgt={tgt} for data_category={data_category}" ) + num_shards_dict[key] = shards_dict[tgt] else: - num_shards_dict[key] = self.get_split_num_shards( - split, src, tgt, paths, self.args.dataset_impl - ) + if f"{src}-{tgt}" in shards_dict: + num_shards_dict[key] = shards_dict[f"{src}-{tgt}"] + elif f'{tgt}-{src}' in shards_dict: + # follow the fairseq tradition to use reversed direction data if it is not available + num_shards_dict[key] = shards_dict[f'{tgt}-{src}'] self._num_shards_dict[split] = num_shards_dict logger.info(f"[{split}] num of shards: {num_shards_dict}") return num_shards_dict From 9ae39465daff48650ed7956adf2980a697016d50 Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Sat, 19 Sep 2020 17:56:52 -0700 Subject: [PATCH 157/707] Replace string literals for special symbols in Torchscript decoder Summary: repeat title. Due to TorchScript constraints it's not trivial to sync fairseq dictionary symbols in VocabConstants. Reviewed By: theweiho Differential Revision: D23774348 fbshipit-source-id: ec8db5ca41ee448ed317e76eabd4971f833015cc --- fairseq/data/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 77ef9538a3..d4b88024b0 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -27,7 +27,7 @@ def __init__( unk="", extra_special_symbols=None, ): - self.unk_word, self.pad_word, self.eos_word = unk, pad, eos + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos self.symbols = [] self.count = [] self.indices = {} From d5f7b50e1cf7d99abbf1ddfd1a985969c13ff4c3 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Mon, 21 Sep 2020 13:45:35 -0700 Subject: [PATCH 158/707] fix memory leak in pyspeech Summary: Recently pyspeech users have reported memory leaks, where the RAM used by the job would spike up at the start of every epoch. Kritika and I debugged this and we ran bisect: hg bisect -c "./script.sh" (script.sh: P143079305, training json: P143079313) This identified my diff in D23443012 (https://github.com/pytorch/fairseq/commit/e171c8d86a939cf4ebc483cd649bee1935379771) as the culprit. I then found out it was specifically the line where we wrap the torch dataloader with itertools.islice that causes issues. Differential Revision: D23817986 fbshipit-source-id: 22eee7d384944febddc197bb734e9db79d73e6ea --- fairseq/data/iterators.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 19add56afa..196eb30887 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -521,8 +521,6 @@ def take(self, n): # Propagate this change to the underlying iterator if hasattr(self._iterable, "take"): self._iterable.take(n) - else: - self._iterable = itertools.islice(self._iterable, n) def __next__(self): # Create consumer if not created yet From 7c96648dea371b50c37582c860b9c63ac0242514 Mon Sep 17 00:00:00 2001 From: Matt Post Date: Tue, 22 Sep 2020 08:19:06 -0700 Subject: [PATCH 159/707] Fix inaccuracy in constrained decoding README (#2641) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? * The current README **incorrectly states** that Sockeye 2 no longer supports constrained decoding. * Other minor updates ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2641 Reviewed By: ngoyal2707 Differential Revision: D23833539 Pulled By: myleott fbshipit-source-id: 0ee09830fb3566a1291d3cb41184a1ae400a7836 --- examples/constrained_decoding/README.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/constrained_decoding/README.md b/examples/constrained_decoding/README.md index d101c032da..cfca9c91fd 100644 --- a/examples/constrained_decoding/README.md +++ b/examples/constrained_decoding/README.md @@ -3,8 +3,8 @@ This page provides instructions for how to use lexically constrained decoding in Fairseq. Fairseq implements the code described in the following papers: -* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) -* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) +* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) (Post & Vilar, 2018) +* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) (Hu et al., 2019) ## Quick start @@ -48,8 +48,8 @@ The heart of the implementation is in `fairseq/search.py`, which adds a `Lexical This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`: -* OrderedConstraintState: assumes the C input constraints will be generated in the provided order -* UnorderedConstraintState: tries to apply C (phrasal) constraints in all C! orders +* OrderedConstraintState: assumes the `C` input constraints will be generated in the provided order +* UnorderedConstraintState: tries to apply `C` (phrasal) constraints in all `C!` orders ## Differences from Sockeye @@ -58,9 +58,8 @@ There are a number of [differences from Sockeye's implementation](https://awslab * Generating constraints in the order supplied (the default option here) is not available in Sockeye. * Due to an improved beam allocation method, there is no need to prune the beam. * Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient. -* [The extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged - into the main branch. -* Sockeye 2, released in July 2020, no longer supports constrained decoding. +* [The vector extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged + into the main Sockeye branch. ## Citation From cad08709450714c8a815836c98fcfcee6aff09b0 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 22 Sep 2020 08:25:34 -0700 Subject: [PATCH 160/707] Update transformer_layer.py (#2611) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? In some cases, we want to use the decoder-only model from a pretrained transformer (`encoder_attn` is not None). This commit is minor but important, which not only makes `TransformerDecoderLayer` more robust but also make it compitible with decoder-only model from pretrained transformer. If you want to use decoder from pretrained transformer, you can just set `encoder_out` as None. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2611 Reviewed By: ngoyal2707 Differential Revision: D23833554 Pulled By: myleott fbshipit-source-id: c8b910ec051d8cb0bda4e41b5d9ff249ff88d32b --- fairseq/modules/transformer_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index ced8d933f5..a803f581a5 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -324,7 +324,7 @@ def forward( if not self.normalize_before: x = self.self_attn_layer_norm(x) - if self.encoder_attn is not None: + if self.encoder_attn is not None and encoder_out is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) From 22007c4419da1108af2f5ac54560c73f047e7b36 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Tue, 22 Sep 2020 09:14:32 -0700 Subject: [PATCH 161/707] Pad sequence lengths to a multiple (#2642) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Provides an option to pad the sequence lengths to a multiple. We found this provided significant speed-up when using tensor-cores (multiple of 8). Currently the code only allows for the number of sequences in a batch to be a multiple of 8 which we found to be not as important as padding the batch length to a multiple of 8. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2642 Reviewed By: ngoyal2707 Differential Revision: D23833542 Pulled By: myleott fbshipit-source-id: 7af53c3ee3d6388eafdd2f4d29f6b696cbc4fa3b --- fairseq/data/data_utils.py | 5 ++++- fairseq/data/language_pair_dataset.py | 5 +++++ fairseq/options.py | 2 ++ fairseq/tasks/translation.py | 3 +++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index c8cc90fb08..a3f93cb8fc 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -32,10 +32,13 @@ def infer_language_pair(path): return src, dst -def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, pad_to_length=None): +def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, + pad_to_length=None, pad_to_multiple=1): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size-0.1)//pad_to_multiple + 1) * pad_to_multiple) res = values[0].new(len(values), size).fill_(pad_idx) def copy_tensor(src, dst): diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 5cc5087b2d..3014354e7c 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -22,6 +22,7 @@ def collate( left_pad_target=False, input_feeding=True, pad_to_length=None, + pad_to_multiple=1, ): if len(samples) == 0: return {} @@ -31,6 +32,7 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) def check_alignment(alignment, src_len, tgt_len): @@ -197,6 +199,7 @@ def __init__( num_buckets=0, src_lang_id=None, tgt_lang_id=None, + pad_to_multiple=1, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -257,6 +260,7 @@ def __init__( ] else: self.buckets = None + self.pad_to_multiple = pad_to_multiple def get_batch_shapes(self): return self.buckets @@ -345,6 +349,7 @@ def collater(self, samples, pad_to_length=None): left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res['net_input']['src_tokens'] diff --git a/fairseq/options.py b/fairseq/options.py index 8e50064d76..4add9dd5fa 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -317,6 +317,8 @@ def add_dataset_args(parser, train=False, gen=False): group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N', help='batch size will either be less than this value, ' 'or a multiple of this value') + group.add_argument('--required-seq-len-multiple', default=1, type=int, metavar='N', + help='maximum sequence length in batch will be a multiplier of this value') parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 8181c1a650..a04924605c 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -42,6 +42,7 @@ def load_langpair_dataset( truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, + pad_to_multiple=1, ): def split_exists(split, src, tgt, lang, data_path): @@ -129,6 +130,7 @@ def split_exists(split, src, tgt, lang, data_path): align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, + pad_to_multiple=pad_to_multiple, ) @@ -268,6 +270,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): truncate_source=self.args.truncate_source, num_buckets=self.args.num_batch_buckets, shuffle=(split != 'test'), + pad_to_multiple=self.args.required_seq_len_multiple, ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): From 9e4088bc3d2630d5a4285138662fed4426190e73 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 22 Sep 2020 09:31:39 -0700 Subject: [PATCH 162/707] Update sentence_prediction.py (#2594) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? It is more preferable to use `cls` rather than the actual class name to construct an instance. Actual class name may cause issues if you subclass it and use the class method. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2594 Reviewed By: myleott Differential Revision: D23607595 Pulled By: lematt1991 fbshipit-source-id: 2ca5bb455876817094e8d6153db466e2c8f65192 --- fairseq/tasks/sentence_prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index fec19e0a75..d9a82faddd 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -110,7 +110,7 @@ def setup_task(cls, args, **kwargs): logger.info('[label] dictionary: {} types'.format(len(label_dict))) else: label_dict = data_dict - return SentencePredictionTask(args, data_dict, label_dict) + return cls(args, data_dict, label_dict) def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" From 66f66a4fc7139fdaa724d0854520ef5f68eb42ea Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 22 Sep 2020 11:44:59 -0700 Subject: [PATCH 163/707] Gate psutil import to make tests pass (#1282) Summary: Gate psutil import to make tests pass Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1282 Reviewed By: tangyuq Differential Revision: D23822037 Pulled By: myleott fbshipit-source-id: c652c7931147ecd377d78322840e343c55cb85a2 --- fairseq/data/data_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index a3f93cb8fc..224169c366 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -440,7 +440,9 @@ def arrange(s, e, length, keep_length): def get_mem_usage(): - # for debug - import psutil - mb = 1024 * 1024 - return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb' + try: + import psutil + mb = 1024 * 1024 + return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb' + except ImportError: + return 'N/A' From 3f484789de756ed03a19df9c2ca0bcae205aae97 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Tue, 22 Sep 2020 11:46:23 -0700 Subject: [PATCH 164/707] Add sizes to TransformEosLangPairDataset Summary: LangPairDataset was updated with sizes to improve efficiency, so also added it to TransformEosLangPairDataset. Reviewed By: pipibjc Differential Revision: D23817426 fbshipit-source-id: d62a4ac3535a06618ff70c717ab55f43fc6d48c6 --- fairseq/data/transform_eos_lang_pair_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index 55137ca55c..2783824838 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -78,6 +78,11 @@ def num_tokens(self, index): def size(self, index): return self.dataset.size(index) + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + def ordered_indices(self): return self.dataset.ordered_indices() From 3a979a40e66a3a4006da55809b2f4dd6bd517d20 Mon Sep 17 00:00:00 2001 From: Kaushik Rangadurai Date: Tue, 22 Sep 2020 14:04:38 -0700 Subject: [PATCH 165/707] Remove in-place operations in TransformerSentenceEncoder for captum insights Summary: When using captum insights, we can't take gradients on leaf operations that have in-place operations. This Notebook - N355930 demonstrates this with a simple example Reviewed By: myleott Differential Revision: D23827055 fbshipit-source-id: d49b8ec9a016ccc7195438be9b1782e08479d90d --- fairseq/modules/transformer_sentence_encoder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 8a6994181b..9562430dfa 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -235,13 +235,13 @@ def forward( x = self.embed_tokens(tokens) if self.embed_scale is not None: - x *= self.embed_scale + x = x * self.embed_scale if self.embed_positions is not None: - x += self.embed_positions(tokens, positions=positions) + x = x + self.embed_positions(tokens, positions=positions) if self.segment_embeddings is not None and segment_labels is not None: - x += self.segment_embeddings(segment_labels) + x = x + self.segment_embeddings(segment_labels) if self.quant_noise is not None: x = self.quant_noise(x) @@ -253,7 +253,7 @@ def forward( # account for padding while computing the representation if padding_mask is not None: - x *= 1 - padding_mask.unsqueeze(-1).type_as(x) + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) From 66101b0564d99daa598ead2cc75615e96b362411 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Tue, 22 Sep 2020 21:15:34 -0700 Subject: [PATCH 166/707] hydra fairseq 5 - add module dataclasses Summary: hydra fairseq 5 - add module dataclasses Reviewed By: myleott Differential Revision: D23416496 fbshipit-source-id: f8cf811d73ed81391eb0e48ecf5b6a4104341b08 --- fairseq/criterions/__init__.py | 2 + fairseq/criterions/adaptive_loss.py | 70 ++- fairseq/criterions/cross_entropy.py | 10 + fairseq/models/__init__.py | 2 + fairseq/models/transformer_lm.py | 414 ++++++++++++------ fairseq/optim/__init__.py | 2 + fairseq/optim/adam.py | 106 +++-- fairseq/optim/bmuf.py | 36 +- fairseq/optim/lr_scheduler/__init__.py | 2 + .../optim/lr_scheduler/cosine_lr_scheduler.py | 61 ++- .../inverse_square_root_schedule.py | 34 +- fairseq/optim/nag.py | 43 +- fairseq/tasks/__init__.py | 3 + fairseq/tasks/language_modeling.py | 92 +++- 14 files changed, 650 insertions(+), 227 deletions(-) diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 1c28780111..b3663d6394 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -10,6 +10,8 @@ from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion +CRITERION_DATACLASS_REGISTRY = {} + build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( '--criterion', base_class=FairseqCriterion, diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 1916131bb1..4e6506337e 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -4,14 +4,25 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES +from fairseq.dataclass.utils import FairseqDataclass +from omegaconf import II + +@dataclass +class AdaptiveLossConfig(FairseqDataclass): + sentence_avg: bool = II("params.optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II( + "params.distributed_training.ddp_backend" + ) -@register_criterion('adaptive_loss') + +@register_criterion("adaptive_loss") class AdaptiveLoss(FairseqCriterion): """This is an implementation of the loss function accompanying the adaptive softmax approximation for graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" @@ -23,11 +34,11 @@ def __init__(self, task, sentence_avg): @classmethod def build_criterion(cls, args, task): - if getattr(args, 'ddp_backend', None) == 'c10d': + if getattr(args, "ddp_backend", None) == "c10d": raise Exception( - 'AdaptiveLoss is not compatible with the c10d ' - 'version of DistributedDataParallel. Please use ' - '`--ddp-backend=no_c10d` instead.' + "AdaptiveLoss is not compatible with the c10d " + "version of DistributedDataParallel. Please use " + "`--ddp-backend=no_c10d` instead." ) return cls(task, args.sentence_avg) @@ -40,10 +51,13 @@ def forward(self, model, sample, reduce=True): 3) logging outputs to display while training """ - assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None + assert ( + hasattr(model.decoder, "adaptive_softmax") + and model.decoder.adaptive_softmax is not None + ) adaptive_softmax = model.decoder.adaptive_softmax - net_output = model(**sample['net_input']) + net_output = model(**sample["net_input"]) orig_target = model.get_targets(sample, net_output) nsentences = orig_target.size(0) @@ -58,38 +72,48 @@ def forward(self, model, sample, reduce=True): for i in range(len(target)): if target[i] is not None: - assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) + assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1) loss += F.cross_entropy( logits[i], target[i], ignore_index=self.padding_idx, - reduction='sum' if reduce else 'none', + reduction="sum" if reduce else "none", ) orig = utils.strip_pad(orig_target, self.padding_idx) ntokens = orig.numel() - sample_size = sample['target'].size(0) if self.sentence_avg else ntokens + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens logging_output = { - 'loss': loss.data, - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, + "loss": loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) - ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) - sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) - - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) else: - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg)) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 4e750f62e3..7b690dcff2 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -9,6 +9,16 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES +from dataclasses import dataclass +from fairseq.dataclass.utils import FairseqDataclass +from omegaconf import II + + +@dataclass +class CrossEntropyCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("params.optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") @register_criterion('cross_entropy') diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index f7d8eaafad..332ac822a7 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -25,8 +25,10 @@ MODEL_REGISTRY = {} ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} ARCH_MODEL_INV_REGISTRY = {} ARCH_CONFIG_REGISTRY = {} +ARCH_DATACLASS_REGISTRY = {} __all__ = [ diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index e24452ff8a..88718dd5af 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -3,43 +3,184 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import Optional + from fairseq import options, utils +from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.models import ( FairseqLanguageModel, register_model, register_model_architecture, ) -from fairseq.models.transformer import ( - Embedding, - TransformerDecoder, -) -from fairseq.modules import ( - AdaptiveInput, - CharacterTokenEmbedder, -) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder +from omegaconf import II + DEFAULT_MAX_TARGET_POSITIONS = 1024 -@register_model('transformer_lm') +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + relu_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + decoder_embed_dim: int = field( + default=512, metadata={"help": "decoder embedding dimension"} + ) + decoder_output_dim: int = field( + default=512, metadata={"help": "decoder output dimension"} + ) + decoder_input_dim: int = field( + default=512, metadata={"help": "decoder input dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=2048, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) + decoder_attention_heads: int = field( + default=8, metadata={"help": "num decoder attention heads"} + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + adaptive_softmax_cutoff: Optional[str] = field( + default=None, + metadata={ + "help": "comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion" + }, + ) + adaptive_softmax_dropout: float = field( + default=0, + metadata={"help": "sets adaptive softmax dropout for the tail projections"}, + ) + adaptive_softmax_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + character_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, uses character embedding convolutions to produce token embeddings" + }, + ) + character_filters: str = field( + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + metadata={"help": "size of character embeddings"}, + ) + character_embedding_dim: int = field( + default=4, metadata={"help": "size of character embeddings"} + ) + char_embedder_highway_layers: int = field( + default=2, + metadata={"help": "number of highway layers for character token embeddder"}, + ) + adaptive_input: bool = field( + default=False, metadata={"help": "if set, uses adaptive input"} + ) + adaptive_input_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + adaptive_input_cutoff: Optional[str] = field( + default=None, + metadata={"help": "comma separated list of adaptive input cutoff points."}, + ) + tie_adaptive_weights: bool = field( + default=False, + metadata={ + "help": "if set, ties the weights of adaptive softmax and adaptive input" + }, + ) + tie_adaptive_proj: bool = field( + default=False, + metadata={ + "help": "if set, ties the projection weights of adaptive softmax and adaptive input" + }, + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + quant_noise_pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + quant_noise_pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + # TODO common var add to parent + quant_noise_scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + add_bos_token: bool = II("task.add_bos_token") + tokens_per_sample: int = II("task.tokens_per_sample") + max_target_positions: Optional[int] = II("task.max_target_positions") + # TODO common var add to parent + tpu: bool = II("params.common.tpu") + + +@register_model("transformer_lm") class TransformerLanguageModel(FairseqLanguageModel): - @classmethod def hub_models(cls): - def moses_fastbpe(path): - return { - 'path': path, - 'tokenizer': 'moses', - 'bpe': 'fastbpe', - } + return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} return { - 'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2', - 'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2', - 'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'), - 'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'), - 'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'), + "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", + "transformer_lm.wiki103.adaptive": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2", + "transformer_lm.wmt19.en": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2" + ), + "transformer_lm.wmt19.de": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2" + ), + "transformer_lm.wmt19.ru": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" + ), } def __init__(self, decoder): @@ -134,34 +275,47 @@ def build_model(cls, args, task): if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) - if getattr(args, 'max_target_positions', None) is None: - args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) if args.character_embeddings: embed_tokens = CharacterTokenEmbedder( - task.source_dictionary, eval(args.character_filters), - args.character_embedding_dim, args.decoder_embed_dim, + task.source_dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, args.char_embedder_highway_layers, ) elif args.adaptive_input: embed_tokens = AdaptiveInput( - len(task.source_dictionary), task.source_dictionary.pad(), args.decoder_input_dim, - args.adaptive_input_factor, args.decoder_embed_dim, + len(task.source_dictionary), + task.source_dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, options.eval_str_list(args.adaptive_input_cutoff, type=int), - args.quant_noise_pq, args.quant_noise_pq_block_size, + args.quant_noise_pq, + args.quant_noise_pq_block_size, ) else: - embed_tokens = cls.build_embedding(args, task.source_dictionary, args.decoder_input_dim) + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) if args.tie_adaptive_weights: assert args.adaptive_input assert args.adaptive_input_factor == args.adaptive_softmax_factor - assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format( - args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) assert args.decoder_input_dim == args.decoder_output_dim decoder = TransformerDecoder( - args, task.target_dictionary, embed_tokens, no_encoder_attn=True, + args, task.target_dictionary, embed_tokens, no_encoder_attn=True ) return cls(decoder) @@ -171,140 +325,148 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): return embed_tokens -@register_model_architecture('transformer_lm', 'transformer_lm') +@register_model_architecture("transformer_lm", "transformer_lm") def base_lm_architecture(args): # backward compatibility for older model checkpoints - if hasattr(args, 'no_tie_adaptive_proj'): + if hasattr(args, "no_tie_adaptive_proj"): # previous models defined --no-tie-adaptive-proj, so use the existence of # that option to determine if this is an "old" model checkpoint args.no_decoder_final_norm = True # old models always set this to True if args.no_tie_adaptive_proj is False: args.tie_adaptive_proj = True - if hasattr(args, 'decoder_final_norm'): + if hasattr(args, "decoder_final_norm"): args.no_decoder_final_norm = not args.decoder_final_norm - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.0) - - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) - args.activation_fn = getattr(args, 'activation_fn', 'relu') - - args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0) - args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None) - args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0) - args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size', 8) - args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0) - - args.add_bos_token = getattr(args, 'add_bos_token', False) - args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.character_embeddings = getattr(args, 'character_embeddings', False) - - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.activation_fn = getattr(args, "activation_fn", "relu") + + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + args.add_bos_token = getattr(args, "add_bos_token", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) # Model training is not stable without this args.decoder_normalize_before = True - args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False) + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) - args.adaptive_input = getattr(args, 'adaptive_input', False) - args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) - args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) - args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) - args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) - args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) - args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) -@register_model_architecture('transformer_lm', 'transformer_lm_big') +@register_model_architecture("transformer_lm", "transformer_lm_big") def transformer_lm_big(args): - args.decoder_layers = getattr(args, 'decoder_layers', 12) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) base_lm_architecture(args) -@register_model_architecture('transformer_lm', 'transformer_lm_wiki103') -@register_model_architecture('transformer_lm', 'transformer_lm_baevski_wiki103') +@register_model_architecture("transformer_lm", "transformer_lm_wiki103") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_wiki103") def transformer_lm_baevski_wiki103(args): - args.decoder_layers = getattr(args, 'decoder_layers', 16) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) - args.dropout = getattr(args, 'dropout', 0.3) - args.adaptive_input = getattr(args, 'adaptive_input', True) - args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', True) - args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', '20000,60000') - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000') - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_dropout = getattr(args, 'activation_dropout', 0.1) - args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True) - args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', True) + args.decoder_layers = getattr(args, "decoder_layers", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.dropout = getattr(args, "dropout", 0.3) + args.adaptive_input = getattr(args, "adaptive_input", True) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", True) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", "20000,60000") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "20000,60000" + ) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0.2) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", True) transformer_lm_big(args) -@register_model_architecture('transformer_lm', 'transformer_lm_gbw') -@register_model_architecture('transformer_lm', 'transformer_lm_baevski_gbw') +@register_model_architecture("transformer_lm", "transformer_lm_gbw") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_gbw") def transformer_lm_baevski_gbw(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) transformer_lm_big(args) -@register_model_architecture('transformer_lm', 'transformer_lm_gpt') +@register_model_architecture("transformer_lm", "transformer_lm_gpt") def transformer_lm_gpt(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072) - args.decoder_layers = getattr(args, 'decoder_layers', 12) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) -@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_small') +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_small") def transformer_lm_gpt2_small(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) - args.decoder_layers = getattr(args, 'decoder_layers', 24) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) -@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_medium') +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") def transformer_lm_gpt2_medium(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1280) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 5120) - args.decoder_layers = getattr(args, 'decoder_layers', 36) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 20) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) + args.decoder_layers = getattr(args, "decoder_layers", 36) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) -@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_big') +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big") def transformer_lm_gpt2_big(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1600) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 6400) - args.decoder_layers = getattr(args, 'decoder_layers', 48) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 25) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400) + args.decoder_layers = getattr(args, "decoder_layers", 48) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 2f723866dc..773492775a 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -13,6 +13,8 @@ from fairseq.optim.shard import shard_ +OPTIMIZER_DATACLASS_REGISTRY = {} + __all__ = [ 'FairseqOptimizer', 'FP16Optimizer', diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index d5783b258c..1afec99be6 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,19 +5,39 @@ import logging import math -import types +from dataclasses import dataclass, field +from typing import List import torch -import torch.optim import torch.distributed as dist - +import torch.optim +from fairseq.dataclass.utils import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class +from omegaconf import II + logger = logging.getLogger(__name__) -@register_optimizer('adam') +@dataclass +class FairseqAdamConfig(FairseqDataclass): + adam_betas: str = field( + default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} + ) + adam_eps: float = field( + default=1e-8, metadata={"help": "epsilon for Adam optimizer"} + ) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + use_old_adam: bool = field( + default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} + ) + # TODO common vars below in parent + tpu: bool = II("params.common.tpu") + lr: List[float] = II("params.optimization.lr") + + +@register_optimizer("adam") class FairseqAdam(FairseqOptimizer): """Adam optimizer for fairseq. @@ -30,16 +50,16 @@ def __init__(self, args, params): super().__init__(args) fused_adam_cls = get_fused_adam_class() use_fused_adam = ( - not getattr(args, 'use_old_adam', False) + not getattr(args, "use_old_adam", False) and fused_adam_cls is not None and torch.cuda.is_available() ) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) elif use_fused_adam: - logger.info('using FusedAdam') + logger.info("using FusedAdam") self._optimizer = fused_adam_cls(params, **self.optimizer_config) else: self._optimizer = Adam(params, **self.optimizer_config) @@ -73,10 +93,10 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'betas': eval(self.args.adam_betas), - 'eps': self.args.adam_eps, - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "betas": eval(self.args.adam_betas), + "eps": self.args.adam_eps, + "weight_decay": self.args.weight_decay, } def average_params(self): @@ -118,10 +138,18 @@ class Adam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False): - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + ): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + ) super(Adam, self).__init__(params, defaults) @property @@ -144,15 +172,17 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - amsgrad = group['amsgrad'] + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: @@ -162,26 +192,28 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p_data_fp32) + state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: - state['exp_avg'] = state['exp_avg'].to(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].to(p_data_fp32) + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) if amsgrad: - state['max_exp_avg_sq'] = state['max_exp_avg_sq'].to(p_data_fp32) + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( + p_data_fp32 + ) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) @@ -190,16 +222,18 @@ def step(self, closure=None): # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient - denom = max_exp_avg_sq.sqrt().add_(group['eps']) + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) else: - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = exp_avg_sq.sqrt().add_(group["eps"]) - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index be7bdd74a7..bcdeeee45b 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -3,10 +3,42 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + import torch import torch.distributed as dist - -from . import FairseqOptimizer +from fairseq.dataclass.utils import FairseqDataclass +from fairseq.optim.fairseq_optimizer import FairseqOptimizer +from omegaconf import II + + +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II( + "params.distributed_training.distributed_world_size" + ) class FairseqBMUF(FairseqOptimizer): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index 76c5357189..fe84bc6004 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -10,6 +10,8 @@ from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa +LR_SCHEDULER_DATACLASS_REGISTRY = {} + build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( '--lr-scheduler', base_class=FairseqLRScheduler, diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 9137e11b78..aeee95b84e 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,11 +4,45 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field +from typing import List + +from fairseq.dataclass.utils import FairseqDataclass +from omegaconf import II from . import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('cosine') +@dataclass +class CosineConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is args.lr" + }, + ) + max_lr: float = field( + default=1.0, metadata={"help": "max learning rate, must be more than args.lr"} + ) + t_mult: float = field( + default=1.0, metadata={"help": "factor to grow the length of each period"} + ) + lr_period_updates: float = field( + default=-1, metadata={"help": "initial number of updates per period"} + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + # TODO common var for parent class + lr: List[float] = II("params.optimization.lr") + max_update: int = II("params.optimization.max_update") + + +@register_lr_scheduler("cosine") class CosineSchedule(FairseqLRScheduler): """Assign LR based on a cyclical schedule that follows the cosine function. @@ -36,8 +70,8 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) if len(args.lr) > 1: raise ValueError( - 'Cannot use a fixed learning rate schedule with cosine.' - ' Consider --lr-scheduler=fixed instead.' + "Cannot use a fixed learning rate schedule with cosine." + " Consider --lr-scheduler=fixed instead." ) warmup_end_lr = args.max_lr @@ -47,13 +81,15 @@ def __init__(self, args, optimizer): self.min_lr = args.lr[0] self.max_lr = args.max_lr - assert self.max_lr > self.min_lr, 'max_lr must be more than lr' + assert self.max_lr > self.min_lr, "max_lr must be more than lr" self.t_mult = args.t_mult self.period = args.lr_period_updates if self.period <= 0: - assert args.max_update >= 0, 'Either --max_update or --lr-period-updates must be set' + assert ( + args.max_update >= 0 + ), "Either --max_update or --lr-period-updates must be set" self.period = args.max_update - args.warmup_updates if args.warmup_updates > 0: @@ -100,9 +136,16 @@ def step_update(self, num_updates): else: curr_updates = num_updates - self.args.warmup_updates if self.t_mult != 1: - i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) + i = math.floor( + math.log( + 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult + ) + ) t_i = self.t_mult ** i * self.period - t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period + t_curr = ( + curr_updates + - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period + ) else: i = math.floor(curr_updates / self.period) t_i = self.period @@ -112,7 +155,9 @@ def step_update(self, num_updates): min_lr = self.min_lr * lr_shrink max_lr = self.max_lr * lr_shrink - self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) + self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( + 1 + math.cos(math.pi * t_curr / t_i) + ) self.optimizer.set_lr(self.lr) return self.lr diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index f98a7c3b99..1f59d4c83e 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,10 +3,32 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import List + +from fairseq.dataclass.utils import FairseqDataclass +from omegaconf import II + from . import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('inverse_sqrt') +@dataclass +class InverseSquareRootScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=4000, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is args.lr" + }, + ) + # TODO common vars at parent class + lr: List[float] = II("params.optimization.lr") + + +@register_lr_scheduler("inverse_sqrt") class InverseSquareRootSchedule(FairseqLRScheduler): """Decay the LR based on the inverse square root of the update number. @@ -30,8 +52,8 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) if len(args.lr) > 1: raise ValueError( - 'Cannot use a fixed learning rate schedule with inverse_sqrt.' - ' Consider --lr-scheduler=fixed instead.' + "Cannot use a fixed learning rate schedule with inverse_sqrt." + " Consider --lr-scheduler=fixed instead." ) warmup_end_lr = args.lr[0] if args.warmup_init_lr < 0: @@ -41,7 +63,7 @@ def __init__(self, args, optimizer): self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates # then, decay prop. to the inverse square root of the update number - self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 + self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 # initial learning rate self.lr = args.warmup_init_lr @@ -66,8 +88,8 @@ def step(self, epoch, val_loss=None): def step_update(self, num_updates): """Update the learning rate after each update.""" if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates*self.lr_step + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step else: - self.lr = self.decay_factor * num_updates**-0.5 + self.lr = self.decay_factor * num_updates ** -0.5 self.optimizer.set_lr(self.lr) return self.lr diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index d9b7fb8019..7806e9311e 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,13 +3,26 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import List + import torch +from fairseq.dataclass.utils import FairseqDataclass +from omegaconf import II from torch.optim.optimizer import Optimizer, required from . import FairseqOptimizer, register_optimizer -@register_optimizer('nag') +@dataclass +class FairseqNAGConfig(FairseqDataclass): + momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + # TODO common vars in parent class + lr: List[float] = II("params.optimization.lr") + + +@register_optimizer("nag") class FairseqNAG(FairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -34,9 +47,9 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'momentum': self.args.momentum, - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, } @@ -65,13 +78,13 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - lr = group['lr'] - lr_old = group.get('lr_old', lr) + weight_decay = group["weight_decay"] + momentum = group["momentum"] + lr = group["lr"] + lr_old = group.get("lr_old", lr) lr_correct = lr / lr_old - for p in group['params']: + for p in group["params"]: if p.grad is None: continue @@ -81,12 +94,14 @@ def step(self, closure=None): d_p = p.grad.data.float() param_state = self.state[p] - if 'momentum_buffer' not in param_state: - param_state['momentum_buffer'] = torch.zeros_like(d_p) + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros_like(d_p) else: - param_state['momentum_buffer'] = param_state['momentum_buffer'].to(d_p) + param_state["momentum_buffer"] = param_state["momentum_buffer"].to( + d_p + ) - buf = param_state['momentum_buffer'] + buf = param_state["momentum_buffer"] if weight_decay != 0: p_data_fp32.mul_(1 - lr * weight_decay) @@ -98,6 +113,6 @@ def step(self, closure=None): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) - group['lr_old'] = lr + group["lr_old"] = lr return loss diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 69231a8522..0b8df065a9 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -9,8 +9,11 @@ from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa + +# register dataclass TASK_REGISTRY = {} TASK_CLASS_NAMES = set() +TASK_DATACLASS_REGISTRY = {} def setup_task(args, **kwargs): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 1916a1550c..0c26866b27 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -5,14 +5,14 @@ import logging import os +from dataclasses import dataclass, field +from typing import Optional import numpy as np import torch - from fairseq import utils from fairseq.data import ( AppendTokenDataset, - data_utils, Dictionary, IdDataset, MonolingualDataset, @@ -22,16 +22,77 @@ PrependTokenDataset, StripTokenDataset, TokenBlockDataset, - TransformEosDataset, TruncatedDictionary, + data_utils, ) +from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.tasks import FairseqTask, register_task +from omegaconf import II +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) logger = logging.getLogger(__name__) +@dataclass +class LanguageModelingConfig(FairseqDataclass): + # TODO common var add to parent + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + self_target: bool = field(default=False, metadata={"help": "include self target"}) + future_target: bool = field( + default=False, metadata={"help": "include future target"} + ) + past_target: bool = field(default=False, metadata={"help": "include past target"}) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend beginning of sentence token ()"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + # TODO common vars below add to parent + seed: int = II("params.common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "params.dataset.dataset_impl" + ) + data_buffer_size: int = II("params.dataset.data_buffer_size") + tpu: bool = II("params.common.tpu") + + @register_task("language_modeling") class LanguageModelingTask(FairseqTask): """ @@ -242,23 +303,28 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): else self.source_dictionary.eos() ), ) - tgt_dataset = AppendTokenDataset( - dataset, - token=self.source_dictionary.pad() - ) + tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { - "src_tokens": PadDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, - "target": PadDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), + "target": PadDataset( + tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False + ), }, sizes=[np.array(src_lengths)], ) - def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): with torch.no_grad(): # Generation will always be conditioned on bos_token if getattr(self.args, "add_bos_token", False): @@ -267,7 +333,9 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, constrai bos_token = self.source_dictionary.eos() if constraints is not None: - raise NotImplementedError("Constrained decoding with the language_modeling task is not supported") + raise NotImplementedError( + "Constrained decoding with the language_modeling task is not supported" + ) # SequenceGenerator doesn't use src_tokens directly, we need to # pass the `prefix_tokens` argument instead @@ -277,7 +345,7 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, constrai prefix_tokens = prefix_tokens[:, 1:] return generator.generate( - models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token, + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token ) @property From 53f135765c018c5241ed41ab7c1bc19be257fa90 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 24 Sep 2020 23:26:54 -0700 Subject: [PATCH 167/707] fix tensor layout issues going from 1.4 to 1.5/6 (#1294) Summary: this fixes decoding with wav2letter decoder during training. it is an issue with pytorch 1.5+ where cpu() no longer implies contiguous() and so data_ptr points to non-contiguous layouts. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1294 Reviewed By: pipibjc Differential Revision: D23925604 Pulled By: alexeib fbshipit-source-id: 68d5ce7cf45c008bace3dc4a995cec6b46567b3d --- fairseq/criterions/ctc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index cbf712c69d..7398bf8117 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -117,7 +117,7 @@ def forward(self, model, sample, reduce=True): import editdistance with torch.no_grad(): - lprobs_t = lprobs.transpose(0, 1).float().cpu() + lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() c_err = 0 c_len = 0 From 3b7d85c91f0afaa8b78a3bcb9b216a2ff38c1b01 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 25 Sep 2020 08:28:10 -0700 Subject: [PATCH 168/707] Transformer with integrated pointer-generator network (#2529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This pull request implements a variant of the Transformer model that uses an attention distribution for pointing to input words. The attention distribution over the input words is interpolated with the normal output distribution over the vocabulary words, as in [See et al. (2017)](https://arxiv.org/abs/1704.04368). This allows the model to generate words that appear in the input, even if they don't appear in the vocabulary, helping especially with small vocabularies. The mechanism for copying out-of-vocabulary words from the input has been implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) they convey the word identities through the model in order to be able to produce out-of-vocabulary words. We wanted to minimize changes to the Fairseq code base and took a different approach, which I'll describe below. The entire implementation is contained in one file (plus there's one new test). Copying out-of-vocabulary words is possible by pre-processing the input and post-processing the output. The user may add special words to the end of the vocabulary that can be used in place of `` tokens to identify different input positions (e.g. ``, ``, ``, ...). The number of these special words is given to the model with the `--source-position-markers` argument—the model simply maps all of these to the same word embedding as ``. With a simple post-processing the user may retrieve word at position N in the original text and use it in place of ``. I didn't find a good place to document this usage of this model, so let me know if you think I should improve documentation somewhere. This feature has not yet been discussed via a GitHub issue, but I'll open a new issue for discussion. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2529 Reviewed By: ngoyal2707 Differential Revision: D23398430 Pulled By: myleott fbshipit-source-id: f2f26c8ce8802ae6cf95515637660348ff3fc457 --- examples/pointer_generator/README.md | 82 +++ examples/pointer_generator/README.xsum.md | 180 +++++++ examples/pointer_generator/postprocess.py | 96 ++++ examples/pointer_generator/preprocess.py | 102 ++++ examples/pointer_generator/src/__init__.py | 6 + .../pointer_generator/src/transformer_pg.py | 470 ++++++++++++++++++ tests/test_binaries.py | 18 + tests/utils.py | 18 + 8 files changed, 972 insertions(+) create mode 100644 examples/pointer_generator/README.md create mode 100644 examples/pointer_generator/README.xsum.md create mode 100755 examples/pointer_generator/postprocess.py create mode 100755 examples/pointer_generator/preprocess.py create mode 100644 examples/pointer_generator/src/__init__.py create mode 100644 examples/pointer_generator/src/transformer_pg.py diff --git a/examples/pointer_generator/README.md b/examples/pointer_generator/README.md new file mode 100644 index 0000000000..6096570825 --- /dev/null +++ b/examples/pointer_generator/README.md @@ -0,0 +1,82 @@ +# Transformer with Pointer-Generator Network + +This page describes the `transformer_pointer_generator` model that incorporates +a pointing mechanism in the Transformer model that facilitates copying of input +words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/). + +## Background + +The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368) +for RNN encoder-decoder attention models. A similar mechanism can be +incorporated in a Transformer model by reusing one of the many attention +distributions for pointing. The attention distribution over the input words is +interpolated with the normal output distribution over the vocabulary words. This +allows the model to generate words that appear in the input, even if they don't +appear in the vocabulary, helping especially with small vocabularies. + +## Implementation + +The mechanism for copying out-of-vocabulary words from the input has been +implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) +they convey the word identities through the model in order to be able to produce +words that appear in the input sequence but not in the vocabulary. A different +approach was taken in the Fairseq implementation to keep it self-contained in +the model file, avoiding any changes to the rest of the code base. Copying +out-of-vocabulary words is possible by pre-processing the input and +post-processing the output. This is described in detail in the next section. + +## Usage + +The training and evaluation procedure is outlined below. You can also find a +more detailed example for the XSum dataset on [this page](README.xsum.md). + +##### 1. Create a vocabulary and extend it with source position markers + +The pointing mechanism is especially helpful with small vocabularies, if we are +able to recover the identities of any out-of-vocabulary words that are copied +from the input. For this purpose, the model allows extending the vocabulary with +special tokens that can be used in place of `` tokens to identify different +input positions. For example, the user may add ``, ``, ``, +etc. to the end of the vocabulary, after the normal words. Below is an example +of how to create a vocabulary of 10000 most common words and add 1000 input +position markers. + +```bash +vocab_size=10000 +position_markers=1000 +export LC_ALL=C +cat train.src train.tgt | + tr -s '[:space:]' '\n' | + sort | + uniq -c | + sort -k1,1bnr -k2 | + head -n "$((vocab_size - 4))" | + awk '{ print $2 " " $1 }' >dict.pg.txt +python3 -c "[print(' 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt +``` + +##### 2. Preprocess the text data + +The idea is that any `` tokens in the text are replaced with `` if +it appears in the first input position, `` if it appears in the second +input position, and so on. This can be achieved using the `preprocess.py` script +that is provided in this directory. + +##### 3. Train a model + +The number of these special tokens is given to the model with the +`--source-position-markers` argument—the model simply maps all of these to the +same word embedding as ``. + +The attention distribution that is used for pointing is selected using the +`--alignment-heads` and `--alignment-layer` command-line arguments in the same +way as with the `transformer_align` model. + +##### 4. Generate text and postprocess it + +When using the model to generate text, you want to preprocess the input text in +the same way that training data was processed, replacing out-of-vocabulary words +with `` tokens. If any of these tokens are copied to the output, the +actual words can be retrieved from the unprocessed input text. Any `` +token should be replaced with the word at position N in the original input +sequence. This can be achieved using the `postprocess.py` script. diff --git a/examples/pointer_generator/README.xsum.md b/examples/pointer_generator/README.xsum.md new file mode 100644 index 0000000000..ab288afc0c --- /dev/null +++ b/examples/pointer_generator/README.xsum.md @@ -0,0 +1,180 @@ +## Training a pointer-generator model on the Extreme Summarization dataset + +##### 1. Download the Extreme Summarization data and preprocess it + +Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain +the original Extreme Summarization dataset. You should have six files, +{train,validation,test}.{document,summary}. + +##### 2. Create a vocabulary and extend it with source position markers + +```bash +vocab_size=10000 +position_markers=1000 +export LC_ALL=C +cat train.document train.summary | + tr -s '[:space:]' '\n' | + sort | + uniq -c | + sort -k1,1bnr -k2 | + head -n "$((vocab_size - 4))" | + awk '{ print $2 " " $1 }' >dict.pg.txt +python3 -c "[print(' 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt +``` + +This creates the file dict.pg.txt that contains the 10k most frequent words, +followed by 1k source position markers: + +``` +the 4954867 +. 4157552 +, 3439668 +to 2212159 +a 1916857 +of 1916820 +and 1823350 +... + 0 + 0 + 0 + 0 + 0 +... +``` + +##### 2. Preprocess the text data + +```bash +./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt +./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt +./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src +``` + +The data should now contain `` tokens in place of out-of-vocabulary words. + +##### 3. Binarize the dataset: + +```bash +fairseq-preprocess \ + --source-lang src \ + --target-lang tgt \ + --trainpref train.pg \ + --validpref valid.pg \ + --destdir bin \ + --workers 60 \ + --srcdict dict.pg.txt \ + --joined-dictionary +``` + +##### 3. Train a model + +```bash +total_updates=20000 +warmup_updates=500 +lr=0.001 +max_tokens=4096 +update_freq=4 +pointer_layer=-2 + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \ + --user-dir examples/pointer_generator/src \ + --max-tokens "$max_tokens" \ + --task translation \ + --source-lang src --target-lang tgt \ + --truncate-source \ + --layernorm-embedding \ + --share-all-embeddings \ + --encoder-normalize-before \ + --decoder-normalize-before \ + --required-batch-size-multiple 1 \ + --arch transformer_pointer_generator \ + --alignment-layer "$pointer_layer" \ + --alignment-heads 1 \ + --source-position-markers 1000 \ + --criterion label_smoothed_cross_entropy \ + --label-smoothing 0.1 \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ + --clip-norm 0.1 \ + --lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \ + --update-freq "$update_freq" \ + --skip-invalid-size-inputs-valid-test +``` + +Above we specify that our dictionary contains 1000 source position markers, and +that we want to use one attention head from the penultimate decoder layer for +pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The +logged messages confirm that dictionary indices above 10000 will be mapped to +the `` embedding: + +``` +2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types +2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types +2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src +2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt +2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples +2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3 +``` + +##### 4. Summarize the test sequences + +```bash +batch_size=32 +beam_size=6 +max_length=60 +length_penalty=1.0 + +fairseq-interactive bin \ + --user-dir examples/pointer_generator/src \ + --batch-size "$batch_size" \ + --task translation \ + --source-lang src --target-lang tgt \ + --path checkpoints/checkpoint_last.pt \ + --input test.pg.src \ + --buffer-size 200 \ + --max-len-a 0 \ + --max-len-b "$max_length" \ + --lenpen "$length_penalty" \ + --beam "$beam_size" \ + --skip-invalid-size-inputs-valid-test | + tee generate.out +grep ^H generate.out | cut -f 3- >generate.hyp +``` + +Now you should have the generated sequences in `generate.hyp`. They contain +`` tokens that the model has copied from the source sequence. In order to +retrieve the original words, we need the unprocessed source sequences from +`test.document`. + +##### 5. Process the generated output + +Since we skipped too long inputs when producing `generate.hyp`, we also have to +skip too long sequences now that we read `test.document`. + +```bash +./postprocess.py \ + --source <(awk 'NF<1024' test.document) \ + --target generate.hyp \ + --target-out generate.hyp.processed +``` + +Now you'll find the final sequences from `generate.hyp.processed`, with +`` replaced with the original word from the source sequence. + +##### An example of a summarized sequence + +The original source document in `test.document`: + +> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page . + +The preprocessed source document in `test.src.pg`: + +> de \ moved to \ in june 2016 for an initial # \ m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page . + +The generated summary in `generate.hyp`: + +> middlesbrough striker \ de \ has joined spanish side \ on a season-long loan . + +The generated summary after postprocessing in `generate.hyp.processed`: + +> middlesbrough striker \ de roon has joined spanish side \ on a season-long loan . diff --git a/examples/pointer_generator/postprocess.py b/examples/pointer_generator/postprocess.py new file mode 100755 index 0000000000..a01434b5ce --- /dev/null +++ b/examples/pointer_generator/postprocess.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import re +import argparse + + +class OOVIndexError(IndexError): + def __init__(self, pos, source_seq, target_seq): + super(OOVIndexError, self).__init__( + "A tag in the target sequence refers to a position that is " + "outside the source sequence. Most likely there was a mismatch in " + "provided source and target sequences. Otherwise this would mean that " + "the pointing mechanism somehow attended to a position that is past " + "the actual sequence end." + ) + self.source_pos = pos + self.source_seq = source_seq + self.target_seq = target_seq + + +def replace_oovs(source_in, target_in, target_out): + """Replaces tokens in the target text with the corresponding word in + the source text. + """ + + oov_re = re.compile("^$") + + for source_seq, target_seq in zip(source_in, target_in): + target_seq_out = [] + + pos_to_word = source_seq.strip().split() + for token in target_seq.strip().split(): + m = oov_re.match(token) + if m: + pos = int(m.group(1)) + if pos >= len(pos_to_word): + raise OOVIndexError(pos, source_seq, target_seq) + token_out = pos_to_word[pos] + else: + token_out = token + target_seq_out.append(token_out) + target_out.write(" ".join(target_seq_out) + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Replaces tokens in target sequences with words from " + "the corresponding position in the source sequence." + ) + parser.add_argument( + "--source", type=str, help="text file with source sequences", required=True + ) + parser.add_argument( + "--target", type=str, help="text file with target sequences", required=True + ) + parser.add_argument( + "--target-out", + type=str, + help="where to write target sequences without " "entries", + required=True, + ) + args = parser.parse_args() + + target_in = ( + open(args.target, "r", encoding="utf-8") if args.target is not None else None + ) + target_out = ( + open(args.target_out, "w", encoding="utf-8") + if args.target_out is not None + else None + ) + with open(args.source, "r", encoding="utf-8") as source_in, open( + args.target, "r", encoding="utf-8" + ) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out: + replace_oovs(source_in, target_in, target_out) + + +if __name__ == "__main__": + try: + main() + except OOVIndexError as e: + print(e, file=sys.stderr) + print("Source sequence:", e.source_seq.strip(), file=sys.stderr) + print("Target sequence:", e.target_seq.strip(), file=sys.stderr) + print( + "Source sequence length:", + len(e.source_seq.strip().split()), + file=sys.stderr, + ) + print("The offending tag points to:", e.source_pos) + sys.exit(2) diff --git a/examples/pointer_generator/preprocess.py b/examples/pointer_generator/preprocess.py new file mode 100755 index 0000000000..4b7a5ab9c5 --- /dev/null +++ b/examples/pointer_generator/preprocess.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from itertools import zip_longest + + +def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): + """Replaces out-of-vocabulary words in source and target text with , + where N in is the position of the word in the source sequence. + """ + + def format_unk(pos): + return "".format(pos) + + if target_in is None: + target_in = [] + + for seq_num, (source_seq, target_seq) in enumerate( + zip_longest(source_in, target_in) + ): + source_seq_out = [] + target_seq_out = [] + + word_to_pos = dict() + for position, token in enumerate(source_seq.strip().split()): + if token in vocabulary: + token_out = token + else: + if token in word_to_pos: + oov_pos = word_to_pos[token] + else: + word_to_pos[token] = position + oov_pos = position + token_out = format_unk(oov_pos) + source_seq_out.append(token_out) + source_out.write(" ".join(source_seq_out) + "\n") + + if target_seq is not None: + for token in target_seq.strip().split(): + if token in word_to_pos: + token_out = format_unk(word_to_pos[token]) + else: + token_out = token + target_seq_out.append(token_out) + if target_out is not None: + target_out.write(" ".join(target_seq_out) + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Replaces out-of-vocabulary words in both source and target " + "sequences with tokens that indicate the position of the word " + "in the source sequence." + ) + parser.add_argument( + "--source", type=str, help="text file with source sequences", required=True + ) + parser.add_argument( + "--target", type=str, help="text file with target sequences", default=None + ) + parser.add_argument("--vocab", type=str, help="vocabulary file", required=True) + parser.add_argument( + "--source-out", + type=str, + help="where to write source sequences with entries", + required=True, + ) + parser.add_argument( + "--target-out", + type=str, + help="where to write target sequences with entries", + default=None, + ) + args = parser.parse_args() + + with open(args.vocab, encoding="utf-8") as vocab: + vocabulary = vocab.read().splitlines() + + target_in = ( + open(args.target, "r", encoding="utf-8") if args.target is not None else None + ) + target_out = ( + open(args.target_out, "w", encoding="utf-8") + if args.target_out is not None + else None + ) + with open(args.source, "r", encoding="utf-8") as source_in, open( + args.source_out, "w", encoding="utf-8" + ) as source_out: + replace_oovs(source_in, target_in, vocabulary, source_out, target_out) + if target_in is not None: + target_in.close() + if target_out is not None: + target_out.close() + + +if __name__ == "__main__": + main() diff --git a/examples/pointer_generator/src/__init__.py b/examples/pointer_generator/src/__init__.py new file mode 100644 index 0000000000..c361ff6bd6 --- /dev/null +++ b/examples/pointer_generator/src/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import transformer_pg # noqa diff --git a/examples/pointer_generator/src/transformer_pg.py b/examples/pointer_generator/src/transformer_pg.py new file mode 100644 index 0000000000..af933b3495 --- /dev/null +++ b/examples/pointer_generator/src/transformer_pg.py @@ -0,0 +1,470 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from fairseq import utils, metrics +from fairseq.models import register_model, register_model_architecture +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import ( + TransformerModel, + TransformerDecoder, + TransformerEncoder, + base_architecture, + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, +) + +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +@register_model("transformer_pointer_generator") +class TransformerPointerGeneratorModel(TransformerModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani et al, 2017) + `_, augmented with a pointer-generator + network from `"Get To The Point: Summarization with Pointer-Generator + Networks" (See et al, 2017) `_. + + Args: + encoder (TransformerPointerGeneratorEncoder): the encoder + decoder (TransformerPointerGeneratorDecoder): the decoder + + The Transformer pointer-generator model provides the following named + architectures and command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_pointer_generator_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + TransformerModel.add_args(parser) + parser.add_argument('--alignment-heads', type=int, metavar='N', + help='number of attention heads to be used for ' + 'pointing') + parser.add_argument('--alignment-layer', type=int, metavar='I', + help='layer number to be used for pointing (0 ' + 'corresponding to the bottommost layer)') + parser.add_argument('--source-position-markers', type=int, metavar='N', + help='dictionary includes N additional items that ' + 'represent an OOV token at a particular input ' + 'position') + parser.add_argument('--force-generation', type=float, metavar='P', + default=None, + help='set the vocabulary distribution weight to P, ' + 'instead of predicting it from the input (1.0 ' + 'corresponding to generation, 0.0 to pointing)') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + if getattr(args, "source_position_markers", None) is None: + args.source_position_markers = args.max_source_positions + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + if src_dict != tgt_dict: + raise ValueError("Pointer-generator requires a joined dictionary") + + def build_embedding(dictionary, embed_dim, path=None): + # The dictionary may include additional items that can be used in + # place of the normal OOV token and that all map to the same + # embedding. Using a different token for each input position allows + # one to restore the word identities from the original source text. + num_embeddings = len(dictionary) - args.source_position_markers + padding_idx = dictionary.pad() + unk_idx = dictionary.unk() + logger.info( + "dictionary indices from {0} to {1} will be mapped to {2}".format( + num_embeddings, len(dictionary) - 1, unk_idx + ) + ) + emb = Embedding(num_embeddings, embed_dim, padding_idx, unk_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + if args.share_all_embeddings: + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = build_embedding( + tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return cls(args, encoder, decoder) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerPointerGeneratorEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerPointerGeneratorDecoder(args, tgt_dict, embed_tokens) + + +class TransformerPointerGeneratorEncoder(TransformerEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. The pointer-generator variant adds + the source tokens to the encoder output as these are otherwise not passed + to the decoder. + """ + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + Runs the `forward()` method of the parent Transformer class. Then adds + the source tokens into the encoder output tuple. + + While it might be more elegant that the model would pass the source + tokens to the `forward()` method of the decoder too, this would require + changes to `SequenceGenerator`. + + Args: + src_tokens (torch.LongTensor): tokens in the source language of + shape `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + + Returns: + namedtuple: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + - **src_tokens** (Tensor): input token ids of shape + `(batch, src_len)` + """ + encoder_out = super().forward(src_tokens, src_lengths, **kwargs) + return EncoderOut( + encoder_out=encoder_out.encoder_out, # T x B x C + encoder_padding_mask=encoder_out.encoder_padding_mask, # B x T + encoder_embedding=encoder_out.encoder_embedding, # B x T x C + encoder_states=encoder_out.encoder_states, # List[T x B x C] + src_tokens=src_tokens, # B x T + src_lengths=None, + ) + + +class TransformerPointerGeneratorDecoder(TransformerDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. The pointer-generator variant mixes + the output probabilities with an attention distribution in the output layer. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False) + + # In the pointer-generator model these arguments define the decoder + # layer and the number of attention heads that will be averaged to + # create the alignment for pointing. + self.alignment_heads = args.alignment_heads + self.alignment_layer = args.alignment_layer + + input_embed_dim = embed_tokens.embedding_dim + + # Generation probabilities / interpolation coefficients are predicted + # from the current decoder input embedding and the decoder output, which + # is the size of output_embed_dim. + p_gen_input_size = input_embed_dim + self.output_embed_dim + self.project_p_gens = nn.Linear(p_gen_input_size, 1) + nn.init.zeros_(self.project_p_gens.bias) + + # The dictionary may include a separate entry for an OOV token in each + # input position, so that their identity can be restored from the + # original source text. + self.num_types = len(dictionary) + self.num_oov_types = args.source_position_markers + self.num_embeddings = self.num_types - self.num_oov_types + self.force_p_gen = args.force_generation + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = 0, + alignment_heads: Optional[int] = 1, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (EncoderOut, optional): output from the encoder, used + for encoder-side attention + incremental_state (dict, optional): dictionary used for storing + state during :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False) + alignment_layer (int, optional): 0-based index of the layer to be + used for pointing (default: 0) + alignment_heads (int, optional): number of attention heads to be + used for pointing (default: 1) + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + # The normal Transformer model doesn't pass the alignment_layer and + # alignment_heads parameters correctly. We use our local variables. + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=self.alignment_layer, + alignment_heads=self.alignment_heads, + ) + if not features_only: + # Embedding the tokens again for generation probability prediction, + # so that we don't have to reimplement the whole extract_features() + # method. + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + prev_output_embed = self.embed_tokens(prev_output_tokens) + prev_output_embed *= self.embed_scale + predictors = torch.cat((prev_output_embed, x), 2) + p_gens = self.project_p_gens(predictors) + p_gens = torch.sigmoid(p_gens) + x = self.output_layer(x, extra["attn"][0], encoder_out.src_tokens, p_gens) + return x, extra + + def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): + """ + Project features to the vocabulary size and mix with the attention + distributions. + """ + if self.force_p_gen is not None: + p_gens = self.force_p_gen + + # project back to size of vocabulary + logits = super().output_layer(features, **kwargs) + + batch_size = logits.shape[0] + output_length = logits.shape[1] + assert logits.shape[2] == self.num_embeddings + assert src_tokens.shape[0] == batch_size + src_length = src_tokens.shape[1] + + # The final output distribution will be a mixture of the normal output + # distribution (softmax of logits) and attention weights. + gen_dists = super().get_normalized_probs( + (logits, None), log_probs=False, sample=None + ) + gen_dists = torch.mul(gen_dists, p_gens) + padding_size = (batch_size, output_length, self.num_oov_types) + padding = gen_dists.new_zeros(padding_size) + gen_dists = torch.cat((gen_dists, padding), 2) + assert gen_dists.shape[2] == self.num_types + + # Scatter attention distributions to distributions over the extended + # vocabulary in a tensor of shape [batch_size, output_length, + # vocab_size]. Each attention weight will be written into a location + # that is for other dimensions the same as in the index tensor, but for + # the third dimension it's the value of the index tensor (the token ID). + attn = torch.mul(attn, 1 - p_gens) + index = src_tokens[:, None, :] + index = index.expand(batch_size, output_length, src_length) + attn_dists_size = (batch_size, output_length, self.num_types) + attn_dists = attn.new_zeros(attn_dists_size) + attn_dists.scatter_add_(2, index, attn) + + # Final distributions, [batch_size, output_length, num_types]. + return gen_dists + attn_dists + + def get_normalized_probs(self, net_output, log_probs, sample): + """ + Get normalized probabilities (or log probs) from a net's output. + Pointer-generator network output is already normalized. + """ + probs = net_output[0] + # Make sure the probabilities are greater than zero when returning log + # probabilities. + return probs.clamp(1e-10, 1.0).log() if log_probs else probs + + +class Embedding(nn.Embedding): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. This subclass differs from the standard PyTorch Embedding class by + allowing additional vocabulary entries that will be mapped to the unknown token + embedding. + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int): Pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + unk_idx (int): Maps all token indices that are greater than or equal to + num_embeddings to this index. + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + """ + __constants__ = ["unk_idx"] + + def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx): + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + self.unk_idx = unk_idx + nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(self.weight[padding_idx], 0) + + def forward(self, input): + input = torch.where( + input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input + ) + return super().forward(input) + + +@register_model_architecture( + "transformer_pointer_generator", "transformer_pointer_generator" +) +def transformer_pointer_generator(args): + args.alignment_heads = getattr(args, "alignment_heads", 1) + args.alignment_layer = getattr(args, "alignment_layer", -1) + base_architecture(args) + if args.alignment_layer < 0: + args.alignment_layer = args.decoder_layers + args.alignment_layer + + +@register_model_architecture( + "transformer_pointer_generator", "transformer_pointer_generator_iwslt_de_en" +) +def transformer_pointer_generator_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + transformer_pointer_generator(args) + + +@register_model_architecture( + "transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de" +) +def transformer_pointer_generator_wmt_en_de(args): + transformer_pointer_generator(args) + + +# Transformer pointer-generator with the base Transformer parameters as used in +# the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture( + "transformer_pointer_generator", + "transformer_pointer_generator_vaswani_wmt_en_de_big", +) +def transformer_pointer_generator_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + transformer_pointer_generator(args) + + +@register_model_architecture( + "transformer_pointer_generator", + "transformer_pointer_generator_vaswani_wmt_en_fr_big", +) +def transformer_pointer_generator_vaswani_wmt_en_fr_big(args): + args.dropout = getattr(args, "dropout", 0.1) + transformer_pointer_generator_vaswani_wmt_en_de_big(args) + + +@register_model_architecture( + "transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big" +) +def transformer_pointer_generator_wmt_en_de_big(args): + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + transformer_pointer_generator_vaswani_wmt_en_de_big(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture( + "transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big_t2t" +) +def transformer_pointer_generator_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + transformer_pointer_generator_vaswani_wmt_en_de_big(args) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 16887aacbd..e489fd304a 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -21,6 +21,7 @@ create_dummy_data, preprocess_lm_data, preprocess_translation_data, + preprocess_summarization_data, train_translation_model, generate_main, ) @@ -264,6 +265,23 @@ def test_transformer_cross_self_attention(self): ], run_validation=True) generate_main(data_dir, extra_flags=[]) + def test_transformer_pointer_generator(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_transformer_pointer_generator') as data_dir: + create_dummy_data(data_dir) + preprocess_summarization_data(data_dir) + train_translation_model(data_dir, 'transformer_pointer_generator', [ + '--user-dir', 'examples/pointer_generator/src', + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--alignment-layer', '-1', + '--alignment-heads', '1', + '--source-position-markers', '0', + ], run_validation=True) + generate_main(data_dir) + def test_lightconv(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_lightconv') as data_dir: diff --git a/tests/utils.py b/tests/utils.py index ef546fa58a..f265c13f85 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -196,6 +196,24 @@ def preprocess_translation_data(data_dir, extra_flags=None): preprocess.main(preprocess_args) +def preprocess_summarization_data(data_dir, extra_flags=None): + preprocess_parser = options.get_preprocessing_parser() + preprocess_args = preprocess_parser.parse_args( + [ + '--source-lang', 'in', + '--target-lang', 'out', + '--trainpref', os.path.join(data_dir, 'train'), + '--validpref', os.path.join(data_dir, 'valid'), + '--testpref', os.path.join(data_dir, 'test'), + '--thresholdtgt', '0', + '--thresholdsrc', '0', + '--joined-dictionary', + '--destdir', data_dir, + ] + (extra_flags or []), + ) + preprocess.main(preprocess_args) + + def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, lang_flags=None, extra_valid_flags=None): if lang_flags is None: From 8dde7de8a22d6e59c4101fe0de618f888c33ba81 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Fri, 25 Sep 2020 13:55:55 -0700 Subject: [PATCH 169/707] Fix dummy batch issues (#1293) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes issues when `sample.keys` changes across subsets. Without this diff, dummy batches are reused across epochs, which causes a hang if subsets have different keys. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1293 Reviewed By: myleott Differential Revision: D23924558 Pulled By: joshim5 fbshipit-source-id: d57cb1f545f649ec4dac62a43b088e637797c90f --- fairseq/tasks/fairseq_task.py | 4 ++++ fairseq/trainer.py | 12 ++++++++++++ fairseq_cli/train.py | 1 + 3 files changed, 17 insertions(+) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 8da07bf8bb..d27c38d305 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -426,6 +426,10 @@ def begin_epoch(self, epoch, model): """Hook function called before the start of each epoch.""" pass + def begin_valid_epoch(self, epoch, model): + """Hook function called before the start of each validation epoch.""" + pass + def aggregate_logging_outputs(self, logging_outputs, criterion): """[deprecated] Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 60fa161d02..583d37c259 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -409,12 +409,24 @@ def begin_epoch(self, epoch): # task specific setup per epoch self.task.begin_epoch(epoch, self.get_model()) + # reset dummy batch + self._dummy_batch = 'DUMMY' + if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous('begin_epoch') # wait for all workers xm.mark_step() + def begin_valid_epoch(self, epoch): + """Called at the beginning of each validation epoch.""" + + # task specific setup per validation epoch + self.task.begin_valid_epoch(epoch, self.get_model()) + + # reset dummy batch + self._dummy_batch = 'DUMMY' + @metrics.aggregate("train") def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index f7dd527166..46d98628aa 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -294,6 +294,7 @@ def validate(args, trainer, task, epoch_itr, subsets): # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) + trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) From 94a1b924f3adec25c8c508ac112410d02b400d1e Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Sat, 26 Sep 2020 15:50:04 -0700 Subject: [PATCH 170/707] hydra fairseq 6 - add_args from dataclass except migrated model Summary: hydra fairseq 6 - add_args from dataclass except migrated model Reviewed By: myleott Differential Revision: D23416669 fbshipit-source-id: 223f773384dab95d5a90095379c9b88e4a12754d --- fairseq/criterions/adaptive_loss.py | 11 ++-- fairseq/criterions/cross_entropy.py | 51 ++++++++++++------- fairseq/models/transformer_lm.py | 1 + fairseq/optim/adam.py | 19 +------ fairseq/optim/bmuf.py | 36 +------------ .../optim/lr_scheduler/cosine_lr_scheduler.py | 17 +------ .../inverse_square_root_schedule.py | 9 +--- fairseq/optim/nag.py | 9 +--- fairseq/tasks/language_modeling.py | 39 +++----------- 9 files changed, 57 insertions(+), 135 deletions(-) diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 4e6506337e..7bc41d6000 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -10,16 +10,14 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from omegaconf import II @dataclass class AdaptiveLossConfig(FairseqDataclass): sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II( - "params.distributed_training.ddp_backend" - ) + ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") @register_criterion("adaptive_loss") @@ -32,6 +30,11 @@ def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser. optionaly register config store""" + gen_parser_from_dataclass(parser, AdaptiveLossConfig()) + @classmethod def build_criterion(cls, args, task): if getattr(args, "ddp_backend", None) == "c10d": diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 7b690dcff2..08d64eced9 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -4,14 +4,13 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from dataclasses import dataclass -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from omegaconf import II @@ -21,13 +20,17 @@ class CrossEntropyCriterionConfig(FairseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") -@register_criterion('cross_entropy') +@register_criterion("cross_entropy") class CrossEntropyCriterion(FairseqCriterion): - def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser. optionaly register config store""" + gen_parser_from_dataclass(parser, CrossEntropyCriterionConfig()) + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -36,14 +39,16 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) + net_output = model(**sample["net_input"]) loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { - 'loss': loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } return loss, sample_size, logging_output @@ -55,23 +60,31 @@ def compute_loss(self, model, net_output, sample, reduce=True): lprobs, target, ignore_index=self.padding_idx, - reduction='sum' if reduce else 'none', + reduction="sum" if reduce else "none", ) return loss, loss @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) else: - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg)) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 88718dd5af..fd7adb3c15 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + from dataclasses import dataclass, field from typing import Optional diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 1afec99be6..b33a5d89e9 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist import torch.optim -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class from omegaconf import II @@ -67,22 +67,7 @@ def __init__(self, args, params): @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', - help='betas for Adam optimizer') - parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', - help='epsilon for Adam optimizer') - parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', - help='weight decay') - # Maintain backward compatibility with old checkpoints that have stored - # optimizer state as fairseq.optim.adam.Adam. - parser.add_argument( - "--use-old-adam", - action='store_true', - default=False, - help="Use fairseq.optim.adam.Adam", - ) - # fmt: on + gen_parser_from_dataclass(parser, FairseqAdamConfig()) @property def optimizer_config(self): diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index bcdeeee45b..5d98aa2f84 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer from omegaconf import II @@ -69,39 +69,7 @@ def __init__(self, args, optimizer): @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" - parser.add_argument( - "--block-lr", default=1, type=float, help="block learning rate for bmuf" - ) - parser.add_argument( - "--block-momentum", - default=0.875, - type=float, - help="block momentum for bmuf", - ) - parser.add_argument( - "--global-sync-iter", - default=50, - type=int, - help="Iteration for syncing global model", - ) - parser.add_argument( - "--warmup-iterations", - default=500, - type=int, - help="warmup iterations for model to broadcast", - ) - parser.add_argument( - "--use-nbm", - default=False, - action="store_true", - help="Specify whether you want to use classical BM / Nesterov BM", - ) - parser.add_argument( - "--average-sync", - default=False, - action="store_true", - help="Specify whether you want to average the local momentum after each sync", - ) + gen_parser_from_dataclass(parser, FairseqBMUFConfig()) @property def optimizer(self): diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index aeee95b84e..bd133ef091 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from omegaconf import II from . import FairseqLRScheduler, register_lr_scheduler @@ -108,20 +108,7 @@ def __init__(self, args, optimizer): @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', - help='initial learning rate during warmup phase; default is args.lr') - parser.add_argument('--max-lr', type=float, metavar='LR', - help='max learning rate, must be more than args.lr') - parser.add_argument('--t-mult', default=1, type=float, metavar='LR', - help='factor to grow the length of each period') - parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR', - help='initial number of updates per period') - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing') - # fmt: on + gen_parser_from_dataclass(parser, CosineConfig()) def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index 1f59d4c83e..388ac216bc 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from omegaconf import II from . import FairseqLRScheduler, register_lr_scheduler @@ -72,12 +72,7 @@ def __init__(self, args, optimizer): @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', - help='initial learning rate during warmup phase; default is args.lr') - # fmt: on + gen_parser_from_dataclass(parser, InverseSquareRootScheduleConfig()) def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 7806e9311e..1050071f51 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -7,7 +7,7 @@ from typing import List import torch -from fairseq.dataclass.utils import FairseqDataclass +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from omegaconf import II from torch.optim.optimizer import Optimizer, required @@ -31,12 +31,7 @@ def __init__(self, args, params): @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--momentum', default=0.99, type=float, metavar='M', - help='momentum factor') - parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', - help='weight decay') - # fmt: on + gen_parser_from_dataclass(parser, FairseqNAGConfig()) @property def optimizer_config(self): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 0c26866b27..190bc27cf2 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -27,7 +27,11 @@ ) from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass +from fairseq.dataclass.utils import ( + ChoiceEnum, + FairseqDataclass, + gen_parser_from_dataclass, +) from fairseq.tasks import FairseqTask, register_task from omegaconf import II @@ -125,37 +129,8 @@ class LanguageModelingTask(FairseqTask): @staticmethod def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument('data', help='path to data directory') - parser.add_argument('--sample-break-mode', default='none', - choices=['none', 'complete', 'complete_doc', 'eos'], - help='If omitted or "none", fills each sample with tokens-per-sample ' - 'tokens. If set to "complete", splits samples only at the end ' - 'of sentence, but may include multiple sentences per sample. ' - '"complete_doc" is similar but respects doc boundaries. ' - 'If set to "eos", includes only one sentence per sample.') - parser.add_argument('--tokens-per-sample', default=1024, type=int, - help='max number of tokens per sample for LM dataset') - parser.add_argument('--output-dictionary-size', default=-1, type=int, - help='limit the size of output dictionary') - parser.add_argument('--self-target', action='store_true', - help='include self target') - parser.add_argument('--future-target', action='store_true', - help='include future target') - parser.add_argument('--past-target', action='store_true', - help='include past target') - parser.add_argument('--add-bos-token', action='store_true', - help='prepend beginning of sentence token ()') - parser.add_argument('--max-target-positions', type=int, metavar='N', - help='max number of tokens in the target sequence') - parser.add_argument('--shorten-method', default='none', - choices=['none', 'truncate', 'random_crop'], - help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-list', default='', - help='comma-separated list of dataset splits to apply shortening to, ' - 'e.g., "train,valid" (default: all dataset splits)') - # fmt: on + """Add task-specific arguments to the parser. optionaly register config store""" + gen_parser_from_dataclass(parser, LanguageModelingConfig()) def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args) From a524832d1d6883de90bd0c6bc5fd039d6f87a000 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 28 Sep 2020 15:31:18 -0700 Subject: [PATCH 171/707] Publish Linformer to public fairseq Summary: Initial open source release for Linformer Reviewed By: madian9 Differential Revision: D22771263 fbshipit-source-id: bf08c64c5ecb899db9da00b79d09f6308347c915 --- README.md | 4 + examples/linformer/README.md | 22 + examples/linformer/src/__init__.py | 6 + examples/linformer/src/models/__init__.py | 0 .../linformer/src/models/linformer_roberta.py | 131 +++++ examples/linformer/src/modules/__init__.py | 0 .../src/modules/linformer_sentence_encoder.py | 167 +++++++ .../linformer_sentence_encoder_layer.py | 83 ++++ .../src/modules/multihead_linear_attention.py | 451 ++++++++++++++++++ tests/test_binaries.py | 62 ++- 10 files changed, 924 insertions(+), 2 deletions(-) create mode 100644 examples/linformer/README.md create mode 100644 examples/linformer/src/__init__.py create mode 100644 examples/linformer/src/models/__init__.py create mode 100644 examples/linformer/src/models/linformer_roberta.py create mode 100644 examples/linformer/src/modules/__init__.py create mode 100644 examples/linformer/src/modules/linformer_sentence_encoder.py create mode 100644 examples/linformer/src/modules/linformer_sentence_encoder_layer.py create mode 100644 examples/linformer/src/modules/multihead_linear_attention.py diff --git a/README.md b/README.md index 3743571b91..d6094880da 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,8 @@ We provide reference implementations of various sequence modeling papers: - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -52,6 +54,8 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +- September 2020: [Added Linformer code](examples/linformer/README.md) +- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) - August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) - August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) - July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) diff --git a/examples/linformer/README.md b/examples/linformer/README.md new file mode 100644 index 0000000000..e5c11e052d --- /dev/null +++ b/examples/linformer/README.md @@ -0,0 +1,22 @@ +# Linformer: Self-Attention with Linear Complexity (Wang et al., 2020) + +This example contains code to train Linformer models as described in our paper +[Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768). + +## Training a new Linformer RoBERTa model + +You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md), +but replace the architecture with `--arch linformer_roberta_base` in your training command. + +## Citation + +If you use our work, please cite: + +```bibtex +@article{wang2020linformer, + title={Linformer: Self-Attention with Linear Complexity}, + author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao}, + journal={arXiv preprint arXiv:2006.04768}, + year={2020} +} +``` diff --git a/examples/linformer/src/__init__.py b/examples/linformer/src/__init__.py new file mode 100644 index 0000000000..1c52f135ea --- /dev/null +++ b/examples/linformer/src/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .models import linformer_roberta # noqa diff --git a/examples/linformer/src/models/__init__.py b/examples/linformer/src/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/linformer/src/models/linformer_roberta.py b/examples/linformer/src/models/linformer_roberta.py new file mode 100644 index 0000000000..722f5a4b9e --- /dev/null +++ b/examples/linformer/src/models/linformer_roberta.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Linformer: Self-Attention with Linear Complexity +""" + +import logging + +from fairseq.models import ( + register_model, + register_model_architecture, +) +from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder + +from fairseq.models.roberta import ( + RobertaModel, + RobertaEncoder, +) + + +logger = logging.getLogger(__name__) + + +@register_model('linformer_roberta') +class LinformerModel(RobertaModel): + + @staticmethod + def add_args(parser): + RobertaModel.add_args(parser) + + # add args for Linformer + parser.add_argument('--compressed', type=int, + help='compressed ratio of sequence length') + parser.add_argument('--shared-kv-compressed', type=int, + help='share compressed matrix between k and v, in each layer') + parser.add_argument('--shared-layer-kv-compressed', type=int, + help='share compressed matrix between k and v and across all layers') + parser.add_argument('--freeze-compress', type=int, + help='freeze the parameters in compressed layer') + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + if not hasattr(args, 'max_positions'): + args.max_positions = args.tokens_per_sample + + encoder = LinformerEncoder(args, task.source_dictionary) + return cls(args, encoder) + + +class LinformerEncoder(RobertaEncoder): + """Linformer encoder.""" + + def __init__(self, args, dictionary): + super().__init__(args, dictionary) + + self.sentence_encoder = LinformerSentenceEncoder( + padding_idx=dictionary.pad(), + vocab_size=len(dictionary), + num_encoder_layers=args.encoder_layers, + embedding_dim=args.encoder_embed_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + layerdrop=args.encoder_layerdrop, + max_seq_len=args.max_positions, + num_segments=0, + encoder_normalize_before=True, + apply_bert_init=True, + activation_fn=args.activation_fn, + q_noise=args.quant_noise_pq, + qn_block_size=args.quant_noise_pq_block_size, + compressed=args.compressed, + shared_kv_compressed=args.shared_kv_compressed, + shared_layer_kv_compressed=args.shared_layer_kv_compressed, + freeze_compress=args.freeze_compress, + ) + + +@register_model_architecture('linformer_roberta', 'linformer_roberta') +def base_architecture(args): + args.encoder_layers = getattr(args, 'encoder_layers', 12) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) + + args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + + args.dropout = getattr(args, 'dropout', 0.1) + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.activation_dropout = getattr(args, 'activation_dropout', 0.0) + args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) + args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) + args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) + args.compressed = getattr(args, 'compressed', 4) + args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) + args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) + args.freeze_compress = getattr(args, 'freeze_compress', 0) + + +@register_model_architecture('linformer_roberta', 'linformer_roberta_base') +def linformer_roberta_base_architecture(args): + base_architecture(args) + + +@register_model_architecture('linformer_roberta', 'linformer_roberta_large') +def linformer_roberta_large_architecture(args): + args.encoder_layers = getattr(args, 'encoder_layers', 24) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + + args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + + args.dropout = getattr(args, 'dropout', 0.1) + args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.activation_dropout = getattr(args, 'activation_dropout', 0.0) + args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) + args.compressed = getattr(args, 'compressed', 4) + args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) + args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) diff --git a/examples/linformer/src/modules/__init__.py b/examples/linformer/src/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/linformer/src/modules/linformer_sentence_encoder.py b/examples/linformer/src/modules/linformer_sentence_encoder.py new file mode 100644 index 0000000000..e3d170023d --- /dev/null +++ b/examples/linformer/src/modules/linformer_sentence_encoder.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch.nn as nn + +from fairseq.modules import TransformerSentenceEncoder +from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer + + +class LinformerSentenceEncoder(TransformerSentenceEncoder): + """ + Implementation for a Bi-directional Linformer based Sentence Encoder used + in BERT/XLM style pre-trained models. + + This first computes the token embedding using the token embedding matrix, + position embeddings (if specified) and segment embeddings + (if specified). After applying the specified number of + LinformerEncoderLayers, it outputs all the internal states of the + encoder as well as the final representation associated with the first + token (usually CLS token). + + Input: + - tokens: B x T matrix representing sentences + - segment_labels: B x T matrix representing segment label for tokens + + Output: + - a tuple of the following: + - a list of internal model states used to compute the + predictions where each tensor has shape T x B x C + - sentence representation associated with first input token + in format B x C. + """ + + def __init__( + self, + padding_idx: int, + vocab_size: int, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + layerdrop: float = 0.0, + max_seq_len: int = 256, + num_segments: int = 2, + use_position_embeddings: bool = True, + offset_positions_by_padding: bool = True, + encoder_normalize_before: bool = False, + apply_bert_init: bool = False, + activation_fn: str = "relu", + learned_pos_embedding: bool = True, + embed_scale: float = None, + freeze_embeddings: bool = False, + n_trans_layers_to_freeze: int = 0, + export: bool = False, + traceable: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + compressed: int = 4, + shared_kv_compressed: int = 0, + shared_layer_kv_compressed: int = 0, + freeze_compress: int = 0, + ) -> None: + + # Initialize linformer parameters + self.compressed = compressed + self.shared_kv_compressed = shared_kv_compressed + self.shared_layer_kv_compressed = shared_layer_kv_compressed + self.compress_layer = None + self.freeze_compress = freeze_compress + + super().__init__( + padding_idx=padding_idx, + vocab_size=vocab_size, + num_encoder_layers=num_encoder_layers, + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + layerdrop=layerdrop, + max_seq_len=max_seq_len, + num_segments=num_segments, + use_position_embeddings=use_position_embeddings, + offset_positions_by_padding=offset_positions_by_padding, + encoder_normalize_before=encoder_normalize_before, + apply_bert_init=apply_bert_init, + activation_fn=activation_fn, + learned_pos_embedding=learned_pos_embedding, + embed_scale=embed_scale, + freeze_embeddings=freeze_embeddings, + n_trans_layers_to_freeze=n_trans_layers_to_freeze, + export=export, + traceable=traceable, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + q_noise, + qn_block_size, + ): + if self.shared_layer_kv_compressed == 1: + compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed) + # intialize parameters for compressed layer + nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) + if self.freeze_compress == 1: + compress_layer.weight.requires_grad = False + self.compress_layer = compress_layer + + return LinformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + compressed=self.compressed, + max_seq_len=self.max_seq_len, + shared_kv_compressed=self.shared_kv_compressed, + shared_compress_layer=( + None if self.shared_layer_kv_compressed == 0 + else self.compress_layer + ), + freeze_compress=self.freeze_compress, + ) + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + + # update key name for shared layer in new version of code + for k in state_dict.keys(): + if k.startswith(prefix + "compress_layer"): + if self.shared_layer_kv_compressed: + for layer_idx in range(len(self.layers)): + new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( + layer_idx, k[len(prefix + 'compress_layer.'):], + ) + items_to_add[new_k] = state_dict[k] + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py new file mode 100644 index 0000000000..e0a6047ce8 --- /dev/null +++ b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +from fairseq.modules import TransformerSentenceEncoderLayer +from .multihead_linear_attention import MultiheadLinearAttention + + +class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): + """ + Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = 'relu', + export: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + init_fn: Callable = None, + compressed: int = 1, + max_seq_len: int = 256, + shared_kv_compressed: int = 0, + shared_compress_layer: any = None, + freeze_compress: int = 0, + ) -> None: + + # Initialize linformer parameters + self.compressed = compressed + self.max_seq_len = max_seq_len + self.shared_kv_compressed = shared_kv_compressed + self.freeze_compress = freeze_compress + + def init_fn(): + # This needs to be set after nn.Module.__init__ is called + self.shared_compress_layer = shared_compress_layer + + super().__init__( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + init_fn=init_fn, + ) + + def build_self_attention( + self, + embed_dim, + num_attention_heads, + dropout, + self_attention, + q_noise, + qn_block_size, + ): + return MultiheadLinearAttention( + embed_dim, + num_attention_heads, + dropout=dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + compressed=self.compressed, + max_seq_len=self.max_seq_len, + shared_kv_compressed=self.shared_kv_compressed, + shared_compress_layer=self.shared_compress_layer, + freeze_compress=self.freeze_compress, + ) diff --git a/examples/linformer/src/modules/multihead_linear_attention.py b/examples/linformer/src/modules/multihead_linear_attention.py new file mode 100644 index 0000000000..472cd4e3ea --- /dev/null +++ b/examples/linformer/src/modules/multihead_linear_attention.py @@ -0,0 +1,451 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from torch import Tensor, nn +from torch.nn import Parameter +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.quant_noise import quant_noise + + +@with_incremental_state +class MultiheadLinearAttention(nn.Module): + """Multi-headed linformer attention. + + Projects the key and values down to the compressed dimension, before computing self-attention. + + See "Linformer: Self-Attention with Linear Complexity" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + compressed=1, + max_seq_len=256, + shared_kv_compressed=0, + shared_compress_layer=None, + freeze_compress=0, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + # used for compress sequence to subsequence + if shared_compress_layer is None: + self.compress_seq_len = max_seq_len // compressed + self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) + if shared_kv_compressed == 0: + self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) + self.layerwise_sharing = False + else: + self.compress_k = shared_compress_layer + if shared_kv_compressed == 0: + self.compress_v = shared_compress_layer + self.layerwise_sharing = True + self.shared_kv_compressed = shared_kv_compressed + + self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + if freeze_compress == 1: + self.compress_k.weight.requires_grad = False + if shared_kv_compressed == 0: + self.compress_v.weight.requires_grad = False + + self.onnx_trace = False + self.tpu = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def prepare_for_tpu_(self, **kwargs): + self.tpu = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + if not self.layerwise_sharing: # otherwise, we already initialize the parameters + nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2)) + if self.shared_kv_compressed == 0: + nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + if not self.layerwise_sharing: # otherwise, we already initialize the parameters + nn.init.xavier_uniform_(self.compress_k.weight) + if self.shared_kv_compressed == 0: + nn.init.xavier_uniform_(self.compress_v.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + + k_input = query.permute(1, 2, 0).contiguous() # B * C * T + k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + k = self.k_proj(k_input) + + v_input = query.permute(1, 2, 0).contiguous() # B * C * T + if self.shared_kv_compressed == 0: + v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + if self.shared_kv_compressed == 1: # use shared kv compressed linear layer + v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + v = self.v_proj(v_input) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadLinearAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.size(1) + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout( + attn_weights, + p=self.dropout, + training=self.training, + ) + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + elif key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim:2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/tests/test_binaries.py b/tests/test_binaries.py index e489fd304a..6e2b5004bf 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -280,7 +280,10 @@ def test_transformer_pointer_generator(self): '--alignment-heads', '1', '--source-position-markers', '0', ], run_validation=True) - generate_main(data_dir) + generate_main( + data_dir, + extra_flags=['--user-dir', 'examples/pointer_generator/src'], + ) def test_lightconv(self): with contextlib.redirect_stdout(StringIO()): @@ -589,7 +592,7 @@ def test_roberta_masked_lm(self): with tempfile.TemporaryDirectory("test_roberta_mlm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) - train_masked_lm(data_dir, "roberta_base") + train_masked_lm(data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"]) def test_roberta_sentence_prediction(self): num_classes = 3 @@ -616,6 +619,60 @@ def test_roberta_regression_multiple(self): preprocess_lm_data(os.path.join(data_dir, 'input0')) train_roberta_head(data_dir, "roberta_base", num_classes=num_classes, extra_flags=['--regression-target']) + def test_linformer_roberta_masked_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_mlm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_masked_lm( + data_dir, + "linformer_roberta_base", + extra_flags=[ + "--user-dir", "examples/linformer/src", + "--encoder-layers", "2", + ], + ) + + def test_linformer_roberta_sentence_prediction(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_head") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes) + preprocess_lm_data(os.path.join(data_dir, 'input0')) + preprocess_lm_data(os.path.join(data_dir, 'label')) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=["--user-dir", "examples/linformer/src"], + ) + + def test_linformer_roberta_regression_single(self): + num_classes = 1 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_regression_single") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) + preprocess_lm_data(os.path.join(data_dir, 'input0')) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"], + ) + + def test_linformer_roberta_regression_multiple(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_linformer_roberta_regression_multiple") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) + preprocess_lm_data(os.path.join(data_dir, 'input0')) + train_roberta_head( + data_dir, + "linformer_roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"], + ) + def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_mlm") as data_dir: @@ -836,6 +893,7 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): '--task', 'sentence_prediction', data_dir, '--arch', arch, + '--encoder-layers', '2', '--num-classes', str(num_classes), '--optimizer', 'adam', '--lr', '0.0001', From caea771afafbcb3471f9007ca1cd46a4d3d8c869 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 29 Sep 2020 07:26:18 -0700 Subject: [PATCH 172/707] Fix tests (#2670) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2670 Reviewed By: ngoyal2707 Differential Revision: D23982491 Pulled By: myleott fbshipit-source-id: 629b791d6c05dd67b63dcc2da0313c6799f777f8 --- fairseq/utils.py | 13 ++++++++++--- tests/test_binaries.py | 26 ++++++++++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/fairseq/utils.py b/fairseq/utils.py index 888e4d95e4..8f6bda393f 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -403,9 +403,16 @@ def import_user_module(args): module_path = fairseq_rel_path module_parent, module_name = os.path.split(module_path) - if module_name not in sys.modules: - sys.path.insert(0, module_parent) - importlib.import_module(module_name) + if module_name in sys.modules: + module_bak = sys.modules[module_name] + del sys.modules[module_name] + else: + module_bak = None + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + sys.modules['fairseq_user_dir'] = sys.modules[module_name] + if module_bak is not None and module_name != 'fairseq_user_dir': + sys.modules[module_name] = module_bak def softmax(x, dim: int, onnx_trace: bool = False): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 6e2b5004bf..24d1b6cc03 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -270,16 +270,22 @@ def test_transformer_pointer_generator(self): with tempfile.TemporaryDirectory('test_transformer_pointer_generator') as data_dir: create_dummy_data(data_dir) preprocess_summarization_data(data_dir) - train_translation_model(data_dir, 'transformer_pointer_generator', [ - '--user-dir', 'examples/pointer_generator/src', - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--alignment-layer', '-1', - '--alignment-heads', '1', - '--source-position-markers', '0', - ], run_validation=True) + train_translation_model( + data_dir, + 'transformer_pointer_generator', + extra_flags=[ + '--user-dir', 'examples/pointer_generator/src', + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--alignment-layer', '-1', + '--alignment-heads', '1', + '--source-position-markers', '0', + ], + run_validation=True, + extra_valid_flags=['--user-dir', 'examples/pointer_generator/src'], + ) generate_main( data_dir, extra_flags=['--user-dir', 'examples/pointer_generator/src'], From df1b3c6a8bb391e21c2c930b9050dec8523a37cc Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Tue, 29 Sep 2020 09:07:22 -0700 Subject: [PATCH 173/707] Allow splitting of embeddings for more even distribution of optimizer states across DDP nodes when using DeepSpeed Zero (#1295) Summary: * Embedding parameters take up a lot of optimizer state memory when the vocab size is huge (e.g. >100K) and the embedding dimension is huge (e.g. 4-6K) * This causes the DDP worker carrying the optimizer states of the embedding parameter to be more overloaded than other DDP workers. * To avoid this, this PR adds the ability to divide embedding parameters into chunks (maintaining functional parity with having a single embedding parameter) so that different DDP workers can be assigned different embedding parameter chunks for OSS and one DDP worker doesn't have to hold optimizer states for a huge embedding parameter. Testing details in this PR: https://github.com/fairinternal/fairseq-py/pull/1226 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1295 Reviewed By: myleott Differential Revision: D23980434 Pulled By: shruti-bh fbshipit-source-id: a8731b4016aff3f944b0706327694457348d0979 --- .../pipeline_parallel_transformer/layers.py | 33 ++++++++++-------- .../pipeline_parallel_transformer/model.py | 34 ++++++++++++++----- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py index 70551ca900..e11f491486 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -186,6 +186,7 @@ class TransformerDecoderOutputLayer(nn.Module): def __init__(self, args, embed_tokens, dictionary): super().__init__() self.share_input_output_embed = args.share_decoder_input_output_embed + self.embed_tokens = embed_tokens self.output_embed_dim = args.decoder_output_dim embed_dim = args.decoder_embed_dim @@ -203,20 +204,9 @@ def __init__(self, args, embed_tokens, dictionary): factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) - elif self.share_input_output_embed: - self.output_projection = nn.Linear( - embed_tokens.weight.shape[1], - embed_tokens.weight.shape[0], - bias=False, - ) - self.output_projection.weight = embed_tokens.weight elif not self.share_input_output_embed: - self.output_projection = nn.Linear( - self.output_embed_dim, len(dictionary), bias=False - ) - nn.init.normal_( - self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 - ) + self.embed_tokens = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim)) + nn.init.normal_(self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5) if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False): self.layer_norm = LayerNorm(embed_dim) @@ -245,7 +235,22 @@ def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary - return self.output_projection(features) + if self.share_input_output_embed: + if isinstance(self.embed_tokens, nn.ModuleList): + output = None + for i, emb in enumerate(self.embed_tokens): + sidx = i * emb.embedding_dim + eidx = (i + 1) * emb.embedding_dim + if output is None: + output = F.linear(features[:, :, sidx:eidx], emb.weight) + else: + output += F.linear(features[:, :, sidx:eidx], emb.weight) + + return output + else: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_tokens) else: return features diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index ca1c2698fb..37fa877eaf 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -129,6 +129,11 @@ def add_args(parser): 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, + help='Number of embedding layer chunks (enables more even distribution' + 'of optimizer states across data parallel nodes' + 'when using optimizer state sharding and' + 'a big embedding vocabulary)') # fmt: on @classmethod @@ -145,17 +150,27 @@ def build_model_base(cls, args, task): src_dict, tgt_dict = task.source_dictionary, task.target_dictionary - def build_embedding(dictionary, embed_dim, path=None): + def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): + assert embed_dim % num_embed_chunks == 0, \ + f"Number of embedding chunks = {num_embed_chunks} should be " + \ + f"divisible by the embedding dimension = {embed_dim}" + assert path is None or num_embed_chunks == 1, \ + "Loading embedding from a path with number of embedding chunks > 1" + \ + " is not yet supported" num_embeddings = len(dictionary) padding_idx = dictionary.pad() - - emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: + emb = Embedding(num_embeddings, embed_dim, padding_idx) embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) + else: + embed_chunk_dim = embed_dim // num_embed_chunks + emb = nn.ModuleList() + for i in range(num_embed_chunks): + emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) return emb - + num_embed_chunks = args.num_embedding_chunks if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError('--share-all-embeddings requires a joined dictionary') @@ -166,16 +181,19 @@ def build_embedding(dictionary, embed_dim, path=None): args.decoder_embed_path != args.encoder_embed_path): raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path, + src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: + assert args.share_decoder_input_output_embed or num_embed_chunks == 1, \ + "Not sharing decoder I/O embeddings is not yet supported with number of " + \ + "embedding chunks > 1" encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path + src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, ) decoder_embed_tokens = build_embedding( - tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + tgt_dict, args.decoder_embed_dim, args.decoder_embed_path, num_embed_chunks, ) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) @@ -238,7 +256,7 @@ def get_normalized_probs(self, net_output, log_probs, sample=None): def max_decoder_positions(self): """Maximum length supported by the decoder.""" - return self.decoder.max_positions() + return self.decoder_max_positions def load_state_dict(self, state_dict, strict=True, args=None): """Copies parameters and buffers from *state_dict* into this module and From 73ad5d4abe9442454ace141e1743df610a9aecae Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Tue, 29 Sep 2020 14:28:30 -0700 Subject: [PATCH 174/707] Pass long string arguments as files (#1296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1296 Reviewed By: huihuifan Differential Revision: D23991462 Pulled By: shruti-bh fbshipit-source-id: 00dd0de22414b20c587e45b9e9108e9946808c77 --- .../multilingual/multilingual_data_manager.py | 4 +++- fairseq/options.py | 17 +++++++++++++++++ fairseq/tasks/translation_multi_simple_epoch.py | 5 +++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 6240cf76d5..42cdad6041 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -33,7 +33,7 @@ ) from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat from fairseq.file_io import PathManager -from fairseq.options import csv_str_list, eval_str_dict +from fairseq.options import csv_str_list, eval_str_dict, FileContentsAction logger = logging.getLogger(__name__) @@ -79,6 +79,7 @@ def add_args(parser): "data", help="colon separated path to data directories list, \ will be iterated upon during epochs in round-robin manner", + action=FileContentsAction, ) parser.add_argument( "--langs", @@ -95,6 +96,7 @@ def add_args(parser): "languages which can appear in lang-pairs; " "note that the ordering determines language token IDs; " "--langs and --lang-dict are two exclusive options", + action=FileContentsAction, ) parser.add_argument( "--lang-tok-style", diff --git a/fairseq/options.py b/fairseq/options.py index 4add9dd5fa..5945b054ca 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -4,16 +4,33 @@ # LICENSE file in the root directory of this source tree. import argparse +import os import sys from typing import Callable, List, Optional import torch # this import is for backward compatibility +from fairseq.file_io import PathManager from fairseq.utils import csv_str_list, eval_str_list, eval_str_dict, eval_bool # noqa from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl +class FileContentsAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + if PathManager.isfile(values): + with open(values) as f: + argument = f.read().strip() + else: + argument = values + setattr(namespace, self.dest, argument) + + def get_preprocessing_parser(default_task="translation"): parser = get_parser("Preprocessing", default_task) add_preprocess_args(parser) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 2858d0ad7b..10aaeaa12c 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -19,7 +19,7 @@ from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.multilingual.sampling_method import SamplingMethod from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager - +from fairseq.options import FileContentsAction ### def get_time_gap(s, e): @@ -62,7 +62,8 @@ def add_args(parser): parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='inference target language') parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', - help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', + action=FileContentsAction) parser.add_argument('--keep-inference-langtok', action='store_true', help='keep language tokens in inference output (e.g. for analysis or debugging)') From 82c116d20073c9c54ace4555fecfbcd782fb46a0 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Tue, 29 Sep 2020 15:48:46 -0700 Subject: [PATCH 175/707] split parallel transformer lm base arch Summary: move base_arch to parallel transformer lm so it doesnt not depend on regular transformer lm Reviewed By: myleott Differential Revision: D23417282 fbshipit-source-id: 32e2d7294b4ec0d52598d3da829e80922cb4b576 --- .../model_parallel/models/transformer_lm.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 81bc93bc0a..598807147e 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -7,7 +7,6 @@ from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer_lm import ( - base_lm_architecture, TransformerLanguageModel, ) from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder @@ -64,6 +63,55 @@ def _vocab_init(tensor, **kwargs): return embed_tokens +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, 'no_tie_adaptive_proj'): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if hasattr(args, 'decoder_final_norm'): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.activation_fn = getattr(args, 'activation_fn', 'relu') + args.dropout = getattr(args, 'dropout', 0.1) + args.attention_dropout = getattr(args, 'attention_dropout', 0.0) + args.activation_dropout = getattr(args, 'activation_dropout', 0.0) + args.relu_dropout = getattr(args, 'relu_dropout', 0.0) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) + args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) + args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) + args.character_embeddings = getattr(args, 'character_embeddings', False) + args.character_filters = getattr(args, 'character_filters', '[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]') + args.character_embedding_dim = getattr(args, 'character_embedding_dim', 4) + args.char_embedder_highway_layers = getattr(args, 'char_embedder_highway_layers', 2) + args.adaptive_input = getattr(args, 'adaptive_input', False) + args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) + args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None) + args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) + args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) + args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0.0) + args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None) + args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) + args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0.0) + args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size', 8) + args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0.0) + args.add_bos_token = getattr(args, 'add_bos_token', False) + @register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron') def transformer_lm_megatron(args): args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 3072) From e4a5427ef4ffad63fa265acbade098bb963a814b Mon Sep 17 00:00:00 2001 From: Monideep De Date: Tue, 29 Sep 2020 17:45:05 -0700 Subject: [PATCH 176/707] Updated link to wav2letter repository (#2663) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # 2662. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2663 Reviewed By: alexeib Differential Revision: D24001964 Pulled By: myleott fbshipit-source-id: 6a264c9cf53d77bb0062c41ec5ff03c5552f3e55 --- examples/wav2vec/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index d849dde85d..e6db6c6796 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -125,7 +125,7 @@ For example, for 10 hours, we see in the paper that timestep mask prob should be Evaluating a CTC model with a language model requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings) to be installed. -Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/models/sota/2019). +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). Be sure to upper-case the language model vocab after downloading it. Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). From 356973065100c43798a416ce0235a4582e5cb48d Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Wed, 30 Sep 2020 22:31:01 -0700 Subject: [PATCH 177/707] Fix bug in subsample dataset (shuffle was never assigned) (#1298) Summary: Fix bug in subsample dataset (shuffle was never assigned) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1298 Reviewed By: myleott Differential Revision: D23940719 Pulled By: joshim5 fbshipit-source-id: 8127d42ecf9e359312df104f8ee90bda25589023 --- fairseq/data/subsample_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py index e395674a55..7eca9d4cb3 100644 --- a/fairseq/data/subsample_dataset.py +++ b/fairseq/data/subsample_dataset.py @@ -21,13 +21,14 @@ class SubsampleDataset(BaseWrapperDataset): size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) """ - def __init__(self, dataset, size_ratio): + def __init__(self, dataset, size_ratio, shuffle=False): super().__init__(dataset) assert size_ratio < 1 self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) self.indices = np.random.choice( list(range(len(self.dataset))), self.actual_size, replace=False ) + self.shuffle = shuffle logger.info( "subsampled dataset from {} to {} (ratio={})".format( len(self.dataset), self.actual_size, size_ratio From 805d00aa7c50c6fc17bbcc30b10255777bf03607 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Thu, 1 Oct 2020 06:26:36 -0700 Subject: [PATCH 178/707] support multiple data parallel groups for ZeRO optimizer state sharding (#1326) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/fairinternal/fairseq-py/issues/1299. Previously, optimizer state sharding was causing incorrect results when used with model parallel. This resolves the issue. In order to use this with Fairseq, you'll also need to use the version of [Fairscale from the PR I opened there](https://github.com/facebookresearch/fairscale/pull/121). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1326 Reviewed By: myleott Differential Revision: D24042124 Pulled By: joshim5 fbshipit-source-id: cdc8be2a00096cf7ae4e4918915431b5d493aad3 --- fairseq/optim/shard.py | 5 +++-- fairseq/trainer.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 4f35dbda47..8c508f41f2 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -11,7 +11,7 @@ _has_fairscale = False -def shard_(args, optimizer): +def shard_(args, optimizer, group): if not _has_fairscale: raise ImportError( '\n\nPlease install the fairscale package:' @@ -30,4 +30,5 @@ def __getattr__(self, name): torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) - optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, **optimizer.optimizer_config) + + optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, group=group, **optimizer.optimizer_config) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 583d37c259..66f1bcbcb9 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -229,7 +229,7 @@ def _build_optimizer(self): "Please use --fp16-no-flatten-grads" ) else: - optim.shard_(self.args, self._optimizer) + optim.shard_(self.args, self._optimizer, self.data_parallel_process_group) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. @@ -567,6 +567,7 @@ def maybe_no_sync(): with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() + except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails @@ -641,7 +642,6 @@ def maybe_no_sync(): ) metrics.log_stop_time("train_wall") - return logging_output @metrics.aggregate("valid") From 7d2a3e10a9436c2e3a006a94cc3229e2920ff71c Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Thu, 1 Oct 2020 06:28:48 -0700 Subject: [PATCH 179/707] Prevent failing when there is no valid set (#1328) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? If the valid set is small or non-existent, Fairseq will fail because `valid_losses` will not have been defined anywhere. This fixes the issue. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1328 Reviewed By: myleott Differential Revision: D24042313 Pulled By: joshim5 fbshipit-source-id: ef45ade8e2d5f853364a05e03f566968e9aaa089 --- fairseq_cli/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 46d98628aa..a2a7763488 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -197,6 +197,7 @@ def train(args, trainer, task, epoch_itr): trainer.begin_epoch(epoch_itr.epoch) + valid_losses = [None] valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() From c049749c7a7c08cca9e4663c85bd3961f4b260f8 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 1 Oct 2020 12:35:54 -0700 Subject: [PATCH 180/707] Fix full-context alignment with transformer_align model (#2675) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Fixes https://github.com/pytorch/fairseq/issues/2673. # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2673 (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2675 Reviewed By: ngoyal2707 Differential Revision: D24001793 Pulled By: myleott fbshipit-source-id: 6b4e9270e5f5a31ba1b65ae2ae717019108af913 --- fairseq/models/transformer.py | 4 ++++ fairseq/models/transformer_align.py | 2 +- fairseq/sequence_generator.py | 2 +- tests/test_binaries.py | 24 +++++++++++++++++++++++- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index a55c47c155..ae0ba5aad0 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -641,6 +641,7 @@ def forward( encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, + full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, @@ -656,6 +657,8 @@ def forward( :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). Returns: tuple: @@ -666,6 +669,7 @@ def forward( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, + full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) diff --git a/fairseq/models/transformer_align.py b/fairseq/models/transformer_align.py index 4195ff3982..c80cc4341c 100644 --- a/fairseq/models/transformer_align.py +++ b/fairseq/models/transformer_align.py @@ -32,7 +32,7 @@ def add_args(parser): help='Number of cross attention heads per layer to supervised with alignments') parser.add_argument('--alignment-layer', type=int, metavar='D', help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.') - parser.add_argument('--full-context-alignment', type=bool, metavar='D', + parser.add_argument('--full-context-alignment', action='store_true', help='Whether or not alignment is supervised conditioned on the full target context.') # fmt: on diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 6cfcc90baf..965594cd6e 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -952,7 +952,7 @@ def forward_align(self, src_tokens, src_lengths, prev_output_tokens): avg_attn = None for model in self.models: decoder_out = model(src_tokens, src_lengths, prev_output_tokens) - attn = decoder_out[1]["attn"] + attn = decoder_out[1]["attn"][0] if avg_attn is None: avg_attn = attn else: diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 24d1b6cc03..e6259092eb 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -446,7 +446,29 @@ def test_alignment(self): '--decoder-embed-dim', '8', '--load-alignments', '--alignment-layer', '1', - '--criterion', 'label_smoothed_cross_entropy_with_alignment' + '--criterion', 'label_smoothed_cross_entropy_with_alignment', + ], + run_validation=True, + ) + generate_main(data_dir) + + def test_alignment_full_context(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_alignment') as data_dir: + create_dummy_data(data_dir, alignment=True) + preprocess_translation_data(data_dir, ['--align-suffix', 'align']) + train_translation_model( + data_dir, + 'transformer_align', + [ + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--load-alignments', + '--alignment-layer', '1', + '--criterion', 'label_smoothed_cross_entropy_with_alignment', + '--full-context-alignment', ], run_validation=True, ) From 65d87b0605b2a930397a99ea11083d7b55f03277 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 1 Oct 2020 15:54:15 -0700 Subject: [PATCH 181/707] Fixes for TPUs (#1324) Summary: - support training on a single TPU core - fix clip grad norm logic - log memory usage - fix --memory-efficient-bf16 - print XLA compilation warnings on every device Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1324 Reviewed By: ngoyal2707 Differential Revision: D24024226 Pulled By: myleott fbshipit-source-id: 5ae178e663e69923776196da3c0e3217efdecb61 --- fairseq/distributed_utils.py | 2 +- .../modules/multihead_attention.py | 16 ++++++++++--- fairseq/optim/fp16_optimizer.py | 4 +++- fairseq/trainer.py | 24 +++++++++++++------ fairseq/utils.py | 5 +++- 5 files changed, 38 insertions(+), 13 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index b25232f386..3611a3d6ce 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -241,7 +241,7 @@ def call_main(args, main, **kwargs): ) else: distributed_main(args.device_id, main, args, kwargs) - elif getattr(args, "tpu", False): + elif getattr(args, "tpu", False) and args.distributed_world_size > 1: import torch_xla.distributed.xla_multiprocessing as xmp torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index e92a3f6a71..97da7db35b 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -83,6 +83,11 @@ def __init__( self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False) self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True) + self.tpu = False + + def prepare_for_tpu_(self, **kwargs): + self.tpu = True + def forward( self, query, @@ -220,9 +225,14 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads_partition, tgt_len, src_len) - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") - ) + if not self.tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf')) + attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads_partition, tgt_len, src_len) attn_weights_float = utils.softmax( diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 593519eb7f..e621a6f114 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -363,7 +363,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): # detect overflow and adjust loss scale self.scaler.check_overflow(grad_norm_cpu) - else: + elif max_norm > 0.0: clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) self._multiply_factor *= clip_coef @@ -386,6 +386,8 @@ def zero_grad(self): self.wrapped_optimizer.zero_grad() if self.scaler is not None: self._multiply_factor = 1. / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1. class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 66f1bcbcb9..2dc44367ab 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -607,6 +607,17 @@ def maybe_no_sync(): # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: + # log memory usage + mem_info = xm.get_memory_info(self.device) + gb_free = mem_info['kb_free'] / 1024 / 1024 + gb_total = mem_info['kb_total'] / 1024 / 1024 + metrics.log_scalar( + 'gb_free', gb_free, priority=1500, round=1, weight=0, + ) + metrics.log_scalar( + 'gb_total', gb_total, priority=1600, round=1, weight=0, + ) + logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) @@ -1017,19 +1028,18 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): del logging_output[key_to_delete] return logging_output - def _check_xla_compilation(self, message=None): + def _check_xla_compilation(self): import torch_xla.debug.metrics as met compile_stats = met.metric_data("CompileTime") if compile_stats is None: return num_xla_compiles = compile_stats[0] if num_xla_compiles > self._num_xla_compiles: - if message is None: - message = ( - "too many of these can lead to slow training, " - "but we expect a few in the beginning" - ) - logging.info("NOTE: XLA compilation detected; {}".format(message)) + logger.warning( + "XLA compilation detected on device #{}; too many of these can lead " + "to slow training, but we expect a few in the beginning" + .format(self.args.distributed_rank) + ) self._num_xla_compiles = num_xla_compiles diff --git a/fairseq/utils.py b/fairseq/utils.py index 8f6bda393f..d2271020db 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -305,13 +305,16 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if multi_tensor_l2norm_available: total_norm = multi_tensor_total_norm(grads) else: - device = torch.device('cpu') if torch.cuda.is_available(): warnings.warn( "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " "you may get better performance by installing NVIDIA's apex library" ) device = torch.cuda.current_device() + elif grads[0].device.type == 'xla': + device = grads[0].device + else: + device = torch.device('cpu') total_norm = torch.norm( torch.stack([torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]) ) From 0557ed8b0df90fe671bcb745f384ef7fd0386ab3 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 1 Oct 2020 16:05:29 -0700 Subject: [PATCH 182/707] Enable FileContentsAction to handle manifold file Summary: The latest FileContentsAction can not handle manifold files Reviewed By: shruti-bh Differential Revision: D24059296 fbshipit-source-id: 05cfa227b55d297498c03c33347a24e1914064b4 --- fairseq/data/multilingual/multilingual_data_manager.py | 2 +- fairseq/options.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 42cdad6041..0d135bb8f9 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -87,6 +87,7 @@ def add_args(parser): type=csv_str_list, help="a list of languages comma sperated languages which can appear in lang-pairs; " "note that the ordering determines language token IDs", + action=FileContentsAction, ) parser.add_argument( "--lang-dict", @@ -96,7 +97,6 @@ def add_args(parser): "languages which can appear in lang-pairs; " "note that the ordering determines language token IDs; " "--langs and --lang-dict are two exclusive options", - action=FileContentsAction, ) parser.add_argument( "--lang-tok-style", diff --git a/fairseq/options.py b/fairseq/options.py index 5945b054ca..beee239b96 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -24,7 +24,7 @@ def __init__(self, option_strings, dest, nargs=None, **kwargs): def __call__(self, parser, namespace, values, option_string=None): if PathManager.isfile(values): - with open(values) as f: + with PathManager.open(values) as f: argument = f.read().strip() else: argument = values From f902a363abc578906f29239f995cacce5e93a807 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 2 Oct 2020 10:48:15 -0700 Subject: [PATCH 183/707] Small fixes (#1325) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1325 Reviewed By: ngoyal2707 Differential Revision: D24024198 Pulled By: myleott fbshipit-source-id: c3b776970d625eff21a26bf7c86cd28ef9e9d2ef --- fairseq/benchmark/dummy_masked_lm.py | 1 - fairseq/model_parallel/models/roberta/model.py | 3 +++ fairseq/model_parallel/models/transformer.py | 1 + fairseq/model_parallel/models/transformer_lm.py | 3 +++ fairseq/model_parallel/modules/multihead_attention.py | 2 +- .../model_parallel/modules/transformer_sentence_encoder.py | 6 +----- fairseq/optim/adam.py | 2 +- fairseq/optim/fairseq_optimizer.py | 2 ++ fairseq/utils.py | 1 + hubconf.py | 2 ++ setup.py | 2 +- 11 files changed, 16 insertions(+), 9 deletions(-) diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index 3b0bdc51f5..81398945f3 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -30,7 +30,6 @@ def add_args(parser): def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary - self.seed = args.seed # add mask token self.mask_idx = dictionary.add_symbol('') diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index 6ba097b14d..ed49fbb338 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -66,6 +66,9 @@ def build_model(cls, args, task): # make sure all arguments are present base_architecture(args) + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + if not hasattr(args, 'max_positions'): args.max_positions = args.tokens_per_sample diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index f5756ad898..3ba539319f 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -50,6 +50,7 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron' ) + dictionary.pad_to_multiple_(args.model_parallel_size * 8) num_embeddings = len(dictionary) padding_idx = dictionary.pad() diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 598807147e..492dad653c 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -36,6 +36,9 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_lm_architecture(args) + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index 97da7db35b..f55a712b01 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -60,7 +60,7 @@ def __init__( self.num_heads_partition = num_heads // self.model_parallel_size assert ( self.num_heads_partition * self.model_parallel_size == num_heads - ), "Number of heads must be divisble by model parallel size" + ), "Number of heads must be divisible by model parallel size" self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py index 101eca7bd4..a2a6eb81fa 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder.py @@ -20,11 +20,7 @@ ) try: - from fairseq.model_parallel.megatron.mpu import ( - copy_to_model_parallel_region, - gather_from_model_parallel_region, - VocabParallelEmbedding, - ) + from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index b33a5d89e9..81f1d15bd1 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -167,7 +167,7 @@ def step(self, closure=None): raise RuntimeError( "Adam does not support sparse gradients, please consider SparseAdam instead" ) - amsgrad = group["amsgrad"] + amsgrad = group.get("amsgrad", False) p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 18c26a3a39..07ac45f60d 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -108,6 +108,8 @@ def step(self, closure=None, scale=1.): if self.supports_step_with_scale: self.optimizer.step(closure, scale=scale) else: + if scale != 1.: + self.multiply_grads(1. / scale) self.optimizer.step(closure) def zero_grad(self): diff --git a/fairseq/utils.py b/fairseq/utils.py index d2271020db..af0c587583 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -288,6 +288,7 @@ def multi_tensor_total_norm(grads, chunk_size=2048*32) -> torch.Tensor: return total_norm +@torch.no_grad() def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if isinstance(params, torch.Tensor): params = [params] diff --git a/hubconf.py b/hubconf.py index 2a0c9c097e..99d9cabbd4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -11,6 +11,8 @@ dependencies = [ + 'dataclasses', + 'hydra-core', 'numpy', 'regex', 'requests', diff --git a/setup.py b/setup.py index fa59acef78..215276925f 100644 --- a/setup.py +++ b/setup.py @@ -136,9 +136,9 @@ def include_dirs(self, dirs): install_requires=[ 'cffi', 'cython', - 'hydra-core', 'dataclasses', 'editdistance', + 'hydra-core', 'numpy', 'regex', 'sacrebleu', From 7c392f7d0ea54485c303fd8349ea34b446acaada Mon Sep 17 00:00:00 2001 From: Dmitriy Genzel Date: Fri, 2 Oct 2020 12:35:58 -0700 Subject: [PATCH 184/707] Provide proper diagnostic for an empty batch Summary: Dummy batch logic as introduced in D23924558 (https://github.com/pytorch/fairseq/commit/8dde7de8a22d6e59c4101fe0de618f888c33ba81) does not provide a proper diagnostic message if the first batch given is empty. This change makes sure that the assertion in lines 799-804 is actually triggered in this case. Otherwise this change is a noop. Reviewed By: joshim5 Differential Revision: D24081932 fbshipit-source-id: 45fc49fb8a5f9a49f858683f97b5b27455f9a8d5 --- fairseq/trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 2dc44367ab..18e691ffd4 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -423,16 +423,13 @@ def begin_valid_epoch(self, epoch): # task specific setup per validation epoch self.task.begin_valid_epoch(epoch, self.get_model()) - + # reset dummy batch self._dummy_batch = 'DUMMY' @metrics.aggregate("train") def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" - if self._dummy_batch == "DUMMY": - self._dummy_batch = samples[0] - self._set_seed() self.model.train() self.criterion.train() @@ -450,6 +447,8 @@ def train_step(self, samples, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample is_dummy_batch = False def maybe_no_sync(): @@ -658,8 +657,6 @@ def maybe_no_sync(): @metrics.aggregate("valid") def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous('valid_step') # wait for all workers @@ -674,6 +671,8 @@ def valid_step(self, sample, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample is_dummy_batch = False try: From 7c292af66f61b1125854218519bf81d494e5b11e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 2 Oct 2020 19:00:29 -0700 Subject: [PATCH 185/707] Fix hub (#2687) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2687 Reviewed By: alexeib Differential Revision: D24095130 Pulled By: myleott fbshipit-source-id: 7d371bccb550ec68b2b9b39dfa4c0718356508d6 --- fairseq/optim/fp16_optimizer.py | 8 +++++-- hubconf.py | 39 +++++++++++++++++++++++---------- tests/gpu/test_binaries_gpu.py | 4 ++-- tests/test_fp16_optimizer.py | 5 ++++- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index e621a6f114..edb4f536ea 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -34,13 +34,17 @@ def has_flat_params(self): def build_fp32_params(cls, args, params, flatten=True): # create FP32 copy of parameters and grads if flatten: + is_pipeline_parallel = ( + getattr(args, 'pipeline_model_parallel', False) + and getattr(args, 'distributed_no_spawn', False) + ) total_param_size = sum(p.data.numel() for p in params) devices = [torch.cuda.current_device()] - if args.pipeline_model_parallel and args.distributed_no_spawn: + if is_pipeline_parallel: devices = list(set(args.pipeline_devices)) fp32_params = {} for device in devices: - if args.pipeline_model_parallel and args.distributed_no_spawn: + if is_pipeline_parallel: device_param_size = sum(p.data.numel() for p in params if p.device.index == device) device_params = [p for p in params if p.device.index == device] else: diff --git a/hubconf.py b/hubconf.py index 99d9cabbd4..c63fa8ae89 100644 --- a/hubconf.py +++ b/hubconf.py @@ -4,15 +4,12 @@ # LICENSE file in the root directory of this source tree. import functools - -from fairseq.hub_utils import BPEHubInterface as bpe # noqa -from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa -from fairseq.models import MODEL_REGISTRY +import importlib dependencies = [ 'dataclasses', - 'hydra-core', + 'hydra', 'numpy', 'regex', 'requests', @@ -20,31 +17,51 @@ ] +# Check for required dependencies and raise a RuntimeError if any are missing. +missing_deps = [] +for dep in dependencies: + try: + importlib.import_module(dep) + except ImportError: + # Hack: the hydra package is provided under the "hydra-core" name in + # pypi. We don't want the user mistakenly calling `pip install hydra` + # since that will install an unrelated package. + if dep == 'hydra': + dep = 'hydra-core' + missing_deps.append(dep) +if len(missing_deps) > 0: + raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) + + # torch.hub doesn't build Cython components, so if they are not found then try # to build them here try: - import fairseq.data.token_block_utils_fast -except (ImportError, ModuleNotFoundError): + import fairseq.data.token_block_utils_fast # noqa +except ImportError: try: - import cython + import cython # noqa import os from setuptools import sandbox sandbox.run_setup( os.path.join(os.path.dirname(__file__), 'setup.py'), ['build_ext', '--inplace'], ) - except (ImportError, ModuleNotFoundError): + except ImportError: print( 'Unable to build Cython components. Please make sure Cython is ' 'installed if the torch.hub model you are loading depends on it.' ) +from fairseq.hub_utils import BPEHubInterface as bpe # noqa +from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa +from fairseq.models import MODEL_REGISTRY # noqa + + +# automatically expose models defined in FairseqModel::hub_models for _model_type, _cls in MODEL_REGISTRY.items(): for model_name in _cls.hub_models().keys(): globals()[model_name] = functools.partial( _cls.from_pretrained, model_name, ) - # to simplify the interface we only expose named models - # globals()[_model_type] = _cls.from_pretrained diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index b65b545a4e..e3fadef9f2 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -64,9 +64,9 @@ def test_transformer_fp16(self): "--decoder-layers", "2", "--encoder-embed-dim", - "8", + "64", "--decoder-embed-dim", - "8", + "64", "--fp16", ], run_validation=True, diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index ae7b797ec8..bca341af1a 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -63,7 +63,10 @@ def test_mixed_precision(self): optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) self.run_iter(model, params, optimizer) - self.assertTrue(torch.all(optimizer.fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True)))) + self.assertTrue(all( + torch.all(fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True))) + for fp32_params in optimizer.fp32_params.values() + )) def test_memory_efficient(self): model = copy.deepcopy(self.model) From 5e82514d687289a73a6dec33b555217acd97cb0d Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Fri, 2 Oct 2020 21:23:15 -0700 Subject: [PATCH 186/707] update registries (#1330) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1330 construct args from dataclasses Reviewed By: myleott Differential Revision: D23973591 fbshipit-source-id: 6d7a4b80c2a815bcabd6f955513e5cd8f5cf5ab4 --- .../eval/agents/__init__.py | 2 +- .../eval/scorers/__init__.py | 3 +- .../modules/__init__.py | 3 +- fairseq/criterions/__init__.py | 31 +- fairseq/criterions/adaptive_loss.py | 11 +- fairseq/criterions/cross_entropy.py | 11 +- fairseq/criterions/ctc.py | 16 +- fairseq/criterions/fairseq_criterion.py | 9 +- fairseq/data/encoders/__init__.py | 4 +- .../multilingual/multilingual_data_manager.py | 25 +- fairseq/dataclass/__init__.py | 9 + fairseq/dataclass/constants.py | 13 + fairseq/dataclass/data_class.py | 83 +++-- fairseq/dataclass/utils.py | 6 +- fairseq/distributed_utils.py | 4 + fairseq/models/__init__.py | 112 ++++-- fairseq/models/composite_encoder.py | 2 +- fairseq/models/distributed_fairseq_model.py | 1 - fairseq/models/fairseq_model.py | 27 +- fairseq/models/transformer_lm.py | 84 +---- fairseq/optim/__init__.py | 43 ++- fairseq/optim/adam.py | 9 +- fairseq/optim/bmuf.py | 3 +- fairseq/optim/fairseq_optimizer.py | 34 +- fairseq/optim/lr_scheduler/__init__.py | 32 +- .../optim/lr_scheduler/cosine_lr_scheduler.py | 9 +- .../lr_scheduler/fairseq_lr_scheduler.py | 23 +- .../inverse_square_root_schedule.py | 9 +- fairseq/optim/nag.py | 9 +- fairseq/options.py | 348 +++--------------- fairseq/registry.py | 79 ++-- fairseq/scoring/__init__.py | 2 +- fairseq/tasks/__init__.py | 56 ++- fairseq/tasks/fairseq_task.py | 77 ++-- fairseq/tasks/language_modeling.py | 13 +- .../tasks/translation_multi_simple_epoch.py | 2 +- fairseq/utils.py | 79 ++-- 37 files changed, 578 insertions(+), 705 deletions(-) create mode 100644 fairseq/dataclass/constants.py diff --git a/examples/simultaneous_translation/eval/agents/__init__.py b/examples/simultaneous_translation/eval/agents/__init__.py index d49426344c..1c23fc1ad9 100644 --- a/examples/simultaneous_translation/eval/agents/__init__.py +++ b/examples/simultaneous_translation/eval/agents/__init__.py @@ -7,7 +7,7 @@ import os from fairseq import registry -build_agent, register_agent, MONOTONIC_AGENT = registry.setup_registry('--agent-type') +build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type') DEFAULT_EOS = '' diff --git a/examples/simultaneous_translation/eval/scorers/__init__.py b/examples/simultaneous_translation/eval/scorers/__init__.py index 6d4ab3acc7..c7fbb5495d 100644 --- a/examples/simultaneous_translation/eval/scorers/__init__.py +++ b/examples/simultaneous_translation/eval/scorers/__init__.py @@ -9,7 +9,8 @@ ( build_scorer, register_scorer, - SCORER_REGISTRIES + SCORER_REGISTRIES, + _ ) = registry.setup_registry('--scorer-type') for file in os.listdir(os.path.dirname(__file__)): diff --git a/examples/simultaneous_translation/modules/__init__.py b/examples/simultaneous_translation/modules/__init__.py index 4f311b1ee6..8fd9d379a5 100644 --- a/examples/simultaneous_translation/modules/__init__.py +++ b/examples/simultaneous_translation/modules/__init__.py @@ -10,7 +10,8 @@ ( build_monotonic_attention, register_monotonic_attention, - MONOTONIC_ATTENTION_REGISTRY + MONOTONIC_ATTENTION_REGISTRY, + _ ) = registry.setup_registry('--simul-type') for file in os.listdir(os.path.dirname(__file__)): diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index b3663d6394..30edb2f312 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -5,22 +5,33 @@ import importlib import os +from argparse import Namespace +from typing import Union from fairseq import registry -from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion - +from fairseq.criterions.fairseq_criterion import ( # noqa + FairseqCriterion, + LegacyFairseqCriterion, +) +from omegaconf import DictConfig -CRITERION_DATACLASS_REGISTRY = {} -build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( - '--criterion', - base_class=FairseqCriterion, - default='cross_entropy', +( + build_criterion_, + register_criterion, + CRITERION_REGISTRY, + CRITERION_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--criterion", base_class=FairseqCriterion, default="cross_entropy" ) +def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): + return build_criterion_(criterion_cfg, task) + + # automatically import any Python files in the criterions/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('fairseq.criterions.' + module) + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.criterions." + file_name) diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 7bc41d6000..74ba37c321 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -9,8 +9,8 @@ import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.constants import DDP_BACKEND_CHOICES from omegaconf import II @@ -20,7 +20,7 @@ class AdaptiveLossConfig(FairseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") -@register_criterion("adaptive_loss") +@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) class AdaptiveLoss(FairseqCriterion): """This is an implementation of the loss function accompanying the adaptive softmax approximation for graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" @@ -30,11 +30,6 @@ def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser. optionaly register config store""" - gen_parser_from_dataclass(parser, AdaptiveLossConfig()) - @classmethod def build_criterion(cls, args, task): if getattr(args, "ddp_backend", None) == "c10d": diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 08d64eced9..91b58545ed 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -9,28 +9,21 @@ import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from omegaconf import II @dataclass class CrossEntropyCriterionConfig(FairseqDataclass): sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") -@register_criterion("cross_entropy") +@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) class CrossEntropyCriterion(FairseqCriterion): def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser. optionaly register config store""" - gen_parser_from_dataclass(parser, CrossEntropyCriterionConfig()) - def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 7398bf8117..4f93b3cbfd 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -4,14 +4,14 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from argparse import Namespace import math +from argparse import Namespace import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.data.data_utils import post_process from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.data.data_utils import post_process from fairseq.logging.meters import safe_round @@ -219,20 +219,26 @@ def reduce_metrics(logging_outputs) -> None: if c_total > 0: metrics.log_derived( "uer", - lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3) + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) if meters["_c_total"].sum > 0 else float("nan"), ) if w_total > 0: metrics.log_derived( "wer", - lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) if meters["_w_total"].sum > 0 else float("nan"), ) metrics.log_derived( "raw_wer", - lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) if meters["_w_total"].sum > 0 else float("nan"), ) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 9873574d47..0239a548a9 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -9,6 +9,7 @@ from torch.nn.modules.loss import _Loss from fairseq import metrics, utils +from fairseq.dataclass.utils import gen_parser_from_dataclass class FairseqCriterion(_Loss): @@ -20,10 +21,12 @@ def __init__(self, task): tgt_dict = task.target_dictionary self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" - pass + dc = getattr(cls, '__dataclass', None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) @classmethod def build_criterion(cls, args, task): diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py index c0909b6697..d796496b86 100644 --- a/fairseq/data/encoders/__init__.py +++ b/fairseq/data/encoders/__init__.py @@ -10,13 +10,13 @@ from fairseq import registry -build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( +build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( '--tokenizer', default=None, ) -build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( +build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( '--bpe', default=None, ) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 0d135bb8f9..1eedab3dce 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -33,7 +33,7 @@ ) from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat from fairseq.file_io import PathManager -from fairseq.options import csv_str_list, eval_str_dict, FileContentsAction +from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict logger = logging.getLogger(__name__) @@ -788,9 +788,9 @@ def _get_shard_num_dict(cls, split, paths): for path in paths: files = PathManager.ls(path) for f in files: - if f.startswith(split) and f.endswith('.idx'): + if f.startswith(split) and f.endswith(".idx"): # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" - direction = f.split('.')[-3] + direction = f.split(".")[-3] shards[direction] += 1 # each direction has two '.idx' files # one for source language and one for target language, so: @@ -816,15 +816,16 @@ def get_split_num_data_shards(self, split): if "mono_" in data_category: # monolingual data requires tgt only assert src is None or src == tgt, ( - f"error: src={src}, " "tgt={tgt} for data_category={data_category}" + f"error: src={src}, " + "tgt={tgt} for data_category={data_category}" ) num_shards_dict[key] = shards_dict[tgt] else: if f"{src}-{tgt}" in shards_dict: num_shards_dict[key] = shards_dict[f"{src}-{tgt}"] - elif f'{tgt}-{src}' in shards_dict: + elif f"{tgt}-{src}" in shards_dict: # follow the fairseq tradition to use reversed direction data if it is not available - num_shards_dict[key] = shards_dict[f'{tgt}-{src}'] + num_shards_dict[key] = shards_dict[f"{tgt}-{src}"] self._num_shards_dict[split] = num_shards_dict logger.info(f"[{split}] num of shards: {num_shards_dict}") return num_shards_dict @@ -893,7 +894,9 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): ) return param_list - def get_train_dataset_sizes(self, data_param_list, datasets, epoch, shard_epoch=None): + def get_train_dataset_sizes( + self, data_param_list, datasets, epoch, shard_epoch=None + ): num_shards = [ self.get_split_num_data_shards(param["split"])[param["key"]] for param in data_param_list @@ -921,8 +924,12 @@ def get_train_dataset_sizes(self, data_param_list, datasets, epoch, shard_epoch= ) return [s for _, s in data_sizes] - def get_train_sampling_ratios(self, data_param_list, datasets, epoch=1, shard_epoch=None): - data_sizes = self.get_train_dataset_sizes(data_param_list, datasets, epoch, shard_epoch) + def get_train_sampling_ratios( + self, data_param_list, datasets, epoch=1, shard_epoch=None + ): + data_sizes = self.get_train_dataset_sizes( + data_param_list, datasets, epoch, shard_epoch + ) sampling_func = self.sampling_method.sampling_method_selector() sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None return sample_ratios diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py index e69de29bb2..32870814d5 100644 --- a/fairseq/dataclass/__init__.py +++ b/fairseq/dataclass/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .utils import ChoiceEnum, FairseqDataclass + + +__all__ = ["FairseqDataclass", "ChoiceEnum"] diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py new file mode 100644 index 0000000000..21b36450f9 --- /dev/null +++ b/fairseq/dataclass/constants.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.dataclass.utils import ChoiceEnum + + +LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) +DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) +DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) +ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) +PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index 19e22f4c58..bbd15d1204 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -3,25 +3,28 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import sys +from argparse import Namespace from dataclasses import dataclass, field -from typing import Any, List, Optional, Dict, Type, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type + import torch -from fairseq.data.indexed_dataset import get_available_dataset_impl -from fairseq.dataclass.utils import FairseqDataclass, ChoiceEnum -import sys -from fairseq.tasks import TASK_DATACLASS_REGISTRY -from fairseq.models import ARCH_DATACLASS_REGISTRY from fairseq.criterions import CRITERION_DATACLASS_REGISTRY +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.constants import ( + DDP_BACKEND_CHOICES, + DISTRIBUTED_WRAPPER_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + ZERO_SHARDING_CHOICES, +) +from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass +from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY from fairseq.optim.bmuf import FairseqBMUFConfig from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY +from fairseq.tasks import TASK_DATACLASS_REGISTRY from hydra.core.config_store import ConfigStore -from argparse import Namespace - - -LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) -DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) -DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) @dataclass @@ -213,6 +216,36 @@ class DistributedTrainingParams(FairseqDataclass): "and gossip across different nodes" }, ) + pipeline_model_parallel: bool = field( + default=False, + metadata={"help": "if set, use pipeline model parallelism across GPUs"}, + ) + pipeline_balance: str = field( + default=None, + metadata={ + "help": "partition the model into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_balance) " + "should equal the total number of layers in the model" + }, + ) + pipeline_devices: str = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-balance argument" + }, + ) + pipeline_chunks: int = field( + default=0, metadata={"help": "microbatch count for pipeline model parallelism"} + ) + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( + default="never", + metadata={"help": "checkpointing mode for pipeline model parallelism"}, + ) + zero_sharding: ZERO_SHARDING_CHOICES = field( + default="none", metadata={"help": "ZeRO sharding"} + ) @dataclass @@ -236,6 +269,9 @@ class DatasetParams(FairseqDataclass): required_batch_size_multiple: int = field( default=8, metadata={"help": "batch size will be a multiplier of this value"} ) + required_seq_len_multiple: int = field( + default=1, metadata={"help": "maximum sequence length in batch will be a multiplier of this value"} + ) dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field( default=None, metadata={"help": "output dataset implementation"} ) @@ -544,7 +580,7 @@ def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: ) register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, ARCH_DATACLASS_REGISTRY, "model") + register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") @@ -680,15 +716,18 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: else: deletes.append("lr_scheduler") + no_dc = True if hasattr(args, "arch"): - if args.arch in ARCH_DATACLASS_REGISTRY: - overrides.append("model={}".format(args.arch)) - overrides.append("model._name={}".format(args.arch)) - # override model params with those exist in args - overrides.extend( - _override_attr("model", ARCH_DATACLASS_REGISTRY[args.arch], args) - ) - else: - deletes.append("model") + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + overrides.append("model={}".format(args.arch)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") return overrides, deletes diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index b910d8353d..093ecd8f6b 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -3,10 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, MISSING -from typing import Any, List, Optional, Dict -from enum import Enum from argparse import ArgumentParser +from dataclasses import MISSING, dataclass +from enum import Enum +from typing import Any, Dict, List, Optional def eval_str_list(x, x_type=float): diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 3611a3d6ce..886a30dfe3 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -36,6 +36,10 @@ def infer_init_method(args, force_distributed=False): raise ValueError('--pipeline-balance is currently required for pipeline model parallelism') if args.pipeline_devices is None: raise ValueError('--pipeline-devices is currently required for pipeline model parallelism') + + args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) + args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) + gpus_per_node = torch.cuda.device_count() num_pipeline_devices = len(set(args.pipeline_devices)) assert gpus_per_node >= num_pipeline_devices and gpus_per_node % num_pipeline_devices == 0, ( diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 332ac822a7..e441e7cd7d 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -6,51 +6,58 @@ import argparse import importlib import os +from argparse import Namespace +from typing import Union +import fairseq +from fairseq.dataclass import FairseqDataclass +from omegaconf import DictConfig, OmegaConf + +from .composite_encoder import CompositeEncoder +from .distributed_fairseq_model import DistributedFairseqModel from .fairseq_decoder import FairseqDecoder from .fairseq_encoder import FairseqEncoder from .fairseq_incremental_decoder import FairseqIncrementalDecoder from .fairseq_model import ( BaseFairseqModel, - FairseqEncoderModel, FairseqEncoderDecoderModel, + FairseqEncoderModel, FairseqLanguageModel, FairseqModel, FairseqMultiModel, ) -from .composite_encoder import CompositeEncoder -from .distributed_fairseq_model import DistributedFairseqModel - MODEL_REGISTRY = {} +MODEL_DATACLASS_REGISTRY = {} ARCH_MODEL_REGISTRY = {} ARCH_MODEL_NAME_REGISTRY = {} ARCH_MODEL_INV_REGISTRY = {} ARCH_CONFIG_REGISTRY = {} -ARCH_DATACLASS_REGISTRY = {} __all__ = [ - 'BaseFairseqModel', - 'CompositeEncoder', - 'DistributedFairseqModel', - 'FairseqDecoder', - 'FairseqEncoder', - 'FairseqEncoderDecoderModel', - 'FairseqEncoderModel', - 'FairseqIncrementalDecoder', - 'FairseqLanguageModel', - 'FairseqModel', - 'FairseqMultiModel', + "BaseFairseqModel", + "CompositeEncoder", + "DistributedFairseqModel", + "FairseqDecoder", + "FairseqEncoder", + "FairseqEncoderDecoderModel", + "FairseqEncoderModel", + "FairseqIncrementalDecoder", + "FairseqLanguageModel", + "FairseqModel", + "FairseqMultiModel", ] -def build_model(args, task): - return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) +def build_model(model_cfg: Union[DictConfig, Namespace], task): + if isinstance(model_cfg, DictConfig): + return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task) + return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task) -def register_model(name): +def register_model(name, dataclass=None): """ New model types can be added to fairseq with the :func:`register_model` function decorator. @@ -72,10 +79,19 @@ class LSTM(FairseqEncoderDecoderModel): def register_model_cls(cls): if name in MODEL_REGISTRY: - raise ValueError('Cannot register duplicate model ({})'.format(name)) + raise ValueError("Cannot register duplicate model ({})".format(name)) if not issubclass(cls, BaseFairseqModel): - raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) + raise ValueError( + "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) + ) MODEL_REGISTRY[name] = cls + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + MODEL_DATACLASS_REGISTRY[name] = dataclass return cls return register_model_cls @@ -106,16 +122,42 @@ def lstm_luong_wmt_en_de(args): arch_name (str): the name of the model architecture (``--arch``) """ + def arch_override_from_yaml(args, arch): + root_dir = os.path.dirname(os.path.dirname(fairseq.__file__)) + yaml_path = os.path.join(root_dir, "config/model/{}.yaml".format(arch)) + if not os.path.exists(yaml_path): + raise RuntimeError(f"yaml file {yaml_path} does not exist!") + arch_cfg = OmegaConf.load(yaml_path) + for k, v in arch_cfg.items(): + setattr(args, k, getattr(args, k, v)) + def register_model_arch_fn(fn): if model_name not in MODEL_REGISTRY: - raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) + raise ValueError( + "Cannot register model architecture for unknown model type ({})".format( + model_name + ) + ) if arch_name in ARCH_MODEL_REGISTRY: - raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) + raise ValueError( + "Cannot register duplicate model architecture ({})".format(arch_name) + ) if not callable(fn): - raise ValueError('Model architecture must be callable ({})'.format(arch_name)) + raise ValueError( + "Model architecture must be callable ({})".format(arch_name) + ) ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] + ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) - ARCH_CONFIG_REGISTRY[arch_name] = fn + if type(fn) is type and issubclass(fn, BaseFairseqModel): + # for model classes migrated with hydra + # in this case, we are using this decorator directly on model class since + # we do not need arch overriding functions. + ARCH_CONFIG_REGISTRY[arch_name] = lambda args: arch_override_from_yaml( + args, arch=arch_name + ) + else: + ARCH_CONFIG_REGISTRY[arch_name] = fn return fn return register_model_arch_fn @@ -126,18 +168,20 @@ def register_model_arch_fn(fn): for file in os.listdir(models_dir): path = os.path.join(models_dir, file) if ( - not file.startswith('_') - and not file.startswith('.') - and (file.endswith('.py') or os.path.isdir(path)) + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) ): - model_name = file[:file.find('.py')] if file.endswith('.py') else file - module = importlib.import_module('fairseq.models.' + model_name) + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.models." + model_name) # extra `model_parser` for sphinx if model_name in MODEL_REGISTRY: parser = argparse.ArgumentParser(add_help=False) - group_archs = parser.add_argument_group('Named architectures') - group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name]) - group_args = parser.add_argument_group('Additional command-line arguments') + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group("Additional command-line arguments") MODEL_REGISTRY[model_name].add_args(group_args) - globals()[model_name + '_parser'] = parser + globals()[model_name + "_parser"] = parser diff --git a/fairseq/models/composite_encoder.py b/fairseq/models/composite_encoder.py index afef248cdc..60d1473f5f 100644 --- a/fairseq/models/composite_encoder.py +++ b/fairseq/models/composite_encoder.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.models import FairseqEncoder +from .fairseq_encoder import FairseqEncoder class CompositeEncoder(FairseqEncoder): diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index dd74bf1f13..4fe02b20dd 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -8,7 +8,6 @@ import torch.nn as nn from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel -from fairseq.models import BaseFairseqModel _GOSSIP_DISABLED = False diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 5cf6cba118..9d777f02fa 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -15,6 +15,7 @@ from fairseq import utils from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.models import FairseqDecoder, FairseqEncoder from torch import Tensor @@ -29,10 +30,12 @@ def __init__(self): super().__init__() self._is_generation_fast = False - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add model-specific arguments to the parser.""" - pass + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) @classmethod def build_model(cls, args, task): @@ -123,22 +126,22 @@ def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" def _apply(m): - if hasattr(m, 'set_num_updates') and m != self: + if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) + self.apply(_apply) def prepare_for_inference_(self, args): """Prepare model for inference.""" kwargs = {} - kwargs['beamable_mm_beam_size'] = ( - None if getattr(args, 'no_beamable_mm', False) - else getattr(args, 'beam', 5) + kwargs["beamable_mm_beam_size"] = ( + None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5) ) - kwargs['need_attn'] = getattr(args, 'print_alignment', False) - if hasattr(args, 'retain_dropout'): - kwargs['retain_dropout'] = args.retain_dropout - kwargs['retain_dropout_modules'] = getattr( - args, 'retain_dropout_modules', None + kwargs["need_attn"] = getattr(args, "print_alignment", False) + if hasattr(args, "retain_dropout"): + kwargs["retain_dropout"] = args.retain_dropout + kwargs["retain_dropout_modules"] = getattr( + args, "retain_dropout_modules", None ) self.make_generation_fast_(**kwargs) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index fd7adb3c15..22b17f06ee 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -8,7 +8,7 @@ from typing import Optional from fairseq import options, utils -from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass +from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import ( FairseqLanguageModel, register_model, @@ -159,11 +159,10 @@ class TransformerLanguageModelConfig(FairseqDataclass): add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - # TODO common var add to parent tpu: bool = II("params.common.tpu") -@register_model("transformer_lm") +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) class TransformerLanguageModel(FairseqLanguageModel): @classmethod def hub_models(cls): @@ -187,85 +186,6 @@ def moses_fastbpe(path): def __init__(self, decoder): super().__init__(decoder) - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', - help='dropout probability after activation in FFN.') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-output-dim', type=int, metavar='N', - help='decoder output dimension') - parser.add_argument('--decoder-input-dim', type=int, metavar='N', - help='decoder input dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads') - parser.add_argument('--decoder-normalize-before', action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--no-decoder-final-norm', action='store_true', - help='don\'t add an extra layernorm after the last decoder block') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N', - help='adaptive input factor') - parser.add_argument('--no-token-positional-embeddings', action='store_true', - help='if set, disables positional embeddings (outside self attention)') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--character-embeddings', action='store_true', - help='if set, uses character embedding convolutions to produce token embeddings') - parser.add_argument('--character-filters', type=str, metavar='LIST', - default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]', - help='size of character embeddings') - parser.add_argument('--character-embedding-dim', default=4, type=int, metavar='N', - help='size of character embeddings') - parser.add_argument('--char-embedder-highway-layers', default=2, type=int, metavar='N', - help='number of highway layers for character token embeddder') - parser.add_argument('--adaptive-input', action='store_true', - help='if set, uses adaptive input') - parser.add_argument('--adaptive-input-factor', type=float, metavar='N', - help='adaptive input factor') - parser.add_argument('--adaptive-input-cutoff', metavar='EXPR', - help='comma separated list of adaptive input cutoff points.') - parser.add_argument('--tie-adaptive-weights', action='store_true', - help='if set, ties the weights of adaptive softmax and adaptive input') - parser.add_argument('--tie-adaptive-proj', action='store_true', - help='if set, ties the projection weights of adaptive softmax and adaptive input') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') - parser.add_argument('--layernorm-embedding', action='store_true', - help='add layernorm to embedding') - parser.add_argument('--no-scale-embedding', action='store_true', - help='if True, dont scale embeddings') - # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) - parser.add_argument('--decoder-layerdrop', type=float, metavar='D', - help='LayerDrop probability for decoder') - parser.add_argument('--decoder-layers-to-keep', - help='which layers to *keep* when pruning as a comma-separated list') - # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) - parser.add_argument('--quant-noise-pq', type=float, metavar='D', - help='iterative PQ quantization noise at training time') - parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', - help='block size of quantization noise at training time') - parser.add_argument('--quant-noise-scalar', type=float, metavar='D', - help='scalar quantization noise and scalar quantization at training time') - # fmt: on - @classmethod def build_model(cls, args, task): """Build a new model instance.""" diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 773492775a..e2c3a3ceff 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -5,40 +5,47 @@ import importlib import os +from argparse import Namespace +from typing import Union from fairseq import registry -from fairseq.optim.fairseq_optimizer import FairseqOptimizer, LegacyFairseqOptimizer # noqa -from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.fairseq_optimizer import ( # noqa + FairseqOptimizer, + LegacyFairseqOptimizer, +) +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.shard import shard_ +from omegaconf import DictConfig -OPTIMIZER_DATACLASS_REGISTRY = {} - __all__ = [ - 'FairseqOptimizer', - 'FP16Optimizer', - 'MemoryEfficientFP16Optimizer', - 'shard_', + "FairseqOptimizer", + "FP16Optimizer", + "MemoryEfficientFP16Optimizer", + "shard_", ] -_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( - '--optimizer', - base_class=FairseqOptimizer, - required=True, -) +( + _build_optimizer, + register_optimizer, + OPTIMIZER_REGISTRY, + OPTIMIZER_DATACLASS_REGISTRY, +) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) -def build_optimizer(args, params, *extra_args, **extra_kwargs): +def build_optimizer( + optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs +): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] params = list(filter(lambda p: p.requires_grad, params)) - return _build_optimizer(args, params, *extra_args, **extra_kwargs) + return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) # automatically import any Python files in the optim/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('fairseq.optim.' + module) + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim." + file_name) diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 81f1d15bd1..f678a9f56c 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist import torch.optim -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class from omegaconf import II @@ -37,7 +37,7 @@ class FairseqAdamConfig(FairseqDataclass): lr: List[float] = II("params.optimization.lr") -@register_optimizer("adam") +@register_optimizer("adam", dataclass=FairseqAdamConfig) class FairseqAdam(FairseqOptimizer): """Adam optimizer for fairseq. @@ -64,11 +64,6 @@ def __init__(self, args, params): else: self._optimizer = Adam(params, **self.optimizer_config) - @staticmethod - def add_args(parser): - """Add optimizer-specific arguments to the parser.""" - gen_parser_from_dataclass(parser, FairseqAdamConfig()) - @property def optimizer_config(self): """ diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 5d98aa2f84..3312f81103 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -7,7 +7,8 @@ import torch import torch.distributed as dist -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer from omegaconf import II diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 07ac45f60d..b602e51818 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -4,37 +4,38 @@ # LICENSE file in the root directory of this source tree. import torch - from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass class FairseqOptimizer(object): - def __init__(self, args): super().__init__() self.args = args - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add optimizer-specific arguments to the parser.""" - pass + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) @property def optimizer(self): """Return a torch.optim.optimizer.Optimizer instance.""" - if not hasattr(self, '_optimizer'): + if not hasattr(self, "_optimizer"): raise NotImplementedError if not isinstance(self._optimizer, torch.optim.Optimizer): - raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") return self._optimizer @optimizer.setter def optimizer(self, optimizer): """Reset optimizer instance.""" - if not hasattr(self, '_optimizer'): + if not hasattr(self, "_optimizer"): raise NotImplementedError if not isinstance(self._optimizer, torch.optim.Optimizer): - raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") self._optimizer = optimizer @property @@ -51,7 +52,7 @@ def optimizer_config(self): def params(self): """Return an iterable of the parameters held by the optimizer.""" for param_group in self.param_groups: - for p in param_group['params']: + for p in param_group["params"]: yield p @property @@ -63,12 +64,12 @@ def __getstate__(self): def get_lr(self): """Return the current learning rate.""" - return self.param_groups[0]['lr'] + return self.param_groups[0]["lr"] def set_lr(self, lr): """Set the learning rate.""" for param_group in self.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr def state_dict(self): """Return the optimizer's state dict.""" @@ -103,7 +104,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) - def step(self, closure=None, scale=1.): + def step(self, closure=None, scale=1.0): """Performs a single optimization step.""" if self.supports_step_with_scale: self.optimizer.step(closure, scale=scale) @@ -120,13 +121,13 @@ def zero_grad(self): @property def supports_memory_efficient_fp16(self): - if hasattr(self.optimizer, 'supports_memory_efficient_fp16'): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): return self.optimizer.supports_memory_efficient_fp16 return False @property def supports_step_with_scale(self): - if hasattr(self.optimizer, 'supports_step_with_scale'): + if hasattr(self.optimizer, "supports_step_with_scale"): return self.optimizer.supports_step_with_scale return False @@ -136,7 +137,7 @@ def supports_flat_params(self): Whether the optimizer supports collapsing of the model parameters/gradients into a single contiguous Tensor. """ - if hasattr(self.optimizer, 'supports_flat_params'): + if hasattr(self.optimizer, "supports_flat_params"): return self.optimizer.supports_flat_params return False @@ -145,6 +146,5 @@ def average_params(self): class LegacyFairseqOptimizer(FairseqOptimizer): - def __init__(self, args): self.args = args diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index fe84bc6004..85773aab39 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -5,21 +5,33 @@ import importlib import os +from argparse import Namespace +from typing import Union from fairseq import registry -from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa - +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa + FairseqLRScheduler, + LegacyFairseqLRScheduler, +) +from omegaconf import DictConfig -LR_SCHEDULER_DATACLASS_REGISTRY = {} -build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( - '--lr-scheduler', - base_class=FairseqLRScheduler, - default='fixed', +( + build_lr_scheduler_, + register_lr_scheduler, + LR_SCHEDULER_REGISTRY, + LR_SCHEDULER_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" ) + +def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): + return build_lr_scheduler_(lr_scheduler_cfg, optimizer) + + # automatically import any Python files in the optim/lr_scheduler/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('fairseq.optim.lr_scheduler.' + module) + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim.lr_scheduler." + file_name) diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index bd133ef091..98d557504f 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from omegaconf import II from . import FairseqLRScheduler, register_lr_scheduler @@ -42,7 +42,7 @@ class CosineConfig(FairseqDataclass): max_update: int = II("params.optimization.max_update") -@register_lr_scheduler("cosine") +@register_lr_scheduler("cosine", dataclass=CosineConfig) class CosineSchedule(FairseqLRScheduler): """Assign LR based on a cyclical schedule that follows the cosine function. @@ -105,11 +105,6 @@ def __init__(self, args, optimizer): self.lr = args.warmup_init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - gen_parser_from_dataclass(parser, CosineConfig()) - def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" super().step(epoch, val_loss) diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 5569de3db8..8fde0713aa 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -3,32 +3,36 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .. import FairseqOptimizer from argparse import Namespace +from fairseq.dataclass.utils import gen_parser_from_dataclass -class FairseqLRScheduler(object): +from .. import FairseqOptimizer + +class FairseqLRScheduler(object): def __init__(self, args, optimizer): super().__init__() if not isinstance(optimizer, FairseqOptimizer): - raise ValueError('optimizer must be an instance of FairseqOptimizer') + raise ValueError("optimizer must be an instance of FairseqOptimizer") self.args = args self.optimizer = optimizer self.best = None - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add arguments to the parser for this LR scheduler.""" - pass + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) def state_dict(self): """Return the LR scheduler state dict.""" - return {'best': self.best} + return {"best": self.best} def load_state_dict(self, state_dict): """Load an LR scheduler state dict.""" - self.best = state_dict['best'] + self.best = state_dict["best"] def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" @@ -44,10 +48,9 @@ def step_update(self, num_updates): class LegacyFairseqLRScheduler(FairseqLRScheduler): - def __init__(self, args: Namespace, optimizer): if not isinstance(optimizer, FairseqOptimizer): - raise ValueError('optimizer must be an instance of FairseqOptimizer') + raise ValueError("optimizer must be an instance of FairseqOptimizer") self.args = args self.optimizer = optimizer self.best = None diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index 388ac216bc..d27261ad48 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from omegaconf import II from . import FairseqLRScheduler, register_lr_scheduler @@ -28,7 +28,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): lr: List[float] = II("params.optimization.lr") -@register_lr_scheduler("inverse_sqrt") +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) class InverseSquareRootSchedule(FairseqLRScheduler): """Decay the LR based on the inverse square root of the update number. @@ -69,11 +69,6 @@ def __init__(self, args, optimizer): self.lr = args.warmup_init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - gen_parser_from_dataclass(parser, InverseSquareRootScheduleConfig()) - def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" super().step(epoch, val_loss) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 1050071f51..58d2f3560f 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -7,7 +7,7 @@ from typing import List import torch -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from omegaconf import II from torch.optim.optimizer import Optimizer, required @@ -22,17 +22,12 @@ class FairseqNAGConfig(FairseqDataclass): lr: List[float] = II("params.optimization.lr") -@register_optimizer("nag") +@register_optimizer("nag", dataclass=FairseqNAGConfig) class FairseqNAG(FairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = NAG(params, **self.optimizer_config) - @staticmethod - def add_args(parser): - """Add optimizer-specific arguments to the parser.""" - gen_parser_from_dataclass(parser, FairseqNAGConfig()) - @property def optimizer_config(self): """ diff --git a/fairseq/options.py b/fairseq/options.py index beee239b96..fd7c12fbd7 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -4,31 +4,24 @@ # LICENSE file in the root directory of this source tree. import argparse -import os -import sys from typing import Callable, List, Optional import torch -# this import is for backward compatibility -from fairseq.file_io import PathManager -from fairseq.utils import csv_str_list, eval_str_list, eval_str_dict, eval_bool # noqa from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.data_class import ( + CheckpointParams, + CommonEvalParams, + CommonParams, + DatasetParams, + DistributedTrainingParams, + EvalLMParams, + OptimizationParams, +) +from fairseq.dataclass.utils import gen_parser_from_dataclass - -class FileContentsAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - if PathManager.isfile(values): - with PathManager.open(values) as f: - argument = f.read().strip() - else: - argument = values - setattr(namespace, self.dest, argument) +# this import is for backward compatibility +from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list # noqa def get_preprocessing_parser(default_task="translation"): @@ -74,7 +67,7 @@ def get_validation_parser(default_task=None): add_dataset_args(parser, train=True) add_distributed_training_args(parser, default_world_size=1) group = parser.add_argument_group("Evaluation") - add_common_eval_args(group) + gen_parser_from_dataclass(group, CommonEvalParams()) return parser @@ -112,7 +105,7 @@ def parse_args_and_arch( **{k: v for k, v in vars(args).items() if v is not None} ) - from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. @@ -138,7 +131,12 @@ def parse_args_and_arch( # arguments or which have default values. argument_default=argparse.SUPPRESS, ) - ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + if args.arch in ARCH_MODEL_REGISTRY: + ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + elif args.arch in MODEL_REGISTRY: + MODEL_REGISTRY[args.arch].add_args(model_specific_group) + else: + raise RuntimeError() # Add *-specific args to parser. from fairseq.registry import REGISTRIES @@ -169,7 +167,6 @@ def parse_args_and_arch( else: args = parser.parse_args(input_args) extra = None - # Post-process args. if hasattr(args, "max_sentences_valid") and args.max_sentences_valid is None: args.max_sentences_valid = args.max_sentences @@ -193,7 +190,7 @@ def parse_args_and_arch( args.no_seed_provided = False # Apply architecture configuration. - if hasattr(args, "arch"): + if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: ARCH_CONFIG_REGISTRY[args.arch](args) if parse_known: @@ -211,65 +208,27 @@ def get_parser(desc, default_task="translation"): utils.import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) - # fmt: off - parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') - parser.add_argument('--log-interval', type=int, default=100, metavar='N', - help='log progress every N batches (when progress bar is disabled)') - parser.add_argument('--log-format', default=None, help='log format to use', - choices=['json', 'none', 'simple', 'tqdm']) - parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', - help='path to save logs for tensorboard, should match --logdir ' - 'of running tensorboard (default: no tensorboard logging)') - parser.add_argument('--seed', default=None, type=int, metavar='N', - help='pseudo random number generator seed') - parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') - parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA') - parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu') - parser.add_argument('--fp16', action='store_true', help='use FP16') - parser.add_argument('--memory-efficient-bf16', action='store_true', - help='use a memory-efficient version of BF16 training; implies --bf16') - parser.add_argument('--memory-efficient-fp16', action='store_true', - help='use a memory-efficient version of FP16 training; implies --fp16') - parser.add_argument('--fp16-no-flatten-grads', action='store_true', - help='don\'t flatten FP16 grads tensor') - parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int, - help='default FP16 loss scale') - parser.add_argument('--fp16-scale-window', type=int, - help='number of updates before increasing loss scale') - parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float, - help='pct of updates that can overflow before decreasing the loss scale') - parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', - help='minimum FP16 loss scale, after which training is stopped') - parser.add_argument('--threshold-loss-scale', type=float, - help='threshold FP16 loss scale from below') - parser.add_argument('--user-dir', default=None, - help='path to a python module containing custom extensions (tasks and/or architectures)') - parser.add_argument('--empty-cache-freq', default=0, type=int, - help='how often to clear the PyTorch CUDA cache (0 to disable)') - parser.add_argument('--all-gather-list-size', default=16384, type=int, - help='number of bytes reserved for gathering stats from workers') - parser.add_argument('--model-parallel-size', type=int, metavar='N', - default=1, - help='total number of GPUs to parallelize model over') - parser.add_argument('--checkpoint-suffix', default='', - help='suffix to add to the checkpoint file name') - parser.add_argument('--quantization-config-path', default=None, - help='path to quantization config file') - parser.add_argument('--profile', action='store_true', help='enable autograd profiler emit_nvtx') + gen_parser_from_dataclass(parser, CommonParams()) from fairseq.registry import REGISTRIES + for registry_name, REGISTRY in REGISTRIES.items(): parser.add_argument( - '--' + registry_name.replace('_', '-'), - default=REGISTRY['default'], - choices=REGISTRY['registry'].keys(), + "--" + registry_name.replace("_", "-"), + default=REGISTRY["default"], + choices=REGISTRY["registry"].keys(), ) # Task definitions can be found under fairseq/tasks/ from fairseq.tasks import TASK_REGISTRY - parser.add_argument('--task', metavar='TASK', default=default_task, - choices=TASK_REGISTRY.keys(), - help='task') + + parser.add_argument( + "--task", + metavar="TASK", + default=default_task, + choices=TASK_REGISTRY.keys(), + help="task", + ) # fmt: on return parser @@ -321,261 +280,46 @@ def add_preprocess_args(parser): def add_dataset_args(parser, train=False, gen=False): - group = parser.add_argument_group("Dataset and data loading") - # fmt: off - group.add_argument('--num-workers', default=1, type=int, metavar='N', - help='how many subprocesses to use for data loading') - group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', - help='ignore too long or too short lines in valid and test set') - group.add_argument('--max-tokens', type=int, metavar='N', - help='maximum number of tokens in a batch') - group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', - help='maximum number of sentences in a batch') - group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N', - help='batch size will either be less than this value, ' - 'or a multiple of this value') - group.add_argument('--required-seq-len-multiple', default=1, type=int, metavar='N', - help='maximum sequence length in batch will be a multiplier of this value') - parser.add_argument('--dataset-impl', metavar='FORMAT', - choices=get_available_dataset_impl(), - help='output dataset implementation') - group.add_argument('--data-buffer-size', default=10, type=int, metavar='N', - help='number of batches to preload') - if train: - group.add_argument('--train-subset', default='train', metavar='SPLIT', - help='data subset to use for training (e.g. train, valid, test)') - group.add_argument('--valid-subset', default='valid', metavar='SPLIT', - help='comma separated list of data subsets to use for validation' - ' (e.g. train, valid, test)') - group.add_argument('--validate-interval', type=int, default=1, metavar='N', - help='validate every N epochs') - group.add_argument('--validate-interval-updates', type=int, default=0, metavar='N', - help='validate every N updates') - group.add_argument('--validate-after-updates', type=int, default=0, metavar='N', - help='dont validate until reaching this many updates') - group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N', - help='specified random seed for validation') - group.add_argument('--disable-validation', action='store_true', - help='disable validation') - group.add_argument('--max-tokens-valid', type=int, metavar='N', - help='maximum number of tokens in a validation batch' - ' (defaults to --max-tokens)') - group.add_argument('--max-sentences-valid', type=int, metavar='N', - help='maximum number of sentences in a validation batch' - ' (defaults to --max-sentences)') - group.add_argument('--curriculum', default=0, type=int, metavar='N', - help='don\'t shuffle batches for first N epochs') - if gen: - group.add_argument('--gen-subset', default='test', metavar='SPLIT', - help='data subset to generate (train, valid, test)') - group.add_argument('--num-shards', default=1, type=int, metavar='N', - help='shard generation over N shards') - group.add_argument('--shard-id', default=0, type=int, metavar='ID', - help='id of the shard to generate (id < num_shards)') + group = parser.add_argument_group("dataset_data_loading") + gen_parser_from_dataclass(group, DatasetParams()) # fmt: on return group def add_distributed_training_args(parser, default_world_size=None): - group = parser.add_argument_group("Distributed training") - # fmt: off + group = parser.add_argument_group("distributed_training") if default_world_size is None: default_world_size = max(1, torch.cuda.device_count()) - group.add_argument('--distributed-world-size', type=int, metavar='N', - default=default_world_size, - help='total number of GPUs across all nodes (default: all visible GPUs)') - group.add_argument('--distributed-rank', default=0, type=int, - help='rank of the current worker') - group.add_argument('--distributed-backend', default='nccl', type=str, - help='distributed backend') - group.add_argument('--distributed-init-method', default=None, type=str, - help='typically tcp://hostname:port that will be used to ' - 'establish initial connetion') - group.add_argument('--distributed-port', default=-1, type=int, - help='port number (not required if using --distributed-init-method)') - group.add_argument('--device-id', '--local_rank', default=0, type=int, - help='which GPU to use (usually configured automatically)') - group.add_argument('--distributed-no-spawn', action='store_true', - help='do not spawn multiple processes even if multiple GPUs are visible') - group.add_argument('--distributed-num-procs', default=None, type=int, - help='number of processes to spawn (usually configured automatically)') - # "c10d" is PyTorch's DDP implementation and provides the fastest - # training. "no_c10d" is a more robust, but slightly slower DDP - # implementation. Try this if you get warning messages about - # inconsistent gradients between workers, or if some of your model - # parameters are not always used. - group.add_argument('--ddp-backend', default='c10d', type=str, - choices=['c10d', 'no_c10d'], - help='DistributedDataParallel backend') - group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB', - help='bucket size for reduction') - group.add_argument('--fix-batches-to-gpus', action='store_true', - help='don\'t shuffle batches between GPUs; this reduces overall ' - 'randomness and may affect precision but avoids the cost of ' - 're-reading the data') - group.add_argument('--find-unused-parameters', default=False, action='store_true', - help='disable unused parameter detection (not applicable to ' - 'no_c10d ddp-backend') - group.add_argument('--fast-stat-sync', default=False, action='store_true', - help='[deprecated] this is now defined per Criterion') - group.add_argument('--broadcast-buffers', default=False, action='store_true', - help='Copy non-trainable parameters between GPUs, such as ' - 'batchnorm population statistics') - - group.add_argument('--distributed-wrapper', default='DDP', type=str, - choices=['DDP', 'SlowMo'], - help='DistributedDataParallel backend') - # Add arguments for SlowMo - these will be used when SlowMo is enabled via above - group.add_argument('--slowmo-momentum', default=None, type=float, - help='SlowMo momentum term; by default use 0.0 for 16 GPUs, ' - '0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs') - group.add_argument('--slowmo-algorithm', default='LocalSGD', choices=['LocalSGD', 'SGP'], - help='whether to use LocalSGD or SGP') - group.add_argument('--localsgd-frequency', default=3, type=int, - help='Local SGD allreduce frequency') - group.add_argument('--nprocs-per-node', type=int, metavar='N', - default=max(1, torch.cuda.device_count()), - help='number of GPUs in each node. An allreduce operation across GPUs in ' - 'a node is very fast. Hence, we do allreduce across GPUs in a node, ' - 'and gossip across different nodes') - # Pipeline Parallel Arguments - group.add_argument('--pipeline-model-parallel', default=False, action='store_true', - help='if set, use pipeline model parallelism across GPUs') - group.add_argument('--pipeline-balance', metavar='N1,N2,...,N_K', - type=lambda x: eval_str_list(x, type=int), - help='partition the model into N_K pieces, where each piece ' - 'contains N_i layers. The sum(args.pipeline_balance) ' - 'should equal the total number of layers in the model') - group.add_argument('--pipeline-devices', metavar='N1,N2,...,N_K', - type=lambda x: eval_str_list(x, type=int), - help='a list of device indices indicating which device to place ' - 'each of the N_K partitions. The length of this list should ' - 'equal the length of the --pipeline-balance argument') - group.add_argument('--pipeline-chunks', type=int, metavar='N', - help='microbatch count for pipeline model parallelism') - group.add_argument('--pipeline-checkpoint', type=str, metavar='STR', - choices=['always', 'never', 'except_last'], - default='never', - help='checkpointing mode for pipeline model parallelism') - # Add argument for ZeRO sharding of OptimizerState(os), gradients(g) and parameters(p) - group.add_argument('--zero-sharding', default='none', type=str, - choices=['none', 'os'], - help='ZeRO sharding') - # fmt: on + gen_parser_from_dataclass( + group, DistributedTrainingParams(distributed_world_size=default_world_size) + ) return group def add_optimization_args(parser): - group = parser.add_argument_group("Optimization") + group = parser.add_argument_group("optimization") # fmt: off - group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', - help='force stop training at specified epoch') - group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', - help='force stop training at specified update') - group.add_argument('--stop-time-hours', default=0, type=float, metavar='N', - help='force stop training after specified cumulative time (if >0)') - group.add_argument('--clip-norm', default=0.0, type=float, metavar='NORM', - help='clip threshold of gradients') - group.add_argument('--sentence-avg', action='store_true', - help='normalize gradients by the number of sentences in a batch' - ' (default is to normalize by number of tokens)') - group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K', - type=lambda uf: eval_str_list(uf, type=int), - help='update parameters every N_i batches, when in epoch i') - group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list, - metavar='LR_1,LR_2,...,LR_N', - help='learning rate for the first N epochs; all epochs >N using LR_N' - ' (note: this may be interpreted differently depending on --lr-scheduler)') - group.add_argument('--min-lr', default=-1, type=float, metavar='LR', - help='stop training when the learning rate reaches this minimum') - group.add_argument('--use-bmuf', default=False, action='store_true', - help='specify global optimizer for syncing models on different GPUs/shards') + gen_parser_from_dataclass(group, OptimizationParams()) # fmt: on return group def add_checkpoint_args(parser): - group = parser.add_argument_group("Checkpointing") + group = parser.add_argument_group("checkpoint") # fmt: off - group.add_argument('--save-dir', metavar='DIR', default='checkpoints', - help='path to save checkpoints') - group.add_argument('--restore-file', default='checkpoint_last.pt', - help='filename from which to load checkpoint ' - '(default: /checkpoint_last.pt') - group.add_argument('--finetune-from-model', default=None, type=str, - help='finetune from a pretrained model; ' - 'note that meters and lr scheduler will be reset') - group.add_argument('--reset-dataloader', action='store_true', - help='if set, does not reload dataloader state from the checkpoint') - group.add_argument('--reset-lr-scheduler', action='store_true', - help='if set, does not load lr scheduler state from the checkpoint') - group.add_argument('--reset-meters', action='store_true', - help='if set, does not load meters from the checkpoint') - group.add_argument('--reset-optimizer', action='store_true', - help='if set, does not load optimizer state from the checkpoint') - group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT', - help='a dictionary used to override optimizer args when loading a checkpoint') - group.add_argument('--save-interval', type=int, default=1, metavar='N', - help='save a checkpoint every N epochs') - group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', - help='save a checkpoint (and validate) every N updates') - group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', - help='keep the last N checkpoints saved with --save-interval-updates') - group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N', - help='keep last N epoch checkpoints') - group.add_argument('--keep-best-checkpoints', type=int, default=-1, metavar='N', - help='keep best N checkpoints based on scores') - group.add_argument('--no-save', action='store_true', - help='don\'t save models or checkpoints') - group.add_argument('--no-epoch-checkpoints', action='store_true', - help='only store last and best checkpoints') - group.add_argument('--no-last-checkpoints', action='store_true', - help='don\'t store last checkpoints') - group.add_argument('--no-save-optimizer-state', action='store_true', - help='don\'t save optimizer-state as part of checkpoint') - group.add_argument('--best-checkpoint-metric', type=str, default='loss', - help='metric to use for saving "best" checkpoints') - group.add_argument('--maximize-best-checkpoint-metric', action='store_true', - help='select the largest metric value for saving "best" checkpoints') - group.add_argument('--patience', type=int, default=-1, metavar='N', - help=('early stop training if valid performance doesn\'t ' - 'improve for N consecutive validation runs; note ' - 'that this is influenced by --validate-interval')) + gen_parser_from_dataclass(group, CheckpointParams()) # fmt: on return group def add_common_eval_args(group): - # fmt: off - group.add_argument('--path', metavar='FILE', - help='path(s) to model file(s), colon separated') - group.add_argument('--remove-bpe', '--post-process', nargs='?', const='@@ ', default=None, - help='remove BPE tokens before scoring (can be set to sentencepiece)') - group.add_argument('--quiet', action='store_true', - help='only print final scores') - group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', - help='a dictionary used to override model args at generation ' - 'that were used during model training') - group.add_argument('--results-path', metavar='RESDIR', type=str, default=None, - help='path to save eval results (optional)"') - # fmt: on + gen_parser_from_dataclass(group, CommonEvalParams()) def add_eval_lm_args(parser): group = parser.add_argument_group("LM Evaluation") add_common_eval_args(group) - # fmt: off - group.add_argument('--output-word-probs', action='store_true', - help='if set, outputs words and their predicted log probabilities to standard output') - group.add_argument('--output-word-stats', action='store_true', - help='if set, outputs word statistics such as word count, average probability, etc') - group.add_argument('--context-window', default=0, type=int, metavar='N', - help='ensures that every evaluated token has access to a context of at least this size,' - ' if possible') - group.add_argument('--softmax-batch', default=sys.maxsize, type=int, metavar='N', - help='if BxT is more than this, will batch the softmax over vocab to this amount of tokens' - ' in order to fit into GPU memory') - # fmt: on + gen_parser_from_dataclass(group, EvalLMParams()) def add_generation_args(parser): diff --git a/fairseq/registry.py b/fairseq/registry.py index 3859872420..382dec22a8 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -4,72 +4,95 @@ # LICENSE file in the root directory of this source tree. import argparse +from argparse import Namespace +from typing import Union + +from fairseq.dataclass import FairseqDataclass +from omegaconf import DictConfig REGISTRIES = {} -def setup_registry( - registry_name: str, - base_class=None, - default=None, - required=False, -): - assert registry_name.startswith('--') - registry_name = registry_name[2:].replace('-', '_') +def setup_registry(registry_name: str, base_class=None, default=None, required=False): + assert registry_name.startswith("--") + registry_name = registry_name[2:].replace("-", "_") REGISTRY = {} REGISTRY_CLASS_NAMES = set() + DATACLASS_REGISTRY = {} # maintain a registry of all registries if registry_name in REGISTRIES: return # registry already exists - REGISTRIES[registry_name] = { - 'registry': REGISTRY, - 'default': default, - } + REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default} + + def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs): + if isinstance(args, DictConfig): + if getattr(args, "_name", None) is not None: + choice = args._name + elif hasattr(args, registry_name): + choice = args.registry_name + else: + raise RuntimeError( + f"Neither _name nor {registry_name} in args, args = {args}" + ) + else: + choice = getattr(args, registry_name, None) - def build_x(args, *extra_args, **extra_kwargs): - choice = getattr(args, registry_name, None) if choice is None: if required: - raise ValueError('--{} is required!'.format(registry_name)) + raise ValueError("--{} is required!".format(registry_name)) return None cls = REGISTRY[choice] - if hasattr(cls, 'build_' + registry_name): - builder = getattr(cls, 'build_' + registry_name) + if hasattr(cls, "build_" + registry_name): + builder = getattr(cls, "build_" + registry_name) else: builder = cls - set_defaults(args, cls) + if isinstance(args, Namespace): + set_defaults(args, cls) return builder(args, *extra_args, **extra_kwargs) - def register_x(name): - + def register_x(name, dataclass=None): def register_x_cls(cls): if name in REGISTRY: - raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name)) + raise ValueError( + "Cannot register duplicate {} ({})".format(registry_name, name) + ) if cls.__name__ in REGISTRY_CLASS_NAMES: raise ValueError( - 'Cannot register {} with duplicate class name ({})'.format( - registry_name, cls.__name__, + "Cannot register {} with duplicate class name ({})".format( + registry_name, cls.__name__ ) ) if base_class is not None and not issubclass(cls, base_class): - raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__)) + raise ValueError( + "{} must extend {}".format(cls.__name__, base_class.__name__) + ) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass REGISTRY[name] = cls + DATACLASS_REGISTRY[name] = cls.__dataclass REGISTRY_CLASS_NAMES.add(cls.__name__) return cls return register_x_cls - return build_x, register_x, REGISTRY + return build_x, register_x, REGISTRY, DATACLASS_REGISTRY -def set_defaults(args, cls): +def set_defaults(args: Namespace, cls): """Helper to set default arguments based on *add_args*.""" - if not hasattr(cls, 'add_args'): + if not hasattr(cls, "add_args"): return - parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False) + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) cls.add_args(parser) # copied from argparse.py: defaults = argparse.Namespace() diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index c17aad368a..c86d6b6c23 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -10,7 +10,7 @@ from fairseq import registry -_build_scoring, register_scoring, SCORING_REGISTRY = registry.setup_registry( +_build_scoring, register_scoring, SCORING_REGISTRY, _ = registry.setup_registry( "--scoring", default="bleu" ) diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 0b8df065a9..eda2ed34b7 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -6,21 +6,28 @@ import argparse import importlib import os +from argparse import Namespace +from typing import Union -from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa +from fairseq.dataclass import FairseqDataclass +from omegaconf import DictConfig + +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa # register dataclass +TASK_DATACLASS_REGISTRY = {} TASK_REGISTRY = {} TASK_CLASS_NAMES = set() -TASK_DATACLASS_REGISTRY = {} -def setup_task(args, **kwargs): - return TASK_REGISTRY[args.task].setup_task(args, **kwargs) +def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs): + if isinstance(task_cfg, DictConfig): + return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs) + return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs) -def register_task(name): +def register_task(name, dataclass=None): """ New tasks can be added to fairseq with the :func:`~fairseq.tasks.register_task` function decorator. @@ -36,21 +43,34 @@ class ClassificationTask(FairseqTask): All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` interface. - Please see the - Args: name (str): the name of the task """ def register_task_cls(cls): if name in TASK_REGISTRY: - raise ValueError('Cannot register duplicate task ({})'.format(name)) + raise ValueError("Cannot register duplicate task ({})".format(name)) if not issubclass(cls, FairseqTask): - raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) + raise ValueError( + "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) + ) if cls.__name__ in TASK_CLASS_NAMES: - raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) TASK_REGISTRY[name] = cls TASK_CLASS_NAMES.add(cls.__name__) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + TASK_DATACLASS_REGISTRY[name] = dataclass + return cls return register_task_cls @@ -65,21 +85,21 @@ def get_task(name): for file in os.listdir(tasks_dir): path = os.path.join(tasks_dir, file) if ( - not file.startswith('_') - and not file.startswith('.') - and (file.endswith('.py') or os.path.isdir(path)) + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) ): - task_name = file[:file.find('.py')] if file.endswith('.py') else file - importlib.import_module('fairseq.tasks.' + task_name) + task_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.tasks." + task_name) # expose `task_parser` for sphinx if task_name in TASK_REGISTRY: parser = argparse.ArgumentParser(add_help=False) - group_task = parser.add_argument_group('Task name') + group_task = parser.add_argument_group("Task name") # fmt: off group_task.add_argument('--task', metavar=task_name, help='Enable this task with: ``--task=' + task_name + '``') # fmt: on - group_args = parser.add_argument_group('Additional command-line arguments') + group_args = parser.add_argument_group("Additional command-line arguments") TASK_REGISTRY[task_name].add_args(group_args) - globals()[task_name + '_parser'] = parser + globals()[task_name + "_parser"] = parser diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index d27c38d305..2aa6c8ff28 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -6,13 +6,13 @@ import logging import os import warnings - +from argparse import Namespace import torch - from fairseq import metrics, search, tokenizer, utils -from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary -from argparse import Namespace +from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators +from fairseq.dataclass.utils import gen_parser_from_dataclass + logger = logging.getLogger(__name__) @@ -23,10 +23,12 @@ class FairseqTask(object): Datasets, initializing the Model/Criterion and calculating the loss. """ - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add task-specific arguments to the parser.""" - pass + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) @staticmethod def logging_outputs_can_be_summed(criterion) -> bool: @@ -85,7 +87,7 @@ def setup_task(cls, args, **kwargs): return cls(args, **kwargs) def has_sharded_data(self, split): - return (os.pathsep in getattr(self.args, 'data', '')) + return os.pathsep in getattr(self.args, "data", "") def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. @@ -114,11 +116,7 @@ def dataset(self, split): return self.datasets[split] def filter_indices_by_size( - self, - indices, - dataset, - max_positions=None, - ignore_invalid_inputs=False, + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False ): """ Filter examples that are too large @@ -136,14 +134,18 @@ def filter_indices_by_size( indices, ignored = dataset.filter_indices_by_size(indices, max_positions) if len(ignored) > 0: if not ignore_invalid_inputs: - raise Exception(( - 'Size of sample #{} is invalid (={}) since max_positions={}, ' - 'skip this example with --skip-invalid-size-inputs-valid-test' - ).format(ignored[0], dataset.size(ignored[0]), max_positions)) - logger.warning(( - '{} samples have invalid sizes and will be skipped, ' - 'max_positions={}, first few sample ids={}' - ).format(len(ignored), max_positions, ignored[:10])) + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) return indices def can_reuse_epoch_itr(self, dataset): @@ -151,7 +153,7 @@ def can_reuse_epoch_itr(self, dataset): # hasn't disabled it. We default to ``False`` here, although in practice # this will be ``True`` for most datasets that inherit from # ``FairseqDataset`` due to the base implementation there. - return getattr(dataset, 'can_reuse_epoch_itr_across_epochs', False) + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) def get_batch_iterator( self, @@ -204,12 +206,11 @@ def get_batch_iterator( ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ - can_reuse_epoch_itr = ( - not disable_iterator_cache - and self.can_reuse_epoch_itr(dataset) + can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( + dataset ) if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: - logger.debug('reusing EpochBatchIterator for epoch {}'.format(epoch)) + logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) return self.dataset_to_epoch_iter[dataset] assert isinstance(dataset, FairseqDataset) @@ -265,8 +266,9 @@ def build_model(self, args): a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils + model = models.build_model(args, self) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, args) return model @@ -287,8 +289,7 @@ def build_criterion(self, args): return criterions.build_criterion(args, self) def build_generator( - self, models, args, - seq_gen_cls=None, extra_gen_cls_kwargs=None + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None ): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer @@ -352,7 +353,9 @@ def build_generator( self.target_dictionary, diversity_rate ) elif constrained: - search_strategy = search.LexicallyConstrainedBeamSearch(self.target_dictionary, args.constraints) + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) else: search_strategy = search.BeamSearch(self.target_dictionary) @@ -418,9 +421,13 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): with torch.no_grad(): - return generator.generate(models, sample, prefix_tokens=prefix_tokens, constraints=constraints) + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) def begin_epoch(self, epoch, model): """Hook function called before the start of each epoch.""" @@ -494,7 +501,6 @@ def target_dictionary(self): class LegacyFairseqTask(FairseqTask): - def __init__(self, args: Namespace): self.args = args self.datasets = {} @@ -510,7 +516,7 @@ def setup_task(cls, args: Namespace, **kwargs): return cls(args, **kwargs) def has_sharded_data(self, split): - return (os.pathsep in getattr(self.args, 'data', '')) + return os.pathsep in getattr(self.args, "data", "") def build_model(self, args: Namespace): """ @@ -524,8 +530,9 @@ def build_model(self, args: Namespace): a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils + model = models.build_model(args, self) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, args) return model diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 190bc27cf2..5477c28aa9 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -27,11 +27,7 @@ ) from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.dataclass.utils import ( - ChoiceEnum, - FairseqDataclass, - gen_parser_from_dataclass, -) +from fairseq.dataclass import FairseqDataclass, ChoiceEnum from fairseq.tasks import FairseqTask, register_task from omegaconf import II @@ -97,7 +93,7 @@ class LanguageModelingConfig(FairseqDataclass): tpu: bool = II("params.common.tpu") -@register_task("language_modeling") +@register_task("language_modeling", dataclass=LanguageModelingConfig) class LanguageModelingTask(FairseqTask): """ Train a language model. @@ -127,11 +123,6 @@ class LanguageModelingTask(FairseqTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser. optionaly register config store""" - gen_parser_from_dataclass(parser, LanguageModelingConfig()) - def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args) self.dictionary = dictionary diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 10aaeaa12c..d9c0fa985b 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -19,7 +19,7 @@ from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.multilingual.sampling_method import SamplingMethod from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager -from fairseq.options import FileContentsAction +from fairseq.utils import FileContentsAction ### def get_time_gap(s, e): diff --git a/fairseq/utils.py b/fairseq/utils.py index af0c587583..4de258d9a2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -3,29 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import contextlib import copy import importlib.util import logging -import math import os import sys import warnings -from collections import defaultdict from itertools import accumulate from typing import Callable, Dict, List, Optional -import numpy as np import torch import torch.nn.functional as F from fairseq.data import iterators +from fairseq.file_io import PathManager from fairseq.logging.meters import safe_round from fairseq.modules import gelu, gelu_accurate from fairseq.modules.multihead_attention import MultiheadAttention from torch import Tensor + try: from amp_C import multi_tensor_l2norm + multi_tensor_l2norm_available = True except ImportError: multi_tensor_l2norm_available = False @@ -37,8 +38,27 @@ MANIFOLD_PATH_SEP = "|" +class FileContentsAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + if PathManager.isfile(values): + with PathManager.open(values) as f: + argument = f.read().strip() + else: + argument = values + setattr(namespace, self.dest, argument) + + def split_paths(paths: str) -> List[str]: - return paths.split(os.pathsep) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) + return ( + paths.split(os.pathsep) + if "://" not in paths + else paths.split(MANIFOLD_PATH_SEP) + ) def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): @@ -54,7 +74,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): def apply_to_sample(f, sample): - if hasattr(sample, '__len__') and len(sample) == 0: + if hasattr(sample, "__len__") and len(sample) == 0: return {} def _apply(x): @@ -189,9 +209,17 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk): def post_process_prediction( - hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe=None, extra_symbols_to_ignore=None + hypo_tokens, + src_str, + alignment, + align_dict, + tgt_dict, + remove_bpe=None, + extra_symbols_to_ignore=None, ): - hypo_str = tgt_dict.string(hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore) + hypo_str = tgt_dict.string( + hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore + ) if align_dict is not None: hypo_str = replace_unk( hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string() @@ -264,7 +292,7 @@ def item(tensor): return tensor -def multi_tensor_total_norm(grads, chunk_size=2048*32) -> torch.Tensor: +def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: per_device_grads = {} norms = [] for grad in grads: @@ -280,7 +308,9 @@ def multi_tensor_total_norm(grads, chunk_size=2048*32) -> torch.Tensor: # TODO(msb) return has_inf has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) with torch.cuda.device(device): - norm = multi_tensor_l2norm(chunk_size, has_inf, [cur_device_grads], False) + norm = multi_tensor_l2norm( + chunk_size, has_inf, [cur_device_grads], False + ) norms.append(norm[0].to(torch.cuda.current_device())) else: norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] @@ -296,9 +326,9 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] if len(grads) == 0: if len(params) > 0: - return params[0].new_tensor(0.) + return params[0].new_tensor(0.0) else: - return torch.tensor(0.) + return torch.tensor(0.0) if len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) @@ -312,12 +342,14 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: "you may get better performance by installing NVIDIA's apex library" ) device = torch.cuda.current_device() - elif grads[0].device.type == 'xla': + elif grads[0].device.type == "xla": device = grads[0].device else: - device = torch.device('cpu') + device = torch.device("cpu") total_norm = torch.norm( - torch.stack([torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]) + torch.stack( + [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] + ) ) if aggregate_norm_fn is not None: @@ -414,8 +446,8 @@ def import_user_module(args): module_bak = None sys.path.insert(0, module_parent) importlib.import_module(module_name) - sys.modules['fairseq_user_dir'] = sys.modules[module_name] - if module_bak is not None and module_name != 'fairseq_user_dir': + sys.modules["fairseq_user_dir"] = sys.modules[module_name] + if module_bak is not None and module_name != "fairseq_user_dir": sys.modules[module_name] = module_bak @@ -435,11 +467,11 @@ def log_softmax(x, dim: int, onnx_trace: bool = False): def get_perplexity(loss, round=2, base=2): if loss is None: - return 0. + return 0.0 try: return safe_round(base ** loss, round) except OverflowError: - return float('inf') + return float("inf") def deprecation_warning(message, stacklevel=3): @@ -544,8 +576,12 @@ def get_token_to_word_mapping(tokens, exclude_list): def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): - tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) - src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + tgt_valid = ( + ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_invalid = ( + ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) alignment = [] @@ -575,6 +611,7 @@ def new_arange(x, *size): def get_tpu_device(args): import torch_xla.core.xla_model as xm + return xm.xla_device() @@ -622,7 +659,7 @@ def pretty_print_cuda_env_list(cuda_env_list): def csv_str_list(x): - return x.split(',') + return x.split(",") def eval_str_list(x, type=float): From e3c4282551e819853952284681e9ed60398c5c4a Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 5 Oct 2020 19:07:38 -0700 Subject: [PATCH 187/707] remove max_sentences from args, use batch_size instead (#1333) Summary: now that we are moving to using dataclasses to define fairseq configuration, having aliases for options is no longer practical. this pr removes "max-sentences" argument while keeping its alias "batch-size", which is more appropriate Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1333 Reviewed By: shruti-bh Differential Revision: D24121305 Pulled By: alexeib fbshipit-source-id: 34343cea54c8f2c8b059c38ef9f29b66e76df9fb --- config/params/eval_lm_params.yaml | 5 ++--- config/params/training_params.yaml | 5 ++--- examples/bart/README.glue.md | 6 +++--- examples/byte_level_bpe/README.md | 2 +- examples/language_model/README.md | 2 +- examples/mbart/README.md | 2 +- examples/megatron_11b/README.md | 4 ++-- examples/multilingual/README.md | 2 +- examples/multilingual/multilingual_fairseq_gen.sh | 2 +- examples/noisychannel/rerank_generate.py | 2 +- examples/quant_noise/README.md | 6 +++--- examples/roberta/README.custom_classification.md | 6 +++--- examples/roberta/README.glue.md | 8 ++++---- examples/roberta/README.pretraining.md | 2 +- examples/roberta/README.race.md | 6 +++--- examples/roberta/commonsense_qa/README.md | 4 ++-- examples/roberta/wsc/README.md | 4 ++-- examples/simultaneous_translation/docs/baseline.md | 4 ++-- examples/speech_recognition/infer.py | 4 ++-- fairseq/benchmark/dummy_lm.py | 4 ++-- fairseq/benchmark/dummy_masked_lm.py | 4 ++-- fairseq/benchmark/dummy_mt.py | 4 ++-- fairseq/dataclass/data_class.py | 11 ++++------- fairseq/hub_utils.py | 2 +- fairseq/models/fairseq_model.py | 8 ++++++++ fairseq/options.py | 4 ++-- fairseq/trainer.py | 4 ++-- fairseq_cli/eval_lm.py | 2 +- fairseq_cli/generate.py | 4 ++-- fairseq_cli/interactive.py | 10 +++++----- fairseq_cli/train.py | 6 +++--- fairseq_cli/validate.py | 6 +++--- tests/test_binaries.py | 4 ++-- tests/utils.py | 2 +- 34 files changed, 77 insertions(+), 74 deletions(-) diff --git a/config/params/eval_lm_params.yaml b/config/params/eval_lm_params.yaml index 4a0259bca6..6f27055d64 100644 --- a/config/params/eval_lm_params.yaml +++ b/config/params/eval_lm_params.yaml @@ -42,8 +42,7 @@ dataset: num_workers: 1 skip_invalid_size_inputs_valid_test: false max_tokens: null - max_sentences: null - batch_size: ${params.dataset.max_sentences} + batch_size: ${params.dataset.batch_size} required_batch_size_multiple: 8 dataset_impl: null data_buffer_size: 10 @@ -57,7 +56,7 @@ dataset: num_shards: 1 shard_id: 0 max_tokens_valid: ${params.dataset.max_tokens} - max_sentences_valid: ${params.dataset.max_sentences} + batch_size_valid: ${params.dataset.batch_size} optimization: max_epoch: 0 max_update: 0 diff --git a/config/params/training_params.yaml b/config/params/training_params.yaml index 3d52a82ac4..2ce94f9290 100644 --- a/config/params/training_params.yaml +++ b/config/params/training_params.yaml @@ -42,8 +42,7 @@ dataset: num_workers: 1 skip_invalid_size_inputs_valid_test: false max_tokens: null - max_sentences: null - batch_size: ${params.dataset.max_sentences} + batch_size: ${params.dataset.batch_size} required_batch_size_multiple: 8 dataset_impl: null data_buffer_size: 10 @@ -57,7 +56,7 @@ dataset: num_shards: 1 shard_id: 0 max_tokens_valid: ${params.dataset.max_tokens} - max_sentences_valid: ${params.dataset.max_sentences} + batch_size_valid: ${params.dataset.batch_size} optimization: max_epoch: 0 max_update: 0 diff --git a/examples/bart/README.glue.md b/examples/bart/README.glue.md index 2948ff25ea..a010934e1e 100644 --- a/examples/bart/README.glue.md +++ b/examples/bart/README.glue.md @@ -26,7 +26,7 @@ BART_PATH=/path/to/bart/model.pt CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \ --restore-file $BART_PATH \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-tokens 4400 \ --task sentence_prediction \ --add-prev-output-tokens \ @@ -63,9 +63,9 @@ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` **Note:** -a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task. +a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task. -b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. +b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`. ### Inference on GLUE task After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: diff --git a/examples/byte_level_bpe/README.md b/examples/byte_level_bpe/README.md index d8c4cb6747..657092660e 100644 --- a/examples/byte_level_bpe/README.md +++ b/examples/byte_level_bpe/README.md @@ -29,7 +29,7 @@ fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_le --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \ - --max-sentences 100 --max-update 100000 --update-freq 2 + --batch-size 100 --max-update 100000 --update-freq 2 ``` ## Generation diff --git a/examples/language_model/README.md b/examples/language_model/README.md index 3d5c3862bb..dc84d8c761 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -99,7 +99,7 @@ number of GPUs. ```bash fairseq-eval-lm data-bin/wikitext-103 \ --path checkpoints/transformer_wiki103/checkpoint_best.pt \ - --max-sentences 2 \ + --batch-size 2 \ --tokens-per-sample 512 \ --context-window 400 # | Evaluated 245569 tokens in 56.1s (4379.02 tokens/s) diff --git a/examples/mbart/README.md b/examples/mbart/README.md index f22d43dba4..510edeff64 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -102,7 +102,7 @@ fairseq-generate path_2_data \ -t ro_RO -s en_XX \ --bpe 'sentencepiece' --sentencepiece-model $model_dir/sentence.bpe.model \ --sacrebleu --remove-bpe 'sentencepiece' \ - --max-sentences 32 --langs $langs > en_ro + --batch-size 32 --langs $langs > en_ro cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref diff --git a/examples/megatron_11b/README.md b/examples/megatron_11b/README.md index d6b6cc0774..945c96c91e 100644 --- a/examples/megatron_11b/README.md +++ b/examples/megatron_11b/README.md @@ -64,7 +64,7 @@ fairseq-train \ --lr-scheduler inverse_sqrt --lr 0.00015 \ --warmup-updates 3000 --weight-decay 0.01 \ --dropout 0.1 --attention-dropout 0.1 \ - --max-sentences 2 \ + --batch-size 2 \ --max-update 300000; ``` @@ -139,7 +139,7 @@ fairseq-eval-lm \ --path megatron_11b/model.pt \ --task language_modeling \ --gen-subset test \ - --max-sentences 8 \ + --batch-size 8 \ --criterion cross_entropy \ --context-window 992 \ --distributed-world-size 8 \ diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md index 392476f3b0..3559c244e2 100644 --- a/examples/multilingual/README.md +++ b/examples/multilingual/README.md @@ -90,7 +90,7 @@ fairseq-generate $path_2_data \ --source-lang $source_lang \ --target-lang $target_lang --sacrebleu --remove-bpe 'sentencepiece'\ - --max-sentences 32 \ + --batch-size 32 \ --encoder-langtok "src" \ --decoder-langtok \ --lang-dict "$lang_list" \ diff --git a/examples/multilingual/multilingual_fairseq_gen.sh b/examples/multilingual/multilingual_fairseq_gen.sh index a7487975e4..8c2c7703b2 100644 --- a/examples/multilingual/multilingual_fairseq_gen.sh +++ b/examples/multilingual/multilingual_fairseq_gen.sh @@ -14,7 +14,7 @@ fairseq-generate "$path_2_data" \ --source-lang "$source_lang" \ --target-lang "$target_lang" \ --sacrebleu --remove-bpe 'sentencepiece'\ - --max-sentences 32 \ + --batch-size 32 \ --encoder-langtok "src" \ --decoder-langtok \ --lang-dict "$lang_list" \ diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py index 48f954d360..d2da6eacf9 100644 --- a/examples/noisychannel/rerank_generate.py +++ b/examples/noisychannel/rerank_generate.py @@ -99,7 +99,7 @@ def gen_and_reprocess_nbest(args): "--nbest", str(args.num_rescore), "--batch-size", str(args.batch_size), "--beam", str(args.num_rescore), - "--max-sentences", str(args.num_rescore), + "--batch-size", str(args.num_rescore), "--gen-subset", args.gen_subset, "--source-lang", args.source_lang, "--target-lang", args.target_lang] diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index bcf0c4c827..057ea620ab 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -151,7 +151,7 @@ fairseq-train $DATA_DIR \ --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ --dropout 0.1 --attention-dropout 0.1 \ --weight-decay 0.01 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --update-freq $UPDATE_FREQ --max-update $TOTAL_UPDATES \ --save-dir checkpoint/roberta \ --ddp-backend no_c10d --encoder-layerdrop 0.2 \ @@ -172,7 +172,7 @@ ROBERTA_PATH=/path/to/roberta_quantnoise/model.pt fairseq-train /path/to/rte/data/ \ --restore-file $ROBERTA_PATH \ --max-positions 512 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-tokens 4400 \ --task sentence_prediction \ --reset-optimizer --reset-dataloader --reset-meters \ @@ -242,7 +242,7 @@ fairseq-train --task sentence_prediction /path/to/data/ \ --restore-file $ROBERTA_PATH \ --save-dir checkpoints/roberta_finetuned \ --max-positions 512 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-tokens 4400 \ --init-token 0 --separator-token 2 \ --arch roberta_large \ diff --git a/examples/roberta/README.custom_classification.md b/examples/roberta/README.custom_classification.md index 72e490ddc7..7254bb7d17 100644 --- a/examples/roberta/README.custom_classification.md +++ b/examples/roberta/README.custom_classification.md @@ -106,7 +106,7 @@ ROBERTA_PATH=/path/to/roberta.large/model.pt CUDA_VISIBLE_DEVICES=0 fairseq-train IMDB-bin/ \ --restore-file $ROBERTA_PATH \ --max-positions 512 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-tokens 4400 \ --task sentence_prediction \ --reset-optimizer --reset-dataloader --reset-meters \ @@ -129,10 +129,10 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train IMDB-bin/ \ ``` The above command will finetune RoBERTa-large with an effective batch-size of 32 -sentences (`--max-sentences=8 --update-freq=4`). The expected +sentences (`--batch-size=8 --update-freq=4`). The expected `best-validation-accuracy` after 10 epochs is ~96.5%. -If you run out of GPU memory, try decreasing `--max-sentences` and increase +If you run out of GPU memory, try decreasing `--batch-size` and increase `--update-freq` to compensate. diff --git a/examples/roberta/README.glue.md b/examples/roberta/README.glue.md index db20360e2c..77015d2e2f 100644 --- a/examples/roberta/README.glue.md +++ b/examples/roberta/README.glue.md @@ -27,7 +27,7 @@ ROBERTA_PATH=/path/to/roberta/model.pt CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \ --restore-file $ROBERTA_PATH \ --max-positions 512 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-tokens 4400 \ --task sentence_prediction \ --reset-optimizer --reset-dataloader --reset-meters \ @@ -52,7 +52,7 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B ---|---|---|---|---|---|---|---|--- `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 `--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5 -`--max-sentences` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16 +`--batch-size` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16 `--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598 `--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214 @@ -60,9 +60,9 @@ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` **Note:** -a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=16/32` depending on the task. +a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=16/32` depending on the task. -b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. +b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`. c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. diff --git a/examples/roberta/README.pretraining.md b/examples/roberta/README.pretraining.md index b841631d3e..8b6e10c08c 100644 --- a/examples/roberta/README.pretraining.md +++ b/examples/roberta/README.pretraining.md @@ -64,7 +64,7 @@ fairseq-train --fp16 $DATA_DIR \ --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \ --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ - --max-sentences $MAX_SENTENCES --update-freq $UPDATE_FREQ \ + --batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \ --max-update $TOTAL_UPDATES --log-format simple --log-interval 1 ``` diff --git a/examples/roberta/README.race.md b/examples/roberta/README.race.md index c2d1acaba6..527a0bce14 100644 --- a/examples/roberta/README.race.md +++ b/examples/roberta/README.race.md @@ -36,7 +36,7 @@ CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \ --clip-norm 0.0 \ --lr-scheduler fixed --lr $LR \ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --required-batch-size-multiple 1 \ --update-freq $UPDATE_FREQ \ --max-epoch $MAX_EPOCH @@ -46,7 +46,7 @@ CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \ a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size. -b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. +b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`. c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. @@ -61,7 +61,7 @@ fairseq-validate \ $DATA_DIR \ --valid-subset $TEST_SPLIT \ --path $MODEL_PATH \ - --max-sentences 1 \ + --batch-size 1 \ --task sentence_ranking \ --criterion sentence_ranking \ --save-predictions $PREDS_OUT diff --git a/examples/roberta/commonsense_qa/README.md b/examples/roberta/commonsense_qa/README.md index 7302794805..4f371f8b30 100644 --- a/examples/roberta/commonsense_qa/README.md +++ b/examples/roberta/commonsense_qa/README.md @@ -53,14 +53,14 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 --ddp-backend=no_c10d \ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 --clip-norm 0.0 \ --lr-scheduler polynomial_decay --lr $LR \ --warmup-updates $WARMUP_UPDATES --total-num-update $MAX_UPDATES \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-update $MAX_UPDATES \ --log-format simple --log-interval 25 \ --seed $SEED ``` The above command assumes training on 1 GPU with 32GB of RAM. For GPUs with -less memory, decrease `--max-sentences` and increase `--update-freq` +less memory, decrease `--batch-size` and increase `--update-freq` accordingly to compensate. ### 3) Evaluate diff --git a/examples/roberta/wsc/README.md b/examples/roberta/wsc/README.md index 0d3f62a07f..d40da6a5fd 100644 --- a/examples/roberta/wsc/README.md +++ b/examples/roberta/wsc/README.md @@ -59,7 +59,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \ --lr-scheduler polynomial_decay --lr $LR \ --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-update $TOTAL_NUM_UPDATES \ --log-format simple --log-interval 100 \ --seed $SEED @@ -119,7 +119,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \ --lr-scheduler polynomial_decay --lr $LR \ --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \ - --max-sentences $MAX_SENTENCES \ + --batch-size $MAX_SENTENCES \ --max-update $TOTAL_NUM_UPDATES \ --log-format simple --log-interval 100 ``` diff --git a/examples/simultaneous_translation/docs/baseline.md b/examples/simultaneous_translation/docs/baseline.md index 9aeb44a364..d9bf1a1117 100644 --- a/examples/simultaneous_translation/docs/baseline.md +++ b/examples/simultaneous_translation/docs/baseline.md @@ -84,7 +84,7 @@ CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ --max-epoch 100 \ --lr 0.001 \ --clip-norm 5.0 \ - --max-sentences 128 \ + --batch-size 128 \ --log-format json \ --log-interval 10 \ --criterion cross_entropy_acc \ @@ -127,7 +127,7 @@ CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ --max-epoch 100 \ --lr 0.001 \ --clip-norm 5.0 \ - --max-sentences 128 \ + --batch-size 128 \ --log-format json \ --log-interval 10 \ --criterion cross_entropy_acc \ diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 19f2c2ed03..950490d5f5 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -95,7 +95,7 @@ def get_dataset_itr(args, task, models): return task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=(sys.maxsize, sys.maxsize), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, @@ -234,7 +234,7 @@ def generate(self, models, sample, **unused): def main(args, task=None, model_state=None): check_args(args) - if args.max_tokens is None and args.max_sentences is None: + if args.max_tokens is None and args.batch_size is None: args.max_tokens = 4000000 logger.info(args) diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index f33a1adcf6..3c400e9d7f 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -53,8 +53,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - if self.args.max_sentences is not None: - bsz = self.args.max_sentences + if self.args.batch_size is not None: + bsz = self.args.batch_size else: bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index 81398945f3..621265d452 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -61,8 +61,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - if self.args.max_sentences is not None: - bsz = self.args.max_sentences + if self.args.batch_size is not None: + bsz = self.args.batch_size else: bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 0371b3e754..2f8d65d5be 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -55,8 +55,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): split (str): name of the split (e.g., train, valid, test) """ item_size = max(self.args.src_len, self.args.tgt_len) - if self.args.max_sentences is not None: - bsz = self.args.max_sentences + if self.args.batch_size is not None: + bsz = self.args.batch_size else: bsz = max(1, self.args.max_tokens // item_size) tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index bbd15d1204..e6997b69cd 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -260,11 +260,8 @@ class DatasetParams(FairseqDataclass): max_tokens: Optional[int] = field( default=None, metadata={"help": "maximum number of tokens in a batch"} ) - max_sentences: Optional[int] = field( - default=None, metadata={"help": "maximum number of sentences in a batch"} - ) batch_size: Optional[int] = field( - default=None, metadata={"help": "maximum number of sentences in a batch"} + default=None, metadata={"help": "number of examples in a batch"} ) required_batch_size_multiple: int = field( default=8, metadata={"help": "batch size will be a multiplier of this value"} @@ -311,11 +308,11 @@ class DatasetParams(FairseqDataclass): " (defaults to --max-tokens)" }, ) - max_sentences_valid: Optional[int] = field( + batch_size_valid: Optional[int] = field( default=None, metadata={ - "help": "maximum number of sentences in a validation batch" - " (defaults to --max-sentences)" + "help": "batch size of the validation batch" + " (defaults to --batch-size)" }, ) curriculum: int = field( diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 4e499e141d..b56135abf3 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -230,7 +230,7 @@ def _build_batches( batch_iterator = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference(tokens, lengths), max_tokens=self.args.max_tokens, - max_sentences=self.args.max_sentences, + max_sentences=self.args.batch_size, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, disable_iterator_cache=True, diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 9d777f02fa..4a2b6b66b9 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -219,6 +219,11 @@ def apply_prepare_for_tpu_(module): self.apply(apply_prepare_for_tpu_) + @classmethod + def upgrade_args(cls, args): + if hasattr(args, 'max_sentences') and not hasattr(args, 'batch_size'): + args.batch_size = args.max_sentences + @classmethod def from_pretrained( cls, @@ -257,6 +262,9 @@ def from_pretrained( archive_map=cls.hub_models(), **kwargs, ) + + cls.upgrade_args(x["args"]) + logger.info(x["args"]) return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) diff --git a/fairseq/options.py b/fairseq/options.py index fd7c12fbd7..4d3a9766c8 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -168,8 +168,8 @@ def parse_args_and_arch( args = parser.parse_args(input_args) extra = None # Post-process args. - if hasattr(args, "max_sentences_valid") and args.max_sentences_valid is None: - args.max_sentences_valid = args.max_sentences + if hasattr(args, "batch_size_valid") and args.batch_size_valid is None: + args.batch_size_valid = args.batch_size if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: args.max_tokens_valid = args.max_tokens if getattr(args, "memory_efficient_fp16", False): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 18e691ffd4..dbac04e045 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -358,7 +358,7 @@ def get_train_iterator( return self.task.get_batch_iterator( dataset=self.task.dataset(self.args.train_subset), max_tokens=self.args.max_tokens, - max_sentences=self.args.max_sentences, + max_sentences=self.args.batch_size, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), @@ -384,7 +384,7 @@ def get_valid_iterator( return self.task.get_batch_iterator( dataset=self.task.dataset(subset), max_tokens=self.args.max_tokens_valid, - max_sentences=self.args.max_sentences_valid, + max_sentences=self.args.batch_size_valid, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index c5dd7fe4ce..66217f64fc 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -117,7 +117,7 @@ def main(parsed_args, **unused_kwargs): itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 90d270f460..786f699432 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -56,7 +56,7 @@ def _main(args, output_file): utils.import_user_module(args) - if args.max_tokens is None and args.max_sentences is None: + if args.max_tokens is None and args.batch_size is None: args.max_tokens = 12000 logger.info(args) @@ -103,7 +103,7 @@ def _main(args, output_file): itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index f8ee0197dd..dd6249bd90 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -86,7 +86,7 @@ def encode_fn_target(x): itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor), max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=max_positions, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test ).next_epoch_itr(shuffle=False) @@ -112,13 +112,13 @@ def main(args): if args.buffer_size < 1: args.buffer_size = 1 - if args.max_tokens is None and args.max_sentences is None: - args.max_sentences = 1 + if args.max_tokens is None and args.batch_size is None: + args.batch_size = 1 assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' - assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ - '--max-sentences/--batch-size cannot be larger than --buffer-size' + assert not args.batch_size or args.batch_size <= args.buffer_size, \ + '--batch-size cannot be larger than --buffer-size' logger.info(args) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index a2a7763488..cd3a93b13e 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -43,8 +43,8 @@ def main(args): utils.import_user_module(args) assert ( - args.max_tokens is not None or args.max_sentences is not None - ), "Must specify batch size either with --max-tokens or --max-sentences" + args.max_tokens is not None or args.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() @@ -101,7 +101,7 @@ def main(args): ) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.max_sentences + args.max_tokens, args.batch_size ) ) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 510560d968..717a776c8f 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -28,8 +28,8 @@ def main(args, override_args=None): utils.import_user_module(args) - assert args.max_tokens is not None or args.max_sentences is not None, \ - 'Must specify batch size either with --max-tokens or --max-sentences' + assert args.max_tokens is not None or args.batch_size is not None, \ + 'Must specify batch size either with --max-tokens or --batch-size' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu @@ -77,7 +77,7 @@ def main(args, override_args=None): itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], diff --git a/tests/test_binaries.py b/tests/test_binaries.py index e6259092eb..c0c4abb6ed 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -901,7 +901,7 @@ def train_masked_lm(data_dir, arch, extra_flags=None): '--optimizer', 'adam', '--lr', '0.0001', '--criterion', 'masked_lm', - '--max-sentences', '500', + '--batch-size', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', @@ -928,7 +928,7 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): '--criterion', 'sentence_prediction', '--max-tokens', '500', '--max-positions', '500', - '--max-sentences', '500', + '--batch-size', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', diff --git a/tests/utils.py b/tests/utils.py index f265c13f85..e8528292e4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -282,7 +282,7 @@ def generate_main(data_dir, extra_flags=None): # evaluate model interactively generate_args.buffer_size = 0 generate_args.input = '-' - generate_args.max_sentences = None + generate_args.batch_size = None orig_stdin = sys.stdin sys.stdin = StringIO('h e l l o\n') interactive.main(generate_args) From e056de1fb641425974b0941704e63b9b42fe8a1f Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Tue, 6 Oct 2020 13:13:42 -0700 Subject: [PATCH 188/707] Rework dummy batch logic (#1331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? This PR reworks the dummy batch logic in fairseq. It does away with `"DUMMY"` strings and instead has `trainer.get_train_iterator` and `trainer.get_valid_iterator` set up the dummy batch by taking the first sample in the dataset. See conversation here where the issue was reported and discussed: https://fb.workplace.com/groups/fairseq/permalink/1241915899501646/ ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1331 Reviewed By: dgenzel2 Differential Revision: D24093421 Pulled By: joshim5 fbshipit-source-id: 71e747d90496f5158d07f0c2db87b6a6a974ef4f --- fairseq/data/iterators.py | 12 ++++++++++++ fairseq/trainer.py | 28 +++++++++++----------------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 196eb30887..fee4d8e9cb 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -250,6 +250,18 @@ def frozen_batches(self): if self._frozen_batches is None: self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) def __len__(self): return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index dbac04e045..5cf199770e 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -81,7 +81,7 @@ def __init__(self, args, task, model, criterion, quantizer=None): ) _set_module_by_path(self._model, path, ref) - self._dummy_batch = "DUMMY" # indicates we don't have a dummy batch at first + self._dummy_batch = None # indicates we don't have a dummy batch at first self._lr_scheduler = None self._num_updates = 0 self._num_xla_compiles = 0 # for TPUs @@ -355,7 +355,7 @@ def get_train_iterator( combine=combine, data_selector=data_selector, ) - return self.task.get_batch_iterator( + batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(self.args.train_subset), max_tokens=self.args.max_tokens, max_sentences=self.args.batch_size, @@ -374,6 +374,8 @@ def get_train_iterator( data_buffer_size=self.args.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator def get_valid_iterator( self, @@ -381,7 +383,7 @@ def get_valid_iterator( disable_iterator_cache=False, ): """Return an EpochBatchIterator over given validation subset for a given epoch.""" - return self.task.get_batch_iterator( + batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(subset), max_tokens=self.args.max_tokens_valid, max_sentences=self.args.batch_size_valid, @@ -398,6 +400,8 @@ def get_valid_iterator( data_buffer_size=self.args.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator def begin_epoch(self, epoch): """Called at the beginning of each epoch.""" @@ -423,9 +427,9 @@ def begin_valid_epoch(self, epoch): # task specific setup per validation epoch self.task.begin_valid_epoch(epoch, self.get_model()) - - # reset dummy batch - self._dummy_batch = 'DUMMY' + + def reset_dummy_batch(self, batch): + self._dummy_batch = batch @metrics.aggregate("train") def train_step(self, samples, raise_oom=False): @@ -447,8 +451,6 @@ def train_step(self, samples, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample is_dummy_batch = False def maybe_no_sync(): @@ -671,8 +673,6 @@ def valid_step(self, sample, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample is_dummy_batch = False try: @@ -807,13 +807,7 @@ def _local_cumulative_training_time(self): return time.time() - self._start_time + self._previous_training_time def _prepare_sample(self, sample): - if sample == "DUMMY": - raise Exception( - "Trying to use an uninitialized 'dummy' batch. This usually indicates " - "that the total number of batches is smaller than the number of " - "participating GPUs. Try reducing the batch size or using fewer GPUs." - ) - + if sample is None or len(sample) == 0: return None From bcc81f6d5291c3996c8b2472282458dead46343f Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Tue, 6 Oct 2020 16:59:20 -0700 Subject: [PATCH 189/707] replace max-sentences with batch-size for dependencies Summary: this fixes some regressions introduced by D24121305 (https://github.com/pytorch/fairseq/commit/e3c4282551e819853952284681e9ed60398c5c4a). fairseq configuration is changing from command line to dataclasses (via hydra eventually) which no longer supports option aliases. one such alias is --max-sentences / --batch-size, and D24121305 (https://github.com/pytorch/fairseq/commit/e3c4282551e819853952284681e9ed60398c5c4a) removed --max-sentences as --batch-size is more appropriate (fairseq is not just an nlp framework dealing with sentences). unfortunately it seems some existing flows broke and this diff attempts to fix this Differential Revision: D24142488 fbshipit-source-id: 075180ea10a9d706a3f8d64b978d66dfd83c3d2b --- fairseq/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/options.py b/fairseq/options.py index 4d3a9766c8..e1df860fbe 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -168,7 +168,7 @@ def parse_args_and_arch( args = parser.parse_args(input_args) extra = None # Post-process args. - if hasattr(args, "batch_size_valid") and args.batch_size_valid is None: + if (hasattr(args, "batch_size_valid") and args.batch_size_valid is None) or not hasattr(args, "batch_size_valid"): args.batch_size_valid = args.batch_size if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: args.max_tokens_valid = args.max_tokens From 63aca3ff6cf1bcf87edb72ab35f48ddc106bb9a6 Mon Sep 17 00:00:00 2001 From: Dmitriy Genzel Date: Tue, 6 Oct 2020 22:02:10 -0700 Subject: [PATCH 190/707] Fix a bad merge from D24093421 Summary: Fix a bad merge from D24093421 (https://github.com/pytorch/fairseq/commit/e056de1fb641425974b0941704e63b9b42fe8a1f) Also linted Reviewed By: joshim5 Differential Revision: D24154657 fbshipit-source-id: 98afbd9f7ea4200756f07f3271b66fe6477dafff --- fairseq/data/iterators.py | 2 +- fairseq/trainer.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index fee4d8e9cb..729a3141d1 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -250,7 +250,7 @@ def frozen_batches(self): if self._frozen_batches is None: self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) return self._frozen_batches - + @property def first_batch(self): if len(self.frozen_batches) == 0: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5cf199770e..14bba0c662 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -414,7 +414,7 @@ def begin_epoch(self, epoch): self.task.begin_epoch(epoch, self.get_model()) # reset dummy batch - self._dummy_batch = 'DUMMY' + self._dummy_batch = None if self.tpu: import torch_xla.core.xla_model as xm @@ -427,7 +427,7 @@ def begin_valid_epoch(self, epoch): # task specific setup per validation epoch self.task.begin_valid_epoch(epoch, self.get_model()) - + def reset_dummy_batch(self, batch): self._dummy_batch = batch @@ -807,7 +807,6 @@ def _local_cumulative_training_time(self): return time.time() - self._start_time + self._previous_training_time def _prepare_sample(self, sample): - if sample is None or len(sample) == 0: return None From b880744c1e6b7c526ab8a2c22161b105778e807e Mon Sep 17 00:00:00 2001 From: Dmitriy Genzel Date: Wed, 7 Oct 2020 09:17:06 -0700 Subject: [PATCH 191/707] Fix a bad merge from D24154657 - second try Summary: This removes a line that reset the dummy batch in begin_epoch. In actuality the batch is reset in get_{train,valid}_iterator, and this line was resetting it to None unnecessarily. Reviewed By: joshim5, jhcross Differential Revision: D24157057 fbshipit-source-id: 59ac68327094ceff70f66d7b471fa810997fe84e --- fairseq/trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 14bba0c662..f9ad1b2f95 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -413,9 +413,6 @@ def begin_epoch(self, epoch): # task specific setup per epoch self.task.begin_epoch(epoch, self.get_model()) - # reset dummy batch - self._dummy_batch = None - if self.tpu: import torch_xla.core.xla_model as xm From 5379461e613263911050a860b79accdf4d75fd37 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 7 Oct 2020 16:38:49 -0700 Subject: [PATCH 192/707] lm rescoring attempt (#1242) Summary: CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/private/home/abaevski/fairseq-py-master python fairseq_cli/generate.py /checkpoint/henryzhou7/dataset/libri/960h/raw3/decoder --task audio_pretraining --seed 1 --nbest 1 --gen-subset dev_other --max-tokens 600000 --path ~/models/wav2vec2/vox_960h_seq2seq_10kwp.pt --labels 10k --remove-bpe 'wordpiece' --quiet --beam 50 --temperature 1 --scoring wer --lm-path /checkpoint/henryzhou7/wp_lm/transformer_raw3_adam_cosine2node/lr_1e-4_updatefreq_8/checkpoint_best.pt --lm-weight 1 results: no lm: 4.30577896347444 lm (1.5): 24.691650853889943 lm (1): 10.884539582804846 lm (0.5): 4.894205665744457 lm (0.25): 4.012853671917862 lm (0.1): 4.087637055489084 lm (0.05): 4.194788887144875 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1242 Reviewed By: kahne Differential Revision: D23277386 Pulled By: alexeib fbshipit-source-id: 062f483bd45ddd2dd5ff24a8a35cc1c4f34ce6ab --- .../tasks/speech_recognition.py | 2 +- fairseq/options.py | 5 +++ fairseq/sequence_generator.py | 18 ++++++++ .../tasks/translation_from_pretrained_bart.py | 2 +- fairseq/tasks/translation_lev.py | 2 +- fairseq_cli/generate.py | 41 ++++++++++++++++--- 6 files changed, 62 insertions(+), 8 deletions(-) diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index dde0b12577..1181c9aef5 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -113,7 +113,7 @@ def load_dataset(self, split, combine=False, **kwargs): data_json_path = os.path.join(self.args.data, "{}.json".format(split)) self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder diff --git a/fairseq/options.py b/fairseq/options.py index e1df860fbe..31ed28a80e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -380,6 +380,11 @@ def add_generation_args(parser): help='if set, uses attention feedback to compute and print alignment to source tokens') group.add_argument('--print-step', action='store_true') + group.add_argument('--lm-path', default=None, type=str, metavar='PATH', + help='path to lm checkpoint for lm fusion') + group.add_argument('--lm-weight', default=0.0, type=float, metavar='N', + help='weight for lm probs for lm fusion') + # arguments for iterative refinement generator group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N', help='if > 0.0, it penalized early-stopping in decoding.') diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 965594cd6e..ff45c7dfb7 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -33,6 +33,8 @@ def __init__( search_strategy=None, eos=None, symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0 ): """Generates translations of a given source sentence. @@ -94,6 +96,11 @@ def __init__( self.model.eval() + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + def cuda(self): self.model.cuda() return self @@ -292,6 +299,15 @@ def _generate( incremental_states, self.temperature, ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) lprobs[:, self.pad] = -math.inf # never select pad @@ -820,9 +836,11 @@ def forward_decoder( avg_attn = attn else: avg_attn.add_(attn) + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( self.models_size ) + if avg_attn is not None: avg_attn.div_(self.models_size) return avg_probs, avg_attn diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index b3c9f8e440..4d574ffc82 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -84,7 +84,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): append_source_id=True ) - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): if getattr(args, 'score_reference', False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index be362a1881..18ac0ca385 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -128,7 +128,7 @@ def _full_mask(target_tokens): else: raise NotImplementedError - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): # add models input to match the API for SequenceGenerator from fairseq.iterative_refinement_generator import IterativeRefinementGenerator return IterativeRefinementGenerator( diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 786f699432..0cf09feaee 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -7,6 +7,8 @@ Translate pre-processed data with a trained model. """ +import ast +from itertools import chain import logging import math import os @@ -78,17 +80,39 @@ def _main(args, output_file): src_dict = None tgt_dict = task.target_dictionary + overrides = ast.literal_eval(args.model_overrides) + # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), - arg_overrides=eval(args.model_overrides), + arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) + if args.lm_path is not None: + overrides['data'] = args.data + + try: + lms, _ = checkpoint_utils.load_model_ensemble( + [args.lm_path], + arg_overrides=overrides, + task=None, + ) + except: + logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({args.data})") + raise + + assert len(lms) == 1 + else: + lms = [None] + # Optimize ensemble for generation - for model in models: + for model in chain(models, lms): + if model is None: + continue model.prepare_for_inference_(args) if args.fp16: model.half() @@ -124,7 +148,12 @@ def _main(args, output_file): # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + + extra_gen_cls_kwargs = { + 'lm_model': lms[0], + 'lm_weight': args.lm_weight + } + generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) @@ -269,9 +298,11 @@ def decode_fn(x): if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: - logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") + logger.warning( + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: - logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") + logger.warning( + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), From fcb75729c01c112b3a58539777260d352eb4cd5d Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Thu, 8 Oct 2020 17:36:06 -0700 Subject: [PATCH 193/707] don't store DDP model in fairseq nan detector Summary: I ran into an error with nan detector recently: f221506814 ``` torch.nn.modules.module.ModuleAttributeError: 'AddJoinerSparse' object has no attribute '_parameters' ``` full log: P144224349 It seems to fail when we do a deepcopy of the joiner in our implementation of transducer transformer: https://fburl.com/diffusion/78jkfk2z Very interestingly, it seems to be copying a DistributedDataParallel object! This seemed really weird, since a user module shouldn't really contain a reference to a DistributedDataParallel object. After investigation this seems to be because of the backward hooks that `NanDetector` adds to the module. The backward hooks reference NanDetector, which references `model`, which is the `DistributedDataParallel` object. The fix is then to not store a reference to the `DistributedDataParallel` in `NanDetector` Reviewed By: zhengwy888 Differential Revision: D24058995 fbshipit-source-id: 48209339243d8b23b078274b780e850335839e89 --- fairseq/nan_detector.py | 13 +++++++------ fairseq/trainer.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index df4e28ec89..0d7d8d7d79 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -19,7 +19,7 @@ def __init__(self, model, forward=True, backward=True): self.fhooks = [] self.forward = forward self.backward = backward - self.model = model + self.named_parameters = list(model.named_parameters()) self.reset() for name, mod in model.named_modules(): @@ -33,11 +33,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback): # Dump out all model gnorms to enable better debugging norm = {} gradients = {} - for name, param in self.model.named_parameters(): - grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) - norm[name] = grad_norm.item() - if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): - gradients[name] = param.grad.data + for name, param in self.named_parameters: + if param.grad is not None: + grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) + norm[name] = grad_norm.item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data if len(gradients) > 0: logger.info("Detected nan/inf grad norm, dumping norms...") logger.info(f"norms: {norm}") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index f9ad1b2f95..d3724f22b1 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -569,7 +569,7 @@ def maybe_no_sync(): except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails - with NanDetector(self.model): + with NanDetector(self.get_model()): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False From bf06ca7cab3b0ef9f572453a76aa35a42feab1f6 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Fri, 9 Oct 2020 13:33:33 -0700 Subject: [PATCH 194/707] Improve dictionary & checkpoint reading w/ local caching Reviewed By: myleott Differential Revision: D24148700 fbshipit-source-id: 666300639243688939e137be748f7b76fc3c21a6 --- fairseq/checkpoint_utils.py | 2 +- fairseq/data/dictionary.py | 2 +- fairseq/data/multilingual/multilingual_data_manager.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 3b9e6bfd27..b0109caa83 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -200,7 +200,7 @@ def load_checkpoint(args, trainer, **passthrough_args): def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" - with PathManager.open(path, "rb") as f: + with open(PathManager.get_local_path(path), "rb") as f: state = torch.load( f, map_location=lambda s, l: default_restore_location(s, "cpu") ) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index d4b88024b0..3d11f93137 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -221,7 +221,7 @@ def add_from_file(self, f): """ if isinstance(f, str): try: - with PathManager.open(f, "r", encoding="utf-8") as fd: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: self.add_from_file(fd) except FileNotFoundError as fnfe: raise fnfe diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 1eedab3dce..2d2a984fcc 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -260,7 +260,7 @@ def load_langs(cls, args, **kwargs): langs = sorted(langs) logger.info(f"inferred language list: {langs}") elif args.lang_dict: - with PathManager.open(args.lang_dict, "r", encoding="utf-8") as f: + with open(PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8") as f: langs = [lang.strip() for lang in f.readlines() if lang.strip()] logger.info( f"loaded language list from {args.lang_dict} as they are ordered in file" From aa39ab1b4568479bf9a1360cfcdd4f4fce5f1838 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 9 Oct 2020 17:08:49 -0700 Subject: [PATCH 195/707] Adapt prod fairseq eval to latest vocab changes Summary: # Facebook: With changes in D23653256, we don't have thrift vocab anymore. This diff changes the fairseq eval accordingly. Reviewed By: chtran Differential Revision: D23831452 fbshipit-source-id: 5e4f39140c9f25d99324fb7eded42c7fca439d3f --- .../data/multilingual/multilingual_data_manager.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 2d2a984fcc..16a4d98690 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -787,14 +787,15 @@ def _get_shard_num_dict(cls, split, paths): shards = defaultdict(int) for path in paths: files = PathManager.ls(path) + directions = set() for f in files: if f.startswith(split) and f.endswith(".idx"): # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" - direction = f.split(".")[-3] - shards[direction] += 1 - # each direction has two '.idx' files - # one for source language and one for target language, so: - return {k: v // 2 for k, v in shards.items()} + direction = f.split('.')[-3] + directions.add(direction) + for direction in directions: + shards[direction] += 1 + return shards def get_split_num_data_shards(self, split): if split in self._num_shards_dict: From 60442af216d551e4afc9d4fab1c056c1051725cc Mon Sep 17 00:00:00 2001 From: alexeib Date: Sun, 11 Oct 2020 19:30:50 -0700 Subject: [PATCH 196/707] add support for "const" argparse converts for now (#1338) Summary: Fixes issue #2705 Re: [pytorch/fairseq] The registries update forces the "--remove-bpe" option to require an argument (#2705) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1338 Reviewed By: myleott Differential Revision: D24242258 Pulled By: alexeib fbshipit-source-id: 0eafcae8de3476c4237b1a32bad203dd9e940cc3 --- fairseq/dataclass/data_class.py | 3 ++- fairseq/dataclass/utils.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index e6997b69cd..d402e07106 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -481,7 +481,8 @@ class CommonEvalParams(FairseqDataclass): remove_bpe: Optional[str] = field( default=None, metadata={ - "help": "remove BPE tokens before scoring (can be set to sentencepiece)" + "help": "remove BPE tokens before scoring (can be set to sentencepiece)", + "argparse_const": "@@ ", }, ) quiet: bool = field(default=False, metadata={"help": "only print final scores"}) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 093ecd8f6b..9ab235d16d 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -93,6 +93,9 @@ def _get_type(self, attribute_name: str) -> Any: def _get_help(self, attribute_name: str) -> Any: return self._get_meta(attribute_name, "help") + def _get_argparse_const(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_const") + def _get_choices(self, attribute_name: str) -> Any: return self._get_meta(attribute_name, "choices") @@ -139,6 +142,7 @@ def get_kwargs_from_dc( field_choices = None field_help = dataclass_instance._get_help(k) + field_const = dataclass_instance._get_argparse_const(k) kwargs = {} if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default @@ -160,9 +164,7 @@ def get_kwargs_from_dc( raise NotImplementedError() if field_default is not MISSING: kwargs["default"] = ",".join(map(str, field_default)) - elif (isinstance(inter_type, type) and issubclass(inter_type, Enum)) or ( - "Enum" in str(inter_type) - ): + elif (isinstance(inter_type, type) and issubclass(inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): @@ -180,6 +182,9 @@ def get_kwargs_from_dc( kwargs["default"] = field_default kwargs["help"] = field_help + if field_const is not None: + kwargs["const"] = field_const + kwargs["nargs"] = '?' return kwargs for k in dataclass_instance._get_all_attributes(): From 3ab136e41ef3eb91f55d997377d3a2c18e1e1438 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Mon, 12 Oct 2020 10:54:19 -0700 Subject: [PATCH 197/707] Support generation with huge pipeline parallel Transformer models (#1297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## What is this PR about? * Support loading sharded checkpoints (loading and saving checkpoints without sharding for 30B models runs OOM) * Ability to provide a fixed dictionary for multilingual machine translation to match the dictionary used for training the checkpoint * Support generation with PipelineParallelTransformer models ## Testing ``` python fairseq_cli/generate.py \ /large_exp/angelafan/cc100_multilingual/eval/binarized_spm_128k_dict/ted \ --batch-size 1 \ --path /large_exp/angelafan/checkpoints/shru/mmmt_5b_checkpoints_for_master/checkpoint4.pt \ -s en -t es --remove-bpe 'sentencepiece' --beam 5 \ --task translation_multi_simple_epoch \ --lang-pairs /private/home/angelafan/mmt/lang100_bt_pairs.txt \ --decoder-langtok --encoder-langtok src --gen-subset valid --fp16 \ --dataset-impl mmap \ --distributed-world-size 1 --distributed-no-spawn \ --pipeline-model-parallel \ --pipeline-chunks 1 \ --pipeline-encoder-balance '[26]' \ --pipeline-encoder-devices '[0]' \ --pipeline-decoder-balance '[26]' \ --pipeline-decoder-devices '[0]' \ --fixed-dictionary /private/home/shru/projects/fairseq-py-kakaoval-mmt-save-redist-gen/dict.100langs.txt 2020-09-23 23:18:27 | WARNING | fairseq.data.multilingual.multilingual_data_manager | External language dictionary is not provided; use lang-pairs to infer the set of supported languages. The language ordering is not stable which might ca use misalignment in pretraining and finetuning. 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | inferred language list: ['af', 'am', 'ar', 'arbt', 'ast', 'az', 'azbt', 'ba', 'be', 'bebt', 'bg', 'bgbt', 'bn', 'bnbt', 'br', 'bs', 'ca', 'ceb', 'cs', 'csb t', 'cy', 'da', 'de', 'debt', 'el', 'en', 'es', 'esbt', 'et', 'etbt', 'fa', 'fabt', 'ff', 'fi', 'fr', 'frbt', 'fy', 'ga', 'gd', 'gl', 'gu', 'ha', 'he', 'hebt', 'hi', 'hibt', 'hr', 'ht', 'hu', 'hubt', 'hy', 'hybt', 'id', 'ig', 'ilo', 'is', 'it', 'ja', 'jabt', 'jv', 'ka', 'kabt', 'kk', 'km', 'kn', 'ko', 'kobt', 'lb', 'lg', 'ln', 'lo', 'lt', 'lv', 'mg', 'mk', 'mkbt', 'ml', 'mn', 'mr', 'mrbt', 'ms', 'msbt', 'my', 'ne', 'nebt', 'nl', 'no', 'ns', 'oc', 'or', 'pa', 'pl', 'ps', ' pt', 'ro', 'robt', 'ru', 'sd', 'si', 'sk', 'sl', 'so', 'sq', 'sr', 'srbt', 'ss', 'su', 'sv', 'sw', 'ta', 'th', 'tl', 'tn', 'tr', 'trbt', 'uk', 'ur', 'uz', 'vi', 'vibt', 'wo', 'xh', 'yi', 'yo', 'zh', 'zhbt', 'zu'] 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | [en] dictionary: 128112 types 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | [es] dictionary: 128112 types 2020-09-23 23:18:27 | INFO | fairseq.tasks.translation_multi_simple_epoch | loading data for valid epoch=1/None 2020-09-23 23:18:27 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: N/A 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | langtoks settings: {'main': ('src', 'tgt')} 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | [valid] num of shards: {'main:en-es': 1} 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | main:en-es src_langtok: 128022; tgt_langtok: 128023 2020-09-23 23:18:27 | INFO | fairseq.data.data_utils | loaded 4231 examples from: /large_exp/angelafan/cc100_multilingual/eval/binarized_spm_128k_dict/ted/valid.en-es.en 2020-09-23 23:18:27 | INFO | fairseq.data.data_utils | loaded 4231 examples from: /large_exp/angelafan/cc100_multilingual/eval/binarized_spm_128k_dict/ted/valid.en-es.es 2020-09-23 23:18:27 | INFO | fairseq.data.multilingual.multilingual_data_manager | /large_exp/angelafan/cc100_multilingual/eval/binarized_spm_128k_dict/ted valid en-es 4231 examples 2020-09-23 23:18:27 | INFO | fairseq_cli.generate | loading model(s) from /large_exp/angelafan/checkpoints/mmmt_100_langs_new_dict_5b_model/checkpoint4.pt balance=[29,22,1], devices=[0,1,0], chunks=8, checkpoint=always 2020-09-23 23:20:30 | INFO | fairseq.tasks.translation_multi_simple_epoch | start batch sampler: mem usage: N/A 2020-09-23 23:20:30 | INFO | fairseq.tasks.translation_multi_simple_epoch | [valid] batch_sampler order indices time: 0:00:00.004368 2020-09-23 23:20:30 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: N/A 2020-09-23 23:20:30 | INFO | fairseq.tasks.translation_multi_simple_epoch | [valid] batch_sampler filter_by_size time: 0:00:00.057953 2020-09-23 23:20:30 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: N/A 2020-09-23 23:20:31 | INFO | fairseq.tasks.translation_multi_simple_epoch | [valid] batch_sampler batch_by_size time: 0:00:01.158644 2020-09-23 23:20:31 | INFO | fairseq.tasks.translation_multi_simple_epoch | [valid] per epoch batch_sampler set-up time: 0:00:01.223290 2020-09-23 23:20:31 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: N/A S-2521 __en__ No. T-2521 No. H-2521 -2.4406352043151855 y no. D-2521 -2.4406352043151855 y no. P-2521 -5.2677 -4.4289 -0.1922 -2.1578 -0.1565 S-2261 __en__ Why? T-2261 ¿Por qué? H-2261 -1.7077901363372803 ¿Y por qué? D-2261 -1.7077901363372803 ¿Y por qué? P-2261 -5.2733 -0.7970 -3.7643 -0.5474 -0.3339 -1.0629 -0.1757 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1297 Reviewed By: myleott, msbaines Differential Revision: D23991647 Pulled By: shruti-bh fbshipit-source-id: 81ea1af64cbe75c8b53e050d2c2339dcb70fe1eb --- fairseq/checkpoint_utils.py | 40 +-- .../multilingual/multilingual_data_manager.py | 29 ++- fairseq/dataclass/data_class.py | 40 +++ fairseq/distributed_utils.py | 20 +- .../pipeline_parallel_transformer/model.py | 240 ++++++++++-------- fairseq_cli/eval_lm.py | 6 +- fairseq_cli/generate.py | 6 +- fairseq_cli/interactive.py | 8 +- 8 files changed, 251 insertions(+), 138 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index b0109caa83..60ab3190c7 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -213,7 +213,7 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): return state -def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix=''): +def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): """Loads an ensemble of models. Args: @@ -222,29 +222,37 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s were used during model training task (fairseq.tasks.FairseqTask, optional): task to use for loading """ + assert not (strict and num_shards > 1), \ + "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble, args, _task = load_model_ensemble_and_task( - filenames, arg_overrides, task, strict, suffix, + filenames, arg_overrides, task, strict, suffix, num_shards, ) return ensemble, args -def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix=''): +def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): from fairseq import tasks - + assert not (strict and num_shards > 1), \ + "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] for filename in filenames: - filename = filename.replace(".pt", suffix + ".pt") - if not PathManager.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - state = load_checkpoint_to_cpu(filename, arg_overrides) - - args = state["args"] - if task is None: - task = tasks.setup_task(args) - - # build model for ensemble - model = task.build_model(args) - model.load_state_dict(state["model"], strict=strict, args=args) + orig_filename = filename + for shard_idx in range(num_shards): + if num_shards == 1: + filename = filename.replace(".pt", suffix + ".pt") + else: + filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if not PathManager.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + state = load_checkpoint_to_cpu(filename, arg_overrides) + if shard_idx == 0: + args = state["args"] + if task is None: + task = tasks.setup_task(args) + + # build model for ensemble + model = task.build_model(args) + model.load_state_dict(state["model"], strict=strict, args=args) ensemble.append(model) return ensemble, args, task diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 16a4d98690..f0f93f25e1 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -195,6 +195,12 @@ def add_args(parser): type=lambda uf: eval_str_dict(uf, type=str), default=None, ) + parser.add_argument( + "--fixed-dictionary", + help='Fixed dictionary to use with model path', + default=None, + type=str, + ) parser.add_argument( "--langtoks-specs", help='a list of comma separated data types that a set of language tokens to be specialized for, \ @@ -346,16 +352,19 @@ def check_langs(langs, pairs): paths = utils.split_paths(args.data) assert len(paths) > 0 for lang in langs_to_load_dicts: - dicts[lang] = load_dictionary( - os.path.join(paths[0], "dict.{}.txt".format(lang)) - ) - augment_dictionary( - dictionary=dicts[lang], - language_list=language_list, - lang_tok_style=args.lang_tok_style, - langtoks_specs=args.langtoks_specs, - extra_data=args.extra_data, - ) + if args.fixed_dictionary is not None: + dicts[lang] = load_dictionary(args.fixed_dictionary) + else: + dicts[lang] = load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) + augment_dictionary( + dictionary=dicts[lang], + language_list=language_list, + lang_tok_style=args.lang_tok_style, + langtoks_specs=args.langtoks_specs, + extra_data=args.extra_data, + ) if len(dicts) > 0: assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index d402e07106..0685c968d5 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -112,6 +112,14 @@ class CommonParams(FairseqDataclass): checkpoint_suffix: str = field( default="", metadata={"help": "suffix to add to the checkpoint file name"} ) + checkpoint_shard_count: int = field( + default=1, metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + } + ) quantization_config_path: Optional[str] = field( default=None, metadata={"help": "path to quantization config file"} ) @@ -239,6 +247,38 @@ class DistributedTrainingParams(FairseqDataclass): pipeline_chunks: int = field( default=0, metadata={"help": "microbatch count for pipeline model parallelism"} ) + pipeline_encoder_balance: str = field( + default=None, + metadata={ + "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_encoder_balance) " + "should equal the total number of encoder layers in the model" + }, + ) + pipeline_encoder_devices: str = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-encoder-balance argument" + }, + ) + pipeline_decoder_balance: str = field( + default=None, + metadata={ + "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_decoder_balance) " + "should equal the total number of decoder layers in the model" + }, + ) + pipeline_decoder_devices: str = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-decoder-balance argument" + }, + ) pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( default="never", metadata={"help": "checkpointing mode for pipeline model parallelism"}, diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 886a30dfe3..d67604f54e 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -32,16 +32,26 @@ def infer_init_method(args, force_distributed=False): return if args.pipeline_model_parallel: - if args.pipeline_balance is None: + balance_exists = args.pipeline_balance is not None or \ + args.pipeline_encoder_balance is not None or \ + args.pipeline_decoder_balance is not None + devices_exist = args.pipeline_devices is not None or \ + args.pipeline_encoder_devices is not None or \ + args.pipeline_decoder_devices is not None + if not balance_exists: raise ValueError('--pipeline-balance is currently required for pipeline model parallelism') - if args.pipeline_devices is None: + if not devices_exist: raise ValueError('--pipeline-devices is currently required for pipeline model parallelism') args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) - args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) - + if args.pipeline_devices is not None: + args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) + num_pipeline_devices = len(set(args.pipeline_devices)) + else: + args.pipeline_encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int) + args.pipeline_decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int) + num_pipeline_devices = len(set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)) gpus_per_node = torch.cuda.device_count() - num_pipeline_devices = len(set(args.pipeline_devices)) assert gpus_per_node >= num_pipeline_devices and gpus_per_node % num_pipeline_devices == 0, ( 'the number of unique device IDs in --pipeline-devices must evenly divide ' 'the number of GPUs per node (multi-node pipelining is not yet supported)' diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 37fa877eaf..65a087a3fb 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + from fairseq import utils from fairseq.models import ( BaseFairseqModel, @@ -17,6 +19,7 @@ transformer_wmt_en_de_big, ) from fairseq.modules import SinusoidalPositionalEmbedding +from fairseq.models.fairseq_encoder import EncoderOut from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( Embedding, TransformerEncoderLayer, @@ -30,6 +33,8 @@ import torch.nn as nn import torch.nn.functional as F +logger = logging.getLogger(__name__) + DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -45,17 +50,19 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) - module_list = nn.Sequential( - encoder.embedding_layer, - *list(encoder.encoder_layers), - encoder.final_layer_norm, - decoder.embedding_layer, - *list(decoder.decoder_layers), - decoder.decoder_output_layer, - ) + encoder_module_list = \ + [encoder.embedding_layer] + \ + list(encoder.encoder_layers) + \ + [encoder.final_layer_norm] + self.num_encoder_modules = len(encoder_module_list) + decoder_module_list = [decoder.embedding_layer] + \ + list(decoder.decoder_layers) + \ + [decoder.decoder_output_layer] + self.num_decoder_modules = len(decoder_module_list) + module_list = encoder_module_list + decoder_module_list self.devices = devices self.model = Pipe( - module_list, + nn.Sequential(*module_list), balance=balance, devices=devices, chunks=chunks, @@ -70,11 +77,39 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): 'max_target_positions' ) self.adaptive_softmax = getattr(decoder, 'adaptive_softmax', None) + # Note: To be populated during inference + self.encoder = None + self.decoder = None def forward(self, src_tokens, src_lengths, prev_output_tokens): - input_lst = [src_tokens, src_lengths, prev_output_tokens] - input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) - return self.model(input) + if self.training: + input_lst = [src_tokens, src_lengths, prev_output_tokens] + input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) + return self.model(input) + else: + assert self.encoder is not None and self.decoder is not None, \ + "encoder and decoder need to be initialized by " + \ + "calling the `prepare_for_inference_()` method" + encoder_output_tuple = self.encoder(input) + return self.decoder(encoder_output_tuple) + + def prepare_for_inference_(self, args): + if self.encoder is not None and self.decoder is not None: + logger.info("Encoder and Decoder already initialized") + return + encoder_module_list = [] + decoder_module_list = [] + module_count = 0 + for partition in self.model.partitions: + for module in partition: + if module_count < self.num_encoder_modules: + encoder_module_list.append(module) + else: + decoder_module_list.append(module) + module_count += 1 + self.model = None + self.encoder = TransformerEncoder(args, None, None, encoder_module_list) + self.decoder = TransformerDecoder(args, None, None, decoder_module_list=decoder_module_list) @staticmethod def add_args(parser): @@ -214,8 +249,8 @@ def build_model(cls, args, task): return PipelineParallelTransformerModel( encoder=encoder, decoder=decoder, - balance=args.pipeline_balance, - devices=args.pipeline_devices, + balance=utils.eval_str_list(args.pipeline_balance, type=int), + devices=utils.eval_str_list(args.pipeline_devices, type=int), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) @@ -248,7 +283,9 @@ def get_normalized_probs(self, net_output, log_probs, sample=None): out = self.adaptive_softmax.get_log_prob(net_output, target=target) return out.exp_() if not log_probs else out - logits = net_output + # A Pipe() module returns a tuple of tensors as the output. + # In this case, the tuple has one element - the output tensor of logits + logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=False) else: @@ -299,7 +336,7 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict): 'final_layer_norm.weight', 'final_layer_norm.bias' ] for pid, partition in enumerate(self.model.partitions): - print(f"Begin Partition {pid}") + logger.info(f"Begin Partition {pid}") for mid, module in enumerate(partition): # fmt: off if isinstance(module, TransformerEncoderEmbedding): @@ -337,24 +374,44 @@ class TransformerEncoder(FairseqEncoder): embed_tokens (torch.nn.Embedding): input embedding """ - def __init__(self, args, dictionary, embed_tokens): + def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) - self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) - layers = [ - TransformerEncoderLayer(args) for i in range(args.encoder_layers) - ] - # Note: layer drop not supported yet - # Note: layer wise attention not supported yet - self.encoder_layers = nn.Sequential(*layers) - if isinstance(embed_tokens, nn.ModuleList): - emb_dim = sum(e.embedding_dim for e in embed_tokens) + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError('Please install fairscale with: pip install fairscale') + if encoder_module_list is None: + embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + layers = [ + TransformerEncoderLayer(args) for i in range(args.encoder_layers) + ] + if isinstance(embed_tokens, nn.ModuleList): + emb_dim = sum(e.embedding_dim for e in embed_tokens) + else: + emb_dim = embed_tokens.embedding_dim + final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) + encoder_module_list = [embedding_layer] + layers + [final_layer_norm] + self.use_pipeline = (getattr(args, "pipeline_encoder_balance", None) is not None) + if self.use_pipeline: + encoder_balance = utils.eval_str_list(args.pipeline_encoder_balance, type=int) + encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int) + assert sum(encoder_balance) == len(encoder_module_list), \ + f"Sum of encoder_balance={encoder_balance} is not equal " + \ + f"to num_encoder_modules={len(encoder_module_list)}" + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) else: - emb_dim = embed_tokens.embedding_dim - self.final_layer_norm = \ - TransformerEncoderLayerNorm(args, emb_dim) + self.embedding_layer = encoder_module_list[0] + self.encoder_layers = nn.Sequential(*encoder_module_list[1:-1]) + self.final_layer_norm = encoder_module_list[-1] - def forward(self, src_tokens, src_lengths, prev_output_tokens): + def forward(self, src_tokens, src_lengths): """ Args: input_tuple( @@ -362,7 +419,6 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` - prev_output_tokens ) Returns: @@ -377,10 +433,20 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): Only populated if *return_all_hiddens* is True. ) """ - input_tuple = (src_tokens, src_lengths, prev_output_tokens) - encoder_embed_output_tuple = self.embedding_layer(input_tuple) - encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) - return self.final_layer_norm(encoder_layers_output) + dummy_prev_output_tokens = torch.zeros(1, dtype=src_tokens.dtype, device=src_tokens.device) + input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + encoder_out = self.model(input_tuple) + else: + encoder_embed_output_tuple = self.embedding_layer(input_tuple) + encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) + encoder_out = self.final_layer_norm(encoder_layers_output) + # first element is the encoder output + # second element is the encoder padding mask + # the remaining elements of EncoderOut are not computed by + # the PipelineParallelTransformer + return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) def reorder_encoder_out(self, encoder_out, new_order): """ @@ -417,34 +483,6 @@ def max_positions(self): return min(self.embedding_layer.max_source_positions, self.embedding_layer.embed_positions.max_positions) - def buffered_future_mask(self, tensor): - dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) - if self._future_mask.size(0) < dim: - self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) - return self._future_mask[:dim, :dim] - - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = '{}.embed_positions.weights'.format(name) - if weights_key in state_dict: - print('deleting {0}'.format(weights_key)) - del state_dict[weights_key] - state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) - for i in range(len(self.layers)): - # update layer norms - self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) - - version_key = '{}.version'.format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - return state_dict - class TransformerDecoder(FairseqDecoder): """ @@ -459,16 +497,39 @@ class TransformerDecoder(FairseqDecoder): (default: False). """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) - self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) - layers = [ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ] - self.decoder_layers = nn.Sequential(*layers) - self.decoder_output_layer = TransformerDecoderOutputLayer(args, embed_tokens, dictionary) + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError('Please install fairscale with: pip install fairscale') + if decoder_module_list is None: + embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + layers = [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + decoder_output_layer = TransformerDecoderOutputLayer(args, embed_tokens, dictionary) + decoder_module_list = [embedding_layer] + layers + [decoder_output_layer] + self.use_pipeline = (getattr(args, "pipeline_decoder_balance", None) is not None) + if self.use_pipeline: + decoder_balance = utils.eval_str_list(args.pipeline_decoder_balance, type=int) + decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int) + assert sum(decoder_balance) == len(decoder_module_list), \ + f"Sum of decoder_balance={decoder_balance} is not equal " + \ + f"to num_decoder_modules={len(decoder_module_list)}" + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.embedding_layer = decoder_module_list[0] + self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1]) + self.decoder_output_layer = decoder_module_list[-1] def forward(self, prev_output_tokens, encoder_out=None,): """ @@ -487,35 +548,14 @@ def forward(self, prev_output_tokens, encoder_out=None,): - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ - input = (prev_output_tokens, encoder_out) - embed_layer_output = self.embedding_layer(input) - state = self.decoder_layers(embed_layer_output) - return self.decoder_output_layer(state) - - def extract_features(self, prev_output_tokens, encoder_out=None,): - """ - Similar to *forward* but only return features. - - Includes several features from "Jointly Learning to Align and - Translate with Transformer Models" (Garg et al., EMNLP 2019). - - Args: - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - alignment_layer (int, optional): return mean alignment over - heads at this layer (default: last layer). - alignment_heads (int, optional): only average alignment over - this many heads (default: all heads). - - Returns: - tuple: - - the decoder's features of shape `(batch, tgt_len, embed_dim)` - - a dictionary with any model-specific outputs - """ - input = (prev_output_tokens, encoder_out) - embed_layer_output = self.embedding_layer(input) - state = self.decoder_layers(embed_layer_output) - return self.decoder_output_layer(state, apply_final_proj=False) + input_tuple = (encoder_out.encoder_out, encoder_out.encoder_padding_mask, prev_output_tokens) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + return (self.model(input_tuple), ) + else: + embed_layer_output = self.embedding_layer(input_tuple) + state = self.decoder_layers(embed_layer_output) + return (self.decoder_output_layer(state), ) def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 66217f64fc..64c83673e6 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -77,6 +77,8 @@ def main(parsed_args, **unused_kwargs): arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), + strict=(parsed_args.checkpoint_shard_count == 1), + num_shards=parsed_args.checkpoint_shard_count, ) for arg in vars(parsed_args).keys(): @@ -104,11 +106,11 @@ def main(parsed_args, **unused_kwargs): # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: - model.prepare_for_inference_(args) if args.fp16: model.half() - if use_cuda: + if use_cuda and not args.pipeline_model_parallel: model.cuda() + model.prepare_for_inference_(args) assert len(models) > 0 diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 0cf09feaee..15b0552c3c 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -89,6 +89,8 @@ def _main(args, output_file): arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, ) if args.lm_path is not None: @@ -113,11 +115,11 @@ def _main(args, output_file): for model in chain(models, lms): if model is None: continue - model.prepare_for_inference_(args) if args.fp16: model.half() - if use_cuda: + if use_cuda and not args.pipeline_model_parallel: model.cuda() + model.prepare_for_inference_(args) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index dd6249bd90..fc4b46e39d 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -22,7 +22,7 @@ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders from fairseq.token_generation_constraints import pack_constraints, unpack_constraints -from .generate import get_symbols_to_strip_from_output +from fairseq_cli.generate import get_symbols_to_strip_from_output logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', @@ -139,6 +139,8 @@ def main(args): arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, ) # Set dictionaries @@ -147,11 +149,11 @@ def main(args): # Optimize ensemble for generation for model in models: - model.prepare_for_inference_(args) if args.fp16: model.half() - if use_cuda: + if use_cuda and not args.pipeline_model_parallel: model.cuda() + model.prepare_for_inference_(args) # Initialize generator generator = task.build_generator(models, args) From a9baca376616bed56e5df5115d7adf8059c0d296 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Mon, 12 Oct 2020 11:23:44 -0700 Subject: [PATCH 198/707] add supports_fetch_outside_dataloader property to avoid creating real first batches for datasets that only expect to be used inside dataloader workers Summary: D24093421 (https://github.com/pytorch/fairseq/commit/e056de1fb641425974b0941704e63b9b42fe8a1f) added `first_batch` to iterators in fairseq. This means that FairseqDataset objects now might need to fetch data outside the dataloader workers. This causes issues with certain datasets, in particular datasets that fetch data via everstore/memcache, since these clients open a ton of file descriptors based on how many items are fetched. Opening too many file descriptors causes forking to fail in python multiprocessing. To fix this, lets have a property `supports_fetch_outside_dataloader` in the FairseqDataset that allows us to decide if it is safe to fetch the first batch. If it is not, we will revert back to the original behavior before D24093421 (https://github.com/pytorch/fairseq/commit/e056de1fb641425974b0941704e63b9b42fe8a1f) which is to just use "DUMMY", and set this as a real batch late.r Reviewed By: yqwangustc Differential Revision: D24234470 fbshipit-source-id: 7ad66a6de622ce26f59f00d00b19700fbd992921 --- fairseq/data/fairseq_dataset.py | 5 +++++ fairseq/data/iterators.py | 5 ++++- fairseq/trainer.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index a4a0985210..caaef8f713 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -166,6 +166,11 @@ def filter_indices_by_size(self, indices, max_sizes): indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) return indices, ignored + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """ diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 729a3141d1..5b2fc219c4 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -261,7 +261,10 @@ def first_batch(self): "a larger dataset." ) - return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" def __len__(self): return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d3724f22b1..5d68783bfb 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -448,6 +448,8 @@ def train_step(self, samples, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample is_dummy_batch = False def maybe_no_sync(): @@ -670,6 +672,8 @@ def valid_step(self, sample, raise_oom=False): sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample is_dummy_batch = False try: @@ -804,6 +808,13 @@ def _local_cumulative_training_time(self): return time.time() - self._start_time + self._previous_training_time def _prepare_sample(self, sample): + if sample == "DUMMY": + raise Exception( + "Trying to use an uninitialized 'dummy' batch. This usually indicates " + "that the total number of batches is smaller than the number of " + "participating GPUs. Try reducing the batch size or using fewer GPUs." + ) + if sample is None or len(sample) == 0: return None From e0d5d8e669528be579d7aa4749fbcfe5cacdce90 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 12 Oct 2020 14:01:46 -0700 Subject: [PATCH 199/707] refactor build_generator to reduce code duplication. (#2716) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2693 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2716 Reviewed By: kahne Differential Revision: D24243384 Pulled By: myleott fbshipit-source-id: cdf4fb3b97d87dd8dbb0ea7cdb5f286277892d81 --- examples/speech_recognition/infer.py | 20 ++----------------- .../tasks/speech_recognition.py | 4 ++++ 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 950490d5f5..b27cf5add5 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -244,6 +244,7 @@ def main(args, task=None, model_state=None): # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) + logger.info( "| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)) @@ -281,24 +282,7 @@ def main(args, task=None, model_state=None): # Initialize generator gen_timer = StopwatchMeter() - def build_generator(args): - w2l_decoder = getattr(args, "w2l_decoder", None) - if w2l_decoder == "viterbi": - from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder - - return W2lViterbiDecoder(args, task.target_dictionary) - elif w2l_decoder == "kenlm": - from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder - - return W2lKenLMDecoder(args, task.target_dictionary) - elif w2l_decoder == "fairseqlm": - from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder - - return W2lFairseqLMDecoder(args, task.target_dictionary) - else: - return super().build_generator(args) - - generator = build_generator(args) + generator = task.build_generator(models, args) if args.load_emissions: generator = ExistingEmissionsDecoder( diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index 1181c9aef5..769ce4ff54 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -123,6 +123,10 @@ def build_generator(self, models, args, **unused): from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, self.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, self.target_dictionary) else: return super().build_generator(models, args) From fc1c38aa1c70e1d1ef45a6af335e3c6571ba436d Mon Sep 17 00:00:00 2001 From: Sharvil Nanavati Date: Mon, 12 Oct 2020 15:19:50 -0700 Subject: [PATCH 200/707] Fix broken links to Wav2Vec2 Large checkpoints (#2657) Summary: Typo fixes. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2657 Reviewed By: alexeib Differential Revision: D24237518 Pulled By: myleott fbshipit-source-id: d101a8b8cc9c8d725eb63265c85be92f6b2c5a6c --- examples/wav2vec/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e6db6c6796..518d8f86cb 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -10,10 +10,10 @@ Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [d Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) -Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) -Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) -Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) -Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https//dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) +Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) +Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) +Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) Wav2Vec 2.0 Large (LV-60) | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) Wav2Vec 2.0 Large (LV-60) | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m.pt) Wav2Vec 2.0 Large (LV-60) | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h.pt) From f910ea9d4cf9c9964ec307dde3144622c4b61e62 Mon Sep 17 00:00:00 2001 From: Helen Craig Date: Tue, 13 Oct 2020 08:06:03 -0700 Subject: [PATCH 201/707] Update sequence_scorer.py (#2715) Summary: Changing so attention is returned for joint alignment example. related to this issue: https://github.com/pytorch/fairseq/issues/2695 And this one: https://github.com/pytorch/fairseq/issues/2634 Pull Request resolved: https://github.com/pytorch/fairseq/pull/2715 Reviewed By: pipibjc Differential Revision: D24237512 Pulled By: myleott fbshipit-source-id: 2b2be8002ab20b89fd6a8ef6e9d2b74063c5c7c8 --- fairseq/sequence_scorer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index 343c29acc2..c8ded1930c 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -87,8 +87,11 @@ def gather_target_probs(probs, target): avg_probs = probs else: avg_probs.add_(probs) - if attn is not None and torch.is_tensor(attn): - attn = attn.data + if attn is not None: + if torch.is_tensor(attn): + attn = attn.data + else: + attn = attn[0] if avg_attn is None: avg_attn = attn else: From c005362349075fb5952ece139481232cc49e2286 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 13 Oct 2020 08:07:20 -0700 Subject: [PATCH 202/707] Misc fixes (#1341) Summary: - Fix all_reduce_dict for non-homogenous sized tensors (fixes #2707) - Add extra docs for tutorial-specific code (fixes #2554) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1341 Reviewed By: pipibjc Differential Revision: D24264129 Pulled By: myleott fbshipit-source-id: 58619bb34afe51ea956abe8e6b41505f35417c09 --- fairseq/distributed_utils.py | 6 ++++-- fairseq/models/fairseq_model.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index d67604f54e..ab5aad1425 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -384,9 +384,11 @@ def all_reduce_dict( def _all_reduce_dict(data: OrderedDict): if len(data) == 0: return data - buf = torch.stack(list(data.values())).to(device=device) + buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) all_reduce(buf, group=group) - return {k: buf[i] for i, k in enumerate(data)} + split_buf = torch.split(buf, [t.numel() for t in data.values()]) + reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] + return OrderedDict(zip(data.keys(), reduced_data)) cpu_data = _all_reduce_dict(cpu_data) device_data = _all_reduce_dict(device_data) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 4a2b6b66b9..facb7d011b 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -69,6 +69,8 @@ def get_normalized_probs_scriptable( if hasattr(self, "decoder"): return self.decoder.get_normalized_probs(net_output, log_probs, sample) elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) logits = net_output.float() if log_probs: return F.log_softmax(logits, dim=-1) From d6cdc2f47b74e3126df748c0da02be43d7356a07 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 13 Oct 2020 08:07:57 -0700 Subject: [PATCH 203/707] bump fairseq version in preparation for a release. (#2717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1948 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � This may be something to discuss - it seems like there's a lot of confusion about what features are supported when in fairseq. Hopefully versioning will allow for more discrete cuts. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2717 Reviewed By: pipibjc Differential Revision: D24244678 Pulled By: myleott fbshipit-source-id: 4d4c7bd13387c43fb11c64d7e62985d212b5a02a --- fairseq/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/__init__.py b/fairseq/__init__.py index f7d7793349..dfa4ef7898 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. __all__ = ['pdb'] -__version__ = '0.9.0' +__version__ = '1.0.0a0' import sys From 85f097141d83d6aac378838b6c0c8f2a0f77154f Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Tue, 13 Oct 2020 12:32:58 -0700 Subject: [PATCH 204/707] Remove FileContentsAction for --langs Summary: # Facebook: Revert changes made in D24059296 (https://github.com/pytorch/fairseq/commit/0557ed8b0df90fe671bcb745f384ef7fd0386ab3) since it breaks normal --langs usage. To specify languages in a file, use --lang-dict instead Also update integration test params so it can catch this Reviewed By: tangyuq Differential Revision: D24224622 fbshipit-source-id: 292eeb86e02528128ced09f8165045be9c847c19 --- fairseq/data/multilingual/multilingual_data_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index f0f93f25e1..806e4c360d 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -87,7 +87,6 @@ def add_args(parser): type=csv_str_list, help="a list of languages comma sperated languages which can appear in lang-pairs; " "note that the ordering determines language token IDs", - action=FileContentsAction, ) parser.add_argument( "--lang-dict", From 086fe1c5d1317caad090b2ff60f965d2dfa130f7 Mon Sep 17 00:00:00 2001 From: Nicola De Cao Date: Wed, 14 Oct 2020 08:28:12 -0700 Subject: [PATCH 205/707] adding search.PrefixConstrainedBeamSearch (#2646) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This adds a new decoding strategy `search.PrefixConstrainedBeamSearch` that limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argument `prefix_allowed_tokens_fn` to `.generate` or `.sample` to activate `PrefixConstrainedBeamSearch`. `prefix_allowed_tokens_fn(batch_id, tokens)` is a callback function that given the `batch_id` and `tokens` returns the list of allowed token for the next generation step. ## Did you have fun? YES! � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2646 Reviewed By: fabiopetroni Differential Revision: D24006805 Pulled By: myleott fbshipit-source-id: 40b1a866c6ea9f936272db27e2a020b18dbf8164 --- fairseq/search.py | 209 ++++++++++++++++++++++++++++------ fairseq/sequence_generator.py | 12 ++ fairseq/tasks/fairseq_task.py | 5 + 3 files changed, 189 insertions(+), 37 deletions(-) diff --git a/fairseq/search.py b/fairseq/search.py index ecb4764a82..2c21b66bbd 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -10,7 +10,11 @@ import torch.nn as nn from torch import Tensor -from fairseq.token_generation_constraints import ConstraintState, UnorderedConstraintState, OrderedConstraintState +from fairseq.token_generation_constraints import ( + ConstraintState, + UnorderedConstraintState, + OrderedConstraintState, +) class Search(nn.Module): @@ -22,8 +26,11 @@ def __init__(self, tgt_dict): self.vocab_size = len(tgt_dict) self.src_lengths = torch.tensor(-1) self.supports_constraints = False + self.stop_on_max_len = False - def step(self, step, lprobs, scores): + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): """Take a single search step. Args: @@ -32,6 +39,12 @@ def step(self, step, lprobs, scores): the model's log-probabilities over the vocabulary at the current step scores: (bsz x input_beam_size x step) the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices Return: A tuple of (scores, indices, beams) where: scores: (bsz x output_beam_size) @@ -94,7 +107,14 @@ def __init__(self, tgt_dict): self.constraint_states = None @torch.jit.export - def step(self, step: int, lprobs, scores: Optional[Tensor]): + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if step == 0: @@ -125,6 +145,69 @@ def step(self, step: int, lprobs, scores: Optional[Tensor]): return scores_buf, indices_buf, beams_buf +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + class LexicallyConstrainedBeamSearch(Search): """Implements lexically constrained beam search as described in @@ -143,6 +226,7 @@ class LexicallyConstrainedBeamSearch(Search): constraints have been generated and using this information to shape the beam for each input sentence. """ + def __init__(self, tgt_dict, representation): super().__init__(tgt_dict) self.representation = representation @@ -163,17 +247,28 @@ def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): @torch.jit.export def prune_sentences(self, batch_idxs: Tensor): - self.constraint_states = [self.constraint_states[i] for i in batch_idxs.tolist()] + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] @torch.jit.export def update_constraints(self, active_hypos: Tensor): if self.constraint_states: batch_size = active_hypos.size(0) for sentid in range(batch_size): - self.constraint_states[sentid] = [self.constraint_states[sentid][i] for i in active_hypos[sentid]] + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] @torch.jit.export - def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): """ A constrained step builds a large candidates list from the following: - the top 2 * {beam_size} items over the whole beam @@ -222,7 +317,9 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): not_finished_indices.append(index) not_finished_indices = torch.tensor(not_finished_indices) if not_finished_indices.numel() > 0: - lprobs.view(batch_size * beam_size, -1)[not_finished_indices, self.eos] = -math.inf + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf if step == 0: # at the first step all hypotheses are equally likely, so use @@ -265,13 +362,15 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() for sentno, states in enumerate(constraint_states): - scores, indices, beams, new_states = self.step_sentence(step, - sentno, - lprobs[sentno], - constraint_states[sentno], - beams_buf[sentno].clone(), - indices_buf[sentno].clone(), - scores_buf[sentno].clone()) + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) new_scores_buf[sentno] = scores new_indices_buf[sentno] = indices new_beams_buf[sentno] = beams @@ -280,14 +379,16 @@ def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]): return new_scores_buf, new_indices_buf, new_beams_buf @torch.jit.export - def step_sentence(self, - step: int, - sentno: int, - lprobs: Tensor, - constraint_states: List[List[ConstraintState]], - beams_buf: Tensor, - indices_buf: Tensor, - scores_buf: Tensor): + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): """Does per-sentence processing. Adds all constraints for each hypothesis to the list of candidates; then removes duplicates, sorts, and dynamically stripes across the banks. All tensor inputs @@ -300,7 +401,11 @@ def step_sentence(self, next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() if next_tokens.numel() != 0: indices_buf = torch.cat((indices_buf, next_tokens)) - next_beams = torch.tensor(beamno, device=device).repeat(next_tokens.size(0)).long() + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) beams_buf = torch.cat((beams_buf, next_beams)) next_values = lprobs[beamno].take(next_tokens.view(-1)) scores_buf = torch.cat((scores_buf, next_values)) @@ -320,8 +425,10 @@ def step_sentence(self, # Compute the new states for all candidates cands_size = indices_buf.size(0) - constraint_states = [constraint_states[beams_buf[i]].advance(indices_buf[i]) - for i in range(cands_size)] + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] banks = torch.tensor([state.bank for state in constraint_states], device=device) @@ -357,7 +464,7 @@ def roll(t): # This is then shifted by 1. We can then easily identify # duplicates and create a mask that identifies unique # extensions. - uniques_mask = (beams_buf * (self.vocab_size + 1) + indices_buf) + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf uniques_mask = roll(uniques_mask) != uniques_mask # Use the mask to pare down the data structures @@ -410,9 +517,9 @@ def roll(t): constraint_states = [constraint_states[i] for i in sort_indices] # STEP 8: Truncate to the candidates size! - scores_buf = scores_buf[:self.num_cands] - indices_buf = indices_buf[:self.num_cands] - beams_buf = beams_buf[:self.num_cands] + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] return scores_buf, indices_buf, beams_buf, constraint_states @@ -427,7 +534,14 @@ def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): self.beam = BeamSearch(tgt_dict) self.needs_src_lengths = True - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): min_lens = self.min_len_a * self.src_lengths + self.min_len_b max_lens = self.max_len_a * self.src_lengths + self.max_len_b lprobs[step < min_lens, :, self.eos] = -math.inf @@ -452,7 +566,14 @@ def __init__(self, tgt_dict, num_groups, diversity_strength): self.beam = BeamSearch(tgt_dict) @torch.jit.export - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if beam_size % self.num_groups != 0: raise ValueError( @@ -553,7 +674,14 @@ def _sample_topp(self, lprobs): return trimed_probs, truncated_indices @torch.jit.export - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() if step == 0: @@ -576,7 +704,9 @@ def step(self, step: int, lprobs, scores): # sample if step == 0: indices_buf = torch.multinomial( - probs.view(bsz, -1), beam_size, replacement=True, + probs.view(bsz, -1), + beam_size, + replacement=True, ).view(bsz, beam_size) else: indices_buf = torch.multinomial( @@ -590,9 +720,7 @@ def step(self, step: int, lprobs, scores): probs = probs.expand(bsz, beam_size, -1) # gather scores - scores_buf = torch.gather( - probs, dim=2, index=indices_buf.unsqueeze(-1) - ) + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) scores_buf = scores_buf.log_().view(bsz, -1) # remap indices if using top-k or top-P sampling @@ -635,7 +763,14 @@ def __init__(self, tgt_dict, diversity_rate): self.diversity_rate = diversity_rate self.beam = BeamSearch(tgt_dict) - def step(self, step: int, lprobs, scores): + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): bsz, beam_size, vocab_size = lprobs.size() k = min( # Take the best 2 x beam_size predictions. We'll choose the first diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ff45c7dfb7..7ce797746f 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -276,6 +276,13 @@ def _generate( reorder_state: Optional[Tensor] = None batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams # print(f'step: {step}') @@ -288,6 +295,7 @@ def _generate( reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size ) + original_batch_idxs = original_batch_idxs[batch_idxs] self.model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state @@ -358,6 +366,8 @@ def _generate( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, ) # cand_bbsz_idx contains beam indices for the top candidate @@ -401,6 +411,8 @@ def _generate( assert num_remaining_sent >= 0 if num_remaining_sent == 0: break + if self.search.stop_on_max_len and step >= max_len: + break assert step < max_len # Remove finalized sentences (ones for which {beam_size} diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 2aa6c8ff28..1ce4ab1921 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -313,6 +313,7 @@ def build_generator( match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if ( sum( int(cond) @@ -356,6 +357,10 @@ def build_generator( search_strategy = search.LexicallyConstrainedBeamSearch( self.target_dictionary, args.constraints ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) else: search_strategy = search.BeamSearch(self.target_dictionary) From 5e831033069b52b09905e0bf8ba104d016e04efd Mon Sep 17 00:00:00 2001 From: Vasiliy Alekseev Date: Wed, 14 Oct 2020 09:30:29 -0700 Subject: [PATCH 206/707] Fix apply_sparse_mask signature in MultiheadAttention (#2587) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2574 (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2587 Reviewed By: myleott Differential Revision: D23607694 Pulled By: lematt1991 fbshipit-source-id: b8fd27cf9a4fc4287f333a4422ad43fa93128615 --- examples/speech_recognition/models/vggtransformer.py | 6 +++--- fairseq/modules/multihead_attention.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py index 9191a3ffe2..3a5db5bb16 100644 --- a/examples/speech_recognition/models/vggtransformer.py +++ b/examples/speech_recognition/models/vggtransformer.py @@ -394,9 +394,9 @@ def validate_transformer_config(self, transformer_config): input_dim, num_heads = config[:2] if input_dim % num_heads != 0: msg = ( - "ERROR in transformer config {}:".format(config) + "ERROR in transformer config {}: ".format(config) + "input dimension {} ".format(input_dim) - + "not dividable by number of heads".format(num_heads) + + "not dividable by number of heads {}".format(num_heads) ) raise ValueError(msg) @@ -459,7 +459,7 @@ def parse_transformer_sampling(self, transformer_sampling, num_layers): if len(transformer_sampling) != num_layers: raise ValueError( "transformer_sampling {} does not match with the number " - + "of layers {}".format(transformer_sampling, num_layers) + "of layers {}".format(transformer_sampling, num_layers) ) for layer, value in enumerate(transformer_sampling): diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index e33dd450ee..90b635af2b 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -317,7 +317,7 @@ def forward( ) attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] @@ -442,7 +442,7 @@ def _set_input_buffer( ): return self.set_incremental_state(incremental_state, "attn_state", buffer) - def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int): + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights def upgrade_state_dict_named(self, state_dict, name): From 4948d890a4701170c5a84b62cfd310e08af39273 Mon Sep 17 00:00:00 2001 From: Louis Martin Date: Wed, 14 Oct 2020 09:46:03 -0700 Subject: [PATCH 207/707] Fix link in CamemBERT readme (#2722) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fix link in CamemBERT readme ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2722 Reviewed By: louismartin Differential Revision: D24307327 Pulled By: myleott fbshipit-source-id: c3c29a19de06a8062fa7f7212ad6df0d549ad25f --- examples/camembert/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/camembert/README.md b/examples/camembert/README.md index 9e41c57823..5ef4fe3f15 100644 --- a/examples/camembert/README.md +++ b/examples/camembert/README.md @@ -30,7 +30,7 @@ camembert.eval() # disable dropout (or leave in train mode to finetune) ##### Load CamemBERT (for PyTorch 1.0 or custom models): ```python # Download camembert model -wget https://dl.fbaipublicfiles.com/fairseq/models/camembert.tar.gz +wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz tar -xzvf camembert.tar.gz # Load the model in fairseq From c4d322ad9d3e6907e54976289d95c3d9b571a5c3 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 14 Oct 2020 10:30:28 -0700 Subject: [PATCH 208/707] Fix library usage of --user-dir (primarily affects tests) (#1346) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1346 Reviewed By: xianxl Differential Revision: D24306363 Pulled By: myleott fbshipit-source-id: 90c4b59031f04b925ad12a13a96d9225ab0a09b4 --- fairseq/models/huggingface/hf_gpt2.py | 18 +++++------------- fairseq/utils.py | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py index 6a03406ef6..e81954ff65 100644 --- a/fairseq/models/huggingface/hf_gpt2.py +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -20,10 +20,10 @@ # Prepend the transformers submodule to the path, so that # it's prioritized over other installations. This allows # making local changes in the submodule. - sys.path.insert( - 0, os.path.join(os.path.dirname(__file__), 'transformers', 'src') - ) - from transformers import AutoModel, GPT2Config, GPT2LMHeadModel + hf_path = os.path.join(os.path.dirname(__file__), 'transformers', 'src') + sys.path.insert(0, hf_path) + from transformers import GPT2Config, GPT2LMHeadModel + sys.path.remove(hf_path) has_hf = True except ImportError: has_hf = False @@ -78,15 +78,7 @@ class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder): def __init__(self, args, task): super().__init__(task.target_dictionary) - try: - # Prepend the transformers submodule to the path, so that - # it's prioritized over other installations. This allows - # making local changes in the submodule. - sys.path.insert( - 0, os.path.join(os.path.dirname(__file__), 'transformers', 'src') - ) - from transformers import GPT2Config, GPT2LMHeadModel - except ImportError: + if not has_hf: raise ImportError( '\n\nPlease install huggingface/transformers with:' '\n\n pip install transformers' diff --git a/fairseq/utils.py b/fairseq/utils.py index 4de258d9a2..1a18bf5e6c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -6,10 +6,11 @@ import argparse import contextlib import copy -import importlib.util +import importlib import logging import os import sys +import tempfile import warnings from itertools import accumulate from typing import Callable, Dict, List, Optional @@ -437,18 +438,19 @@ def import_user_module(args): ) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path - module_parent, module_name = os.path.split(module_path) - if module_name in sys.modules: - module_bak = sys.modules[module_name] - del sys.modules[module_name] - else: - module_bak = None - sys.path.insert(0, module_parent) - importlib.import_module(module_name) - sys.modules["fairseq_user_dir"] = sys.modules[module_name] - if module_bak is not None and module_name != "fairseq_user_dir": - sys.modules[module_name] = module_bak + # We want to import the module under a unique name so that it doesn't + # collide with existing modules. At the same time we don't want to + # import the module multiple times. The solution is to create a + # temporary directory and symlink the user_dir under a new name, which is + # a deterministic hash of the original module_path. + with tempfile.TemporaryDirectory() as tmpdirname: + unique_mod_name = 'fairseq_user_dir_{}'.format(hash(module_path) % 100000) + os.symlink(module_path, os.path.join(tmpdirname, unique_mod_name)) + + sys.path.insert(0, tmpdirname) + importlib.import_module(unique_mod_name) + sys.path.remove(tmpdirname) def softmax(x, dim: int, onnx_trace: bool = False): From a2d0be4989c0be5c2f08358e85d3c568029fd6dd Mon Sep 17 00:00:00 2001 From: Chau Tran Date: Wed, 14 Oct 2020 10:31:15 -0700 Subject: [PATCH 209/707] Add CRISS README and code to fairseq (#1344) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [Y] Did you make sure to update the docs? - [N/A] Did you write any new necessary tests? ## What does this PR do? Add code to reproduce results from Cross-lingual Retrieval for Iterative Self-supervised Training. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1344 Test Plan: Imported from GitHub, without a `Test Plan:` line. See https://github.com/fairinternal/fairseq-py/tree/criss_pr/examples/criss Reviewed By: myleott Differential Revision: D24268469 Pulled By: chtran fbshipit-source-id: d4dd36b22bde3c364ce6e935bd39baf8f96e0735 --- README.md | 3 + examples/criss/README.md | 51 +++++ .../download_and_preprocess_flores_test.sh | 64 ++++++ .../criss/download_and_preprocess_tatoeba.sh | 37 +++ examples/criss/mining/mine.py | 214 ++++++++++++++++++ examples/criss/mining/mine_example.sh | 103 +++++++++ examples/criss/save_encoder.py | 188 +++++++++++++++ .../sentence_retrieval/encoder_analysis.py | 91 ++++++++ .../sentence_retrieval_tatoeba.sh | 59 +++++ examples/criss/unsupervised_mt/eval.sh | 37 +++ 10 files changed, 847 insertions(+) create mode 100644 examples/criss/README.md create mode 100644 examples/criss/download_and_preprocess_flores_test.sh create mode 100644 examples/criss/download_and_preprocess_tatoeba.sh create mode 100644 examples/criss/mining/mine.py create mode 100644 examples/criss/mining/mine_example.sh create mode 100644 examples/criss/save_encoder.py create mode 100644 examples/criss/sentence_retrieval/encoder_analysis.py create mode 100644 examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh create mode 100644 examples/criss/unsupervised_mt/eval.sh diff --git a/README.md b/README.md index d6094880da..151d4f0507 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ We provide reference implementations of various sequence modeling papers: - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + - [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -54,6 +55,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +- October 2020: [Added CRISS models and code](examples/criss/README.md) - September 2020: [Added Linformer code](examples/linformer/README.md) - September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) - August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) @@ -145,6 +147,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) diff --git a/examples/criss/README.md b/examples/criss/README.md new file mode 100644 index 0000000000..a534056254 --- /dev/null +++ b/examples/criss/README.md @@ -0,0 +1,51 @@ +# Cross-lingual Retrieval for Iterative Self-Supervised Training + +https://arxiv.org/pdf/2006.09526.pdf + +## Introduction + +CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time. + +## Unsupervised Machine Translation +##### 1. Download and decompress CRISS checkpoints +``` +cd examples/criss +wget https://dl.fbaipublicfiles.com/fairseq/models/criss/criss_checkpoints.tar.gz +tar -xf criss_checkpoints.tar.gz +``` +##### 2. Download and preprocess Flores test dataset +``` +bash download_and_preprocess_flores_test.sh +``` + +##### 3. Run Evaluation on Sinhala-English +``` +bash unsupervised_mt/eval.sh +``` + +## Sentence Retrieval +##### 1. Download and preprocess Tatoeba dataset +``` +bash download_and_preprocess_tatoeba.sh +``` + +##### 2. Run Sentence Retrieval on Tatoeba Kazakh-English +``` +bash sentence_retrieval/sentence_retrieval_tatoeba.sh +``` + +## Mining +##### 1. Mine pseudo-parallel +``` +bash sentence_retrieval/sentence_retrieval_tatoeba.sh +``` + +## Citation +```bibtex +@article{tran2020cross, + title={Cross-lingual retrieval for iterative self-supervised training}, + author={Tran, Chau and Tang, Yuqing and Li, Xian and Gu, Jiatao}, + journal={arXiv preprint arXiv:2006.09526}, + year={2020} +} +``` diff --git a/examples/criss/download_and_preprocess_flores_test.sh b/examples/criss/download_and_preprocess_flores_test.sh new file mode 100644 index 0000000000..ed4b390fbd --- /dev/null +++ b/examples/criss/download_and_preprocess_flores_test.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +SPM_ENCODE=flores/scripts/spm_encode.py +DATA=data_tmp +SPM_MODEL=criss_checkpoints/sentence.bpe.model +DICT=criss_checkpoints/dict.txt + +download_data() { + CORPORA=$1 + URL=$2 + + if [ -f $CORPORA ]; then + echo "$CORPORA already exists, skipping download" + else + echo "Downloading $URL" + wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA + if [ -f $CORPORA ]; then + echo "$URL successfully downloaded." + else + echo "$URL not successfully downloaded." + rm -f $CORPORA + fi + fi +} + +if [[ -f flores ]]; then + echo "flores already cloned" +else + git clone https://github.com/facebookresearch/flores +fi + +mkdir -p $DATA +download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz" +pushd $DATA +pwd +tar -vxf wikipedia_en_ne_si_test_sets.tgz +popd + + +for lang in ne_NP si_LK; do + datadir=$DATA/${lang}-en_XX-flores + rm -rf $datadir + mkdir -p $datadir + TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test + python $SPM_ENCODE \ + --model ${SPM_MODEL} \ + --output_format=piece \ + --inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \ + --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX + + # binarize data + fairseq-preprocess \ + --source-lang ${lang} --target-lang en_XX \ + --testpref $datadir/test.bpe.${lang}-en_XX \ + --destdir $datadir \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 4 +done diff --git a/examples/criss/download_and_preprocess_tatoeba.sh b/examples/criss/download_and_preprocess_tatoeba.sh new file mode 100644 index 0000000000..4579d65aba --- /dev/null +++ b/examples/criss/download_and_preprocess_tatoeba.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +SPM_ENCODE=flores/scripts/spm_encode.py +DATA=data_tmp +SPM_MODEL=criss_checkpoints/sentence.bpe.model +DICT=criss_checkpoints/dict.txt + +git clone https://github.com/facebookresearch/LASER +mkdir -p data_tmp +declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn") +for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do + lang_tatoeba=${lang_tatoeba_map[$lang]} + echo $lang_tatoeba + datadir=$DATA/${lang}-en_XX-tatoeba + rm -rf $datadir + mkdir -p $datadir + TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba + python $SPM_ENCODE \ + --model ${SPM_MODEL} \ + --output_format=piece \ + --inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \ + --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX + + # binarize data + fairseq-preprocess \ + --source-lang ${lang} --target-lang en_XX \ + --testpref $datadir/test.bpe.${lang}-en_XX \ + --destdir $datadir \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 4 +done diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py new file mode 100644 index 0000000000..a902a4ab64 --- /dev/null +++ b/examples/criss/mining/mine.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import faiss +import numpy as np +import glob +import argparse +from subprocess import check_call + + +GB = 1024*1024*1024 + + +def call(cmd): + print(cmd) + check_call(cmd, shell=True) + + +def get_batches(directory, lang, prefix='all_avg_pool'): + print(f"Finding in {directory}/{prefix}.{lang}*") + files = glob.glob(f'{directory}/{prefix}.{lang}*') + emb_files = [] + txt_files = [] + for emb_fi in files: + emb_files.append(emb_fi) + txt_fi = emb_fi.replace(prefix, 'sentences') + txt_files.append(txt_fi) + return emb_files, txt_files + + +def load_batch(emb_file, dim): + embeddings = np.fromfile(emb_file, dtype=np.float32) + num_rows = int(embeddings.shape[0] / dim) + embeddings = embeddings.reshape((num_rows, dim)) + faiss.normalize_L2(embeddings) + return embeddings + + +def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'): + sims = [] + inds = [] + xfrom = 0 + xto = 0 + for x_batch_f in x_batches_f: + yfrom = 0 + yto = 0 + x_batch = load_batch(x_batch_f, dim) + xto = xfrom + x_batch.shape[0] + bsims, binds = [], [] + for y_batch_f in y_batches_f: + y_batch = load_batch(y_batch_f, dim) + neighbor_size = min(k, y_batch.shape[0]) + yto = yfrom + y_batch.shape[0] + print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) + idx = faiss.IndexFlatIP(dim) + idx = faiss.index_cpu_to_all_gpus(idx) + idx.add(y_batch) + bsim, bind = idx.search(x_batch, neighbor_size) + + bsims.append(bsim) + binds.append(bind + yfrom) + yfrom += y_batch.shape[0] + del idx + del y_batch + bsims = np.concatenate(bsims, axis=1) + binds = np.concatenate(binds, axis=1) + aux = np.argsort(-bsims, axis=1) + sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32) + ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64) + for i in range(x_batch.shape[0]): + for j in range(k): + sim_batch[i, j] = bsims[i, aux[i, j]] + ind_batch[i, j] = binds[i, aux[i, j]] + sims.append(sim_batch) + inds.append(ind_batch) + xfrom += x_batch.shape[0] + del x_batch + sim = np.concatenate(sims, axis=0) + ind = np.concatenate(inds, axis=0) + return sim, ind + + +def score(sim, fwd_mean, bwd_mean, margin): + return margin(sim, (fwd_mean + bwd_mean) / 2) + + +def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False): + print(' - scoring {:d} candidates'.format(sim_mat.shape[0])) + scores = np.zeros(candidate_inds.shape) + for i in range(scores.shape[0]): + for j in range(scores.shape[1]): + k = int(candidate_inds[i, j]) + scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin) + return scores + + +def load_text(files): + all_sentences = [] + for fi in files: + with open(fi) as sentence_fi: + for line in sentence_fi: + all_sentences.append(line.strip()) + print(f"Read {len(all_sentences)} sentences") + return all_sentences + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Mine bitext') + parser.add_argument('--src-lang', help='Source language') + parser.add_argument('--tgt-lang', help='Target language') + parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt') + parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model') + parser.add_argument('--dim', type=int, default=1024, + help='Embedding dimension') + parser.add_argument('--mem', type=int, default=5, + help='Memory in GB') + parser.add_argument('--src-dir', help='Source directory') + parser.add_argument('--tgt-dir', help='Target directory') + parser.add_argument('--output', help='Output path') + parser.add_argument('--neighborhood', type=int, default=4, + help='Embedding dimension') + parser.add_argument('--threshold', type=float, default=1.06, + help='Threshold on mined bitext') + parser.add_argument('--valid-size', type=int, default=2000, + help='Number of sentences used for validation set') + parser.add_argument('--min-count', type=int, default=50000, + help='Min num sentences used for each language') + args = parser.parse_args() + + x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang) + y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang) + margin = lambda a, b: a / b + y2x_sim, y2x_ind = knnGPU_sharded( + y_batches_f, x_batches_f, + args.dim, + args.neighborhood, + direction='y2x') + x2y_sim, x2y_ind = knnGPU_sharded( + x_batches_f, y_batches_f, + args.dim, + args.neighborhood, + direction='x2y') + + x2y_mean = x2y_sim.mean(axis=1) + y2x_mean = y2x_sim.mean(axis=1) + fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin) + bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin) + fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)] + bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)] + indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1) + scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) + + x_sentences = load_text(x_sents_f) + y_sentences = load_text(y_sents_f) + + threshold = args.threshold + min_count = args.min_count + seen_src, seen_trg = set(), set() + directory = args.output + call(f"mkdir -p {directory}") + src_out = open( + f'{directory}/all.{args.src_lang}', + mode='w', + encoding='utf-8', + errors='surrogateescape') + tgt_out = open( + f'{directory}/all.{args.tgt_lang}', + mode='w', + encoding='utf-8', + errors='surrogateescape') + scores_out = open( + f'{directory}/all.scores', + mode='w', + encoding='utf-8', + errors='surrogateescape') + count = 0 + for i in np.argsort(-scores): + src_ind, trg_ind = indices[i] + if src_ind not in seen_src and trg_ind not in seen_trg: + seen_src.add(src_ind) + seen_trg.add(trg_ind) + if scores[i] > threshold or count < min_count: + if x_sentences[src_ind]: + print(scores[i], file=scores_out) + print(x_sentences[src_ind], file=src_out) + print(y_sentences[trg_ind], file=tgt_out) + count += 1 + else: + print(f"Ignoring sentence: {x_sentences[src_ind]}") + src_out.close() + tgt_out.close() + scores_out.close() + + print(f"Found {count} pairs for threshold={threshold}") + with open(f'{directory}/all.{args.src_lang}') as all_s, \ + open(f'{directory}/all.{args.tgt_lang}') as all_t, \ + open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \ + open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \ + open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \ + open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t: + count = 0 + for s_line, t_line in zip(all_s, all_t): + s_line = s_line.split('\t')[1] + t_line = t_line.split('\t')[1] + if count >= args.valid_size: + train_s.write(s_line) + train_t.write(t_line) + else: + valid_s.write(s_line) + valid_t.write(t_line) + count += 1 diff --git a/examples/criss/mining/mine_example.sh b/examples/criss/mining/mine_example.sh new file mode 100644 index 0000000000..92b5291338 --- /dev/null +++ b/examples/criss/mining/mine_example.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +source_lang=kk_KZ +target_lang=en_XX +MODEL=criss_checkpoints/criss.2nd.pt +SPM=criss_checkpoints/sentence.bpe.model +SPLIT=test +LANG_DICT=criss_checkpoints/lang_dict.txt +SPM_ENCODE=flores/scripts/spm_encode.py +SAVE_ENCODER=save_encoder.py +ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL +DICT=criss_checkpoints/dict.txt +THRESHOLD=1.02 +MIN_COUNT=500 + +DATA_DIR=data_tmp +SAVE_DIR=mining/${source_lang}_${target_lang}_mined +ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang} +INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba + +mkdir -p $ENCODER_SAVE_DIR/${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${source_lang} +mkdir -p $SAVE_DIR + +## Save encoder outputs + +# Save encoder outputs for source sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --task translation_multi_simple_epoch \ + --lang-pairs ${source_lang}-${target_lang} \ + --lang-dict ${LANG_DICT} \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + -s ${source_lang} -t ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang} + +## Save encoder outputs for target sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --lang-pairs ${source_lang}-${target_lang} \ + --lang-dict ${LANG_DICT} \ + --task translation_multi_simple_epoch \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + -t ${source_lang} -s ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang} + +## Mining +python mining/mine.py \ + --src-lang ${source_lang} \ + --tgt-lang ${target_lang} \ + --dim 1024 \ + --mem 10 \ + --neighborhood 4 \ + --src-dir ${ENCODER_SAVE_DIR}/${source_lang} \ + --tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \ + --output $SAVE_DIR \ + --threshold ${THRESHOLD} \ + --min-count ${MIN_COUNT} \ + --valid-size 100 \ + --dict-path ${DICT} \ + --spm-path ${SPM} \ + + +## Process and binarize mined data +python $SPM_ENCODE \ + --model ${SPM} \ + --output_format=piece \ + --inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \ + --outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang} + +python $SPM_ENCODE \ + --model ${SPM} \ + --output_format=piece \ + --inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \ + --outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang} + + +fairseq-preprocess \ + --source-lang ${source_lang} \ + --target-lang ${target_lang} \ + --trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \ + --validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \ + --destdir mining/${source_lang}_${target_lang}_mined \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 8 diff --git a/examples/criss/save_encoder.py b/examples/criss/save_encoder.py new file mode 100644 index 0000000000..8132bbf0fa --- /dev/null +++ b/examples/criss/save_encoder.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate pre-processed data with a trained model. +""" + +import torch + +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils +from fairseq.sequence_generator import EnsembleModel +import numpy as np + + +def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False): + model = EnsembleModel(models) + + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + + # compute the encoder output for each beam + encoder_outs = model.forward_encoder(encoder_input) + np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) + encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32) + encoder_mask = np.expand_dims(encoder_mask.T, axis=2) + if has_langtok: + encoder_mask = encoder_mask[1:, :, :] + np_encoder_outs = np_encoder_outs[1, :, :] + masked_encoder_outs = encoder_mask * np_encoder_outs + avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0) + return avg_pool + + +def main(args): + assert args.path is not None, '--path required for generation!' + assert not args.sampling or args.nbest == args.beam, \ + '--sampling requires --nbest to be equal to --beam' + assert args.replace_unk is None or args.raw_text, \ + '--replace-unk requires a raw text dataset (--raw-text)' + + args.beam=1 + utils.import_user_module(args) + + if args.max_tokens is None: + args.max_tokens = 12000 + print(args) + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + + # Set dictionaries + try: + src_dict = getattr(task, 'source_dictionary', None) + except NotImplementedError: + src_dict = None + tgt_dict = task.target_dictionary + + # Load ensemble + print('| loading model(s) from {}'.format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(':'), + arg_overrides=eval(args.model_overrides), + task=task, + ) + + # Optimize ensemble for generation + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(args.replace_unk) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_positions=utils.resolve_max_positions( + task.max_positions(), + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + ).next_epoch_itr(shuffle=False) + + num_sentences = 0 + source_sentences = [] + shard_id = 0 + all_avg_pool = None + encoder_has_langtok = ( + hasattr(task.args, 'encoder_langtok') + and task.args.encoder_langtok is not None + and hasattr(task.args, 'lang_tok_replacing_bos_eos') + and not task.args.lang_tok_replacing_bos_eos + ) + with progress_bar.build_progress_bar(args, itr) as t: + for sample in t: + if sample is None: + print("Skipping None") + continue + sample = utils.move_to_cuda(sample) if use_cuda else sample + if 'net_input' not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample['target'][:, :args.prefix_size] + + with torch.no_grad(): + avg_pool = get_avg_pool( + models, sample, prefix_tokens, src_dict, + args.remove_bpe, + has_langtok=encoder_has_langtok) + if all_avg_pool is not None: + all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) + else: + all_avg_pool = avg_pool + + if not isinstance(sample['id'], list): + sample_ids = sample['id'].tolist() + else: + sample_ids = sample['id'] + for i, sample_id in enumerate(sample_ids): + # Remove padding + src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + + # Either retrieve the original sentences or regenerate them from tokens. + if align_dict is not None: + src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) + else: + if src_dict is not None: + src_str = src_dict.string(src_tokens, args.remove_bpe) + else: + src_str = "" + + if not args.quiet: + if src_dict is not None: + print('S-{}\t{}'.format(sample_id, src_str)) + + source_sentences.append(f"{sample_id}\t{src_str}") + + num_sentences += sample['nsentences'] + if all_avg_pool.shape[0] >= 1000000: + with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', + 'w') as avg_pool_file: + all_avg_pool.tofile(avg_pool_file) + with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: + sentence_file.writelines(f'{line}\n' for line in source_sentences) + all_avg_pool = None + source_sentences = [] + shard_id += 1 + + if all_avg_pool is not None: + with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', + 'w') as avg_pool_file: + all_avg_pool.tofile(avg_pool_file) + with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: + sentence_file.writelines(f'{line}\n' for line in source_sentences) + return None + + +def cli_main(): + parser = options.get_generation_parser() + parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N', + help='directory to save encoder outputs') + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/criss/sentence_retrieval/encoder_analysis.py b/examples/criss/sentence_retrieval/encoder_analysis.py new file mode 100644 index 0000000000..c0d74af23a --- /dev/null +++ b/examples/criss/sentence_retrieval/encoder_analysis.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import argparse +import glob + + +DIM = 1024 + + +def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): + target_ids = [tid for tid in target_embs] + source_mat = np.stack(source_embs.values(), axis=0) + normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True) + target_mat = np.stack(target_embs.values(), axis=0) + normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True) + sim_mat = normalized_source_mat.dot(normalized_target_mat.T) + if return_sim_mat: + return sim_mat + neighbors_map = {} + for i, sentence_id in enumerate(source_embs): + idx = np.argsort(sim_mat[i, :])[::-1][:k] + neighbors_map[sentence_id] = [target_ids[tid] for tid in idx] + return neighbors_map + + +def load_embeddings(directory, LANGS): + sentence_embeddings = {} + sentence_texts = {} + for lang in LANGS: + sentence_embeddings[lang] = {} + sentence_texts[lang] = {} + lang_dir = f"{directory}/{lang}" + embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") + for embed_file in embedding_files: + shard_id = embed_file.split('.')[-1] + embeddings = np.fromfile(embed_file, dtype=np.float32) + num_rows = embeddings.shape[0] // DIM + embeddings = embeddings.reshape((num_rows, DIM)) + + with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file: + for idx, line in enumerate(sentence_file): + sentence_id, sentence = line.strip().split('\t') + sentence_texts[lang][sentence_id] = sentence + sentence_embeddings[lang][sentence_id] = embeddings[idx, :] + + return sentence_embeddings, sentence_texts + + +def compute_accuracy(directory, LANGS): + sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS) + + top_1_accuracy = {} + + top1_str = " ".join(LANGS) + '\n' + for source_lang in LANGS: + top_1_accuracy[source_lang] = {} + top1_str += f"{source_lang} " + for target_lang in LANGS: + top1 = 0 + top5 = 0 + neighbors_map = compute_dist( + sentence_embeddings[source_lang], + sentence_embeddings[target_lang]) + for sentence_id, neighbors in neighbors_map.items(): + if sentence_id == neighbors[0]: + top1 += 1 + if sentence_id in neighbors[:5]: + top5 += 1 + n = len(sentence_embeddings[target_lang]) + top1_str += f"{top1/n} " + top1_str += "\n" + + print(top1_str) + print(top1_str, file=open(f"{directory}/accuracy", 'w')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Analyze encoder outputs') + parser.add_argument('directory', + help='Source language corpus' + ) + parser.add_argument('--langs', + help='List of langs' + ) + args = parser.parse_args() + langs = args.langs.split(',') + compute_accuracy(args.directory, langs) diff --git a/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh b/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh new file mode 100644 index 0000000000..0428d8bef9 --- /dev/null +++ b/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +source_lang=kk_KZ +target_lang=en_XX +MODEL=criss_checkpoints/criss.3rd.pt +SPM=criss_checkpoints/sentence.bpe.model +SPLIT=test +LANG_DICT=criss_checkpoints/lang_dict.txt +ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py +SAVE_ENCODER=save_encoder.py +ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL + + + +DATA_DIR=data_tmp +INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba +ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${source_lang} + +# Save encoder outputs for source sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --task translation_multi_simple_epoch \ + --lang-dict ${LANG_DICT} \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + --lang-pairs ${source_lang}-${target_lang} \ + -s ${source_lang} -t ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang} + +# Save encoder outputs for target sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --lang-dict ${LANG_DICT} \ + --task translation_multi_simple_epoch \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + --lang-pairs ${target_lang}-${source_lang} \ + -t ${source_lang} -s ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang} + +# Analyze sentence retrieval accuracy +python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR} diff --git a/examples/criss/unsupervised_mt/eval.sh b/examples/criss/unsupervised_mt/eval.sh new file mode 100644 index 0000000000..03b773ed5a --- /dev/null +++ b/examples/criss/unsupervised_mt/eval.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +SRC=si_LK +TGT=en_XX +MODEL=criss_checkpoints/criss.3rd.pt + +MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl +MOSES=mosesdecoder +REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl +NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl +TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl +GEN_TMP_DIR=gen_tmp +LANG_DICT=criss_checkpoints/lang_dict.txt + +if [ ! -d "mosesdecoder" ]; then + git clone https://github.com/moses-smt/mosesdecoder +fi +mkdir -p $GEN_TMP_DIR +fairseq-generate data_tmp/${SRC}-${TGT}-flores \ + --task translation_multi_simple_epoch \ + --max-tokens 2000 \ + --path ${MODEL} \ + --skip-invalid-size-inputs-valid-test \ + --beam 5 --lenpen 1.0 --gen-subset test \ + --remove-bpe=sentencepiece \ + --source-lang ${SRC} --target-lang ${TGT} \ + --decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \ + --lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen +cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp +cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref +${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp From 1d1c145387e37bb10800190f8d31b144b4e7e182 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Wed, 14 Oct 2020 12:27:45 -0700 Subject: [PATCH 210/707] speech-to-text OSS Summary: Imported from https://github.com/fairinternal/fairseq-py/pull/1284. Updated according to PR comments. Main changes: * New task: `fairseq.tasks.speech_to_text` * Multilingual support: multiple train sub-splits, temperature-based sampling, language ID tokens * New dataset: `fairseq.data.audio.speech_to_text_dataset` * Added accuracy metrics and BOS prefix removal to label smoothed cross entropy * New models: Transformer (`fairseq.models.speech_to_text.s2t_transformer`) and BLSTM (`fairseq.models.speech_to_text.berard`) * Extended scorers: * Added a base scorer class: `fairseq.scorers.BaseScorer` (the parent class for all scorers except the BLEU scorer in CPP) * Added an evaluation tokenizer: `fairseq.scorers.eval_tokenizer` which leverages sacreBLEU's built-in tokenizers and allows character-level tokenization as well as punctuation removal (for WER scoring). * Added chrF scorer: `fairseq.scorers.chrf` * Online Mel-filter bank speech feature extraction (via CPP-based pyKaldi or Python-based TorchAudio): `fairseq.data.audio.audio_utils` * Online speech feature transforms: `fairseq.data.audio.feature_transforms.*` * Fixed the subsampled sequence lengths in VGGTransformer (`examples.speech_recognition.models.vggtransformer`) * Examples under `examples/speech_to_text`: * LibriSpeech (ASR): better results than VGGTransformer with smaller Transformer-based models * MuST-C (ST): comparable to [SOTA results](https://arxiv.org/pdf/2004.10234.pdf) but with less tricks Reviewed By: jmp84 Differential Revision: D24065273 fbshipit-source-id: 5f842ca9c826f92d4af660705611885fe440a9ab --- .../models/vggtransformer.py | 9 +- examples/speech_to_text/README.md | 216 +++++++ examples/speech_to_text/data_utils.py | 218 +++++++ examples/speech_to_text/prep_covost_data.py | 232 +++++++ .../speech_to_text/prep_librispeech_data.py | 96 +++ examples/speech_to_text/prep_mustc_data.py | 172 ++++++ .../label_smoothed_cross_entropy.py | 58 +- fairseq/data/audio/audio_utils.py | 81 +++ .../data/audio/feature_transforms/__init__.py | 77 +++ .../audio/feature_transforms/global_cmvn.py | 24 + .../audio/feature_transforms/specaugment.py | 126 ++++ .../feature_transforms/utterance_cmvn.py | 38 ++ fairseq/data/audio/speech_to_text_dataset.py | 478 ++++++++++++++ fairseq/data/data_utils.py | 11 + fairseq/models/speech_to_text/__init__.py | 7 + fairseq/models/speech_to_text/berard.py | 581 ++++++++++++++++++ .../models/speech_to_text/s2t_transformer.py | 394 ++++++++++++ fairseq/tasks/fairseq_task.py | 11 +- fairseq/tasks/speech_to_text.py | 120 ++++ fairseq_cli/generate.py | 5 +- setup.py | 2 +- tests/test_label_smoothing.py | 1 + 22 files changed, 2941 insertions(+), 16 deletions(-) create mode 100644 examples/speech_to_text/README.md create mode 100644 examples/speech_to_text/data_utils.py create mode 100644 examples/speech_to_text/prep_covost_data.py create mode 100644 examples/speech_to_text/prep_librispeech_data.py create mode 100644 examples/speech_to_text/prep_mustc_data.py create mode 100644 fairseq/data/audio/audio_utils.py create mode 100644 fairseq/data/audio/feature_transforms/__init__.py create mode 100644 fairseq/data/audio/feature_transforms/global_cmvn.py create mode 100644 fairseq/data/audio/feature_transforms/specaugment.py create mode 100644 fairseq/data/audio/feature_transforms/utterance_cmvn.py create mode 100644 fairseq/data/audio/speech_to_text_dataset.py create mode 100644 fairseq/models/speech_to_text/__init__.py create mode 100644 fairseq/models/speech_to_text/berard.py create mode 100644 fairseq/models/speech_to_text/s2t_transformer.py create mode 100644 fairseq/tasks/speech_to_text.py diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py index 3a5db5bb16..e9a45ac73e 100644 --- a/examples/speech_recognition/models/vggtransformer.py +++ b/examples/speech_recognition/models/vggtransformer.py @@ -251,6 +251,7 @@ def __init__( self.conv_layers = nn.ModuleList() self.in_channels = in_channels self.input_dim = input_feat_per_channel + self.pooling_kernel_sizes = [] if vggblock_config is not None: for _, config in enumerate(vggblock_config): @@ -272,6 +273,7 @@ def __init__( layer_norm=layer_norm, ) ) + self.pooling_kernel_sizes.append(pooling_kernel_size) in_channels = out_channels input_feat_per_channel = self.conv_layers[-1].output_dim @@ -336,9 +338,9 @@ def forward(self, src_tokens, src_lengths, **kwargs): x = x.transpose(1, 2).transpose(0, 1) x = x.contiguous().view(output_seq_len, bsz, -1) - subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - # TODO: shouldn't subsampling_factor determined in advance ? - input_lengths = (src_lengths.float() / subsampling_factor).ceil().long() + input_lengths = src_lengths.clone() + for s in self.pooling_kernel_sizes: + input_lengths = (input_lengths.float() / s).ceil().long() encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True @@ -346,6 +348,7 @@ def forward(self, src_tokens, src_lengths, **kwargs): if not encoder_padding_mask.any(): encoder_padding_mask = None + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) transformer_layer_idx = 0 diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md new file mode 100644 index 0000000000..62fd2700b8 --- /dev/null +++ b/examples/speech_to_text/README.md @@ -0,0 +1,216 @@ +# Speech-to-Text (S2T) Modeling + +## Data Preparation +S2T modeling data consists of source speech features, target text and other optional information +(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files +to store these information. Each data field is represented by a column in the TSV file. + +Unlike text token embeddings, speech features (e.g. log mel-filter banks) are usually fixed +during model training and can be pre-computed. The manifest file contains the path to +either the feature file in NumPy format or the WAV/FLAC audio file. For the latter, +features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed +into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance. + +Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path +for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment, +temperature-based resampling, etc. + +## Model Training & Evaluation +Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation. +It requires arguments `--task speech_to_text` and `--arch `. + + +## Example 1: Speech Recognition (ASR) on LibriSpeech + +#### Data preparation +Download and preprocess LibriSpeech data with +```bash +python examples/speech_to_text/prep_librispeech_data.py \ + --output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000 +``` +where `LS_ROOT` is the root path for downloaded data as well as generated manifest and feature files. + +#### Training +```bash +fairseq-train ${LS_ROOT} --train-subset train --valid-subset dev --save-dir ${SAVE_DIR} --num-workers 4 \ + --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy --max-update 300000 \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \ + --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example. +You may switch to `s2t_transformer_m` (71M) or `s2t_transformer_l` (268M) for better performance. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the 4 splits +(`dev-clean`, `dev-other`, `test-clean` and `test-other`): +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${SAVE_DIR} --num-epoch-checkpoints 10 --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}" +for SUBSET in dev-clean dev-other test-clean test-other; do + fairseq-generate ${LS_ROOT} --gen-subset ${SUBSET} --task speech_to_text \ + --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring wer +done +``` + +#### Result + +| --arch | Params | dev-clean | dev-other | test-clean | test-other | +|---|---|---|---|---|---| +| s2t_transformer_s | 30M | 4.1 | 9.3 | 4.4 | 9.2 | +| s2t_transformer_sp | 35M | 3.9 | 9.3 | 4.3 | 8.8 | +| s2t_transformer_m | 71M | 3.5 | 8.1 | 3.7 | 8.1 | +| s2t_transformer_mp | 84M | 3.3 | 7.8 | 3.7 | 8.2 | +| s2t_transformer_l | 268M | 3.3 | 7.7 | 3.5 | 7.8 | +| s2t_transformer_lp | 318M | 3.1 | 7.5 | 3.4 | 7.6 | + + +## Example 2: Speech Translation (ST) on MuST-C + +#### Data Preparation +[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `MUSTC_ROOT`, then preprocess it with +```bash +python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} \ + --asr-vocab-type unigram --asr-vocab-size 5000 \ + --st-vocab-type unigram --st-vocab-size 8000 +``` +The generated manifest and feature files will be available under `MUSTC_ROOT`. + +#### ASR +###### Training +```bash +fairseq-train ${MUSTC_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \ + --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ + --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +###### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_asr --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct +``` +###### Result +| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | +|---|---|---|---|---|---|---|---|---|---| +| s2t_transformer_s | 31M | 18.2 | 17.6 | 17.7 | 17.2 | 17.9 | 19.1 | 18.1 | 17.7 | + +#### ST +###### Training +```bash +fairseq-train ${MUSTC_ROOT} --train-subset train_st --valid-subset dev_st --save-dir ${ST_SAVE_DIR} \ + --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ + --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by ASR for faster training and better +performance: `--load-pretrained-encoder-from `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +###### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the `tst-COMMON` split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu +``` + +###### Result +| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | +|---|---|---|---|---|---|---|---|---|---| +| s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 | + + +## Example 3: ST on CoVoST +#### Data Preparation +Download and preprocess CoVoST data with +```bash +# En ASR +python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \ + --vocab-type char --src-lang en +# ST +python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \ + --vocab-type char --src-lang fr --tgt-lang en +``` +where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files. + +#### ASR +###### Training +```bash +fairseq-train ${COVOST_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \ + --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ + --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +###### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct +``` +###### Result +| --arch | Params | En | +|---|---|---| +| s2t_transformer_s | 31M | 25.6 | + +#### ST +###### Training +```bash +fairseq-train ${COVOST_ROOT} --train-subset train_st_fr_en --valid-subset dev_st_fr_en --save-dir ${ST_SAVE_DIR} \ + --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ + --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better +performance: `--load-pretrained-encoder-from `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +###### Inference & Evaluation +Average the last 10 checkpoints and evaluate on test split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu +``` + +###### Result +| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | +|---|---|---|---|---|---|---|---|---|---| +| s2t_transformer_s | 31M | 26.3 | 17.1 | 23.0 | 18.8 | 16.3 | 21.8 | 13.1 | 13.2 | + +## Citation +Please cite as: +``` +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py new file mode 100644 index 0000000000..1983f70c10 --- /dev/null +++ b/examples/speech_to_text/data_utils.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from multiprocessing import cpu_count +import os +import os.path as op +from glob import glob +import zipfile +import csv +from functools import reduce +from typing import Dict, Any, List +from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank + +import sentencepiece as sp +from tqdm import tqdm +import numpy as np + +from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN + +UNK_TOKEN, UNK_TOKEN_ID = '', 3 +BOS_TOKEN, BOS_TOKEN_ID = '', 0 +EOS_TOKEN, EOS_TOKEN_ID = '', 2 +PAD_TOKEN, PAD_TOKEN_ID = '', 1 + + +def gen_vocab( + input_path: str, output_path_prefix: str, model_type='bpe', + vocab_size=1000, +): + # Train SentencePiece Model + arguments = [ + f'--input={input_path}', + f'--model_prefix={output_path_prefix}', + f'--model_type={model_type}', + f'--vocab_size={vocab_size}', + '--character_coverage=1.0', + f'--num_threads={cpu_count()}', + f'--unk_id={UNK_TOKEN_ID}', + f'--bos_id={BOS_TOKEN_ID}', + f'--eos_id={EOS_TOKEN_ID}', + f'--pad_id={PAD_TOKEN_ID}' + ] + sp.SentencePieceTrainer.Train(' '.join(arguments)) + # Export fairseq dictionary + spm = sp.SentencePieceProcessor() + spm.Load(output_path_prefix + '.model') + vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} + assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \ + vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \ + vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \ + vocab.get(EOS_TOKEN_ID) == EOS_TOKEN + vocab = { + i: s for i, s in vocab.items() + if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} + } + with open(output_path_prefix + '.txt', 'w') as f_out: + for _, s in sorted(vocab.items(), key=lambda x: x[0]): + f_out.write(f'{s} 1\n') + + +def extract_fbank_features(waveform, sample_rate, output_path=None, + n_mel_bins=80, apply_utterance_cmvn=True, + overwrite=False): + if output_path is not None and op.exists(output_path) and not overwrite: + return + + _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + _waveform = _waveform.squeeze().numpy() + + features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) + if features is None: + features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) + if features is None: + raise ImportError('Please install pyKaldi or torchaudio to enable ' + 'online filterbank feature extraction') + + if apply_utterance_cmvn: + cmvn = UtteranceCMVN(norm_means=True, norm_vars=True) + features = cmvn(features) + if output_path is not None: + np.save(output_path, features) + else: + return features + + +def create_zip(data_root, zip_path): + cwd = os.path.abspath(os.curdir) + os.chdir(data_root) + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f: + for filename in tqdm(glob('*.npy')): + f.write(filename) + os.chdir(cwd) + + +def is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +def get_zip_manifest(zip_root, zip_filename): + zip_path = op.join(zip_root, zip_filename) + with zipfile.ZipFile(zip_path, mode='r') as f: + info = f.infolist() + manifest = {} + for i in tqdm(info): + utt_id = op.splitext(i.filename)[0] + offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size + manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}' + with open(zip_path, 'rb') as f: + f.seek(offset) + data = f.read(file_size) + assert len(data) > 1 and is_npy_data(data) + return manifest + + +def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml', + specaugment_policy='lb'): + assert specaugment_policy in {'lb', 'ld'} + data_root = op.abspath(data_root) + writer = S2TDataConfigWriter(op.join(data_root, yaml_filename)) + writer.set_audio_root(op.abspath(data_root)) + writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) + writer.set_input_channels(1) + writer.set_input_feat_per_channel(80) + if specaugment_policy == 'lb': + writer.set_specaugment_lb_policy() + else: + writer.set_specaugment_ld_policy() + writer.set_bpe_tokenizer( + {'bpe': 'sentencepiece', + 'sentencepiece_model': op.join(data_root, spm_filename)} + ) + writer.set_feature_transforms('_train', ['specaugment']) + writer.flush() + + +def save_df_to_tsv(dataframe, path): + dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8", + escapechar='\\', quoting=csv.QUOTE_NONE) + + +def filter_manifest_df(df, is_train_split=False, extra_filters=None, + min_n_frames=5, max_n_frames=3000): + filters = { + 'no speech': df['audio'] == '', + f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames, + 'empty sentence': df['tgt_text'] == '', + } + if is_train_split: + filters[f'long speech (>{max_n_frames} frames)'] = \ + df['n_frames'] > max_n_frames + if extra_filters is not None: + filters.update(extra_filters) + invalid = reduce(lambda x, y: x | y, filters.values()) + valid = ~invalid + print( + '| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) + + f', total {invalid.sum()} filtered, {valid.sum()} remained.' + ) + return df[valid] + + +class S2TDataConfigWriter(object): + DEFAULT_VOCAB_FILENAME = 'dict.txt' + DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 + DEFAULT_INPUT_CHANNELS = 1 + + def __init__(self, yaml_path): + try: + import yaml + except ImportError: + print('Please install PyYAML to load YAML files for S2T data config') + self.yaml = yaml + self.yaml_path = yaml_path + self.config = {} + + def flush(self): + with open(self.yaml_path, 'w') as f: + self.yaml.dump(self.config, f) + + def set_audio_root(self, audio_root=''): + self.config['audio_root'] = audio_root + + def set_vocab_filename(self, vocab_filename='dict.txt'): + self.config['vocab_filename'] = vocab_filename + + def set_specaugment(self, time_wrap_w: int, freq_mask_n: int, + freq_mask_f: int, time_mask_n: int, time_mask_t: int, + time_mask_p: float): + self.config['specaugment'] = { + 'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n, + 'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n, + 'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p, + } + + def set_specaugment_lb_policy(self): + self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27, + time_mask_n=1, time_mask_t=100, time_mask_p=1.0) + + def set_specaugment_ld_policy(self): + self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27, + time_mask_n=2, time_mask_t=100, time_mask_p=1.0) + + def set_input_channels(self, input_channels=1): + self.config['input_channels'] = input_channels + + def set_input_feat_per_channel(self, input_feat_per_channel=80): + self.config['input_feat_per_channel'] = input_feat_per_channel + + def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): + self.config['bpe_tokenizer'] = bpe_tokenizer + + def set_feature_transforms(self, split, transforms: List[str]): + if 'transforms' not in self.config: + self.config['transforms'] = {} + self.config['transforms'][split] = transforms diff --git a/examples/speech_to_text/prep_covost_data.py b/examples/speech_to_text/prep_covost_data.py new file mode 100644 index 0000000000..a70e24e04d --- /dev/null +++ b/examples/speech_to_text/prep_covost_data.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from tempfile import NamedTemporaryFile +import os +import os.path as op +import shutil +from typing import Tuple, Optional +import csv + +from torchaudio.datasets.utils import download_url, extract_archive +from tqdm import tqdm +import pandas as pd +from torch.utils.data import Dataset +import torchaudio +from torch import Tensor + +from examples.speech_to_text.data_utils import ( + gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, + extract_fbank_features, gen_config_yaml, filter_manifest_df +) + +log = logging.getLogger(__name__) + + +MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] + + +class CoVoST(Dataset): + """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost). + + Args: + root (str): root path to the dataset and generated manifests/features + source_language (str): source (audio) language + target_language (str, optional): target (text) language, + None for no translation (default: None) + version (int, optional): CoVoST version. (default: 2) + download (bool, optional): Whether to download the dataset if it is not + found at root path. (default: ``False``). + """ + + CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \ + "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz" + COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \ + "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" + + VERSIONS = {2} + SPLITS = ['train', 'dev', 'test'] + + CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"} + + XX_EN_LANGUAGES = { + 1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn', + 'zh-CN'], + 2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn', + 'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy'] + } + EN_XX_LANGUAGES = { + 1: [], + 2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et', + 'id', + 'ar', 'ta', 'lv', 'ja'] + } + + def __init__( + self, root: str, split: str, source_language: str, + target_language: Optional[str] = None, version: int = 2, + download: bool = False + ) -> None: + assert version in self.VERSIONS and split in self.SPLITS + assert source_language is not None + self.no_translation = (target_language is None) + if not self.no_translation: + assert 'en' in {source_language, target_language} + if source_language == 'en': + assert target_language in self.EN_XX_LANGUAGES[version] + else: + assert source_language in self.XX_EN_LANGUAGES[version] + else: + # Hack here so that we can get "split" column from CoVoST TSV. + # Note that we use CoVoST train split for ASR which is an extension + # to Common Voice train split. + target_language = 'de' if source_language == 'en' else 'en' + + self.root = os.path.join(root, 'raw') + os.makedirs(self.root, exist_ok=True) + + cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version], + lang=source_language) + cv_archive = os.path.join(self.root, os.path.basename(cv_url)) + if download: + if not os.path.isfile(cv_archive): + download_url(cv_url, self.root, hash_value=None) + extract_archive(cv_archive) + + covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language, + tgt_lang=target_language) + covost_archive = os.path.join(self.root, os.path.basename(covost_url)) + if download: + if not os.path.isfile(covost_archive): + download_url(covost_url, self.root, hash_value=None) + extract_archive(covost_archive) + + cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv')) + covost_tsv = self.load_from_tsv( + os.path.join(self.root, + os.path.basename(covost_url).replace('.tar.gz', '')) + ) + df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']], + right=covost_tsv[['path', 'translation', 'split']], + how='inner', on='path') + if split == 'train': + df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')] + else: + df = df[df['split'] == split] + self.data = df.to_dict(orient='index').items() + self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])] + + @classmethod + def load_from_tsv(cls, path: str): + return pd.read_csv( + path, sep='\t', header=0, encoding='utf-8', escapechar='\\', + quoting=csv.QUOTE_NONE, na_filter=False + ) + + def __getitem__( + self, n: int + ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + tuple: ``(waveform, sample_rate, sentence, translation, speaker_id, + sample_id)`` + """ + data = self.data[n] + path = os.path.join(self.root, 'clips', data['path']) + waveform, sample_rate = torchaudio.load(path) + sentence = data['sentence'] + translation = None if self.no_translation else data['translation'] + speaker_id = data['client_id'] + _id = data['path'].replace('.mp3', '') + return waveform, sample_rate, sentence, translation, speaker_id, _id + + def __len__(self) -> int: + return len(self.data) + + +def process(args): + root = op.join(args.data_root, args.src_lang) + os.makedirs(root, exist_ok=True) + # Extract features + feature_root = op.join(root, 'fbank80') + os.makedirs(feature_root, exist_ok=True) + for split in CoVoST.SPLITS: + print(f'Fetching split {split}...') + dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, + download=True) + print('Extracting log mel filter bank features...') + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features(waveform, sample_rate, + op.join(feature_root, f'{utt_id}.npy')) + # Pack features into ZIP + zip_filename = 'fbank80.zip' + zip_path = op.join(root, zip_filename) + print('ZIPing features...') + create_zip(feature_root, zip_path) + print('Fetching ZIP manifest...') + zip_manifest = get_zip_manifest(args.data_root, + f'{args.src_lang}/{zip_filename}') + # Generate TSV manifest + print('Generating manifest...') + train_text = [] + task = f'asr_{args.src_lang}' + if args.tgt_lang is not None: + task = f'st_{args.src_lang}_{args.tgt_lang}' + for split in CoVoST.SPLITS: + manifest = {c: [] for c in MANIFEST_COLUMNS} + dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) + for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): + manifest['id'].append(utt_id) + manifest['audio'].append(zip_manifest[utt_id]) + duration_ms = int(wav.size(1) / sr * 1000) + manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) + manifest['tgt_text'].append( + src_utt if args.tgt_lang is None else tgt_utt + ) + manifest['speaker'].append(speaker_id) + is_train_split = split.startswith('train') + if is_train_split: + train_text.extend(manifest['tgt_text']) + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv')) + # Generate vocab + vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size) + spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}' + with NamedTemporaryFile(mode='w') as f: + for t in train_text: + f.write(t + '\n') + gen_vocab(f.name, op.join(root, spm_filename_prefix), + args.vocab_type, args.vocab_size) + # Generate config YAML + gen_config_yaml(root, spm_filename_prefix + '.model', + yaml_filename=f'config_{task}.yaml', + specaugment_policy='lb') + # Clean up + shutil.rmtree(feature_root) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-root', '-d', required=True, type=str) + parser.add_argument('--vocab-type', default='unigram', required=True, + type=str, choices=['bpe', 'unigram', 'char']), + parser.add_argument('--vocab-size', default=1000, type=int) + parser.add_argument('--src-lang', '-s', required=True, type=str) + parser.add_argument('--tgt-lang', '-t', type=str) + args = parser.parse_args() + + process(args) + + +if __name__ == '__main__': + main() diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py new file mode 100644 index 0000000000..4f003ec505 --- /dev/null +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from tempfile import NamedTemporaryFile +import os +import shutil +import os.path as op + +from tqdm import tqdm +from torchaudio.datasets import LIBRISPEECH +import pandas as pd + +from examples.speech_to_text.data_utils import ( + gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, + extract_fbank_features, gen_config_yaml +) + +log = logging.getLogger(__name__) + +SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean', + 'dev-other', 'test-clean', 'test-other'] + +MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] + + +def process(args): + os.makedirs(args.output_root, exist_ok=True) + # Extract features + feature_root = op.join(args.output_root, 'fbank80') + os.makedirs(feature_root, exist_ok=True) + for split in SPLITS: + print(f'Fetching split {split}...') + dataset = LIBRISPEECH(args.output_root, url=split, download=True) + print('Extracting log mel filter bank features...') + for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset): + sample_id = f'{spk_id}-{chapter_id}-{utt_id}' + extract_fbank_features(wav, sample_rate, + op.join(feature_root, f'{sample_id}.npy')) + # Pack features into ZIP + zip_filename = 'fbank80.zip' + zip_path = op.join(args.output_root, zip_filename) + print('ZIPing features...') + create_zip(feature_root, zip_path) + print('Fetching ZIP manifest...') + zip_manifest = get_zip_manifest(args.output_root, zip_filename) + # Generate TSV manifest + print('Generating manifest...') + train_text = [] + for split in SPLITS: + manifest = {c: [] for c in MANIFEST_COLUMNS} + dataset = LIBRISPEECH(args.output_root, url=split) + for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset): + sample_id = f'{spk_id}-{chapter_id}-{utt_id}' + manifest['id'].append(sample_id) + manifest['audio'].append(zip_manifest[sample_id]) + duration_ms = int(wav.size(1) / sample_rate * 1000) + manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) + manifest['tgt_text'].append(utt) + manifest['speaker'].append(spk_id) + save_df_to_tsv(pd.DataFrame.from_dict(manifest), + op.join(args.output_root, f'{split}.tsv')) + if split.startswith('train'): + train_text.extend(manifest['tgt_text']) + # Generate vocab + vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size) + spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}' + with NamedTemporaryFile(mode='w') as f: + for t in train_text: + f.write(t + '\n') + gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix), + args.vocab_type, args.vocab_size) + # Generate config YAML + gen_config_yaml(args.output_root, spm_filename_prefix + '.model', + specaugment_policy='ld') + # Clean up + shutil.rmtree(feature_root) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--output-root', '-o', required=True, type=str) + parser.add_argument('--vocab-type', default='unigram', required=True, + type=str, choices=['bpe', 'unigram', 'char']), + parser.add_argument('--vocab-size', default=10000, type=int) + args = parser.parse_args() + + process(args) + + +if __name__ == '__main__': + main() diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py new file mode 100644 index 0000000000..6c0a9b7132 --- /dev/null +++ b/examples/speech_to_text/prep_mustc_data.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from tempfile import NamedTemporaryFile +import os +import os.path as op +import shutil +from typing import Tuple +from itertools import groupby + +from tqdm import tqdm +import pandas as pd +from torch.utils.data import Dataset +import torchaudio +from torch import Tensor + +from examples.speech_to_text.data_utils import ( + gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, + extract_fbank_features, gen_config_yaml, filter_manifest_df +) + +log = logging.getLogger(__name__) + + +MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] +TASKS = ['asr', 'st'] + + +class MUSTC(Dataset): + """ + Create a Dataset for MuST-C. Each item is a tuple of the form: + waveform, sample_rate, source utterance, target utterance, speaker_id, + utterance_id + """ + SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE'] + LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru'] + + def __init__(self, root: str, lang: str, split: str) -> None: + assert split in self.SPLITS and lang in self.LANGUAGES + _root = op.join(root, f'en-{lang}', 'data', split) + wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt') + assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root) + # Load audio segments + try: + import yaml + except ImportError: + print('Please install PyYAML to load YAML files for ' + 'the MuST-C dataset') + with open(op.join(txt_root, f'{split}.yaml')) as f: + segments = yaml.load(f, Loader=yaml.BaseLoader) + # Load source and target utterances + for _lang in ['en', lang]: + with open(op.join(txt_root, f'{split}.{_lang}')) as f: + utterances = [r.strip() for r in f] + assert len(segments) == len(utterances) + for i, u in enumerate(utterances): + segments[i][_lang] = u + # Gather info + self.data = [] + for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']): + wav_path = op.join(wav_root, wav_filename) + sample_rate = torchaudio.info(wav_path)[0].rate + seg_group = sorted(_seg_group, key=lambda x: x['offset']) + for i, segment in enumerate(seg_group): + offset = int(float(segment['offset']) * sample_rate) + n_frames = int(float(segment['duration']) * sample_rate) + _id = f'{op.splitext(wav_filename)[0]}_{i}' + self.data.append( + (wav_path, offset, n_frames, sample_rate, segment['en'], + segment[lang], segment['speaker_id'], _id) + ) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \ + self.data[n] + waveform, _ = torchaudio.load(wav_path, offset=offset, + num_frames=n_frames) + return waveform, sr, src_utt, tgt_utt, spk_id, utt_id + + def __len__(self) -> int: + return len(self.data) + + +def process(args): + for lang in MUSTC.LANGUAGES: + cur_root = op.join(args.data_root, f'en-{lang}') + if not op.isdir(cur_root): + print(f'{cur_root} does not exist. Skipped.') + continue + # Extract features + feature_root = op.join(cur_root, 'fbank80') + os.makedirs(feature_root, exist_ok=True) + for split in MUSTC.SPLITS: + print(f'Fetching split {split}...') + dataset = MUSTC(args.data_root, lang, split) + print('Extracting log mel filter bank features...') + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features(waveform, sample_rate, + op.join(feature_root, f'{utt_id}.npy')) + # Pack features into ZIP + zip_filename = 'fbank80.zip' + zip_path = op.join(cur_root, zip_filename) + print('ZIPing features...') + create_zip(feature_root, zip_path) + print('Fetching ZIP manifest...') + zip_manifest = get_zip_manifest(args.data_root, + f'en-{lang}/{zip_filename}') + # Generate TSV manifest + print('Generating manifest...') + train_text = {task: [] for task in TASKS} + for split in MUSTC.SPLITS: + is_train_split = split.startswith('train') + manifest = {c: [] for c in MANIFEST_COLUMNS} + text = {task: [] for task in TASKS} + dataset = MUSTC(args.data_root, lang, split) + for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): + manifest['id'].append(utt_id) + manifest['audio'].append(zip_manifest[utt_id]) + duration_ms = int(wav.size(1) / sr * 1000) + manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) + text['asr'].append(src_utt) + text['st'].append(tgt_utt) + manifest['speaker'].append(speaker_id) + if is_train_split: + for task in TASKS: + train_text[task].extend(text[task]) + for task in TASKS: + manifest['tgt_text'] = text[task] + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv')) + # Generate vocab + for task in TASKS: + vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size + if task == 'st': + vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size + vocab_size_str = '' if vocab_type == 'char' else str(vocab_size) + spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}' + with NamedTemporaryFile(mode='w') as f: + for t in train_text[task]: + f.write(t + '\n') + gen_vocab(f.name, op.join(cur_root, spm_filename_prefix), + vocab_type, vocab_size) + # Generate config YAML + gen_config_yaml(cur_root, spm_filename_prefix + '.model', + yaml_filename=f'config_{task}.yaml', + specaugment_policy='lb') + # Clean up + shutil.rmtree(feature_root) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-root', '-d', required=True, type=str) + parser.add_argument('--asr-vocab-type', default='unigram', required=True, + type=str, choices=['bpe', 'unigram', 'char']), + parser.add_argument('--st-vocab-type', default='unigram', required=True, + type=str, choices=['bpe', 'unigram', 'char']), + parser.add_argument('--asr-vocab-size', default=5000, type=int) + parser.add_argument('--st-vocab-size', default=8000, type=int) + args = parser.parse_args() + + process(args) + + +if __name__ == '__main__': + main() diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index d010c3d03d..931a8f76d5 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -5,6 +5,8 @@ import math +import torch + from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -31,11 +33,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T @register_criterion('label_smoothed_cross_entropy') class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): - - def __init__(self, task, sentence_avg, label_smoothing): + def __init__(self, task, sentence_avg, label_smoothing, + ignore_prefix_size=0, report_accuracy=False): super().__init__(task) self.sentence_avg = sentence_avg self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy @staticmethod def add_args(parser): @@ -43,6 +47,10 @@ def add_args(parser): # fmt: off parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', help='epsilon for label smoothing, 0 means no label smoothing') + parser.add_argument('--report-accuracy', action='store_true', + help='report accuracy metric') + parser.add_argument('--ignore-prefix-size', default=0, type=int, + help='Ignore first N tokens') # fmt: on def forward(self, model, sample, reduce=True): @@ -63,19 +71,41 @@ def forward(self, model, sample, reduce=True): 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output['n_correct'] = utils.item(n_correct.data) + logging_output['total'] = utils.item(total.data) return loss, sample_size, logging_output - def compute_loss(self, model, net_output, sample, reduce=True): + def get_lprobs_and_target(self, model, net_output, sample): lprobs = model.get_normalized_probs(net_output, log_probs=True) - lprobs = lprobs.view(-1, lprobs.size(-1)) - target = model.get_targets(sample, net_output).view(-1, 1) + target = model.get_targets(sample, net_output) + if self.ignore_prefix_size > 0: + if getattr(lprobs, "batch_first", False): + lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() + target = target[:, self.ignore_prefix_size:].contiguous() + else: + lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous() + target = target[self.ignore_prefix_size:, :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) return loss, nll_loss - @staticmethod - def reduce_metrics(logging_outputs) -> None: + def compute_accuracy(self, model, net_output, sample): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))) + total = torch.sum(mask) + return n_correct, total + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get('loss', 0) for log in logging_outputs) nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) @@ -86,6 +116,20 @@ def reduce_metrics(logging_outputs) -> None: metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3) metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + total = utils.item(sum(log.get('total', 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar('total', total) + n_correct = utils.item( + sum(log.get('n_correct', 0) for log in logging_outputs) + ) + metrics.log_scalar('n_correct', n_correct) + metrics.log_derived( + 'accuracy', + lambda meters: round( + meters['n_correct'].sum * 100.0 / meters['total'].sum, 3 + ) if meters['total'].sum > 0 else float('nan'), + ) + @staticmethod def logging_outputs_can_be_summed() -> bool: """ diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py new file mode 100644 index 0000000000..3731721953 --- /dev/null +++ b/fairseq/data/audio/audio_utils.py @@ -0,0 +1,81 @@ +import os.path as op +from typing import Union, BinaryIO, Optional, Tuple + +import numpy as np + + +def get_waveform( + path_or_fp: Union[str, BinaryIO], normalization=True +) -> Tuple[np.ndarray, int]: + """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. + + Args: + path_or_fp (str or BinaryIO): the path or file-like object + normalization (bool): Normalize values to [-1, 1] (Default: True) + """ + if isinstance(path_or_fp, str): + ext = op.splitext(op.basename(path_or_fp))[1] + if ext not in {'.flac', '.wav'}: + raise ValueError(f'Unsupported audio format: {ext}') + + try: + import soundfile as sf + except ImportError: + raise ImportError('Please install soundfile to load WAV/FLAC file') + + waveform, sample_rate = sf.read(path_or_fp, dtype='float32') + if not normalization: + waveform *= 2 ** 15 # denormalized to 16-bit signed integers + return waveform, sample_rate + + +def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via PyKaldi.""" + try: + from kaldi.feat.mel import MelBanksOptions + from kaldi.feat.fbank import FbankOptions, Fbank + from kaldi.feat.window import FrameExtractionOptions + from kaldi.matrix import Vector + + mel_opts = MelBanksOptions() + mel_opts.num_bins = n_bins + frame_opts = FrameExtractionOptions() + frame_opts.samp_freq = sample_rate + opts = FbankOptions() + opts.mel_opts = mel_opts + opts.frame_opts = frame_opts + fbank = Fbank(opts=opts) + features = fbank.compute(Vector(waveform), 1.0).numpy() + return features + except ImportError: + return None + + +def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via TorchAudio.""" + try: + import torch + import torchaudio.compliance.kaldi as ta_kaldi + waveform = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank(waveform, num_mel_bins=n_bins, + sample_frequency=sample_rate) + return features.numpy() + except ImportError: + return None + + +def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: + """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi + (faster CPP implementation) to TorchAudio (Python implementation). Note that + Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the + waveform should not be normalized.""" + sound, sample_rate = get_waveform(path_or_fp, normalization=False) + + features = _get_kaldi_fbank(sound, sample_rate, n_bins) + if features is None: + features = _get_torchaudio_fbank(sound, sample_rate, n_bins) + if features is None: + raise ImportError('Please install pyKaldi or torchaudio to enable ' + 'online filterbank feature extraction') + + return features diff --git a/fairseq/data/audio/feature_transforms/__init__.py b/fairseq/data/audio/feature_transforms/__init__.py new file mode 100644 index 0000000000..399956a33b --- /dev/null +++ b/fairseq/data/audio/feature_transforms/__init__.py @@ -0,0 +1,77 @@ +import importlib +import os +from typing import Optional, Dict +from abc import ABC, abstractmethod + + +class AudioFeatureTransform(ABC): + @classmethod + @abstractmethod + def from_config_dict(cls, config: Optional[Dict] = None): + pass + + +AUDIO_FEATURE_TRANSFORM_REGISTRY = {} +AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set() + + +def register_audio_feature_transform(name): + def register_audio_feature_transform_cls(cls): + if name in AUDIO_FEATURE_TRANSFORM_REGISTRY: + raise ValueError(f'Cannot register duplicate transform ({name})') + if not issubclass(cls, AudioFeatureTransform): + raise ValueError(f'Transform ({name}: {cls.__name__}) must extend ' + 'AudioFeatureTransform') + if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES: + raise ValueError( + f'Cannot register audio feature transform with duplicate ' + f'class name ({cls.__name__})' + ) + AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls + AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__) + return cls + + return register_audio_feature_transform_cls + + +def get_audio_feature_transform(name): + return AUDIO_FEATURE_TRANSFORM_REGISTRY[name] + + +transforms_dir = os.path.dirname(__file__) +for file in os.listdir(transforms_dir): + path = os.path.join(transforms_dir, file) + if ( + not file.startswith('_') + and not file.startswith('.') + and (file.endswith('.py') or os.path.isdir(path)) + ): + name = file[:file.find('.py')] if file.endswith('.py') else file + importlib.import_module('fairseq.data.audio.feature_transforms.' + name) + + +class CompositeAudioFeatureTransform(AudioFeatureTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + _transforms = _config.get('transforms') + if _transforms is None: + return None + transforms = [ + get_audio_feature_transform(_t).from_config_dict(_config.get(_t)) + for _t in _transforms + ] + return CompositeAudioFeatureTransform(transforms) + + def __init__(self, transforms): + self.transforms = [t for t in transforms if t is not None] + + def __call__(self, x): + for t in self.transforms: + x = t(x) + return x + + def __repr__(self): + format_string = [self.__class__.__name__ + '('] + \ + [f" {t.__repr__()}" for t in self.transforms] + [')'] + return '\n'.join(format_string) diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py new file mode 100644 index 0000000000..f9c92a66b1 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -0,0 +1,24 @@ +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, register_audio_feature_transform +) + + +@register_audio_feature_transform('global_cmvn') +class GlobalCMVN(AudioFeatureTransform): + """Global CMVN (cepstral mean and variance normalization). The global mean + and variance need to be pre-computed and stored in NumPy format (.npz).""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return GlobalCMVN(_config.get('stats_npz_path')) + + def __init__(self, stats_npz_path): + stats = np.load(stats_npz_path) + self.mean, self.std = stats['mean'], stats['std'] + + def __call__(self, x): + x = np.subtract(x, self.mean) + x = np.divide(x, self.std) + return x diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py new file mode 100644 index 0000000000..e4c36bde3c --- /dev/null +++ b/fairseq/data/audio/feature_transforms/specaugment.py @@ -0,0 +1,126 @@ +import math +import numbers +from typing import Optional + +import numpy as np + +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, register_audio_feature_transform +) + + +@register_audio_feature_transform('specaugment') +class SpecAugmentTransform(AudioFeatureTransform): + """SpecAugment (https://arxiv.org/abs/1904.08779)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return SpecAugmentTransform( + _config.get('time_warp_W', 0), + _config.get('freq_mask_N', 0), + _config.get('freq_mask_F', 0), + _config.get('time_mask_N', 0), + _config.get('time_mask_T', 0), + _config.get('time_mask_p', 0.0), + _config.get('mask_value', None), + ) + + def __init__( + self, + time_warp_w: int = 0, + freq_mask_n: int = 0, + freq_mask_f: int = 0, + time_mask_n: int = 0, + time_mask_t: int = 0, + time_mask_p: float = 0.0, + mask_value: Optional[float] = 0.0, + ): + # Sanity checks + assert mask_value is None or isinstance( + mask_value, numbers.Number + ), f"mask_value (type: {type(mask_value)}) must be None or a number" + if freq_mask_n > 0: + assert ( + freq_mask_f > 0 + ), f"freq_mask_F ({freq_mask_f}) " \ + f"must be larger than 0 when doing freq masking." + if time_mask_n > 0: + assert ( + time_mask_t > 0 + ), f"time_mask_T ({time_mask_t}) must be larger than 0 when " \ + f"doing time masking." + + self.time_warp_w = time_warp_w + self.freq_mask_n = freq_mask_n + self.freq_mask_f = freq_mask_f + self.time_mask_n = time_mask_n + self.time_mask_t = time_mask_t + self.time_mask_p = time_mask_p + self.mask_value = mask_value + + def __repr__(self): + return self.__class__.__name__ + '(' + ', '.join( + [f'time_warp_w={self.time_warp_w}', + f'freq_mask_n={self.freq_mask_n}', + f'freq_mask_f={self.freq_mask_f}', + f'time_mask_n={self.time_mask_n}', + f'time_mask_t={self.time_mask_t}', + f'time_mask_p={self.time_mask_p}'] + ) + ')' + + def __call__(self, spectrogram): + assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." + + distorted = spectrogram.copy() # make a copy of input spectrogram. + num_frames = spectrogram.shape[0] # or 'tau' in the paper. + num_freqs = spectrogram.shape[1] # or 'miu' in the paper. + mask_value = self.mask_value + + if mask_value is None: # if no value was specified, use local mean. + mask_value = spectrogram.mean() + + if num_frames == 0: + return spectrogram + + if num_freqs < self.freq_mask_f: + return spectrogram + + if self.time_warp_w > 0: + if 2 * self.time_warp_w < num_frames: + import cv2 + w0 = np.random.randint( + self.time_warp_w, num_frames - self.time_warp_w + ) + w = np.random.randint(0, self.time_warp_w) + upper, lower = distorted[:w0, :], distorted[w0:, :] + upper = cv2.resize( + upper, dsize=(num_freqs, w0 + w), + interpolation=cv2.INTER_LINEAR + ) + lower = cv2.resize( + lower, + dsize=(num_freqs, num_frames - w0 - w), + interpolation=cv2.INTER_LINEAR, + ) + distorted = np.concatenate((upper, lower), axis=0) + + for _i in range(self.freq_mask_n): + f = np.random.randint(0, self.freq_mask_f) + f0 = np.random.randint(0, num_freqs - f) + if f != 0: + distorted[:, f0: f0 + f] = mask_value + + max_time_mask_t = min( + self.time_mask_t, math.floor(num_frames * self.time_mask_p) + ) + if max_time_mask_t < 1: + return distorted + + for _i in range(self.time_mask_n): + t = np.random.randint(0, max_time_mask_t) + t0 = np.random.randint(0, num_frames - t) + if t != 0: + distorted[t0: t0 + t, :] = mask_value + + return distorted diff --git a/fairseq/data/audio/feature_transforms/utterance_cmvn.py b/fairseq/data/audio/feature_transforms/utterance_cmvn.py new file mode 100644 index 0000000000..cbedd360d0 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/utterance_cmvn.py @@ -0,0 +1,38 @@ +import numpy as np + +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, register_audio_feature_transform +) + + +@register_audio_feature_transform('utterance_cmvn') +class UtteranceCMVN(AudioFeatureTransform): + """Utterance-level CMVN (cepstral mean and variance normalization)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return UtteranceCMVN( + _config.get('norm_means', True), + _config.get('norm_vars', True), + ) + + def __init__(self, norm_means=True, norm_vars=True): + self.norm_means, self.norm_vars = norm_means, norm_vars + + def __repr__(self): + return self.__class__.__name__ + \ + f'(norm_means={self.norm_means}, norm_vars={self.norm_vars})' + + def __call__(self, x): + mean = x.mean(axis=0) + square_sums = (x ** 2).sum(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + if self.norm_vars: + var = square_sums / x.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-10)) + x = np.divide(x, std) + + return x diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py new file mode 100644 index 0000000000..df360b2c74 --- /dev/null +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -0,0 +1,478 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +from typing import List, Tuple, Optional, Dict +import os.path as op +import csv +import io + +import numpy as np +import torch +from fairseq.data import (FairseqDataset, Dictionary, ResamplingDataset, + ConcatDataset, data_utils as fairseq_data_utils) +from fairseq.data.audio.audio_utils import get_fbank, get_waveform +from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO +) +logger = logging.getLogger(__name__) + + +class S2TDataConfig(object): + """Wrapper class for data config YAML""" + def __init__(self, yaml_path): + try: + import yaml + except ImportError: + print('Please install PyYAML to load YAML files for ' + 'S2T data config') + self.config = {} + if op.isfile(yaml_path): + try: + with open(yaml_path) as f: + self.config = yaml.load(f, Loader=yaml.FullLoader) + except Exception as e: + logger.info(f'Failed to load config from {yaml_path}: {e}') + else: + logger.info(f'Cannot find {yaml_path}') + + @property + def vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get('vocab_filename', 'dict.txt') + + @property + def shuffle(self) -> bool: + """Shuffle dataset samples before batching""" + return self.config.get('shuffle', False) + + @property + def pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get('pre_tokenizer', {'tokenizer': None}) + + @property + def bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply after pre-tokenization. Returning + a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get('bpe_tokenizer', None) + + @property + def prepend_tgt_lang_tag(self) -> bool: + """Prepend target lang ID token as the target BOS (e.g. for to-many + multilingual setting). During inference, this requires `--prefix-size 1` + to force BOS to be lang ID token.""" + return self.config.get('prepend_tgt_lang_tag', False) + + @property + def input_feat_per_channel(self): + """The dimension of input features (per audio channel)""" + return self.config.get('input_feat_per_channel', 80) + + @property + def input_channels(self): + """The number of channels in the input audio""" + return self.config.get('input_channels', 1) + + @property + def sampling_alpha(self): + """Hyper-parameter alpha = 1/T for temperature-based resampling. + (alpha = 1 for no resampling)""" + return self.config.get('sampling_alpha', 1.) + + @property + def use_audio_input(self): + """Needed by the dataset loader to see if the model requires + raw audio as inputs.""" + return self.config.get('use_audio_input', False) + + @property + def audio_root(self): + """Audio paths in the manifest TSV can be relative and this provides + the root path. Set this to empty string when using absolute paths.""" + return self.config.get('audio_root', '') + + def get_feature_transforms(self, split, is_train): + """Split-specific feature transforms. Allowing train set wildcard `_train`, + evaluation set wildcard `_eval` and general wildcard `*` for matching.""" + from copy import deepcopy + cfg = deepcopy(self.config) + _cur = cfg.get('transforms', {}) + cur = _cur.get(split) + cur = _cur.get('_train') if cur is None and is_train else cur + cur = _cur.get('_eval') if cur is None and not is_train else cur + cur = _cur.get('*') if cur is None else cur + cfg['transforms'] = cur + return cfg + + +def is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +def is_flac_or_wav_data(data: bytes) -> bool: + is_flac = (data[0] == 102 and data[1] == 76) + is_wav = (data[0] == 82 and data[1] == 73) + return is_flac or is_wav + + +def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes: + with open(file_path, 'rb') as f: + f.seek(offset) + data = f.read(file_size) + return data + + +def get_features_from_npy_or_audio(path): + ext = op.splitext(op.basename(path))[1] + if ext not in {'.npy', '.flac', '.wav'}: + raise ValueError(f'Unsupported file format for "{path}"') + return np.load(path) if ext == '.npy' else get_fbank(path) + + +def get_features_or_waveform_from_uncompressed_zip( + path, byte_offset, byte_size, need_waveform=False +): + assert path.endswith('.zip') + data = read_from_uncompressed_zip(path, byte_offset, byte_size) + f = io.BytesIO(data) + if is_npy_data(data): + features_or_waveform = np.load(f) + elif is_flac_or_wav_data(data): + features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f) + else: + raise ValueError(f'Unknown file format for "{path}"') + return features_or_waveform + + +def get_features_or_waveform(path: str, need_waveform=False): + """Get speech features from .npy file or waveform from .wav/.flac file. + The file may be inside an uncompressed ZIP file and is accessed via byte + offset and length. + + Args: + path (str): File path in the format of "<.npy/.wav/.flac path>" or + "::". + need_waveform (bool): return waveform instead of features. + + Returns: + features_or_waveform (numpy.ndarray): speech features or waveform. + """ + _path, *extra = path.split(':') + if not op.exists(_path): + raise FileNotFoundError(f'File not found: {_path}') + + if len(extra) == 0: + if need_waveform: + return get_waveform(_path) + return get_features_from_npy_or_audio(_path) + elif len(extra) == 2: + extra = [int(i) for i in extra] + features_or_waveform = get_features_or_waveform_from_uncompressed_zip( + _path, extra[0], extra[1], need_waveform=need_waveform + ) + else: + raise ValueError(f'Invalid path: {path}') + + return features_or_waveform + + +def _collate_frames(frames: List[torch.Tensor], + is_audio_input: bool = False) -> torch.Tensor: + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + + +class SpeechToTextDataset(FairseqDataset): + LANG_TAG_TEMPLATE = '' + + def __init__( + self, + split: str, + is_train_split: bool, + data_cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + ): + self.split, self.is_train_split = split, is_train_split + self.data_cfg = data_cfg + self.audio_paths, self.n_frames = audio_paths, n_frames + self.n_samples = len(audio_paths) + assert len(n_frames) == self.n_samples > 0 + assert src_texts is None or len(src_texts) == self.n_samples + assert tgt_texts is None or len(tgt_texts) == self.n_samples + assert speakers is None or len(speakers) == self.n_samples + assert src_langs is None or len(src_langs) == self.n_samples + assert tgt_langs is None or len(tgt_langs) == self.n_samples + assert ids is None or len(ids) == self.n_samples + assert (tgt_dict is None and tgt_texts is None) or \ + (tgt_dict is not None and tgt_texts is not None) + self.tgt_dict = tgt_dict + self.check_tgt_lang_tag() + self.src_texts, self.tgt_texts = src_texts, tgt_texts + self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.ids = ids + self.shuffle = data_cfg.shuffle if is_train_split else False + + self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.data_cfg.get_feature_transforms(split, is_train_split) + ) + + self.pre_tokenizer = pre_tokenizer + self.bpe_tokenizer = bpe_tokenizer + + logger.info(self.__repr__()) + + def __repr__(self): + return self.__class__.__name__ + \ + f'(split="{self.split}", n_samples={self.n_samples}, ' \ + f'prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, ' \ + f'shuffle={self.shuffle}, transforms={self.feature_transforms})' + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace('{}', '(.*)') + return re.match(pattern, token) + + def check_tgt_lang_tag(self): + if self.data_cfg.prepend_tgt_lang_tag: + assert self.tgt_langs is not None and self.tgt_dict is not None + tgt_lang_tags = [self.LANG_TAG_TEMPLATE.format(t) + for t in set(self.tgt_langs)] + assert all(t in self.tgt_dict for t in tgt_lang_tags) + + def tokenize_text(self, text: str): + if self.pre_tokenizer is not None: + text = self.pre_tokenizer.encode(text) + if self.bpe_tokenizer is not None: + text = self.bpe_tokenizer.encode(text) + return text + + def __getitem__( + self, index: int + ) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]: + source = get_features_or_waveform( + self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input + ) + if self.feature_transforms is not None: + assert not self.data_cfg.use_audio_input + source = self.feature_transforms(source) + source = torch.from_numpy(source).float() + + target = None + if self.tgt_texts is not None: + tokenized = self.tokenize_text(self.tgt_texts[index]) + target = self.tgt_dict.encode_line( + tokenized, add_if_not_exist=False, append_eos=True + ).long() + if self.data_cfg.prepend_tgt_lang_tag: + lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index]) + lang_tag_idx = self.tgt_dict.index(lang_tag) + target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) + return index, source, target + + def __len__(self): + return self.n_samples + + def collater( + self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]] + ) -> Dict: + if len(samples) == 0: + return {} + indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long) + frames = _collate_frames([s for _, s, _ in samples], + self.data_cfg.use_audio_input) + # sort samples by descending number of frames + n_frames = torch.tensor( + [s.size(0) for _, s, _ in samples], dtype=torch.long + ) + n_frames, order = n_frames.sort(descending=True) + indices = indices.index_select(0, order) + frames = frames.index_select(0, order) + + target, target_lengths = None, None + prev_output_tokens = None + ntokens = None + if self.tgt_texts is not None: + target = fairseq_data_utils.collate_tokens( + [t for _, _, t in samples], self.tgt_dict.pad(), + self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False + ) + target = target.index_select(0, order) + target_lengths = torch.tensor( + [t.size(0) for _, _, t in samples], dtype=torch.long + ).index_select(0, order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [t for _, _, t in samples], self.tgt_dict.pad(), + self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True + ) + prev_output_tokens = prev_output_tokens.index_select(0, order) + ntokens = sum(t.size(0) for _, _, t in samples) + + out = { + "id": indices, + "net_input": { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + }, + "target": target, + "target_lengths": target_lengths, + "ntokens": ntokens, + "nsentences": len(samples), + } + return out + + def num_tokens(self, index): + return self.n_frames[index] + + def size(self, index): + t_len = 0 + if self.tgt_texts is not None: + tokenized = self.tokenize_text(self.tgt_texts[index]) + t_len = len(tokenized.split(' ')) + return self.n_frames[index], t_len + + @property + def sizes(self): + return np.array(self.n_frames) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + # first by descending order of # of frames then by original/random order + order.append([-n for n in self.n_frames]) + return np.lexsort(order) + + def prefetch(self, indices): + raise False + + +class SpeechToTextDatasetCreator(object): + # mandatory columns + KEY_ID, KEY_AUDIO, KEY_N_FRAMES = 'id', 'audio', 'n_frames' + KEY_TGT_TEXT = 'tgt_text' + # optional columns + KEY_SPEAKER, KEY_SRC_TEXT = 'speaker', 'src_text' + KEY_SRC_LANG, KEY_TGT_LANG = 'src_lang', 'tgt_lang' + # default values + DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = '' + + @classmethod + def _from_list(cls, split_name: str, is_train_split, + samples: List[List[Dict]], data_cfg: S2TDataConfig, tgt_dict, + pre_tokenizer, bpe_tokenizer) -> SpeechToTextDataset: + audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], [] + speakers, src_langs, tgt_langs = [], [], [] + for s in samples: + ids.extend([ss[cls.KEY_ID] for ss in s]) + audio_paths.extend([op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) + for ss in s]) + n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s]) + tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s]) + src_texts.extend([ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) + for ss in s]) + speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) + for ss in s]) + src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) + for ss in s]) + tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) + for ss in s]) + return SpeechToTextDataset( + split_name, is_train_split, data_cfg, audio_paths, n_frames, + src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, + pre_tokenizer, bpe_tokenizer + ) + + @classmethod + def _get_size_ratios(cls, ids: List[str], sizes: List[int], + alpha: float = 1.): + """Size ratios for temperature-based sampling + (https://arxiv.org/abs/1907.05019)""" + _sizes = np.array(sizes) + prob = _sizes / _sizes.sum() + smoothed_prob = prob ** alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + size_ratio = (smoothed_prob * _sizes.sum()) / _sizes + + o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)}) + logger.info(f"original sampling probability: {o_str}") + p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)}) + logger.info(f"balanced sampling probability: {p_str}") + sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)}) + logger.info(f"balanced sampling size ratio: {sr_str}") + return size_ratio.tolist() + + @classmethod + def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict, + pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, + seed: int) -> SpeechToTextDataset: + samples = [] + _splits = splits.split(',') + for split in _splits: + tsv_path = op.join(root, f'{split}.tsv') + if not op.isfile(tsv_path): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, delimiter='\t', quotechar=None, doublequote=False, + lineterminator='\n', quoting=csv.QUOTE_NONE + ) + samples.append([dict(e) for e in reader]) + assert len(samples) > 0 + + datasets = [cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict, + pre_tokenizer, bpe_tokenizer) + for name, s in zip(_splits, samples)] + + if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.: + # temperature-based sampling + size_ratios = cls._get_size_ratios( + _splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha + ) + datasets = [ + ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, + replace=(r >= 1.)) + for d, r in zip(datasets, size_ratios) + ] + return ConcatDataset(datasets) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 224169c366..a8c480c5b1 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -446,3 +446,14 @@ def get_mem_usage(): return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb' except ImportError: return 'N/A' + + +def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: + return ~lengths_to_padding_mask(lens) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py new file mode 100644 index 0000000000..351c16fee5 --- /dev/null +++ b/fairseq/models/speech_to_text/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .berard import * # noqa +from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/berard.py b/fairseq/models/speech_to_text/berard.py new file mode 100644 index 0000000000..f5ae46eeb2 --- /dev/null +++ b/fairseq/models/speech_to_text/berard.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 + +from ast import literal_eval +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, utils +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.data.data_utils import lengths_to_padding_mask + + +@register_model("s2t_berard") +class BerardModel(FairseqEncoderDecoderModel): + """Implementation of a model similar to https://arxiv.org/abs/1802.04200 + + Paper title: End-to-End Automatic Speech Translation of Audiobooks + An implementation is available in tensorflow at + https://github.com/eske/seq2seq + Relevant files in this implementation are the config + (https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml) + and the model code + (https://github.com/eske/seq2seq/blob/master/translate/models.py). + The encoder and decoder try to be close to the original implementation. + The attention is an MLP as in Bahdanau et al. + (https://arxiv.org/abs/1409.0473). + There is no state initialization by averaging the encoder outputs. + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + parser.add_argument("--input-layers", type=str, metavar="EXPR", + help="List of linear layer dimensions. These " + "layers are applied to the input features and " + "are followed by tanh and possibly dropout.") + parser.add_argument( + "--dropout", type=float, metavar="D", + help="Dropout probability to use in the encoder/decoder. " + "Note that this parameters control dropout in various places, " + "there is no fine-grained control for dropout for embeddings " + "vs LSTM layers for example." + ) + parser.add_argument("--in-channels", type=int, metavar="N", + help="Number of encoder input channels. " + "Typically value is 1.") + parser.add_argument("--conv-layers", type=str, metavar="EXPR", + help="List of conv layers " + "(format: (channels, kernel, stride)).") + parser.add_argument("--num-blstm-layers", type=int, metavar="N", + help="Number of encoder bi-LSTM layers.") + parser.add_argument("--lstm-size", type=int, metavar="N", + help="LSTM hidden size.") + parser.add_argument( + "--decoder-embed-dim", type=int, metavar="N", + help="Embedding dimension of the decoder target tokens." + ) + parser.add_argument("--decoder-hidden-dim", type=int, metavar="N", + help="Decoder LSTM hidden dimension.") + parser.add_argument("--decoder-num-layers", type=int, metavar="N", + help="Number of decoder LSTM layers.") + parser.add_argument("--attention-dim", type=int, metavar="N", + help="Hidden layer dimension in MLP attention.") + parser.add_argument( + "--output-layer-dim", type=int, metavar="N", + help="Hidden layer dim for linear layer prior to output projection." + ) + parser.add_argument( + "--load-pretrained-encoder-from", type=str, metavar="STR", + help="model to take encoder weights from (for initialization)" + ) + parser.add_argument( + "--load-pretrained-decoder-from", type=str, metavar="STR", + help="model to take decoder weights from (for initialization)" + ) + + @classmethod + def build_encoder(cls, args, task): + encoder = BerardEncoder( + input_layers=literal_eval(args.input_layers), + conv_layers=literal_eval(args.conv_layers), + in_channels=args.input_channels, + input_feat_per_channel=args.input_feat_per_channel, + num_blstm_layers=args.num_blstm_layers, + lstm_size=args.lstm_size, + dropout=args.dropout, + ) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task): + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + num_layers=args.decoder_num_layers, + hidden_size=args.decoder_hidden_dim, + dropout=args.dropout, + encoder_output_dim=2 * args.lstm_size, # bidirectional + attention_dim=args.attention_dim, + output_layer_dim=args.output_layer_dim, + ) + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + + return cls(encoder, decoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + # lprobs is a (B, T, D) tensor + lprobs.batch_first = True + return lprobs + + +class BerardEncoder(FairseqEncoder): + def __init__( + self, + input_layers: List[int], + conv_layers: List[Tuple[int]], + in_channels: int, + input_feat_per_channel: int, + num_blstm_layers: int, + lstm_size: int, + dropout: float, + ): + """ + Args: + input_layers: list of linear layer dimensions. These layers are + applied to the input features and are followed by tanh and + possibly dropout. + conv_layers: list of conv2d layer configurations. A configuration is + a tuple (out_channels, conv_kernel_size, stride). + in_channels: number of input channels. + input_feat_per_channel: number of input features per channel. These + are speech features, typically 40 or 80. + num_blstm_layers: number of bidirectional LSTM layers. + lstm_size: size of the LSTM hidden (and cell) size. + dropout: dropout probability. Dropout can be applied after the + linear layers and LSTM layers but not to the convolutional + layers. + """ + super().__init__(None) + + self.input_layers = nn.ModuleList() + in_features = input_feat_per_channel + for out_features in input_layers: + if dropout > 0: + self.input_layers.append( + nn.Sequential( + nn.Linear(in_features, out_features), + nn.Dropout(p=dropout) + ) + ) + else: + self.input_layers.append(nn.Linear(in_features, out_features)) + in_features = out_features + + self.in_channels = in_channels + self.input_dim = input_feat_per_channel + self.conv_kernel_sizes_and_strides = [] + self.conv_layers = nn.ModuleList() + lstm_input_dim = input_layers[-1] + for conv_layer in conv_layers: + out_channels, conv_kernel_size, conv_stride = conv_layer + self.conv_layers.append( + nn.Conv2d( + in_channels, + out_channels, + conv_kernel_size, + stride=conv_stride, + padding=conv_kernel_size // 2, + ) + ) + self.conv_kernel_sizes_and_strides.append( + (conv_kernel_size, conv_stride) + ) + in_channels = out_channels + lstm_input_dim //= conv_stride + + lstm_input_dim *= conv_layers[-1][0] + self.lstm_size = lstm_size + self.num_blstm_layers = num_blstm_layers + self.lstm = nn.LSTM( + input_size=lstm_input_dim, + hidden_size=lstm_size, + num_layers=num_blstm_layers, + dropout=dropout, + bidirectional=True, + ) + self.output_dim = 2 * lstm_size # bidirectional + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = None + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + bsz, max_seq_len, _ = src_tokens.size() + # (B, C, T, feat) + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + + for input_layer in self.input_layers: + x = input_layer(x) + x = torch.tanh(x) + + for conv_layer in self.conv_layers: + x = conv_layer(x) + + bsz, _, output_seq_len, _ = x.size() + + # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> + # (T, B, C * feat) + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, + bsz, -1) + + input_lengths = src_lengths.clone() + for k, s in self.conv_kernel_sizes_and_strides: + p = k // 2 + input_lengths = (input_lengths.float() + 2 * p - k) / s + 1 + input_lengths = input_lengths.floor().long() + + packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths) + + h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_() + c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_() + packed_outs, _ = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs) + if self.dropout is not None: + x = self.dropout(x) + + encoder_padding_mask = lengths_to_padding_mask(output_lengths).to( + src_tokens.device).t() + + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": encoder_padding_mask, # (T, B) + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + +class MLPAttention(nn.Module): + """The original attention from Badhanau et al. (2014) + + https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron. + The attention score between position i in the encoder and position j in the + decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a) + """ + + def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim): + super().__init__() + + self.context_dim = context_dim + self.attention_dim = attention_dim + # W_ae and b_a + self.encoder_proj = nn.Linear(context_dim, self.attention_dim, + bias=True) + # W_ad + self.decoder_proj = nn.Linear( + decoder_hidden_state_dim, self.attention_dim, bias=False + ) + # V_a + self.to_scores = nn.Linear(self.attention_dim, 1, bias=False) + + def forward(self, decoder_state, source_hids, encoder_padding_mask): + """The expected input dimensions are: + decoder_state: bsz x decoder_hidden_state_dim + source_hids: src_len x bsz x context_dim + encoder_padding_mask: src_len x bsz + """ + src_len, bsz, _ = source_hids.size() + # (src_len*bsz) x context_dim (to feed through linear) + flat_source_hids = source_hids.view(-1, self.context_dim) + # (src_len*bsz) x attention_dim + encoder_component = self.encoder_proj(flat_source_hids) + # src_len x bsz x attention_dim + encoder_component = encoder_component.view(src_len, bsz, + self.attention_dim) + # 1 x bsz x attention_dim + decoder_component = self.decoder_proj(decoder_state).unsqueeze(0) + # Sum with broadcasting and apply the non linearity + # src_len x bsz x attention_dim + hidden_att = torch.tanh( + (decoder_component + encoder_component).view(-1, self.attention_dim) + ) + # Project onto the reals to get attentions scores (src_len x bsz) + attn_scores = self.to_scores(hidden_att).view(src_len, bsz) + + # Mask + softmax (src_len x bsz) + if encoder_padding_mask is not None: + attn_scores = ( + attn_scores.float() + .masked_fill_(encoder_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back + # srclen x bsz + normalized_masked_attn_scores = F.softmax(attn_scores, dim=0) + + # Sum weighted sources (bsz x context_dim) + attn_weighted_context = ( + source_hids * normalized_masked_attn_scores.unsqueeze(2) + ).sum(dim=0) + + return attn_weighted_context, normalized_masked_attn_scores + + +class LSTMDecoder(FairseqIncrementalDecoder): + def __init__( + self, + dictionary, + embed_dim, + num_layers, + hidden_size, + dropout, + encoder_output_dim, + attention_dim, + output_layer_dim, + ): + """ + Args: + dictionary: target text dictionary. + embed_dim: embedding dimension for target tokens. + num_layers: number of LSTM layers. + hidden_size: hidden size for LSTM layers. + dropout: dropout probability. Dropout can be applied to the + embeddings, the LSTM layers, and the context vector. + encoder_output_dim: encoder output dimension (hidden size of + encoder LSTM). + attention_dim: attention dimension for MLP attention. + output_layer_dim: size of the linear layer prior to output + projection. + """ + super().__init__(dictionary) + self.num_layers = num_layers + self.hidden_size = hidden_size + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx) + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for layer_id in range(num_layers): + input_size = embed_dim if layer_id == 0 else encoder_output_dim + self.layers.append( + nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) + ) + + self.context_dim = encoder_output_dim + self.attention = MLPAttention( + decoder_hidden_state_dim=hidden_size, + context_dim=encoder_output_dim, + attention_dim=attention_dim, + ) + + self.deep_output_layer = nn.Linear( + hidden_size + encoder_output_dim + embed_dim, output_layer_dim + ) + self.output_projection = nn.Linear(output_layer_dim, num_embeddings) + + def forward(self, prev_output_tokens, encoder_out=None, + incremental_state=None, **kwargs): + encoder_padding_mask = encoder_out["encoder_padding_mask"] + encoder_outs = encoder_out["encoder_out"] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + srclen = encoder_outs.size(0) + + # embed tokens + embeddings = self.embed_tokens(prev_output_tokens) + x = embeddings + if self.dropout is not None: + x = self.dropout(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental + # generation) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is not None: + prev_hiddens, prev_cells = cached_state + else: + prev_hiddens = [ + encoder_out["encoder_out"].mean(dim=0) + ] * self.num_layers + prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers + + attn_scores = x.new_zeros(bsz, srclen) + attention_outs = [] + outs = [] + for j in range(seqlen): + input = x[j, :, :] + attention_out = None + for i, layer in enumerate(self.layers): + # the previous state is one layer below except for the bottom + # layer where the previous state is the state emitted by the + # top layer + hidden, cell = layer( + input, + ( + prev_hiddens[(i - 1) % self.num_layers], + prev_cells[(i - 1) % self.num_layers], + ), + ) + if self.dropout is not None: + hidden = self.dropout(hidden) + prev_hiddens[i] = hidden + prev_cells[i] = cell + if attention_out is None: + attention_out, attn_scores = self.attention( + hidden, encoder_outs, encoder_padding_mask + ) + if self.dropout is not None: + attention_out = self.dropout(attention_out) + attention_outs.append(attention_out) + input = attention_out + + # collect the output of the top layer + outs.append(hidden) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, incremental_state, "cached_state", (prev_hiddens, prev_cells) + ) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + attention_outs_concat = torch.cat(attention_outs, dim=0).view( + seqlen, bsz, self.context_dim + ) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + attention_outs_concat = attention_outs_concat.transpose(0, 1) + + # concat LSTM output, attention output and embedding + # before output projection + x = torch.cat((x, attention_outs_concat, embeddings), dim=2) + x = self.deep_output_layer(x) + x = torch.tanh(x) + if self.dropout is not None: + x = self.dropout(x) + # project back to size of vocabulary + x = self.output_projection(x) + + # to return the full attn_scores tensor, we need to fix the decoder + # to account for subsampling input frames + # return x, attn_scores + return x, None + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state( + self, incremental_state, "cached_state", new_state + ) + + +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard") +def berard(args): + """The original version: "End-to-End Automatic Speech Translation of + Audiobooks" (https://arxiv.org/abs/1802.04200) + """ + args.input_layers = getattr(args, "input_layers", "[256, 128]") + args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]") + args.num_blstm_layers = getattr(args, "num_blstm_layers", 3) + args.lstm_size = getattr(args, "lstm_size", 256) + args.dropout = getattr(args, "dropout", 0.2) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 2) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 128) + args.load_pretrained_encoder_from = getattr( + args, "load_pretrained_encoder_from", None + ) + args.load_pretrained_decoder_from = getattr( + args, "load_pretrained_decoder_from", None + ) + + +@register_model_architecture(model_name="s2t_berard", + arch_name="s2t_berard_256_3_3") +def berard_256_3_3(args): + """Used in + * "Harnessing Indirect Training Data for End-to-End Automatic Speech + Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515) + * "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus" + (https://arxiv.org/pdf/2002.01320.pdf) + * "Self-Supervised Representations Improve End-to-End Speech Translation" + (https://arxiv.org/abs/2006.12124) + """ + args.decoder_num_layers = getattr(args, "decoder_num_layers", 3) + berard(args) + + +@register_model_architecture(model_name="s2t_berard", + arch_name="s2t_berard_512_3_2") +def berard_512_3_2(args): + args.num_blstm_layers = getattr(args, "num_blstm_layers", 3) + args.lstm_size = getattr(args, "lstm_size", 512) + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 2) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 256) + berard(args) + + +@register_model_architecture(model_name="s2t_berard", + arch_name="s2t_berard_512_5_3") +def berard_512_5_3(args): + args.num_blstm_layers = getattr(args, "num_blstm_layers", 5) + args.lstm_size = getattr(args, "lstm_size", 512) + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 3) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 256) + berard(args) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py new file mode 100644 index 0000000000..3492f691f7 --- /dev/null +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 + +import logging +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq import utils, checkpoint_utils +from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, + register_model, register_model_architecture) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import (PositionalEmbedding, TransformerEncoderLayer, + FairseqDropout, LayerNorm) +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +class Conv1dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 1D convolution (along temporal + dimension) followed by non-linear activation via gated linear units + (https://arxiv.org/abs/1911.08460) + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, + kernel_sizes: List[int] = (3, 3)): + super(Conv1dSubsampler, self).__init__() + self.n_layers = len(kernel_sizes) + self.conv_layers = nn.ModuleList( + nn.Conv1d( + in_channels if i == 0 else mid_channels // 2, + mid_channels if i < self.n_layers - 1 else out_channels * 2, + k, stride=2, padding=k // 2 + ) + for i, k in enumerate(kernel_sizes) + ) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in range(self.n_layers): + out = ((out.float() - 1) / 2 + 1).floor().long() + return out + + def forward(self, src_tokens, src_lengths): + bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) + x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + x = conv(x) + x = nn.functional.glu(x, dim=1) + _, _, out_seq_len = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) + return x, self.get_out_seq_lens_tensor(src_lengths) + + +@register_model("s2t_transformer") +class S2TTransformerModel(FairseqEncoderDecoderModel): + """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for + speech-to-text tasks. The Transformer encoder/decoder remains the same. + A trainable input subsampler is prepended to the Transformer encoder to + project inputs into the encoder dimension as well as downsample input + sequence for computational efficiency.""" + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # input + parser.add_argument("--conv-kernel-sizes", type=str, metavar="N", + help="kernel sizes of Conv1d subsampling layers") + parser.add_argument("--conv-channels", type=int, metavar="N", + help="# of channels in Conv1d subsampling layers") + # Transformer + parser.add_argument("--activation-fn", type=str, default='relu', + choices=utils.get_available_activation_fns(), + help="activation function to use") + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--attention-dropout", type=float, metavar="D", + help="dropout probability for attention weights") + parser.add_argument("--activation-dropout", "--relu-dropout", + type=float, metavar="D", + help="dropout probability after activation in FFN.") + parser.add_argument("--encoder-embed-dim", type=int, metavar="N", + help="encoder embedding dimension") + parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N", + help="encoder embedding dimension for FFN") + parser.add_argument("--encoder-layers", type=int, metavar="N", + help="num encoder layers") + parser.add_argument("--encoder-attention-heads", type=int, metavar="N", + help="num encoder attention heads") + parser.add_argument("--encoder-normalize-before", action="store_true", + help="apply layernorm before each encoder block") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N", + help="decoder embedding dimension for FFN") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="num decoder layers") + parser.add_argument("--decoder-attention-heads", type=int, metavar="N", + help="num decoder attention heads") + parser.add_argument("--decoder-normalize-before", action="store_true", + help="apply layernorm before each decoder block") + parser.add_argument("--layernorm-embedding", action="store_true", + help="add layernorm to embedding") + parser.add_argument("--no-scale-embedding", action="store_true", + help="if True, dont scale embeddings") + parser.add_argument( + "--load-pretrained-encoder-from", type=str, metavar="STR", + help="model to take encoder weights from (for initialization)" + ) + + @classmethod + def build_encoder(cls, args): + encoder = S2TTransformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + logger.info(f'loaded pretrained encoder from: ' + f'{args.load_pretrained_encoder_from}') + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + return TransformerDecoderScriptable(args, task.target_dictionary, + embed_tokens) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding(task.target_dictionary, + args.decoder_embed_dim) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + lprobs.batch_first = True + return lprobs + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder(src_tokens=src_tokens, + src_lengths=src_lengths) + decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out) + return decoder_out + + +class S2TTransformerEncoder(FairseqEncoder): + """Speech-to-text Transformer encoder that consists of input subsampler and + Transformer encoder.""" + + def __init__(self, args): + super().__init__(None) + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_embedding: + self.embed_scale = 1.0 + self.padding_idx = 1 + + self.subsample = Conv1dSubsampler( + args.input_feat_per_channel * args.input_channels, + args.conv_channels, args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(',')] + ) + + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, + self.padding_idx + ) + + self.transformer_layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def forward(self, src_tokens, src_lengths): + x, input_lengths = self.subsample(src_tokens, src_lengths) + x = self.embed_scale * x + + encoder_padding_mask = lengths_to_padding_mask(input_lengths) + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = self.dropout_module(x) + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return EncoderOut( + encoder_out=x, encoder_padding_mask=encoder_padding_mask, + encoder_embedding=None, encoder_states=None, src_tokens=None, + src_lengths=None + ) + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + """ + Since encoder_padding_mask and encoder_embedding are both of type + Optional[Tensor] in EncoderOut, they need to be copied as local + variables for Torchscript Optional refinement + """ + + encoder_padding_mask: Optional[Tensor] = \ + encoder_out.encoder_padding_mask + encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + + new_encoder_out = ( + encoder_out.encoder_out + if encoder_out.encoder_out is None + else encoder_out.encoder_out.index_select(1, new_order) + ) + + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(0, new_order) + ) + + new_encoder_embedding = ( + encoder_embedding + if encoder_embedding is None + else encoder_embedding.index_select(0, new_order) + ) + + encoder_states = encoder_out.encoder_states + if encoder_states is not None: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return EncoderOut( + encoder_out=new_encoder_out, # T x B x C + encoder_padding_mask=new_encoder_padding_mask, # B x T + encoder_embedding=new_encoder_embedding, # B x T x C + encoder_states=encoder_states, # List[T x B x C] + src_tokens=None, + src_lengths=None, + ) + + +class TransformerDecoderScriptable(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, encoder_out, incremental_state, + full_context_alignment, alignment_layer, alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="s2t_transformer", + arch_name="s2t_transformer") +def base_architecture(args): + # Convolutional subsampler + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", '5,5') + args.conv_channels = getattr(args, "conv_channels", 1024) + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", + True) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", + args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", + args.encoder_ffn_embed_dim) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", + True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", + None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr(args, "decoder_output_dim", + args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, "decoder_input_dim", + args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_s") +def s2t_transformer_s(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_sp") +def s2t_transformer_sp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_s(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_m") +def s2t_transformer_m(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.dropout = getattr(args, "dropout", 0.15) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_mp") +def s2t_transformer_mp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_m(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_l") +def s2t_transformer_l(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", + 1024 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.2) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_lp") +def s2t_transformer_lp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_l(args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 1ce4ab1921..a8bfaa532d 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -10,10 +10,9 @@ import torch from fairseq import metrics, search, tokenizer, utils -from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators +from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators, encoders from fairseq.dataclass.utils import gen_parser_from_dataclass - logger = logging.getLogger(__name__) @@ -504,6 +503,14 @@ def target_dictionary(self): for this task).""" raise NotImplementedError + def build_tokenizer(self, args): + """Build the pre-tokenizer for this task.""" + return encoders.build_tokenizer(args) + + def build_bpe(self, args): + """Build the tokenizer for this task.""" + return encoders.build_bpe(args) + class LegacyFairseqTask(FairseqTask): def __init__(self, args: Namespace): diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py new file mode 100644 index 0000000000..b17ad22602 --- /dev/null +++ b/fairseq/tasks/speech_to_text.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from argparse import Namespace +import os.path as op + +from fairseq.data import encoders, Dictionary +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig +) +from fairseq.tasks import FairseqTask, register_task + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + ) +logger = logging.getLogger(__name__) + + +@register_task('speech_to_text') +class SpeechToTextTask(FairseqTask): + @staticmethod + def add_args(parser): + parser.add_argument('data', help='manifest root path') + parser.add_argument( + '--config-yaml', type=str, default='config.yaml', + help='Configuration YAML filename (under manifest root)' + ) + parser.add_argument('--max-source-positions', default=6000, type=int, + metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, + metavar='N', + help='max number of tokens in the target sequence') + + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) + + @classmethod + def setup_task(cls, args, **kwargs): + data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) + dict_path = op.join(args.data, data_cfg.vocab_filename) + if not op.isfile(dict_path): + raise FileNotFoundError(f'Dict not found: {dict_path}') + tgt_dict = Dictionary.load(dict_path) + logger.info(f'dictionary size ({data_cfg.vocab_filename}): ' + f'{len(tgt_dict):,}') + + if getattr(args, 'train_subset', None) is not None: + if not all(s.startswith('train') for s in args.train_subset.split(',')): + raise ValueError('Train splits should be named like "train*".') + return cls(args, tgt_dict) + + def build_criterion(self, args): + from fairseq import criterions + if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: + raise ValueError('Please set "--ignore-prefix-size 1" since ' + 'target language ID token is prepended as BOS.') + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith('train') + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( + self.args.data, self.data_cfg, split, self.tgt_dict, + pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split, + epoch=epoch, seed=self.args.seed + ) + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def source_dictionary(self): + return None + + def max_positions(self): + return self.args.max_source_positions, self.args.max_target_positions + + def build_model(self, args): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_channels + return super(SpeechToTextTask, self).build_model(args) + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, + ): + if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: + raise ValueError('Please set "--prefix-size 1" since ' + 'target language ID token is prepended as BOS.') + lang_token_ids = { + i for s, i in self.tgt_dict.indices.items() + if SpeechToTextDataset.is_lang_tag(s) + } + extra_gen_cls_kwargs = {'symbols_to_strip_from_output': lang_token_ids} + return super().build_generator( + models, args, seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def build_tokenizer(self, args): + logger.info(f'pre-tokenizer: {self.data_cfg.pre_tokenizer}') + return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) + + def build_bpe(self, args): + logger.info(f'tokenizer: {self.data_cfg.bpe_tokenizer}') + return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) + + @classmethod + def build_dataset_for_inference(cls, audio_paths, n_frames): + return SpeechToTextDataset('interactive', False, {}, audio_paths, + n_frames) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 15b0552c3c..0064b88a95 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -21,7 +21,6 @@ from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter -from fairseq.data import encoders def main(args): @@ -158,8 +157,8 @@ def _main(args, output_file): generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(args) - bpe = encoders.build_bpe(args) + tokenizer = task.build_tokenizer(args) + bpe = task.build_bpe(args) def decode_fn(x): if bpe is not None: diff --git a/setup.py b/setup.py index 215276925f..21e05d8da6 100644 --- a/setup.py +++ b/setup.py @@ -141,7 +141,7 @@ def include_dirs(self, dirs): 'hydra-core', 'numpy', 'regex', - 'sacrebleu', + 'sacrebleu>=1.4.12', 'torch', 'tqdm', ], diff --git a/tests/test_label_smoothing.py b/tests/test_label_smoothing.py index 8432d3c7bf..94e5ccf1f3 100644 --- a/tests/test_label_smoothing.py +++ b/tests/test_label_smoothing.py @@ -38,6 +38,7 @@ def setUp(self): # build model self.args = argparse.Namespace() self.args.sentence_avg = False + self.args.report_accuracy = False self.args.probs = torch.FloatTensor([ # pad eos unk w1 w2 w3 [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05], From 6d3712ed91cc2341f3b70090c34e5f9bf5f839ab Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Wed, 14 Oct 2020 12:27:45 -0700 Subject: [PATCH 211/707] Update scorer with evaluation-time tokenizer and chrF Summary: * Added evaluation-time tokenizer (using sacreBLEU's built-in tokenizers): `fairseq.scoring.tokenizer` * Added chrF scorer: `fairseq.scoring.chrf` * Updated sacreBLEU scorer with evaluation-time tokenizer * Updated WER scorer with evaluation-time tokenizer: There are cases where we train ASR models without pre-tokenization or punctuation removal. The tokenization/normalization is done at evaluation time before scoring. Reviewed By: myleott Differential Revision: D24219634 fbshipit-source-id: ecde21cb19206b96efff7606e101d476d5687888 --- fairseq/scoring/__init__.py | 29 ++++++++++++++++-- fairseq/scoring/bleu.py | 43 +++++++++++++++++--------- fairseq/scoring/chrf.py | 26 ++++++++++++++++ fairseq/scoring/tokenizer.py | 59 ++++++++++++++++++++++++++++++++++++ fairseq/scoring/wer.py | 47 ++++++++++++++++++++++------ 5 files changed, 176 insertions(+), 28 deletions(-) create mode 100644 fairseq/scoring/chrf.py create mode 100644 fairseq/scoring/tokenizer.py diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index c86d6b6c23..4468f2ad21 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -6,11 +6,35 @@ import importlib import os +from abc import ABC, abstractmethod from fairseq import registry -_build_scoring, register_scoring, SCORING_REGISTRY, _ = registry.setup_registry( +class BaseScorer(ABC): + def __init__(self, args): + self.args = args + self.ref = [] + self.pred = [] + + @staticmethod + def add_args(parser): + pass + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + @abstractmethod + def score(self) -> float: + pass + + @abstractmethod + def result_string(self) -> str: + pass + + +_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( "--scoring", default="bleu" ) @@ -26,8 +50,7 @@ def build_scorer(args, tgt_dict): if args.scoring == "bleu": from fairseq.scoring import bleu return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) - else: - return _build_scoring(args) + return _build_scorer(args) # automatically import any Python files in the current directory diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 15275d94c9..a45d44b003 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -9,7 +9,8 @@ import torch -from fairseq.scoring import register_scoring +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer class BleuStat(ctypes.Structure): @@ -27,23 +28,33 @@ class BleuStat(ctypes.Structure): ] -@register_scoring("sacrebleu") -class SacrebleuScorer(object): - def __init__(self, *unused): +@register_scorer("sacrebleu") +class SacrebleuScorer(BaseScorer): + def __init__(self, args): + super(SacrebleuScorer, self).__init__(args) import sacrebleu - self.sacrebleu = sacrebleu - self.reset() + self.tokenizer = EvaluationTokenizer( + tokenizer_type=self.args.sacrebleu_tokenizer, + lowercase=self.args.sacrebleu_lowercase, + character_tokenization=self.args.sacrebleu_char_level + ) - def reset(self, one_init=False): - if one_init: - raise NotImplementedError - self.ref = [] - self.sys = [] + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a', + choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, + help='tokenizer') + parser.add_argument('--sacrebleu-lowercase', type=str, default=False, + help='apply lowercasing') + parser.add_argument('--sacrebleu-char-level', action='store_true', + help='evaluate at character level') + # fmt: on def add_string(self, ref, pred): - self.ref.append(ref) - self.sys.append(pred) + self.ref.append(self.tokenizer.tokenize(ref)) + self.pred.append(self.tokenizer.tokenize(pred)) def score(self, order=4): return self.result_string(order).score @@ -51,10 +62,12 @@ def score(self, order=4): def result_string(self, order=4): if order != 4: raise NotImplementedError - return self.sacrebleu.corpus_bleu(self.sys, [self.ref]).format() + # tokenization and lowercasing are performed by self.tokenizer instead. + return self.sacrebleu.corpus_bleu(self.pred, [self.ref], + tokenize='none').format() -@register_scoring("bleu") +@register_scorer("bleu") class Scorer(object): def __init__(self, pad, eos, unk): self.stat = BleuStat() diff --git a/fairseq/scoring/chrf.py b/fairseq/scoring/chrf.py new file mode 100644 index 0000000000..b932a43604 --- /dev/null +++ b/fairseq/scoring/chrf.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.scoring import BaseScorer, register_scorer + + +@register_scorer('chrf') +class ChrFScorer(BaseScorer): + def __init__(self, args): + super(ChrFScorer, self).__init__(args) + import sacrebleu + self.sacrebleu = sacrebleu + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + return self.result_string(order).score + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py new file mode 100644 index 0000000000..c9d5218e1e --- /dev/null +++ b/fairseq/scoring/tokenizer.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unicodedata + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + ALL_TOKENIZER_TYPES = ['none', '13a', 'intl', 'zh', 'ja-mecab'] + + def __init__(self, tokenizer_type: str = '13a', lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False): + from sacrebleu.tokenizers import TOKENIZERS + + assert tokenizer_type in self.ALL_TOKENIZER_TYPES + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + self.tokenizer = TOKENIZERS[tokenizer_type] + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join( + t for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == 'P' for c in t) + ) + + def tokenize(self, sent: str): + tokenized = self.tokenizer()(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join( + list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) + ) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 3aee5f69db..61c5fd950e 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,29 +3,56 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.scoring import register_scoring +from fairseq.scoring import register_scorer, BaseScorer +from fairseq.scoring.tokenizer import EvaluationTokenizer -@register_scoring("wer") -class WerScorer(object): - def __init__(self, *unused): +@register_scorer("wer") +class WerScorer(BaseScorer): + def __init__(self, args): + super().__init__(args) self.reset() + try: + import editdistance as ed + except ImportError: + raise ImportError('Please install editdistance to use WER scorer') + self.ed = ed + self.tokenizer = EvaluationTokenizer( + tokenizer_type=self.args.wer_tokenizer, + lowercase=self.args.wer_lowercase, + punctuation_removal=self.args.wer_remove_punct, + character_tokenization=self.args.wer_char_level, + ) + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--wer-tokenizer', type=str, default='none', + choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, + help='sacreBLEU tokenizer to use for evaluation') + parser.add_argument('--wer-remove-punct', action='store_true', + help='remove punctuation') + parser.add_argument('--wer-char-level', action='store_true', + help='evaluate at character level') + parser.add_argument('--wer-lowercase', action='store_true', + help='lowercasing') + # fmt: on def reset(self): self.distance = 0 self.ref_length = 0 def add_string(self, ref, pred): - import editdistance - ref_items = ref.split() - pred_items = pred.split() - self.distance += editdistance.eval(ref_items, pred_items) + ref_items = self.tokenizer.tokenize(ref).split() + pred_items = self.tokenizer.tokenize(pred).split() + self.distance += self.ed.eval(ref_items, pred_items) self.ref_length += len(ref_items) def result_string(self): - return f"WER: {self.score()}" + return f"WER: {self.score():.2f}" def score(self): return ( - 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 + 100.0 * self.distance / self.ref_length if self.ref_length > 0 + else 0 ) From 3c118ad1bad66137c95e0a0d970e8b57b6848065 Mon Sep 17 00:00:00 2001 From: Sergey Edunov Date: Wed, 14 Oct 2020 12:55:40 -0700 Subject: [PATCH 212/707] Adding tok.sh and installation scripts for the dependecies (#1339) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Adding tok.sh needed to evaluate performance of multilingual models. Aside tok.sh added installation script "install_dependecies.sh" that will install all the needed dependencies except Arabic. Arabic requires downloading separate installation packages and signing licensing agreements, so it can't be automated. # Before submitting - [X] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [X] Did you make sure to update the docs? - [X] Did you write any new necessary tests? ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1339 Reviewed By: shruti-bh Differential Revision: D24311526 Pulled By: edunov fbshipit-source-id: fe9d46b0c7d7dc090e03f504e048b0c6eb616df2 --- examples/m2m_100/README.md | 18 ++++ examples/m2m_100/install_dependecies.sh | 78 +++++++++++++++++ examples/m2m_100/tok.sh | 83 +++++++++++++++++++ examples/m2m_100/tokenizers/seg_ja.sh | 11 +++ examples/m2m_100/tokenizers/seg_ko.sh | 12 +++ .../m2m_100/tokenizers/thirdparty/.gitignore | 12 +++ examples/m2m_100/tokenizers/tokenize_indic.py | 21 +++++ examples/m2m_100/tokenizers/tokenize_thai.py | 12 +++ examples/m2m_100/tokenizers/tokenize_zh.py | 12 +++ examples/m2m_100/tokenizers/tokenizer_ar.sh | 27 ++++++ 10 files changed, 286 insertions(+) create mode 100644 examples/m2m_100/README.md create mode 100755 examples/m2m_100/install_dependecies.sh create mode 100755 examples/m2m_100/tok.sh create mode 100755 examples/m2m_100/tokenizers/seg_ja.sh create mode 100755 examples/m2m_100/tokenizers/seg_ko.sh create mode 100644 examples/m2m_100/tokenizers/thirdparty/.gitignore create mode 100644 examples/m2m_100/tokenizers/tokenize_indic.py create mode 100644 examples/m2m_100/tokenizers/tokenize_thai.py create mode 100644 examples/m2m_100/tokenizers/tokenize_zh.py create mode 100755 examples/m2m_100/tokenizers/tokenizer_ar.sh diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md new file mode 100644 index 0000000000..d2892cb2d6 --- /dev/null +++ b/examples/m2m_100/README.md @@ -0,0 +1,18 @@ +# MMMT Tokenizer + +We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results. + +To reproduce the results, follow these steps: + +``` +tgt_lang=... +reference_translation=... +cat generation_output | grep -P "^H" |sort -V |cut -f 3- |sh tok.sh $tgt_lang > hyp +cat $reference_translation |sh tok.sh $tgt_lang > ref +sacrebleu -tok 'none' ref < hyp +``` + +# Installation + +Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh +If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install diff --git a/examples/m2m_100/install_dependecies.sh b/examples/m2m_100/install_dependecies.sh new file mode 100755 index 0000000000..82a1054745 --- /dev/null +++ b/examples/m2m_100/install_dependecies.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +CWD=`pwd` +INSTALL_PATH=$CWD/tokenizers/thirdparty + +MOSES=$INSTALL_PATH/mosesdecoder +if [ ! -d $MOSES ]; then + echo 'Cloning Moses github repository (for tokenization scripts)...' + git clone https://github.com/moses-smt/mosesdecoder.git $MOSES + cd $MOSES + # To deal with differences in handling ' vs " + git checkout 03578921cc1a03402 + cd - +fi + +WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts +if [ ! -d $WMT16_SCRIPTS ]; then + echo 'Cloning Romanian tokenization scripts' + git clone https://github.com/rsennrich/wmt16-scripts.git $WMT16_SCRIPTS +fi + +KYTEA=$INSTALL_PATH/kytea +if [ ! -f $KYTEA/bin/kytea ]; then + git clone https://github.com/neubig/kytea.git $KYTEA + cd $KYTEA + autoreconf -i + ./configure --prefix=`pwd` + make + make install + cd .. +fi + +export MECAB=$INSTALL_PATH/mecab-0.996-ko-0.9.2 +if [ ! -f $MECAB/bin/mecab ]; then + cd $INSTALL_PATH + curl -LO https://bitbucket.org/eunjeon/mecab-ko/downloads/mecab-0.996-ko-0.9.2.tar.gz + tar zxfv mecab-0.996-ko-0.9.2.tar.gz + cd mecab-0.996-ko-0.9.2/ + ./configure --prefix=`pwd` + make + make install + + cd .. + curl -LO https://bitbucket.org/eunjeon/mecab-ko-dic/downloads/mecab-ko-dic-2.1.1-20180720.tar.gz + tar zxfv mecab-ko-dic-2.1.1-20180720.tar.gz + cd mecab-ko-dic-2.1.1-20180720/ + ./autogen.sh + ./configure --prefix=`pwd` --with-dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic --with-mecab-config=$MECAB/bin/mecab-config + make + sh -c 'echo "dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic" > $MECAB/etc/mecabrc' + make install + cd $CWD +fi + +INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources +if [ ! -d $INDIC_RESOURCES_PATH ]; then + echo 'Cloning indic_nlp_resources' + git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git $INDIC_RESOURCES_PATH +fi + + +if [ ! -f $INSTALL_PATH/seg_my.py ]; then + cd $INSTALL_PATH + wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip + unzip wat2020.my-en.zip + # switch to python3 + cat wat2020.my-en/myseg.py |sed 's/^sys.std/###sys.std/g' | sed 's/### sys/sys/g' | sed 's/unichr/chr/g' > seg_my.py + cd $CWD +fi + + +pip install pythainlp sacrebleu indic-nlp-library + diff --git a/examples/m2m_100/tok.sh b/examples/m2m_100/tok.sh new file mode 100755 index 0000000000..ba2ec5a2f3 --- /dev/null +++ b/examples/m2m_100/tok.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Copyright (c) 2019-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +set -e + +TOKENIZERS_SCRIPTS=tokenizers +INSTALL_PATH=$TOKENIZERS_SCRIPTS/thirdparty + +N_THREADS=8 + +lg=$1 + +MOSES=$INSTALL_PATH/mosesdecoder +REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl +NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl +TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl + +# special tokenization for Romanian +WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts + +NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py +REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py + +# Burmese +MY_SEGMENT=$INSTALL_PATH/seg_my.py + +# Arabic +AR_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenizer_ar.sh + +# Korean +KO_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ko.sh + +# Japanese +JA_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ja.sh + +# Indic +IN_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_indic.py +INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources + +# Thai +THAI_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_thai.py + +# Chinese +CHINESE_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_zh.py + +# Chinese +if [ "$lg" = "zh" ]; then + cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $CHINESE_TOKENIZER +# Thai +elif [ "$lg" = "th" ]; then + cat - | python $THAI_TOKENIZER +# Japanese +elif [ "$lg" = "ja" ]; then + cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | ${JA_SEGMENT} +# Korean +elif [ "$lg" = "ko" ]; then + cat - | $REM_NON_PRINT_CHAR | ${KO_SEGMENT} +# Romanian +elif [ "$lg" = "ro" ]; then + cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -no-escape -threads $N_THREADS -l $lg +# Burmese +elif [ "$lg" = "my" ]; then + cat - | python ${MY_SEGMENT} +# Arabic +elif [ "$lg" = "ar" ]; then + cat - | ${AR_TOKENIZER} +# Indic +elif [ "$lg" = "ne" ]; then + cat - | python ${IN_TOKENIZER} $lg +elif [ "$lg" = "si" ]; then + cat - | python ${IN_TOKENIZER} $lg +elif [ "$lg" = "hi" ]; then + cat - | python ${IN_TOKENIZER} $lg +# other languages +else + cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg +fi diff --git a/examples/m2m_100/tokenizers/seg_ja.sh b/examples/m2m_100/tokenizers/seg_ja.sh new file mode 100755 index 0000000000..be6f5ca5fe --- /dev/null +++ b/examples/m2m_100/tokenizers/seg_ja.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +SCRIPT=`realpath $0` +KYTEA=`dirname $SCRIPT`/thirdparty/kytea +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$KYTEA/lib:/usr/local/lib +export PATH=$PATH:"$KYTEA/bin" + +cat - | tr -d "[:blank:]" | kytea -notags diff --git a/examples/m2m_100/tokenizers/seg_ko.sh b/examples/m2m_100/tokenizers/seg_ko.sh new file mode 100755 index 0000000000..c523d92634 --- /dev/null +++ b/examples/m2m_100/tokenizers/seg_ko.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +SCRIPT=`realpath $0` +MECAB=`dirname $SCRIPT`/thirdparty/mecab-0.996-ko-0.9.2 + +export PATH=$PATH:"$MECAB/bin":"$MECAB/lib" +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$MECAB/lib" + +cat - | mecab -O wakati diff --git a/examples/m2m_100/tokenizers/thirdparty/.gitignore b/examples/m2m_100/tokenizers/thirdparty/.gitignore new file mode 100644 index 0000000000..19eb6a9dd7 --- /dev/null +++ b/examples/m2m_100/tokenizers/thirdparty/.gitignore @@ -0,0 +1,12 @@ +seg_my.py +indic_nlp_library/ +indic_nlp_resources/ +kytea/ +mecab-0.996-ko-0.9.2.tar.gz +mecab-0.996-ko-0.9.2/ +mosesdecoder/ +wat2020.my-en.zip +wat2020.my-en/ +wmt16-scripts/ +mecab-ko-dic-2.1.1-20180720/ +mecab-ko-dic-2.1.1-20180720.tar.gz \ No newline at end of file diff --git a/examples/m2m_100/tokenizers/tokenize_indic.py b/examples/m2m_100/tokenizers/tokenize_indic.py new file mode 100644 index 0000000000..c1303b3d15 --- /dev/null +++ b/examples/m2m_100/tokenizers/tokenize_indic.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Use: echo {text} | python tokenize_indic.py {language} + +import sys + +from indicnlp.tokenize.indic_tokenize import trivial_tokenize +from indicnlp.normalize.indic_normalize import IndicNormalizerFactory + +factory=IndicNormalizerFactory() +normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing') + +for line in sys.stdin: + normalized_line=normalizer.normalize(line.strip()) + tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1])) + print(tokenized_line) + diff --git a/examples/m2m_100/tokenizers/tokenize_thai.py b/examples/m2m_100/tokenizers/tokenize_thai.py new file mode 100644 index 0000000000..7c7b7ebfaa --- /dev/null +++ b/examples/m2m_100/tokenizers/tokenize_thai.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from pythainlp import word_tokenize + +for line in sys.stdin: + print(" ".join(word_tokenize(line.strip()))) diff --git a/examples/m2m_100/tokenizers/tokenize_zh.py b/examples/m2m_100/tokenizers/tokenize_zh.py new file mode 100644 index 0000000000..531a7fb49b --- /dev/null +++ b/examples/m2m_100/tokenizers/tokenize_zh.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import fileinput +import sacrebleu + +for line in fileinput.input(): + print(sacrebleu.tokenize_zh(line)) diff --git a/examples/m2m_100/tokenizers/tokenizer_ar.sh b/examples/m2m_100/tokenizers/tokenizer_ar.sh new file mode 100755 index 0000000000..ad35d7adf2 --- /dev/null +++ b/examples/m2m_100/tokenizers/tokenizer_ar.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env sh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# Please follow the instructions here http://alt.qcri.org/tools/arabic-normalizer/ +# to install tools needed for Arabic + +echo "Please install Arabic tools: http://alt.qcri.org/tools/arabic-normalizer/" +echo "Then update environment variables in tokenizer_ar.sh" +exit 1 + +SVMTOOL=... +GOMOSESGO=... +QCRI_ARABIC_NORMALIZER=... + +export PERL5LIB="$SVMTOOL/lib":"$GOMOSESGO/bin/MADA-3.2":$PERL5LIB + + +tempfile=$(mktemp) +cat - > $tempfile + +cd $QCRI_ARABIC_NORMALIZER + +bash qcri_normalizer_mada3.2_aramorph1.2.1.sh $tempfile +cat $tempfile.mada_norm-aramorph.europarl_tok From 3544f5f24eb52f3a7c5f2dba78462ca08d52c1f0 Mon Sep 17 00:00:00 2001 From: Sergey Edunov Date: Wed, 14 Oct 2020 14:18:39 -0700 Subject: [PATCH 213/707] Releasing single pre-finetuning models (#1347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1347 Reviewed By: michaelauli, shruti-bh Differential Revision: D24315287 Pulled By: edunov fbshipit-source-id: d94955866b5424ab9c6a78982140e2bd7d1b279b --- examples/wmt19/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/wmt19/README.md b/examples/wmt19/README.md index 3c59851264..5c90d0e6c4 100644 --- a/examples/wmt19/README.md +++ b/examples/wmt19/README.md @@ -14,6 +14,15 @@ Model | Description | Download `transformer_lm.wmt19.de` | De Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz) `transformer_lm.wmt19.ru` | Ru Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz) +## Pre-trained single models before finetuning + +Model | Description | Download +---|---|--- +`transformer.wmt19.en-de` | En->De Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.ffn8192.tar.gz) +`transformer.wmt19.de-en` | De->En Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.ffn8192.tar.gz) +`transformer.wmt19.en-ru` | En->Ru Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ffn8192.tar.gz) +`transformer.wmt19.ru-en` | Ru->En Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ffn8192.tar.gz) + ## Example usage (torch.hub) #### Requirements From 573c2f4b60a50dc7c4ff17084b753c05452381f9 Mon Sep 17 00:00:00 2001 From: Xian Li Date: Thu, 15 Oct 2020 09:23:54 -0700 Subject: [PATCH 214/707] Opensource code for Deep Transformer with Latent Depth (#2703) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Opensource code for Deep Transformer with Latent Depth (https://arxiv.org/pdf/2009.13102.pdf). New features and design choices made: - New feature: allow non-residual block to be weighted by sample z (generated per batch) instead of `x = residual + x`. - Design choice: move `x = residual + x` in transformer_layer.py into a function where the subclass (with latent depth) could overwrite it to `x = residual + z*x`. - New feature: allow TransformerEncoder or TransformerDecoder to have additional logits parameters which will generate the samples z. - Design choice: added subclass LatentTransformerEncoder and LatentTransformerDecoder, which has additional attributes for the logits parameters, and instantiate the corresponding LatentTransformerEncoderLayer and LatentTransformerDecoderLayer. - New feature: allow multilingual_translation task to train with latent depth (results in the paper). - Design choice: - added additional arguments in the multilingual_translation task. - added option for multilingual_transformer to use LatentTransformerEncoder and LatentTransformerDecoder besides standard TransformerEncoder. - added option in multilingual_translation task's `train_step` to generate the samples z and compute the KL (and sparsity) loss per batch. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2703 Reviewed By: myleott Differential Revision: D24155059 Pulled By: xianxl fbshipit-source-id: f3e41639429f9664ec5565839709aa857a643668 --- README.md | 2 + examples/latent_depth/README.md | 77 +++++++++ examples/latent_depth/src/__init__.py | 9 + examples/latent_depth/src/loss/__init__.py | 0 .../latent_depth/src/loss/latent_depth.py | 86 ++++++++++ examples/latent_depth/src/models/__init__.py | 0 .../models/latent_multilingual_transformer.py | 60 +++++++ .../src/models/latent_transformer.py | 130 +++++++++++++++ examples/latent_depth/src/modules/__init__.py | 0 .../latent_depth/src/modules/latent_layers.py | 73 ++++++++ .../multilingual_translation_latent_depth.py | 156 ++++++++++++++++++ fairseq/models/multilingual_transformer.py | 11 +- fairseq/modules/transformer_layer.py | 16 +- fairseq/tasks/multilingual_translation.py | 18 +- tests/test_binaries.py | 46 ++++++ 15 files changed, 672 insertions(+), 12 deletions(-) create mode 100644 examples/latent_depth/README.md create mode 100644 examples/latent_depth/src/__init__.py create mode 100644 examples/latent_depth/src/loss/__init__.py create mode 100644 examples/latent_depth/src/loss/latent_depth.py create mode 100644 examples/latent_depth/src/models/__init__.py create mode 100644 examples/latent_depth/src/models/latent_multilingual_transformer.py create mode 100644 examples/latent_depth/src/models/latent_transformer.py create mode 100644 examples/latent_depth/src/modules/__init__.py create mode 100644 examples/latent_depth/src/modules/latent_layers.py create mode 100644 examples/latent_depth/src/multilingual_translation_latent_depth.py diff --git a/README.md b/README.md index 151d4f0507..997d8833ea 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ We provide reference implementations of various sequence modeling papers: - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + - [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -55,6 +56,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) - October 2020: [Added CRISS models and code](examples/criss/README.md) - September 2020: [Added Linformer code](examples/linformer/README.md) - September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md new file mode 100644 index 0000000000..3faf21bf89 --- /dev/null +++ b/examples/latent_depth/README.md @@ -0,0 +1,77 @@ +# Deep Transformers with Latent Depth (Li et al., 2020) + +[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102). + +## Introduction + +We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair. + +## Training a multilingual model with latent depth + +Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided. +```bash +lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" +databin_dir= + +python fairseq_cli/train.py ${databin_dir} \ + --user-dir, examples/latent_depth/src \ + --lang-pairs "${lang_pairs_str}" \ + --arch multilingual_transformer_iwslt_de_en \ + --task multilingual_translation_latent_depth \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --share-encoders \ + --share-decoders \ + --decoder-langtok \ + --share-decoder-input-output-embed \ + --dropout 0.3 --attention-dropout 0.3 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ + --max-tokens 4096 --update-freq 1 \ + --lr 0.0015 \ + --clip-norm 1.0 \ + --seed 2 \ + --ddp-backend=no_c10d \ + --encoder-layers 12 \ + --decoder-layers 24 \ + --decoder-latent-layer \ + --sparsity-weight 0.1 \ + --anneal-updates 5000 \ + --soft-update 500 \ + --target-layers 12 \ + --share-weight 0.1 +``` +## Inference command + +```bash +lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" +databin_dir= +model_path= +src_lang= +tgt_lang= +gen_data= + +python fairseq_cli/generate.py ${databin_dir} \ + --path ${model_path} \ + --task multilingual_translation_latent_depth \ + --decoder-latent-layer \ + --lang-pairs "${lang_pairs_str}" \ + -s ${src_lang} -t ${tgt_lang} \ + --gen-subset $gen_data \ + --scoring sacrebleu \ + --remove-bpe 'sentencepiece' \ + --lenpen 1.0 \ + --beam 5 \ + --decoder-langtok \ + --max-tokens 4096 +``` + + +## Citation +```bibtex +@article{li2020deep, + title={Deep Transformers with Latent Depth}, + author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang}, + journal={arXiv preprint arXiv:2009.13102}, + year={2020} +} +``` diff --git a/examples/latent_depth/src/__init__.py b/examples/latent_depth/src/__init__.py new file mode 100644 index 0000000000..8a86fa5817 --- /dev/null +++ b/examples/latent_depth/src/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .models import latent_multilingual_transformer # noqa +from .modules import latent_layers # noqa +from .loss import latent_depth # noqa +from . import multilingual_translation_latent_depth # noqa diff --git a/examples/latent_depth/src/loss/__init__.py b/examples/latent_depth/src/loss/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/latent_depth/src/loss/latent_depth.py b/examples/latent_depth/src/loss/latent_depth.py new file mode 100644 index 0000000000..f647c758ee --- /dev/null +++ b/examples/latent_depth/src/loss/latent_depth.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import math +from torch.nn.modules.loss import _Loss + + +class LatentLayersKLLoss(_Loss): + def __init__(self, args): + super().__init__() + self.args = args + + def forward(self, layer_samples, lang_idx, update_num, sample_size): + prior = self.args.prior + samples = layer_samples[lang_idx] + eps = 1e-7 + if prior == "uniform": + # uniform prior + kl_loss = (samples * ( + torch.log(samples + eps) - math.log(0.5) + )).sum(-1) + elif prior == "agged_posterior": + # aggregated posterior + y_t = torch.stack([x.detach() for x in layer_samples], dim=0) + agged_q = torch.sum(y_t, dim=0) + row_norm = agged_q.sum(-1) + normed_agg_q = agged_q / row_norm + kl_loss = (samples * ( + torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1) + else: + raise NotImplementedError("The specified prior is not implemented.") + + # normalized by number of layers + kl_loss /= layer_samples[0].size()[0] + kl_weight = min( + self.args.sparsity_weight, + (update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates + ) + kl_loss *= kl_weight * sample_size + return kl_loss + + +class LatentLayersSparsityLoss(_Loss): + def __init__(self, args): + super().__init__() + self.args = args + + def is_valid(self, update_num): + if self.args.target_layers <= 0: + return False + return update_num > (self.args.soft_update + self.args.anneal_updates) + + def forward(self, layer_samples_list, update_num, sample_size): + batch_loss = 0 + share_loss = 0 + global_sparsity_loss = 0 + layer_samples = torch.stack(layer_samples_list, dim=0) + if ((self.args.target_layers > 0 or self.args.share_weight > 0) and + update_num > (self.args.soft_update + self.args.anneal_updates)): + # anneal sparsity weight + if update_num < (self.args.anneal_updates + self.args.soft_update): + weight_anneal = 0 + elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): + weight_anneal = ( + (update_num - self.args.soft_update - self.args.anneal_updates) + * self.args.share_weight / self.args.anneal_updates + ) + else: + weight_anneal = 1 + # compute ratio among languages + layer_utilization = torch.sum(layer_samples, dim=0) + layer_utilization /= layer_samples.size()[0] + if self.args.share_weight > 0: + # encouraging sharing across languages + share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0) + batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss + if self.args.target_layers > 0: + # computed expected number of layers selected + expeted_layers = sum(layer_utilization) + # compute l2 loss wrt target number of layers + global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 + batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss + return batch_loss diff --git a/examples/latent_depth/src/models/__init__.py b/examples/latent_depth/src/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/latent_depth/src/models/latent_multilingual_transformer.py b/examples/latent_depth/src/models/latent_multilingual_transformer.py new file mode 100644 index 0000000000..97573cbd75 --- /dev/null +++ b/examples/latent_depth/src/models/latent_multilingual_transformer.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + base_architecture, + TransformerEncoder, + TransformerDecoder, +) +from fairseq.models.multilingual_transformer import MultilingualTransformerModel + +from .latent_transformer import ( + LatentTransformerEncoder, + LatentTransformerDecoder, +) + + +@register_model('latent_multilingual_transformer') +class LatentMultilingualTransformerModel(MultilingualTransformerModel): + """A variant of standard multilingual Transformer models which encoder and/or + decoders supports latent depth, as is in "Deep Transformer with Latent Depth" + (https://arxiv.org/abs/2009.13102). + """ + @classmethod + def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): + if is_encoder: + if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: + return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) + else: + return TransformerEncoder(args, lang_dict, embed_tokens) + else: + if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer: + return LatentTransformerDecoder( + args, lang_dict, embed_tokens, num_logits=len(langs) + ) + else: + return TransformerDecoder(args, lang_dict, embed_tokens) + + +@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer') +def latent_multilingual_architecture(args): + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) + args.encoder_layers = getattr(args, 'encoder_layers', 12) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) + args.decoder_layers = getattr(args, 'decoder_layers', 24) + args.share_encoders = getattr(args, 'share_encoders', True) + args.share_decoders = getattr(args, 'share_decoders', True) + args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True) + args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True) + + base_architecture(args) diff --git a/examples/latent_depth/src/models/latent_transformer.py b/examples/latent_depth/src/models/latent_transformer.py new file mode 100644 index 0000000000..5d47340f58 --- /dev/null +++ b/examples/latent_depth/src/models/latent_transformer.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional + +import torch.nn as nn +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import TransformerEncoder, TransformerDecoder +from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer +from ..modules.latent_layers import LayerSelect +from torch import Tensor + + +class LatentTransformerEncoder(TransformerEncoder): + """Latent depth (https://arxiv.org/abs/2009.13102) implemented in + TransformerEncoder. + """ + def __init__(self, args, dictionary, embed_tokens, num_logits=1): + self.num_logits = num_logits + self.num_layers = args.encoder_layers + super().__init__(args, dictionary, embed_tokens) + self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) + self.lang_idx = None + self.layers = nn.ModuleList([ + self._build_encoder_layer(args, idx) + for idx in range(args.encoder_layers) + ]) + + def set_lang_idx(self, lang_idx): + self.lang_idx = lang_idx + + def _build_encoder_layer(self, args, idx=None): + return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select) + + def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): + self.layer_select.sample(self.lang_idx) + return super().forward(src_tokens, src_lengths, return_all_hiddens) + + +class LatentTransformerEncoderLayer(TransformerEncoderLayer): + """Encoder layer with each (non_residual) block weighted by samples of Bernouli + or Gumbel Signmoid samples. + + Args: + args (argparse.Namespace): parsed command-line arguments from standard + TransformerEncoderLayer. + idx (int): layer index (used to retrieve samples). + layer_select (LayerSelect, optional): instance of LayerSelect module with logits + parameters and sampling method. + """ + def __init__(self, args, idx, layer_select=None): + super().__init__(args) + self.idx = idx + self.layer_select = layer_select + + def residual_connection(self, x, residual): + return residual + x * self.layer_select(self.idx) + + +class LatentTransformerDecoder(TransformerDecoder): + """Latent depth (https://arxiv.org/abs/2009.13102) implemented in + TransformerDecoder. + """ + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1): + self.num_logits = num_logits + self.num_layers = args.decoder_layers + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) + self.lang_idx = None + self.layers = nn.ModuleList([ + self._build_decoder_layer(args, no_encoder_attn, idx) + for idx in range(args.decoder_layers) + ]) + + def set_lang_idx(self, lang_idx): + self.lang_idx = lang_idx + + def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None): + return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + self.layer_select.sample(self.lang_idx) + return super().forward( + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + features_only=features_only, + alignment_layer=alignment_layer, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + + +class LatentTransformerDecoderLayer(TransformerDecoderLayer): + """Decoder layer with each (non_residual) block weighted by samples of Bernouli + or Gumbel Signmoid samples. + + Args: + args (argparse.Namespace): parsed command-line arguments from standard + TransformerDecoderLayer. + idx (int): layer index (used to retrieve samples). + layer_select (LayerSelect, optional): instance of LayerSelect module with logits + parameters and sampling method. + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + + """ + def __init__( + self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) + self.idx = idx + self.layer_select = layer_select + + def residual_connection(self, x, residual): + return residual + x * self.layer_select(self.idx) diff --git a/examples/latent_depth/src/modules/__init__.py b/examples/latent_depth/src/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/latent_depth/src/modules/latent_layers.py b/examples/latent_depth/src/modules/latent_layers.py new file mode 100644 index 0000000000..e772ac3237 --- /dev/null +++ b/examples/latent_depth/src/modules/latent_layers.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class LayerSelect(nn.Module): + """Compute samples (from a Gumbel-Sigmoid distribution) which is used as + either (soft) weighting or (hard) selection of residual connection. + https://arxiv.org/abs/2009.13102 + """ + def __init__(self, num_layers, num_logits, args): + super(LayerSelect, self).__init__() + self.args = args + self.layer_logits = torch.nn.Parameter( + torch.Tensor(num_logits, num_layers), + requires_grad=True, + ) + self.hard_select = not (hasattr(args, "soft_select") and args.soft_select) + self.tau = getattr(args, "sampling_tau", 5) + self.detach_grad = False + self.layer_samples = [None] * num_logits + + @staticmethod + def add_args(parser): + parser.add_argument( + '--soft-select', + action='store_true', + help='use soft samples in training an inference' + ) + parser.add_argument('--sampling-tau', type=float, help='sampling temperature') + + def sample(self, logit_idx): + """ To leverage the efficiency of distributed training, samples for all + layers are computed at once for each logit_idx. Logits are parameters + learnt independent of each other. + + Args: + logit_idx: The index of logit parameters used for sampling. + """ + assert logit_idx is not None + self.samples = self._gumbel_sigmoid( + self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :], + dim=-1, + tau=self.tau, + hard=self.hard_select, + ) + self.layer_samples[logit_idx] = self.samples + + def forward(self, i): + sample = self.samples[i] + return sample + + def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5): + # ~Gumbel(0,1) + gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + # Difference of two gumbels because we apply a sigmoid + gumbels1 = (logits + gumbels1 - gumbels2) / tau + y_soft = gumbels1.sigmoid() + if hard: + # Straight through. + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format + ).masked_fill(y_soft > threshold, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret diff --git a/examples/latent_depth/src/multilingual_translation_latent_depth.py b/examples/latent_depth/src/multilingual_translation_latent_depth.py new file mode 100644 index 0000000000..1a19f8f8f9 --- /dev/null +++ b/examples/latent_depth/src/multilingual_translation_latent_depth.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.tasks import register_task +from fairseq.tasks.multilingual_translation import MultilingualTranslationTask +from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss + + +@register_task('multilingual_translation_latent_depth') +class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask): + """A task for multiple translation with latent depth. + + See `"Deep Transformer with Latent Depth" + (Li et al., 2020) `_. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + MultilingualTranslationTask.add_args(parser) + parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder') + parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder') + parser.add_argument('--target-layers', default=-1, type=int, + help='number of effective layers to learn; -1 means no constraint') + parser.add_argument('--sparsity-weight', default=0.0, type=float, + help='weight for sparsity loss') + parser.add_argument('--share-weight', default=0.0, type=float, + help='weight for sharing loss') + parser.add_argument('--soft-update', default=1, type=int, + help='number of updates with soft sampling') + parser.add_argument('--anneal-updates', default=1, type=int, + help='number of updates to anneal the KL loss weight') + parser.add_argument('--prior', default="uniform", type=str, + help='prior used for computing KL loss') + # fmt: on + + def __init__(self, args, dicts, training): + super().__init__(args, dicts, training) + self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]) + if self.training and self.encoder_latent_layer: + assert self.args.share_encoders + if self.training and self.decoder_latent_layer: + assert self.args.share_decoders + if training or self.encoder_latent_layer or self.decoder_latent_layer: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] + self.eval_lang_pairs = self.lang_pairs + self.model_lang_pairs = self.lang_pairs + if self.training and (self.encoder_latent_layer or self.decoder_latent_layer): + self.kl_loss = LatentLayersKLLoss(self.args) + self.sparsity_loss = LatentLayersSparsityLoss(self.args) + + def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad): + src, tgt = lang_pair.split("-") + if self.encoder_latent_layer: + src_lang_idx = self.src_lang_idx_dict[src] + model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) + model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update + if self.decoder_latent_layer: + tgt_lang_idx = self.tgt_lang_idx_dict[tgt] + model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) + model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update + + loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + if self.encoder_latent_layer: + none_samples = sum( + 1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples + ) + if none_samples == 0 or self.args.prior != "agged_posterior": + loss += self.kl_loss( + model.models[lang_pair].encoder.layer_select.layer_samples, + src_lang_idx, + update_num, + sample_size + ) + if self.decoder_latent_layer: + none_samples = sum( + 1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples + ) + if none_samples == 0 or self.args.prior != "agged_posterior": + loss += self.kl_loss( + model.models[lang_pair].decoder.layer_select.layer_samples, + tgt_lang_idx, + update_num, + sample_size + ) + if ignore_grad: + loss *= 0 + + if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): + # need to retain the graph if sparsity loss needs to be added + loss.backward(retain_graph=True) + else: + optimizer.backward(loss) + + return loss, sample_size, logging_output + + def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + agg_loss, agg_sample_size, agg_logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad) + # compute auxiliary loss from layere sparsity, based on all samples from all languages + if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): + sparsity_loss = 0 + if self.encoder_latent_layer: + sparsity_loss += self.sparsity_loss( + next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size) + if self.decoder_latent_layer: + sparsity_loss += self.sparsity_loss( + next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size) + if sparsity_loss > 0: + optimizer.backward(sparsity_loss) + return agg_loss, agg_sample_size, agg_logging_output + + def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): + src, tgt = lang_pair.split("-") + if self.encoder_latent_layer: + src_lang_idx = self.src_lang_idx_dict[src] + model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) + if self.decoder_latent_layer: + tgt_lang_idx = self.tgt_lang_idx_dict[tgt] + model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) + loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + return loss, sample_size, logging_output + + def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + if self.encoder_latent_layer or self.decoder_latent_layer: + for model in models: + if self.encoder_latent_layer: + assert model.encoder.layer_select is not None + src_lang_idx = self.src_lang_idx_dict[self.args.source_lang] + model.encoder.set_lang_idx(src_lang_idx) + if self.decoder_latent_layer: + assert model.decoder.layer_select is not None + tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang] + model.decoder.set_lang_idx(tgt_lang_idx) + return super().inference_step(generator, models, sample, prefix_tokens, constraints) + + @property + def encoder_latent_layer(self): + return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer + + @property + def decoder_latent_layer(self): + return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer + + @property + def src_lang_idx_dict(self): + return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)} + + @property + def tgt_lang_idx_dict(self): + return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)} diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index 2f6a837805..91a413753c 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -136,7 +136,8 @@ def get_encoder(lang): encoder_embed_tokens = build_embedding( task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path ) - lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens) + lang_encoders[lang] = cls._get_module_class( + True, args, task.dicts[lang], encoder_embed_tokens, src_langs) return lang_encoders[lang] def get_decoder(lang): @@ -147,7 +148,8 @@ def get_decoder(lang): decoder_embed_tokens = build_embedding( task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path ) - lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens) + lang_decoders[lang] = cls._get_module_class( + False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs) return lang_decoders[lang] # shared encoders/decoders (if applicable) @@ -164,6 +166,11 @@ def get_decoder(lang): return MultilingualTransformerModel(encoders, decoders) + @classmethod + def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): + module_class = TransformerEncoder if is_encoder else TransformerDecoder + return module_class(args, lang_dict, embed_tokens) + def load_state_dict(self, state_dict, strict=True, args=None): state_dict_subset = state_dict.copy() for k, _ in state_dict.items(): diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index a803f581a5..9965f2f26c 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -72,6 +72,9 @@ def build_self_attention(self, embed_dim, args): qn_block_size=self.quant_noise_block_size, ) + def residual_connection(self, x, residual): + return residual + x + def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to @@ -121,7 +124,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): attn_mask=attn_mask, ) x = self.dropout_module(x) - x = residual + x + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -133,7 +136,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): x = self.activation_dropout_module(x) x = self.fc2(x) x = self.dropout_module(x) - x = residual + x + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) return x @@ -243,6 +246,9 @@ def build_encoder_attention(self, embed_dim, args): def prepare_for_onnx_export_(self): self.onnx_trace = True + def residual_connection(self, x, residual): + return residual + x + def forward( self, x, @@ -320,7 +326,7 @@ def forward( attn_mask=self_attn_mask, ) x = self.dropout_module(x) - x = residual + x + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) @@ -350,7 +356,7 @@ def forward( need_head_weights=need_head_weights, ) x = self.dropout_module(x) - x = residual + x + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.encoder_attn_layer_norm(x) @@ -362,7 +368,7 @@ def forward( x = self.activation_dropout_module(x) x = self.fc2(x) x = self.dropout_module(x) - x = residual + x + x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 7c7e18ec87..161eb436ec 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -264,6 +264,13 @@ def check_args(): raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture') return model + def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad): + loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): model.train() from collections import defaultdict @@ -285,10 +292,8 @@ def maybe_no_sync(): else: return contextlib.ExitStack() # dummy contextmanager with maybe_no_sync(): - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) - if ignore_grad: - loss *= 0 - optimizer.backward(loss) + loss, sample_size, logging_output = self._per_lang_pair_train_loss( + lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad) agg_loss += loss.detach().item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size @@ -297,6 +302,9 @@ def maybe_no_sync(): agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] return agg_loss, agg_sample_size, agg_logging_output + def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): + return criterion(model.models[lang_pair], sample[lang_pair]) + def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): @@ -305,7 +313,7 @@ def valid_step(self, sample, model, criterion): for lang_pair in self.eval_lang_pairs: if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0: continue - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + loss, sample_size, logging_output = self._per_lang_pair_valid_loss(lang_pair, model, criterion, sample) agg_loss += loss.data.item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size diff --git a/tests/test_binaries.py b/tests/test_binaries.py index c0c4abb6ed..53554e8d28 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -207,6 +207,52 @@ def test_multilingual_transformer(self): ] + enc_ltok_flag + dec_ltok_flag, ) + def test_multilingual_translation_latent_depth(self): + # test with latent depth in encoder, decoder, or both + encoder_latent_layer = [[], ['--encoder-latent-layer']] + decoder_latent_layer = [[], ['--decoder-latent-layer']] + with contextlib.redirect_stdout(StringIO()): + for i in range(len(encoder_latent_layer)): + for j in range(len(decoder_latent_layer)): + if i == 0 and j == 0: + continue + enc_ll_flag = encoder_latent_layer[i] + dec_ll_flag = decoder_latent_layer[j] + with tempfile.TemporaryDirectory(f'test_multilingual_translation_latent_depth_{i}_{j}') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, + extra_flags=['--joined-dictionary'] + ) + train_translation_model( + data_dir, + arch='latent_multilingual_transformer', + task='multilingual_translation_latent_depth', + extra_flags=[ + '--user-dir', 'examples/latent_depth/src', + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--share-encoders', + '--share-decoders', + '--sparsity-weight', '0.1', + ] + enc_ll_flag + dec_ll_flag, + lang_flags=['--lang-pairs', 'in-out,out-in'], + run_validation=True, + extra_valid_flags=['--user-dir', 'examples/latent_depth/src'] + enc_ll_flag + dec_ll_flag, + ) + generate_main( + data_dir, + extra_flags=[ + '--user-dir', 'examples/latent_depth/src', + '--task', 'multilingual_translation_latent_depth', + '--lang-pairs', 'in-out,out-in', + '--source-lang', 'in', + '--target-lang', 'out', + ] + enc_ll_flag + dec_ll_flag, + ) + def test_translation_multi_simple_epoch(self): # test with all combinations of encoder/decoder lang tokens encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']] From 05a5232d04a6e5eccf0e1392b17b4908e5035d44 Mon Sep 17 00:00:00 2001 From: Xian Li Date: Thu, 15 Oct 2020 14:05:28 -0700 Subject: [PATCH 215/707] fix README.md (#2735) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2735 Reviewed By: myleott Differential Revision: D24343492 Pulled By: xianxl fbshipit-source-id: c61c717756307036f9d89de5a8ded66784f1acf7 --- examples/latent_depth/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md index 3faf21bf89..bc78ca8055 100644 --- a/examples/latent_depth/README.md +++ b/examples/latent_depth/README.md @@ -1,6 +1,6 @@ # Deep Transformers with Latent Depth (Li et al., 2020) -[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102). +[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102). ## Introduction @@ -8,12 +8,12 @@ We present a probabilistic framework to automatically learn which layer(s) to us ## Training a multilingual model with latent depth -Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided. +Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided. ```bash lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" databin_dir= -python fairseq_cli/train.py ${databin_dir} \ +fairseq-train ${databin_dir} \ --user-dir, examples/latent_depth/src \ --lang-pairs "${lang_pairs_str}" \ --arch multilingual_transformer_iwslt_de_en \ @@ -50,7 +50,7 @@ src_lang= tgt_lang= gen_data= -python fairseq_cli/generate.py ${databin_dir} \ +fairseq-generate ${databin_dir} \ --path ${model_path} \ --task multilingual_translation_latent_depth \ --decoder-latent-layer \ From 698820b2bb3e128c3d43c834525d1210769974ad Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Thu, 15 Oct 2020 16:15:25 -0700 Subject: [PATCH 216/707] Enable shard epoch id recovery from checkpoints Summary: shard_epoch is not recovered when epoch id is loaded from checkpoints. This diff fixed it. Reviewed By: chtran Differential Revision: D24323687 fbshipit-source-id: a3ee84e8eef7ea75b62c6b0c3870d0cc80ad8f78 --- fairseq/data/multilingual/multilingual_data_manager.py | 10 ++++++++++ fairseq/tasks/translation_multi_simple_epoch.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 806e4c360d..7ce269a4df 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -8,6 +8,7 @@ import logging import os from collections import OrderedDict, defaultdict +import math from fairseq import utils from fairseq.data import ( @@ -287,6 +288,15 @@ def _shared_collater(self): not self.args.lang_tok_replacing_bos_eos ) + def estimate_global_pass_epoch(self, epoch): + if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None: + return None + # one epoch more for remaining data in each shard + virtual_epochs_per_shard = math.ceil(self.args.virtual_data_size / self.args.virtual_epoch_size) + # note that fairseq epoch / shard_epoch starts from 1 + shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1 + return shard_epoch + @classmethod def prepare(cls, load_dictionary, args, **kargs): args.left_pad_source = utils.eval_bool(args.left_pad_source) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index d9c0fa985b..960b82e1e8 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -128,7 +128,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # also this avoid always loading from beginning of the data return else: - shard_epoch = None + # estimate the shard epoch from virtual data size and virtual epoch size + shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) logger.info(f'loading data for {split} epoch={epoch}/{shard_epoch}') logger.info(f"mem usage: {data_utils.get_mem_usage()}") if split in self.datasets: From f2fa07106c4cb8faa70615a63bb31a141c1e3828 Mon Sep 17 00:00:00 2001 From: Armen Aghajanyan Date: Fri, 16 Oct 2020 14:30:35 -0700 Subject: [PATCH 217/707] RXF OS Implementation (#2455) Summary: ## What does this PR do? Implements R3F and R4F coming from Facebook Research: https://arxiv.org/abs/2008.03156 This code was used to generate all the results from the paper excluding probing results. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2455 Reviewed By: myleott Differential Revision: D23444863 Pulled By: AkshatSh fbshipit-source-id: b724a6d6cc9cebfdb4bd219828afbb5679f2259b --- README.md | 3 + examples/rxf/README.md | 52 ++++++ examples/rxf/__init__.py | 6 + examples/rxf/src/__init__.py | 6 + .../src/label_smoothed_cross_entropy_r3f.py | 157 ++++++++++++++++ examples/rxf/src/sentence_prediction_r3f.py | 170 ++++++++++++++++++ fairseq/models/bart/model.py | 24 ++- fairseq/models/roberta/model.py | 39 +++- fairseq/models/transformer.py | 20 ++- .../modules/transformer_sentence_encoder.py | 6 +- tests/test_binaries.py | 20 +++ 11 files changed, 483 insertions(+), 20 deletions(-) create mode 100644 examples/rxf/README.md create mode 100644 examples/rxf/__init__.py create mode 100644 examples/rxf/src/__init__.py create mode 100644 examples/rxf/src/label_smoothed_cross_entropy_r3f.py create mode 100644 examples/rxf/src/sentence_prediction_r3f.py diff --git a/README.md b/README.md index 997d8833ea..56ec16cdab 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,14 @@ We provide reference implementations of various sequence modeling papers: - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +- **Finetuning** + - [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)

### What's New: +- October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) - October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) - October 2020: [Added CRISS models and code](examples/criss/README.md) - September 2020: [Added Linformer code](examples/linformer/README.md) diff --git a/examples/rxf/README.md b/examples/rxf/README.md new file mode 100644 index 0000000000..a09de63d33 --- /dev/null +++ b/examples/rxf/README.md @@ -0,0 +1,52 @@ +[Better Fine-Tuning by Reducing Representational Collapse](https://arxiv.org/abs/2008.03156) +===================== +This repo contains the code to replicate all experiments from the _Better Fine-Tuning by Reducing Representational Collapse_ paper excluding the probing results. + +The R3F sentence prediction criterion is registered as `sentence_prediction_r3f` while the label smoothing version of it is implemented as `label_smoothed_cross_entropy_r3f`. The R4F version of the sentence prediction criterion can be achieved by applying spectral norm to the classification head via the `--spectral-norm-classification-head` parameter. + +## Hyper-parameters +Our methods introduce 3 new hyper-parameters; `--eps` which sets the standard deviation or range of the distribution we're sampling from, `--r3f-lambda` which controls the combining of logistic loss and noisy KL loss and `--noise-type` which controls which parametric distribution we use ('normal', 'uniform'). + +For example to run R3F on RTE from GLUE + +``` +TOTAL_NUM_UPDATES=3120 +WARMUP_UPDATES=187 +LR=1e-05 +NUM_CLASSES=2 +MAX_SENTENCES=8 # Batch size. +ROBERTA_PATH=/path/to/roberta/model.pt + +CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin \ + --restore-file $ROBERTA_PATH \ + --max-positions 512 \ + --max-sentences $MAX_SENTENCES \ + --max-tokens 4400 \ + --task sentence_prediction \ + --reset-optimizer --reset-dataloader --reset-meters \ + --required-batch-size-multiple 1 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction_r3f \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --max-epoch 10 \ + --find-unused-parameters \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ + --noise-type uniform --r3f-lambda 0.7 \ + --user-dir examples/rxf; +``` + +## Citation +```bibtex +@article{aghajanyan2020better, + title={Better Fine-Tuning by Reducing Representational Collapse}, + author={Aghajanyan, Armen and Shrivastava, Akshat and Gupta, Anchit and Goyal, Naman and Zettlemoyer, Luke and Gupta, Sonal}, + journal={arXiv preprint arXiv:2008.03156}, + year={2020} +} +``` diff --git a/examples/rxf/__init__.py b/examples/rxf/__init__.py new file mode 100644 index 0000000000..63453f9333 --- /dev/null +++ b/examples/rxf/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import src # noqa diff --git a/examples/rxf/src/__init__.py b/examples/rxf/src/__init__.py new file mode 100644 index 0000000000..306e232d6f --- /dev/null +++ b/examples/rxf/src/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import label_smoothed_cross_entropy_r3f, sentence_prediction_r3f # noqa diff --git a/examples/rxf/src/label_smoothed_cross_entropy_r3f.py b/examples/rxf/src/label_smoothed_cross_entropy_r3f.py new file mode 100644 index 0000000000..079db13e61 --- /dev/null +++ b/examples/rxf/src/label_smoothed_cross_entropy_r3f.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss + + +@register_criterion("label_smoothed_cross_entropy_r3f") +class LabelSmoothedCrossEntropyR3FCriterion(FairseqCriterion): + def __init__( + self, task, sentence_avg, label_smoothing, eps, r3f_lambda, noise_type + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.label_smoothing = label_smoothing + self.eps = eps + self.r3f_lambda = r3f_lambda + self.noise_type = noise_type + if self.noise_type in {"normal"}: + self.noise_sampler = torch.distributions.normal.Normal( + loc=0.0, scale=self.eps + ) + elif self.noise_type == "uniform": + self.noise_sampler = torch.distributions.uniform.Uniform( + low=-self.eps, high=self.eps + ) + else: + raise Exception(f"unrecognized noise type {self.noise_type}") + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', + help='epsilon for label smoothing, 0 means no label smoothing') + parser.add_argument('--eps', type=float, default=1e-5, + help='noise eps') + parser.add_argument('--r3f-lambda', type=float, default=1.0, + help='lambda for combining logistic loss and noisy KL loss') + parser.add_argument('--noise-type', type=str, default='normal', + choices=['normal', 'uniform'], + help='type of noises') + # fmt: on + + def _get_symm_kl(self, noised_logits, input_logits): + return ( + F.kl_div( + F.log_softmax(noised_logits, dim=-1, dtype=torch.float32), + F.softmax(input_logits, dim=-1, dtype=torch.float32), + None, + None, + "sum", + ) + + F.kl_div( + F.log_softmax(input_logits, dim=-1, dtype=torch.float32), + F.softmax(noised_logits, dim=-1, dtype=torch.float32), + None, + None, + "sum", + ) + ) / noised_logits.size(0) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + token_embeddings = model.encoder.embed_tokens(sample["net_input"]["src_tokens"]) + input_logits, extra = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss( + model, (input_logits, extra), sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + if model.training: + noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to( + token_embeddings + ) + noised_embeddings = token_embeddings.clone() + noise + + noised_logits, _ = model( + **sample["net_input"], token_embeddings=noised_embeddings + ) + symm_kl = self._get_symm_kl(noised_logits, input_logits) + + if model.training: + symm_kl = symm_kl * sample_size + loss = loss + self.r3f_lambda * symm_kl + + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + + if model.training: + logging_output.update( + symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data + ) + + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1, 1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.label_smoothing, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs) + + metrics.log_scalar("symm_kl", symm_kl_sum / sample_size, sample_size, round=3) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/examples/rxf/src/sentence_prediction_r3f.py b/examples/rxf/src/sentence_prediction_r3f.py new file mode 100644 index 0000000000..62dd63390c --- /dev/null +++ b/examples/rxf/src/sentence_prediction_r3f.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("sentence_prediction_r3f") +class SentencePredictionR3F(FairseqCriterion): + def __init__( + self, + task, + eps, + r3f_lambda, + noise_type, + classification_head_name, + regression_target, + ): + super().__init__(task) + self.eps = eps + self.r3f_lambda = r3f_lambda + self.noise_type = noise_type + self.classification_head_name = classification_head_name + self.regression_target = regression_target + if self.noise_type in {"normal"}: + self.noise_sampler = torch.distributions.normal.Normal( + loc=0.0, scale=self.eps + ) + elif self.noise_type == "uniform": + self.noise_sampler = torch.distributions.uniform.Uniform( + low=-self.eps, high=self.eps + ) + else: + raise Exception(f"unrecognized noise type {self.noise_type}") + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--eps', type=float, default=1e-5, + help='noise eps') + parser.add_argument('--r3f-lambda', type=float, default=1.0, + help='lambda for combining logistic loss and noisy KL loss') + parser.add_argument('--noise-type', type=str, default='uniform', + choices=['normal', 'uniform'], + help='type of noises for RXF methods') + parser.add_argument('--classification-head-name', + default='sentence_classification_head', + help='name of the classification head to use') + # fmt: on + + def _get_symm_kl(self, noised_logits, input_logits): + return ( + F.kl_div( + F.log_softmax(noised_logits, dim=-1, dtype=torch.float32), + F.softmax(input_logits, dim=-1, dtype=torch.float32), + None, + None, + "sum", + ) + + F.kl_div( + F.log_softmax(input_logits, dim=-1, dtype=torch.float32), + F.softmax(noised_logits, dim=-1, dtype=torch.float32), + None, + None, + "sum", + ) + ) / noised_logits.size(0) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.classification_head_name in model.classification_heads + ), "model must provide sentence classification head for --criterion=sentence_prediction" + + token_embeddings = model.encoder.sentence_encoder.embed_tokens( + sample["net_input"]["src_tokens"] + ) + input_logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + token_embeddings=token_embeddings, + ) + if model.training and self.noise_sampler: + noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to( + token_embeddings + ) + noised_embeddings = token_embeddings.detach().clone() + noise + + noised_logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + token_embeddings=noised_embeddings, + ) + symm_kl = self._get_symm_kl(noised_logits, input_logits) + else: + symm_kl = 0 + + targets = model.get_targets(sample, [input_logits]).view(-1) + sample_size = targets.numel() + + if not self.regression_target: + loss = F.nll_loss( + F.log_softmax(input_logits, dim=-1, dtype=torch.float32), + targets, + reduction="sum", + ) + if model.training: + symm_kl = symm_kl * sample_size + loss = loss + self.r3f_lambda * symm_kl + else: + logits = input_logits.squeeze().float() + targets = targets.float() + loss = F.mse_loss(logits, targets, reduction="sum") + + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + + if not self.regression_target: + preds = input_logits.max(dim=1)[1] + logging_output.update(ncorrect=(preds == targets).sum().item()) + + if model.training and self.noise_sampler: + logging_output.update( + symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data + ) + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + agg_output = { + "loss": loss_sum / sample_size / math.log(2), + "symm_kl": symm_kl_sum / sample_size, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + agg_output.update(accuracy=ncorrect / nsentences) + + if sample_size != ntokens: + agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) + return agg_output diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 62c495cb64..90e79e4651 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -59,6 +59,11 @@ def add_args(parser): choices=utils.get_available_activation_fns(), help='activation function to use for pooler layer' ) + parser.add_argument( + '--spectral-norm-classification-head', + action='store_true', + help='Apply spectral normalization on the classification head' + ) @property def supported_targets(self): @@ -66,7 +71,8 @@ def supported_targets(self): def forward( self, src_tokens, src_lengths, prev_output_tokens, - features_only=False, classification_head_name=None, **kwargs + features_only=False, classification_head_name=None, + token_embeddings=None, **kwargs ): if classification_head_name is not None: features_only = True @@ -74,6 +80,7 @@ def forward( encoder_out = self.encoder( src_tokens, src_lengths=src_lengths, + token_embeddings=token_embeddings, **kwargs, ) x, extra = self.decoder( @@ -127,11 +134,12 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * ) ) self.classification_heads[name] = BARTClassificationHead( - self.args.encoder_embed_dim, - inner_dim or self.args.encoder_embed_dim, - num_classes, - self.args.pooler_activation_fn, - self.args.pooler_dropout, + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + do_spectral_norm=self.args.spectral_norm_classification_head ) def upgrade_state_dict_named(self, state_dict, name): @@ -240,6 +248,7 @@ def __init__( num_classes, activation_fn, pooler_dropout, + do_spectral_norm=False ): super().__init__() self.dense = nn.Linear(input_dim, inner_dim) @@ -247,6 +256,9 @@ def __init__( self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes) + if do_spectral_norm: + self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) + def forward(self, features, **kwargs): x = features x = self.dropout(x) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index e9008076b3..0917927e34 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -98,6 +98,8 @@ def add_args(parser): help='scalar quantization noise and scalar quantization at training time') parser.add_argument('--untie-weights-roberta', action='store_true', help='Untie weights between embeddings and classifiers in RoBERTa') + parser.add_argument('--spectral-norm-classification-head', action='store_true', default=False, + help='Apply spectral normalization on the classification head') @classmethod def build_model(cls, args, task): @@ -143,13 +145,14 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * ) ) self.classification_heads[name] = RobertaClassificationHead( - self.args.encoder_embed_dim, - inner_dim or self.args.encoder_embed_dim, - num_classes, - self.args.pooler_activation_fn, - self.args.pooler_dropout, - self.args.quant_noise_pq, - self.args.quant_noise_pq_block_size, + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + q_noise=self.args.quant_noise_pq, + qn_block_size=self.args.quant_noise_pq_block_size, + do_spectral_norm=self.args.spectral_norm_classification_head, ) @property @@ -260,7 +263,17 @@ def forward(self, features, masked_tokens=None, **kwargs): class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout, q_noise=0, qn_block_size=8): + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + q_noise=0, + qn_block_size=8, + do_spectral_norm=False, + ): super().__init__() self.dense = nn.Linear(input_dim, inner_dim) self.activation_fn = utils.get_activation_fn(activation_fn) @@ -268,6 +281,11 @@ def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_drop self.out_proj = apply_quant_noise_( nn.Linear(inner_dim, num_classes), q_noise, qn_block_size ) + if do_spectral_norm: + if q_noise != 0: + raise NotImplementedError( + "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported") + self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) @@ -343,10 +361,11 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra - def extract_features(self, src_tokens, return_all_hiddens=False, **unused): + def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): inner_states, _ = self.sentence_encoder( src_tokens, last_state_only=not return_all_hiddens, + token_embeddings=kwargs.get('token_embeddings', None), ) features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C return features, {'inner_states': inner_states if return_all_hiddens else None} @@ -375,6 +394,8 @@ def base_architecture(args): args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) + args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) + args.spectral_norm_classification_head = getattr(args, 'spectral_nrom_classification_head', False) @register_model_architecture('roberta', 'roberta_base') diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index ae0ba5aad0..ca1c6aaf5c 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -361,9 +361,13 @@ def __init__(self, args, dictionary, embed_tokens): def build_encoder_layer(self, args): return TransformerEncoderLayer(args) - def forward_embedding(self, src_tokens): + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None + ): # embed tokens and positions - x = embed = self.embed_scale * self.embed_tokens(src_tokens) + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: @@ -373,7 +377,13 @@ def forward_embedding(self, src_tokens): x = self.quant_noise(x) return x, embed - def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): """ Args: src_tokens (LongTensor): tokens in the source language of shape @@ -382,6 +392,8 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings Returns: namedtuple: @@ -395,7 +407,7 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ - x, encoder_embedding = self.forward_embedding(src_tokens) + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # B x T x C -> T x B x C x = x.transpose(0, 1) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 9562430dfa..74cd1d0664 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -225,6 +225,7 @@ def forward( segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # compute padding mask. This is needed for multi-head attention @@ -232,7 +233,10 @@ def forward( if not self.traceable and not self.tpu and not padding_mask.any(): padding_mask = None - x = self.embed_tokens(tokens) + if token_embeddings is not None: + x = token_embeddings + else: + x = self.embed_tokens(tokens) if self.embed_scale is not None: x = x * self.embed_scale diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 53554e8d28..d684488abc 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -809,6 +809,26 @@ def test_pretrained_masked_lm_for_translation_sinusoidal_pos_emb(self): def test_pretrained_masked_lm_for_translation_encoder_only(self): self._test_pretrained_masked_lm_for_translation(True, True) + def test_r4f_roberta(self): + num_classes = 3 + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_r4f_roberta_head") as data_dir: + create_dummy_roberta_head_data(data_dir, num_classes=num_classes) + preprocess_lm_data(os.path.join(data_dir, 'input0')) + preprocess_lm_data(os.path.join(data_dir, 'label')) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=[ + "--user-dir", + "examples/rxf/src", + "--criterion", + 'sentence_prediction_r3f', + '--spectral-norm-classification-head', + ], + ) + def train_legacy_masked_language_model(data_dir, arch, extra_args=()): train_parser = options.get_training_parser() From 2d900bf30814d64035dc78012f9cc7b4fc063ea1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 16 Oct 2020 17:35:01 -0700 Subject: [PATCH 218/707] Fix tests (#1352) Summary: We need to keep `--num-workers=0` during tests Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1352 Reviewed By: alexeib Differential Revision: D24375411 Pulled By: myleott fbshipit-source-id: 9975ed5405f3b19b4dd0877ca15ee3081b185942 --- tests/gpu/test_binaries_gpu.py | 6 +++--- tests/test_binaries.py | 11 +++++++++-- tests/utils.py | 4 +++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index e3fadef9f2..2ac60a0934 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -139,7 +139,7 @@ def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=Fa "--ddp-backend", "no_c10d", "--num-workers", - 0, + "0", ] + (extra_flags or []), ) @@ -177,7 +177,7 @@ def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=Fa "--ddp-backend", "no_c10d", "--num-workers", - 0, + "0", "--quant-noise-scalar", "0.5", ] @@ -215,7 +215,7 @@ def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=Fa "--ddp-backend", "no_c10d", "--num-workers", - 0, + "0", "--restore-file", os.path.join(data_dir, "checkpoint_last.pt"), "--reset-optimizer", diff --git a/tests/test_binaries.py b/tests/test_binaries.py index d684488abc..aa5a6c69d1 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -888,6 +888,8 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): "1", "--dataset-impl", "raw", + "--num-workers", + "0", ] + list(extra_args), ) train.main(train_args) @@ -973,7 +975,7 @@ def train_masked_lm(data_dir, arch, extra_flags=None): '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', - '--num-workers', 0, + '--num-workers', '0', ] + (extra_flags or []), ) train.main(train_args) @@ -1000,7 +1002,7 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', - '--num-workers', 0, + '--num-workers', '0', ] + (extra_flags or []), ) train.main(train_args) @@ -1025,6 +1027,7 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', + '--num-workers', '0', ] + (extra_flags or []), ) train.main(train_args) @@ -1041,6 +1044,7 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) '--valid-subset', 'valid', '--max-tokens', '500', '--no-progress-bar', + '--num-workers', '0', ] ) validate.main(validate_args) @@ -1054,6 +1058,7 @@ def eval_lm_main(data_dir): data_dir, '--path', os.path.join(data_dir, 'checkpoint_last.pt'), '--no-progress-bar', + '--num-workers', '0', ], ) eval_lm.main(eval_lm_args) @@ -1117,6 +1122,8 @@ def train_masked_language_model(data_dir, arch, extra_args=()): "1", "--dataset-impl", "raw", + "--num-workers", + "0", ] + list(extra_args), ) train.main(train_args) diff --git a/tests/utils.py b/tests/utils.py index e8528292e4..44a35fdccf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -235,7 +235,7 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', - '--num-workers', 0, + '--num-workers', '0', ] + lang_flags + (extra_flags or []), ) train.main(train_args) @@ -252,6 +252,7 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' '--valid-subset', 'valid', '--max-tokens', '500', '--no-progress-bar', + '--num-workers', '0', ] + lang_flags + (extra_valid_flags or []) ) validate.main(validate_args) @@ -273,6 +274,7 @@ def generate_main(data_dir, extra_flags=None): '--max-len-b', '5', '--gen-subset', 'valid', '--no-progress-bar', + '--num-workers', '0', ] + (extra_flags or []), ) From bc0474d96e2201d3fac390d6896b4cfca5c0f561 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 17 Oct 2020 13:37:15 -0700 Subject: [PATCH 219/707] fix infer.py (#1354) Summary: this fixes infer.py that was broken by #2716 / D24243384 (https://github.com/pytorch/fairseq/commit/e0d5d8e669528be579d7aa4749fbcfe5cacdce90) we should prob add some tests for infer.py. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1354 Reviewed By: myleott Differential Revision: D24381516 Pulled By: alexeib fbshipit-source-id: b49e6bed7d239a55b8536d13c75fd5287330b1b1 --- examples/speech_recognition/infer.py | 37 ++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index b27cf5add5..fe5f607d1a 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -38,13 +38,16 @@ def add_asr_eval_argument(parser): help="wfstlm on dictonary\ output units", ) - parser.add_argument( - "--lm-weight", - "--lm_weight", - type=float, - default=0.2, - help="weight for lm while interpolating with neural score", - ) + try: + parser.add_argument( + "--lm-weight", + "--lm_weight", + type=float, + default=0.2, + help="weight for lm while interpolating with neural score", + ) + except: + pass parser.add_argument( "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" ) @@ -282,7 +285,25 @@ def main(args, task=None, model_state=None): # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + def build_generator(args): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, task.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, task.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, task.target_dictionary) + else: + print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment') + + # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task + generator = build_generator(args) if args.load_emissions: generator = ExistingEmissionsDecoder( From 7a3f20d0fad62d696aed17d801b840d5c25cc4f5 Mon Sep 17 00:00:00 2001 From: phantomcoder1996 Date: Sun, 18 Oct 2020 10:45:54 -0700 Subject: [PATCH 220/707] fix a bug that caused the label generation script for Librispeech not to work on windows (#2745) Summary: ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2744 Pull Request resolved: https://github.com/pytorch/fairseq/pull/2745 Reviewed By: alexeib Differential Revision: D24381143 Pulled By: myleott fbshipit-source-id: 61690f30e988a9d477e6d5c927c49b10652925c7 --- examples/wav2vec/libri_labels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wav2vec/libri_labels.py b/examples/wav2vec/libri_labels.py index 812528732f..3fa1ec4c8b 100644 --- a/examples/wav2vec/libri_labels.py +++ b/examples/wav2vec/libri_labels.py @@ -33,7 +33,7 @@ def main(): line = line.strip() dir = os.path.dirname(line) if dir not in transcriptions: - parts = dir.split("/") + parts = dir.split(os.path.sep) trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt" path = os.path.join(root, dir, trans_path) assert os.path.exists(path) From 5695cdfb2caf85b3670d6bda7c0bc31666138263 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 18 Oct 2020 13:04:09 -0700 Subject: [PATCH 221/707] Disable isort on several files Summary: isort introduces some import/circular dependency issues. Ideally we'll fix those in the future, but for now just disable isort on many of the `__init__.py` files so that we can apply black+isort across the repo. Reviewed By: alexeib Differential Revision: D24377771 fbshipit-source-id: 9a16343a13f162582722b4147959caea29682bbe --- examples/noisychannel/rerank_options.py | 2 +- fairseq/__init__.py | 1 + fairseq/criterions/__init__.py | 1 + fairseq/criterions/wav2vec_criterion.py | 1 + fairseq/data/__init__.py | 1 + fairseq/model_parallel/modules/__init__.py | 1 + fairseq/models/__init__.py | 1 + fairseq/models/lstm_lm.py | 1 + fairseq/models/nat/__init__.py | 6 ++++++ fairseq/modules/__init__.py | 1 + fairseq/optim/__init__.py | 1 + fairseq/optim/lr_scheduler/__init__.py | 1 + fairseq/tasks/__init__.py | 1 + fairseq/tasks/translation_lev.py | 1 + 14 files changed, 19 insertions(+), 1 deletion(-) diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py index 55c57051ff..a425fb295b 100644 --- a/examples/noisychannel/rerank_options.py +++ b/examples/noisychannel/rerank_options.py @@ -103,7 +103,7 @@ def add_reranking_args(parser): help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)') group.add_argument('--normalize', action='store_true', help='whether to normalize by src and target len') - + # fmt: on return group diff --git a/fairseq/__init__.py b/fairseq/__init__.py index dfa4ef7898..a4244c8a3a 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" __all__ = ['pdb'] __version__ = '1.0.0a0' diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 30edb2f312..a7eb5f6f3c 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import importlib import os diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 85403cb428..cc743524d2 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -32,6 +32,7 @@ def add_args(parser): help='weights for additional loss terms (not first one)') parser.add_argument('--log-keys', type=str, default=None, help='output keys to log') + # fmt: on def forward(self, model, sample, reduce=True, log_pred=False): """Compute the loss for the given sample. diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index d195e59493..785a0aa643 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" from .dictionary import Dictionary, TruncatedDictionary diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index 5c9431f92b..26401dcc7c 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" from .multihead_attention import ModelParallelMultiheadAttention from .transformer_layer import ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index e441e7cd7d..7ff9442711 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import argparse import importlib diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py index 82bd02f6f7..1a39b95289 100644 --- a/fairseq/models/lstm_lm.py +++ b/fairseq/models/lstm_lm.py @@ -51,6 +51,7 @@ def add_args(parser): parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', help='share decoder input and output embeddings') + # fmt: on @classmethod def build_model(cls, args, task): diff --git a/fairseq/models/nat/__init__.py b/fairseq/models/nat/__init__.py index b6ca06acb9..05fe822487 100644 --- a/fairseq/models/nat/__init__.py +++ b/fairseq/models/nat/__init__.py @@ -1,3 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + from .fairseq_nat_model import * from .nonautoregressive_transformer import * from .nat_crf_transformer import * diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index d526d4a92e..52432e0de4 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" from .adaptive_input import AdaptiveInput from .adaptive_softmax import AdaptiveSoftmax diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index e2c3a3ceff..94eb2c7ee9 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import importlib import os diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index 85773aab39..7b72c25784 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import importlib import os diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index eda2ed34b7..e0abce253c 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import argparse import importlib diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 18ac0ca385..3af9bb2532 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -30,6 +30,7 @@ def add_args(parser): '--noise', default='random_delete', choices=['random_delete', 'random_mask', 'no_noise', 'full_mask']) + # fmt: on def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. From a48f235636557b8d3bc4922a6fa90f3a0fa57955 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 18 Oct 2020 18:13:29 -0700 Subject: [PATCH 222/707] Apply black+isort (#1357) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c --- docs/conf.py | 55 +- examples/__init__.py | 2 +- examples/backtranslation/deduplicate_lines.py | 10 +- examples/backtranslation/extract_bt_data.py | 57 +- examples/byte_level_bpe/get_bitext.py | 216 +-- examples/byte_level_bpe/gru_transformer.py | 11 +- examples/constrained_decoding/normalize.py | 7 +- examples/constrained_decoding/tok.py | 11 +- examples/criss/mining/mine.py | 161 +- examples/criss/save_encoder.py | 107 +- .../sentence_retrieval/encoder_analysis.py | 37 +- examples/latent_depth/src/__init__.py | 8 +- .../latent_depth/src/loss/latent_depth.py | 39 +- .../models/latent_multilingual_transformer.py | 53 +- .../src/models/latent_transformer.py | 44 +- .../latent_depth/src/modules/latent_layers.py | 31 +- .../multilingual_translation_latent_depth.py | 78 +- .../linformer/src/models/linformer_roberta.py | 119 +- .../src/modules/linformer_sentence_encoder.py | 12 +- .../linformer_sentence_encoder_layer.py | 3 +- .../src/modules/multihead_linear_attention.py | 78 +- examples/m2m_100/tokenizers/tokenize_indic.py | 14 +- examples/m2m_100/tokenizers/tokenize_thai.py | 1 + examples/m2m_100/tokenizers/tokenize_zh.py | 2 + examples/megatron_11b/detok.py | 16 +- examples/noisychannel/rerank.py | 277 +++- examples/noisychannel/rerank_generate.py | 374 +++-- examples/noisychannel/rerank_options.py | 57 +- examples/noisychannel/rerank_score_bw.py | 188 ++- examples/noisychannel/rerank_score_lm.py | 57 +- examples/noisychannel/rerank_tune.py | 61 +- examples/noisychannel/rerank_utils.py | 557 ++++--- examples/paraphraser/paraphrase.py | 61 +- examples/pointer_generator/postprocess.py | 8 +- examples/pointer_generator/preprocess.py | 4 +- .../pointer_generator/src/transformer_pg.py | 10 +- .../commonsense_qa/commonsense_qa_task.py | 100 +- .../roberta/multiprocessing_bpe_encoder.py | 21 +- examples/roberta/preprocess_RACE.py | 31 +- examples/roberta/wsc/wsc_criterion.py | 105 +- examples/roberta/wsc/wsc_task.py | 158 +- examples/roberta/wsc/wsc_utils.py | 129 +- examples/simultaneous_translation/__init__.py | 2 +- .../criterions/__init__.py | 1 + ...moothed_cross_entropy_latency_augmented.py | 15 +- .../eval/agents/__init__.py | 14 +- .../eval/agents/agent.py | 18 +- .../eval/agents/simul_trans_agent.py | 40 +- .../eval/agents/simul_trans_text_agent.py | 35 +- .../eval/agents/word_splitter.py | 6 +- .../simultaneous_translation/eval/client.py | 38 +- .../eval/eval_latency.py | 26 +- .../simultaneous_translation/eval/evaluate.py | 55 +- .../eval/scorers/__init__.py | 18 +- .../eval/scorers/scorer.py | 57 +- .../eval/scorers/text_scorer.py | 8 +- .../simultaneous_translation/eval/server.py | 33 +- .../models/__init__.py | 9 +- .../models/transformer_monotonic_attention.py | 156 +- .../modules/__init__.py | 14 +- .../modules/monotonic_multihead_attention.py | 206 +-- .../modules/monotonic_transformer_layer.py | 27 +- .../utils/__init__.py | 6 +- .../utils/functions.py | 18 +- .../simultaneous_translation/utils/latency.py | 113 +- examples/speech_recognition/__init__.py | 2 +- .../speech_recognition/criterions/ASG_loss.py | 2 +- examples/speech_recognition/data/__init__.py | 3 +- .../speech_recognition/data/asr_dataset.py | 27 +- examples/speech_recognition/data/collaters.py | 10 +- .../datasets/asr_prep_json.py | 81 +- examples/speech_recognition/infer.py | 84 +- .../speech_recognition/models/__init__.py | 7 +- .../models/vggtransformer.py | 30 +- .../models/w2l_conv_glu_enc.py | 1 - examples/speech_recognition/tasks/__init__.py | 7 +- .../tasks/speech_recognition.py | 22 +- examples/speech_recognition/w2l_decoder.py | 10 +- examples/speech_to_text/data_utils.py | 216 +-- examples/speech_to_text/prep_covost_data.py | 250 +-- .../speech_to_text/prep_librispeech_data.py | 109 +- examples/speech_to_text/prep_mustc_data.py | 168 ++- examples/translation_moe/score.py | 95 +- examples/translation_moe/src/logsumexp_moe.py | 2 +- .../src/mean_pool_gating_network.py | 8 +- .../translation_moe/src/translation_moe.py | 91 +- .../aggregate_scores.py | 29 +- .../unsupervised_quality_estimation/meteor.py | 76 +- .../repeat_lines.py | 14 +- examples/wav2vec/vq-wav2vec_featurize.py | 38 +- examples/wav2vec/wav2vec_featurize.py | 97 +- examples/wav2vec/wav2vec_manifest.py | 50 +- fairseq/__init__.py | 11 +- fairseq/benchmark/__init__.py | 7 +- fairseq/benchmark/dummy_lm.py | 41 +- fairseq/benchmark/dummy_masked_lm.py | 43 +- fairseq/benchmark/dummy_model.py | 55 +- fairseq/benchmark/dummy_mt.py | 37 +- fairseq/binarizer.py | 3 +- fairseq/checkpoint_utils.py | 90 +- fairseq/criterions/composite_loss.py | 23 +- fairseq/criterions/fairseq_criterion.py | 35 +- .../label_smoothed_cross_entropy.py | 92 +- ...l_smoothed_cross_entropy_with_alignment.py | 84 +- fairseq/criterions/legacy_masked_lm.py | 101 +- fairseq/criterions/masked_lm.py | 31 +- fairseq/criterions/nat_loss.py | 61 +- fairseq/criterions/sentence_prediction.py | 50 +- fairseq/criterions/sentence_ranking.py | 58 +- fairseq/criterions/wav2vec_criterion.py | 81 +- fairseq/data/__init__.py | 119 +- fairseq/data/add_target_dataset.py | 26 +- fairseq/data/append_token_dataset.py | 1 - fairseq/data/audio/audio_utils.py | 24 +- .../data/audio/feature_transforms/__init__.py | 35 +- .../audio/feature_transforms/global_cmvn.py | 9 +- .../audio/feature_transforms/specaugment.py | 71 +- .../feature_transforms/utterance_cmvn.py | 16 +- fairseq/data/audio/raw_audio_dataset.py | 11 +- fairseq/data/audio/speech_to_text_dataset.py | 283 ++-- fairseq/data/backtranslation_dataset.py | 26 +- fairseq/data/base_wrapper_dataset.py | 7 +- fairseq/data/bucket_pad_length_dataset.py | 3 +- fairseq/data/colorize_dataset.py | 1 + fairseq/data/concat_dataset.py | 20 +- fairseq/data/concat_sentences_dataset.py | 14 +- fairseq/data/data_utils.py | 162 +- fairseq/data/denoising_dataset.py | 127 +- fairseq/data/dictionary.py | 3 +- fairseq/data/encoders/__init__.py | 10 +- fairseq/data/encoders/byte_bpe.py | 17 +- fairseq/data/encoders/byte_utils.py | 18 +- fairseq/data/encoders/bytes.py | 12 +- fairseq/data/encoders/characters.py | 5 +- fairseq/data/encoders/fastbpe.py | 10 +- fairseq/data/encoders/gpt2_bpe.py | 22 +- fairseq/data/encoders/gpt2_bpe_utils.py | 57 +- fairseq/data/encoders/hf_bert_bpe.py | 20 +- fairseq/data/encoders/hf_byte_bpe.py | 19 +- fairseq/data/encoders/moses_tokenizer.py | 16 +- fairseq/data/encoders/nltk_tokenizer.py | 8 +- fairseq/data/encoders/sentencepiece_bpe.py | 16 +- fairseq/data/encoders/space_tokenizer.py | 5 +- fairseq/data/encoders/subword_nmt_bpe.py | 26 +- fairseq/data/encoders/utils.py | 10 +- fairseq/data/fairseq_dataset.py | 31 +- fairseq/data/id_dataset.py | 1 - fairseq/data/indexed_dataset.py | 174 ++- fairseq/data/iterators.py | 92 +- fairseq/data/language_pair_dataset.py | 193 ++- fairseq/data/legacy/__init__.py | 11 +- fairseq/data/legacy/block_pair_dataset.py | 1 - fairseq/data/legacy/masked_lm_dataset.py | 117 +- fairseq/data/legacy/masked_lm_dictionary.py | 22 +- fairseq/data/list_dataset.py | 1 - fairseq/data/lm_context_window_dataset.py | 23 +- fairseq/data/lru_cache_dataset.py | 1 - fairseq/data/mask_tokens_dataset.py | 17 +- fairseq/data/monolingual_dataset.py | 100 +- .../multilingual/multilingual_data_manager.py | 14 +- .../multilingual/sampled_multi_dataset.py | 199 ++- .../sampled_multi_epoch_dataset.py | 70 +- fairseq/data/multilingual/sampling_method.py | 34 +- fairseq/data/nested_dictionary_dataset.py | 23 +- fairseq/data/noising.py | 66 +- fairseq/data/num_samples_dataset.py | 1 - fairseq/data/numel_dataset.py | 1 - fairseq/data/offset_tokens_dataset.py | 1 - fairseq/data/pad_dataset.py | 3 - fairseq/data/plasma_utils.py | 25 +- fairseq/data/prepend_token_dataset.py | 1 - fairseq/data/raw_label_dataset.py | 1 - fairseq/data/replace_dataset.py | 12 +- fairseq/data/resampling_dataset.py | 3 +- fairseq/data/roll_dataset.py | 1 - fairseq/data/round_robin_zip_datasets.py | 37 +- fairseq/data/shorten_dataset.py | 20 +- fairseq/data/sort_dataset.py | 1 - fairseq/data/strip_token_dataset.py | 1 - fairseq/data/subsample_dataset.py | 8 +- fairseq/data/token_block_dataset.py | 12 +- fairseq/data/transform_eos_dataset.py | 33 +- .../data/transform_eos_lang_pair_dataset.py | 42 +- fairseq/dataclass/data_class.py | 19 +- fairseq/dataclass/utils.py | 6 +- fairseq/distributed_utils.py | 194 ++- fairseq/file_utils.py | 96 +- fairseq/hub_utils.py | 94 +- fairseq/incremental_decoding_utils.py | 7 +- fairseq/iterative_refinement_generator.py | 130 +- fairseq/legacy_distributed_data_parallel.py | 23 +- fairseq/logging/meters.py | 65 +- fairseq/logging/metrics.py | 11 +- fairseq/logging/progress_bar.py | 88 +- fairseq/model_parallel/__init__.py | 2 +- fairseq/model_parallel/criterions/__init__.py | 6 +- .../vocab_parallel_cross_entropy.py | 53 +- fairseq/model_parallel/megatron_trainer.py | 13 +- fairseq/model_parallel/models/__init__.py | 10 +- .../pipeline_parallel_transformer/layers.py | 172 ++- .../pipeline_parallel_transformer/model.py | 337 +++-- .../model_parallel/models/roberta/model.py | 105 +- fairseq/model_parallel/models/transformer.py | 37 +- .../model_parallel/models/transformer_lm.py | 168 ++- fairseq/model_parallel/modules/__init__.py | 19 +- .../modules/multihead_attention.py | 92 +- .../modules/transformer_layer.py | 13 +- .../modules/transformer_sentence_encoder.py | 9 +- .../transformer_sentence_encoder_layer.py | 14 +- fairseq/models/bart/hub_interface.py | 81 +- fairseq/models/bart/model.py | 277 ++-- fairseq/models/composite_encoder.py | 4 +- fairseq/models/distributed_fairseq_model.py | 29 +- fairseq/models/fairseq_encoder.py | 11 +- fairseq/models/fairseq_incremental_decoder.py | 24 +- fairseq/models/fairseq_model.py | 2 +- fairseq/models/fconv.py | 322 ++-- fairseq/models/fconv_lm.py | 125 +- fairseq/models/fconv_self_att.py | 295 ++-- fairseq/models/huggingface/__init__.py | 10 +- fairseq/models/huggingface/hf_gpt2.py | 70 +- fairseq/models/lightconv.py | 700 ++++++--- fairseq/models/lightconv_lm.py | 364 +++-- fairseq/models/lstm.py | 300 ++-- fairseq/models/lstm_lm.py | 69 +- fairseq/models/masked_lm.py | 289 ++-- fairseq/models/model_utils.py | 8 +- fairseq/models/multilingual_transformer.py | 139 +- fairseq/models/nat/cmlm_transformer.py | 28 +- fairseq/models/nat/fairseq_nat_model.py | 42 +- fairseq/models/nat/insertion_transformer.py | 26 +- ...iterative_nonautoregressive_transformer.py | 59 +- fairseq/models/nat/levenshtein_transformer.py | 146 +- fairseq/models/nat/levenshtein_utils.py | 69 +- fairseq/models/nat/nat_crf_transformer.py | 62 +- .../models/nat/nonautoregressive_ensembles.py | 73 +- .../nat/nonautoregressive_transformer.py | 78 +- fairseq/models/roberta/alignment_utils.py | 25 +- fairseq/models/roberta/hub_interface.py | 121 +- fairseq/models/roberta/model.py | 332 ++-- fairseq/models/roberta/model_camembert.py | 31 +- fairseq/models/roberta/model_xlmr.py | 19 +- fairseq/models/speech_to_text/__init__.py | 2 +- fairseq/models/speech_to_text/berard.py | 149 +- .../models/speech_to_text/s2t_transformer.py | 261 ++-- fairseq/models/transformer.py | 9 +- fairseq/models/transformer_align.py | 2 +- .../models/transformer_from_pretrained_xlm.py | 13 +- fairseq/models/wav2vec/wav2vec.py | 2 +- fairseq/models/wav2vec/wav2vec2.py | 6 +- fairseq/models/wav2vec/wav2vec2_asr.py | 18 +- fairseq/modules/__init__.py | 72 +- fairseq/modules/adaptive_input.py | 20 +- fairseq/modules/adaptive_softmax.py | 110 +- fairseq/modules/beamable_mm.py | 10 +- fairseq/modules/character_token_embedder.py | 80 +- fairseq/modules/conv_tbc.py | 20 +- fairseq/modules/cross_entropy.py | 30 +- .../downsampled_multihead_attention.py | 122 +- fairseq/modules/dynamic_convolution.py | 181 ++- fairseq/modules/dynamic_crf_layer.py | 45 +- .../dynamicconv_layer/cuda_function_gen.py | 4 +- .../dynamicconv_layer/dynamicconv_layer.py | 109 +- fairseq/modules/dynamicconv_layer/setup.py | 16 +- fairseq/modules/fairseq_dropout.py | 9 +- fairseq/modules/gumbel_vector_quantizer.py | 7 +- fairseq/modules/kmeans_vector_quantizer.py | 31 +- fairseq/modules/layer_norm.py | 1 + .../lightconv_layer/cuda_function_gen.py | 4 +- .../lightconv_layer/lightconv_layer.py | 60 +- fairseq/modules/lightconv_layer/setup.py | 21 +- fairseq/modules/lightweight_convolution.py | 164 +- fairseq/modules/linearized_convolution.py | 28 +- fairseq/modules/multihead_attention.py | 35 +- fairseq/modules/positional_embedding.py | 13 +- fairseq/modules/quant_noise.py | 31 +- fairseq/modules/quantization/pq/em.py | 2 +- .../quantization/pq/modules/__init__.py | 2 +- .../modules/quantization/pq/modules/qemb.py | 48 +- fairseq/modules/quantization/pq/utils.py | 18 +- .../quantization/scalar/modules/__init__.py | 4 +- .../quantization/scalar/modules/qact.py | 14 +- .../quantization/scalar/modules/qconv.py | 7 +- .../quantization/scalar/modules/qemb.py | 47 +- .../quantization/scalar/modules/qlinear.py | 7 +- fairseq/modules/quantization/scalar/ops.py | 4 +- fairseq/modules/quantization/scalar/utils.py | 22 +- fairseq/modules/sparse_multihead_attention.py | 76 +- .../sparse_transformer_sentence_encoder.py | 31 +- ...arse_transformer_sentence_encoder_layer.py | 12 +- fairseq/modules/transformer_layer.py | 50 +- .../modules/transformer_sentence_encoder.py | 36 +- .../transformer_sentence_encoder_layer.py | 27 +- fairseq/modules/unfold.py | 8 +- fairseq/nan_detector.py | 6 +- fairseq/optim/adadelta.py | 12 +- fairseq/optim/adafactor.py | 145 +- fairseq/optim/adagrad.py | 8 +- fairseq/optim/adamax.py | 68 +- fairseq/optim/dynamic_loss_scaler.py | 27 +- fairseq/optim/fairseq_optimizer.py | 4 +- fairseq/optim/fp16_optimizer.py | 157 +- fairseq/optim/fused_adam.py | 212 +-- fairseq/optim/fused_lamb.py | 15 +- fairseq/optim/lr_scheduler/fixed_schedule.py | 18 +- .../lr_scheduler/polynomial_decay_schedule.py | 36 +- .../lr_scheduler/reduce_lr_on_plateau.py | 29 +- .../lr_scheduler/tri_stage_lr_scheduler.py | 12 +- .../lr_scheduler/triangular_lr_scheduler.py | 10 +- fairseq/optim/sgd.py | 10 +- fairseq/optim/shard.py | 17 +- fairseq/options.py | 4 +- fairseq/pdb.py | 2 +- fairseq/quantization_utils.py | 46 +- fairseq/scoring/__init__.py | 1 + fairseq/scoring/bleu.py | 13 +- fairseq/scoring/chrf.py | 3 +- fairseq/scoring/tokenizer.py | 18 +- fairseq/scoring/wer.py | 9 +- fairseq/search.py | 7 +- fairseq/sequence_generator.py | 68 +- fairseq/sequence_scorer.py | 65 +- fairseq/tasks/audio_pretraining.py | 17 +- fairseq/tasks/cross_lingual_lm.py | 111 +- fairseq/tasks/denoising.py | 145 +- fairseq/tasks/fairseq_task.py | 5 +- fairseq/tasks/language_modeling.py | 2 +- fairseq/tasks/legacy_masked_lm.py | 64 +- fairseq/tasks/masked_lm.py | 148 +- fairseq/tasks/multilingual_denoising.py | 131 +- fairseq/tasks/multilingual_masked_lm.py | 165 +- fairseq/tasks/multilingual_translation.py | 216 ++- fairseq/tasks/semisupervised_translation.py | 234 ++- fairseq/tasks/sentence_prediction.py | 125 +- fairseq/tasks/sentence_ranking.py | 101 +- fairseq/tasks/speech_to_text.py | 106 +- fairseq/tasks/translation.py | 229 +-- .../tasks/translation_from_pretrained_bart.py | 73 +- fairseq/tasks/translation_lev.py | 105 +- .../tasks/translation_multi_simple_epoch.py | 180 ++- fairseq/token_generation_constraints.py | 28 +- fairseq/tokenizer.py | 1 + fairseq/trainer.py | 210 ++- fairseq/utils.py | 2 +- fairseq_cli/eval_lm.py | 119 +- fairseq_cli/generate.py | 217 ++- fairseq_cli/interactive.py | 152 +- fairseq_cli/preprocess.py | 146 +- fairseq_cli/score.py | 28 +- fairseq_cli/validate.py | 30 +- hubconf.py | 38 +- scripts/average_checkpoints.py | 50 +- scripts/build_sym_alignment.py | 48 +- scripts/compare_namespaces.py | 22 +- scripts/constraints/extract.py | 31 +- scripts/constraints/validate.py | 1 + scripts/count_docs.py | 10 +- scripts/read_binarized.py | 11 +- scripts/rm_pt.py | 79 +- scripts/shard_docs.py | 14 +- scripts/split_train_valid_docs.py | 21 +- scripts/spm_decode.py | 10 +- scripts/spm_encode.py | 58 +- setup.py | 148 +- tests/speech_recognition/asr_test_base.py | 19 +- .../speech_recognition/test_cross_entropy.py | 5 +- tests/speech_recognition/test_data_utils.py | 53 +- tests/test_average_checkpoints.py | 69 +- tests/test_backtranslation_dataset.py | 29 +- tests/test_binaries.py | 1338 +++++++++++------ tests/test_bmuf.py | 5 +- tests/test_character_token_embedder.py | 14 +- tests/test_concat_dataset.py | 34 +- tests/test_constraints.py | 125 +- tests/test_convtbc.py | 20 +- tests/test_dictionary.py | 82 +- tests/test_file_io.py | 12 +- tests/test_fp16_optimizer.py | 50 +- tests/test_inference_dropout.py | 8 +- tests/test_iterators.py | 5 +- tests/test_label_smoothing.py | 67 +- tests/test_lstm_jitable.py | 14 +- tests/test_memory_efficient_fp16.py | 8 +- tests/test_metrics.py | 55 +- tests/test_multihead_attention.py | 9 +- tests/test_noising.py | 8 +- tests/test_reproducibility.py | 113 +- tests/test_resampling_dataset.py | 1 - tests/test_sequence_generator.py | 274 ++-- tests/test_sequence_scorer.py | 93 +- tests/test_sparse_multihead_attention.py | 126 +- tests/test_token_block_dataset.py | 19 +- tests/test_train.py | 87 +- tests/test_utils.py | 80 +- tests/utils.py | 373 +++-- train.py | 2 +- 396 files changed, 15455 insertions(+), 9847 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d6d150c1f0..52971a27e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,10 +20,11 @@ import os import sys + # source code directory, relative to this file, for sphinx-autobuild -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) -source_suffix = ['.rst'] +source_suffix = [".rst"] # -- General configuration ------------------------------------------------ @@ -35,34 +36,34 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinxarg.ext', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinxarg.ext", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'fairseq' -copyright = '2019, Facebook AI Research (FAIR)' -author = 'Facebook AI Research (FAIR)' +project = "fairseq" +copyright = "2019, Facebook AI Research (FAIR)" +author = "Facebook AI Research (FAIR)" -github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' +github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.9.0' +version = "0.9.0" # The full version, including alpha/beta/rc tags. -release = '0.9.0' +release = "0.9.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -74,11 +75,11 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' -highlight_language = 'python' +pygments_style = "sphinx" +highlight_language = "python" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -89,7 +90,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -100,11 +101,11 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_context = { - 'css_files': [ - '_static/theme_overrides.css', # override wide tables in RTD theme + "css_files": [ + "_static/theme_overrides.css", # override wide tables in RTD theme ], } @@ -113,7 +114,7 @@ # # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars -#html_sidebars = { +# html_sidebars = { # '**': [ # 'about.html', # 'navigation.html', @@ -121,12 +122,12 @@ # 'searchbox.html', # 'donate.html', # ] -#} +# } # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'python': ('https://docs.python.org/', None), - 'torch': ('https://pytorch.org/docs/master/', None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "python": ("https://docs.python.org/", None), + "torch": ("https://pytorch.org/docs/master/", None), } diff --git a/examples/__init__.py b/examples/__init__.py index 9369be1b77..9a6b08a75b 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -3,6 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -__version__ = '0.9.0' +__version__ = "0.9.0" import examples.noisychannel # noqa diff --git a/examples/backtranslation/deduplicate_lines.py b/examples/backtranslation/deduplicate_lines.py index 35a407e556..50e458328c 100644 --- a/examples/backtranslation/deduplicate_lines.py +++ b/examples/backtranslation/deduplicate_lines.py @@ -7,8 +7,8 @@ import argparse import fileinput import hashlib -from multiprocessing import Pool import sys +from multiprocessing import Pool def get_hashes_and_lines(raw_line): @@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--workers', type=int, default=10) - parser.add_argument('files', nargs='*', help='input files') + parser.add_argument("--workers", type=int, default=10) + parser.add_argument("files", nargs="*", help="input files") args = parser.parse_args() seen = set() - with fileinput.input(args.files, mode='rb') as h: + with fileinput.input(args.files, mode="rb") as h: pool = Pool(args.workers) results = pool.imap_unordered(get_hashes_and_lines, h, 1000) for i, (hash, raw_line) in enumerate(results): @@ -37,5 +37,5 @@ def main(): print(file=sys.stderr, flush=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/backtranslation/extract_bt_data.py b/examples/backtranslation/extract_bt_data.py index 26a46942c8..e766391e87 100644 --- a/examples/backtranslation/extract_bt_data.py +++ b/examples/backtranslation/extract_bt_data.py @@ -11,26 +11,38 @@ def main(): - parser = argparse.ArgumentParser(description=( - 'Extract back-translations from the stdout of fairseq-generate. ' - 'If there are multiply hypotheses for a source, we only keep the first one. ' - )) - parser.add_argument('--output', required=True, help='output prefix') - parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)') - parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)') - parser.add_argument('--minlen', type=int, help='min length filter') - parser.add_argument('--maxlen', type=int, help='max length filter') - parser.add_argument('--ratio', type=float, help='ratio filter') - parser.add_argument('files', nargs='*', help='input files') + parser = argparse.ArgumentParser( + description=( + "Extract back-translations from the stdout of fairseq-generate. " + "If there are multiply hypotheses for a source, we only keep the first one. " + ) + ) + parser.add_argument("--output", required=True, help="output prefix") + parser.add_argument( + "--srclang", required=True, help="source language (extracted from H-* lines)" + ) + parser.add_argument( + "--tgtlang", required=True, help="target language (extracted from S-* lines)" + ) + parser.add_argument("--minlen", type=int, help="min length filter") + parser.add_argument("--maxlen", type=int, help="max length filter") + parser.add_argument("--ratio", type=float, help="ratio filter") + parser.add_argument("files", nargs="*", help="input files") args = parser.parse_args() def validate(src, tgt): - srclen = len(src.split(' ')) if src != '' else 0 - tgtlen = len(tgt.split(' ')) if tgt != '' else 0 + srclen = len(src.split(" ")) if src != "" else 0 + tgtlen = len(tgt.split(" ")) if tgt != "" else 0 if ( (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) - or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen)) - or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)) + or ( + args.maxlen is not None + and (srclen > args.maxlen or tgtlen > args.maxlen) + ) + or ( + args.ratio is not None + and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio) + ) ): return False return True @@ -41,19 +53,20 @@ def safe_index(toks, index, default): except IndexError: return default - with open(args.output + '.' + args.srclang, 'w') as src_h, \ - open(args.output + '.' + args.tgtlang, 'w') as tgt_h: + with open(args.output + "." + args.srclang, "w") as src_h, open( + args.output + "." + args.tgtlang, "w" + ) as tgt_h: for line in tqdm(fileinput.input(args.files)): - if line.startswith('S-'): - tgt = safe_index(line.rstrip().split('\t'), 1, '') - elif line.startswith('H-'): + if line.startswith("S-"): + tgt = safe_index(line.rstrip().split("\t"), 1, "") + elif line.startswith("H-"): if tgt is not None: - src = safe_index(line.rstrip().split('\t'), 2, '') + src = safe_index(line.rstrip().split("\t"), 2, "") if validate(src, tgt): print(src, file=src_h) print(tgt, file=tgt_h) tgt = None -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/byte_level_bpe/get_bitext.py b/examples/byte_level_bpe/get_bitext.py index 7770ea667b..6ac1eeec1e 100644 --- a/examples/byte_level_bpe/get_bitext.py +++ b/examples/byte_level_bpe/get_bitext.py @@ -4,203 +4,251 @@ # LICENSE file in the root directory of this source tree. -import os.path as op import argparse import os -from multiprocessing import cpu_count +import os.path as op from collections import namedtuple -from typing import Optional, List +from multiprocessing import cpu_count +from typing import List, Optional import sentencepiece as sp - -from fairseq.data.encoders.moses_tokenizer import MosesTokenizer -from fairseq.data.encoders.byte_utils import byte_encode -from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE -from fairseq.data.encoders.characters import Characters from fairseq.data.encoders.byte_bpe import ByteBPE +from fairseq.data.encoders.byte_utils import byte_encode from fairseq.data.encoders.bytes import Bytes +from fairseq.data.encoders.characters import Characters +from fairseq.data.encoders.moses_tokenizer import MosesTokenizer +from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE -SPLITS = ['train', 'valid', 'test'] +SPLITS = ["train", "valid", "test"] def _convert_xml(in_path: str, out_path: str): - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: ss = s.strip() - if not ss.startswith('', '').split('">') + ss = ss.replace("", "").split('">') assert len(ss) == 2 - f_o.write(ss[1].strip() + '\n') + f_o.write(ss[1].strip() + "\n") def _convert_train(in_path: str, out_path: str): - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: ss = s.strip() - if ss.startswith('<'): + if ss.startswith("<"): continue - f_o.write(ss.strip() + '\n') + f_o.write(ss.strip() + "\n") def _get_bytes(in_path: str, out_path: str): - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: - f_o.write(Bytes.encode(s.strip()) + '\n') + f_o.write(Bytes.encode(s.strip()) + "\n") def _get_chars(in_path: str, out_path: str): - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: - f_o.write(Characters.encode(s.strip()) + '\n') + f_o.write(Characters.encode(s.strip()) + "\n") def pretokenize(in_path: str, out_path: str, src: str, tgt: str): - Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang', - 'moses_no_dash_splits', 'moses_no_escape']) - args = Args(moses_source_lang=src, moses_target_lang=tgt, - moses_no_dash_splits=False, moses_no_escape=False) + Args = namedtuple( + "Args", + [ + "moses_source_lang", + "moses_target_lang", + "moses_no_dash_splits", + "moses_no_escape", + ], + ) + args = Args( + moses_source_lang=src, + moses_target_lang=tgt, + moses_no_dash_splits=False, + moses_no_escape=False, + ) pretokenizer = MosesTokenizer(args) - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: - f_o.write(pretokenizer.encode(s.strip()) + '\n') + f_o.write(pretokenizer.encode(s.strip()) + "\n") def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str): - with open(out_path, 'w') as f_o: + with open(out_path, "w") as f_o: for lang in [src, tgt]: - with open(f'{in_path_prefix}.{lang}') as f: + with open(f"{in_path_prefix}.{lang}") as f: for s in f: - f_o.write(byte_encode(s.strip()) + '\n') + f_o.write(byte_encode(s.strip()) + "\n") def _get_bpe(in_path: str, model_prefix: str, vocab_size: int): arguments = [ - f'--input={in_path}', f'--model_prefix={model_prefix}', - f'--model_type=bpe', f'--vocab_size={vocab_size}', - '--character_coverage=1.0', '--normalization_rule_name=identity', - f'--num_threads={cpu_count()}' + f"--input={in_path}", + f"--model_prefix={model_prefix}", + f"--model_type=bpe", + f"--vocab_size={vocab_size}", + "--character_coverage=1.0", + "--normalization_rule_name=identity", + f"--num_threads={cpu_count()}", ] - sp.SentencePieceTrainer.Train(' '.join(arguments)) + sp.SentencePieceTrainer.Train(" ".join(arguments)) def _apply_bbpe(model_path: str, in_path: str, out_path: str): - Args = namedtuple('Args', ['sentencepiece_model_path']) + Args = namedtuple("Args", ["sentencepiece_model_path"]) args = Args(sentencepiece_model_path=model_path) tokenizer = ByteBPE(args) - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: - f_o.write(tokenizer.encode(s.strip()) + '\n') + f_o.write(tokenizer.encode(s.strip()) + "\n") def _apply_bpe(model_path: str, in_path: str, out_path: str): - Args = namedtuple('Args', ['sentencepiece_model']) + Args = namedtuple("Args", ["sentencepiece_model"]) args = Args(sentencepiece_model=model_path) tokenizer = SentencepieceBPE(args) - with open(in_path) as f, open(out_path, 'w') as f_o: + with open(in_path) as f, open(out_path, "w") as f_o: for s in f: - f_o.write(tokenizer.encode(s.strip()) + '\n') + f_o.write(tokenizer.encode(s.strip()) + "\n") def _concat_files(in_paths: List[str], out_path: str): - with open(out_path, 'w') as f_o: + with open(out_path, "w") as f_o: for p in in_paths: with open(p) as f: for r in f: f_o.write(r) -def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int], - need_chars: bool, bbpe_size: Optional[int], - need_bytes: bool): +def preprocess_iwslt17( + root: str, + src: str, + tgt: str, + bpe_size: Optional[int], + need_chars: bool, + bbpe_size: Optional[int], + need_bytes: bool, +): # extract bitext - in_root = op.join(root, f'{src}-{tgt}') + in_root = op.join(root, f"{src}-{tgt}") for lang in [src, tgt]: _convert_train( - op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'), - op.join(root, f'train.{lang}') + op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"), + op.join(root, f"train.{lang}"), ) _convert_xml( - op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'), - op.join(root, f'valid.{lang}') + op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"), + op.join(root, f"valid.{lang}"), ) _convert_xml( - op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'), - op.join(root, f'test.{lang}') + op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"), + op.join(root, f"test.{lang}"), ) # pre-tokenize for lang in [src, tgt]: for split in SPLITS: - pretokenize(op.join(root, f'{split}.{lang}'), - op.join(root, f'{split}.moses.{lang}'), src, tgt) + pretokenize( + op.join(root, f"{split}.{lang}"), + op.join(root, f"{split}.moses.{lang}"), + src, + tgt, + ) # tokenize with BPE vocabulary if bpe_size is not None: # learn vocabulary - concated_train_path = op.join(root, 'train.all') + concated_train_path = op.join(root, "train.all") _concat_files( - [op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')], - concated_train_path + [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")], + concated_train_path, ) - bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}') + bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}") _get_bpe(concated_train_path, bpe_model_prefix, bpe_size) os.remove(concated_train_path) # apply for lang in [src, tgt]: for split in SPLITS: _apply_bpe( - bpe_model_prefix + '.model', - op.join(root, f'{split}.moses.{lang}'), - op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}') + bpe_model_prefix + ".model", + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"), ) # tokenize with bytes vocabulary if need_bytes: for lang in [src, tgt]: for split in SPLITS: - _get_bytes(op.join(root, f'{split}.moses.{lang}'), - op.join(root, f'{split}.moses.bytes.{lang}')) + _get_bytes( + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bytes.{lang}"), + ) # tokenize with characters vocabulary if need_chars: for lang in [src, tgt]: for split in SPLITS: - _get_chars(op.join(root, f'{split}.moses.{lang}'), - op.join(root, f'{split}.moses.chars.{lang}')) + _get_chars( + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.chars.{lang}"), + ) # tokenize with byte-level BPE vocabulary if bbpe_size is not None: # learn vocabulary - bchar_path = op.join(root, 'train.bchar') - _convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path) - bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}') + bchar_path = op.join(root, "train.bchar") + _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path) + bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}") _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size) os.remove(bchar_path) # apply for lang in [src, tgt]: for split in SPLITS: _apply_bbpe( - bbpe_model_prefix + '.model', - op.join(root, f'{split}.moses.{lang}'), - op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}') + bbpe_model_prefix + ".model", + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"), ) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--root', type=str, default='data') - parser.add_argument('--bpe-vocab', default=None, type=int, - help='Generate tokenized bitext with BPE of size K.' - 'Default to None (disabled).') - parser.add_argument('--bbpe-vocab', default=None, type=int, - help='Generate tokenized bitext with BBPE of size K.' - 'Default to None (disabled).') - parser.add_argument('--byte-vocab', action='store_true', - help='Generate tokenized bitext with bytes vocabulary') - parser.add_argument('--char-vocab', action='store_true', - help='Generate tokenized bitext with chars vocabulary') + parser.add_argument("--root", type=str, default="data") + parser.add_argument( + "--bpe-vocab", + default=None, + type=int, + help="Generate tokenized bitext with BPE of size K." + "Default to None (disabled).", + ) + parser.add_argument( + "--bbpe-vocab", + default=None, + type=int, + help="Generate tokenized bitext with BBPE of size K." + "Default to None (disabled).", + ) + parser.add_argument( + "--byte-vocab", + action="store_true", + help="Generate tokenized bitext with bytes vocabulary", + ) + parser.add_argument( + "--char-vocab", + action="store_true", + help="Generate tokenized bitext with chars vocabulary", + ) args = parser.parse_args() - preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab, - args.bbpe_vocab, args.byte_vocab) + preprocess_iwslt17( + args.root, + "fr", + "en", + args.bpe_vocab, + args.char_vocab, + args.bbpe_vocab, + args.byte_vocab, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/byte_level_bpe/gru_transformer.py b/examples/byte_level_bpe/gru_transformer.py index 7ba8e4084f..d4efa93a4d 100644 --- a/examples/byte_level_bpe/gru_transformer.py +++ b/examples/byte_level_bpe/gru_transformer.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F from fairseq.models import register_model, register_model_architecture -from fairseq.models.transformer import TransformerModel, TransformerEncoder +from fairseq.models.transformer import TransformerEncoder, TransformerModel @register_model("gru_transformer") @@ -24,9 +24,12 @@ def build_encoder(cls, args, src_dict, embed_tokens): class GRUTransformerEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) - self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim, - hidden_size=embed_tokens.embedding_dim // 2, - num_layers=1, bidirectional=True) + self.emb_ctx = nn.GRU( + input_size=embed_tokens.embedding_dim, + hidden_size=embed_tokens.embedding_dim // 2, + num_layers=1, + bidirectional=True, + ) def forward_embedding(self, src_tokens): # embed tokens and positions diff --git a/examples/constrained_decoding/normalize.py b/examples/constrained_decoding/normalize.py index 2a7ae03102..4ae2b5111b 100755 --- a/examples/constrained_decoding/normalize.py +++ b/examples/constrained_decoding/normalize.py @@ -16,11 +16,12 @@ def main(args): print(normalizer.normalize(line.rstrip()), flush=True) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument('--lang', '-l', default='en') - parser.add_argument('--penn', '-p', action='store_true') + parser.add_argument("--lang", "-l", default="en") + parser.add_argument("--penn", "-p", action="store_true") args = parser.parse_args() main(args) diff --git a/examples/constrained_decoding/tok.py b/examples/constrained_decoding/tok.py index 9215a66538..b1f888a8c0 100755 --- a/examples/constrained_decoding/tok.py +++ b/examples/constrained_decoding/tok.py @@ -6,12 +6,14 @@ # LICENSE file in the root directory of this source tree. import sys + import sacremoses def main(args): """Tokenizes, preserving tabs""" mt = sacremoses.MosesTokenizer(lang=args.lang) + def tok(s): return mt.tokenize(s, return_str=True) @@ -20,12 +22,13 @@ def tok(s): print(*parts, sep="\t", flush=True) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument('--lang', '-l', default='en') - parser.add_argument('--penn', '-p', action='store_true') - parser.add_argument('--fields', '-f', help="fields to tokenize") + parser.add_argument("--lang", "-l", default="en") + parser.add_argument("--penn", "-p", action="store_true") + parser.add_argument("--fields", "-f", help="fields to tokenize") args = parser.parse_args() main(args) diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py index a902a4ab64..c86f73ae87 100644 --- a/examples/criss/mining/mine.py +++ b/examples/criss/mining/mine.py @@ -3,14 +3,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import faiss -import numpy as np -import glob import argparse +import glob from subprocess import check_call +import faiss +import numpy as np + -GB = 1024*1024*1024 +GB = 1024 * 1024 * 1024 def call(cmd): @@ -18,14 +19,14 @@ def call(cmd): check_call(cmd, shell=True) -def get_batches(directory, lang, prefix='all_avg_pool'): +def get_batches(directory, lang, prefix="all_avg_pool"): print(f"Finding in {directory}/{prefix}.{lang}*") - files = glob.glob(f'{directory}/{prefix}.{lang}*') + files = glob.glob(f"{directory}/{prefix}.{lang}*") emb_files = [] txt_files = [] for emb_fi in files: emb_files.append(emb_fi) - txt_fi = emb_fi.replace(prefix, 'sentences') + txt_fi = emb_fi.replace(prefix, "sentences") txt_files.append(txt_fi) return emb_files, txt_files @@ -38,7 +39,7 @@ def load_batch(emb_file, dim): return embeddings -def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'): +def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"): sims = [] inds = [] xfrom = 0 @@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'): y_batch = load_batch(y_batch_f, dim) neighbor_size = min(k, y_batch.shape[0]) yto = yfrom + y_batch.shape[0] - print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) + print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto)) idx = faiss.IndexFlatIP(dim) idx = faiss.index_cpu_to_all_gpus(idx) idx.add(y_batch) @@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin): return margin(sim, (fwd_mean + bwd_mean) / 2) -def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False): - print(' - scoring {:d} candidates'.format(sim_mat.shape[0])) +def score_candidates( + sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False +): + print(" - scoring {:d} candidates".format(sim_mat.shape[0])) scores = np.zeros(candidate_inds.shape) for i in range(scores.shape[0]): for j in range(scores.shape[1]): @@ -106,42 +109,50 @@ def load_text(files): return all_sentences -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Mine bitext') - parser.add_argument('--src-lang', help='Source language') - parser.add_argument('--tgt-lang', help='Target language') - parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt') - parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model') - parser.add_argument('--dim', type=int, default=1024, - help='Embedding dimension') - parser.add_argument('--mem', type=int, default=5, - help='Memory in GB') - parser.add_argument('--src-dir', help='Source directory') - parser.add_argument('--tgt-dir', help='Target directory') - parser.add_argument('--output', help='Output path') - parser.add_argument('--neighborhood', type=int, default=4, - help='Embedding dimension') - parser.add_argument('--threshold', type=float, default=1.06, - help='Threshold on mined bitext') - parser.add_argument('--valid-size', type=int, default=2000, - help='Number of sentences used for validation set') - parser.add_argument('--min-count', type=int, default=50000, - help='Min num sentences used for each language') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mine bitext") + parser.add_argument("--src-lang", help="Source language") + parser.add_argument("--tgt-lang", help="Target language") + parser.add_argument( + "--dict-path", help="Path to dictionary file", default="dict.txt" + ) + parser.add_argument( + "--spm-path", help="Path to SPM model file", default="sentence.bpe.model" + ) + parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension") + parser.add_argument("--mem", type=int, default=5, help="Memory in GB") + parser.add_argument("--src-dir", help="Source directory") + parser.add_argument("--tgt-dir", help="Target directory") + parser.add_argument("--output", help="Output path") + parser.add_argument( + "--neighborhood", type=int, default=4, help="Embedding dimension" + ) + parser.add_argument( + "--threshold", type=float, default=1.06, help="Threshold on mined bitext" + ) + parser.add_argument( + "--valid-size", + type=int, + default=2000, + help="Number of sentences used for validation set", + ) + parser.add_argument( + "--min-count", + type=int, + default=50000, + help="Min num sentences used for each language", + ) args = parser.parse_args() x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang) y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang) margin = lambda a, b: a / b y2x_sim, y2x_ind = knnGPU_sharded( - y_batches_f, x_batches_f, - args.dim, - args.neighborhood, - direction='y2x') + y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x" + ) x2y_sim, x2y_ind = knnGPU_sharded( - x_batches_f, y_batches_f, - args.dim, - args.neighborhood, - direction='x2y') + x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y" + ) x2y_mean = x2y_sim.mean(axis=1) y2x_mean = y2x_sim.mean(axis=1) @@ -149,8 +160,13 @@ def load_text(files): bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin) fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)] bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)] - indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)), - np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1) + indices = np.stack( + ( + np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))), + ), + axis=1, + ) scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) x_sentences = load_text(x_sents_f) @@ -162,20 +178,20 @@ def load_text(files): directory = args.output call(f"mkdir -p {directory}") src_out = open( - f'{directory}/all.{args.src_lang}', - mode='w', - encoding='utf-8', - errors='surrogateescape') + f"{directory}/all.{args.src_lang}", + mode="w", + encoding="utf-8", + errors="surrogateescape", + ) tgt_out = open( - f'{directory}/all.{args.tgt_lang}', - mode='w', - encoding='utf-8', - errors='surrogateescape') + f"{directory}/all.{args.tgt_lang}", + mode="w", + encoding="utf-8", + errors="surrogateescape", + ) scores_out = open( - f'{directory}/all.scores', - mode='w', - encoding='utf-8', - errors='surrogateescape') + f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape" + ) count = 0 for i in np.argsort(-scores): src_ind, trg_ind = indices[i] @@ -195,20 +211,23 @@ def load_text(files): scores_out.close() print(f"Found {count} pairs for threshold={threshold}") - with open(f'{directory}/all.{args.src_lang}') as all_s, \ - open(f'{directory}/all.{args.tgt_lang}') as all_t, \ - open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \ - open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \ - open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \ - open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t: - count = 0 - for s_line, t_line in zip(all_s, all_t): - s_line = s_line.split('\t')[1] - t_line = t_line.split('\t')[1] - if count >= args.valid_size: - train_s.write(s_line) - train_t.write(t_line) - else: - valid_s.write(s_line) - valid_t.write(t_line) - count += 1 + with open(f"{directory}/all.{args.src_lang}") as all_s, open( + f"{directory}/all.{args.tgt_lang}" + ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open( + f"{directory}/valid.{args.tgt_lang}", "w" + ) as valid_t, open( + f"{directory}/train.{args.src_lang}", "w" + ) as train_s, open( + f"{directory}/train.{args.tgt_lang}", "w" + ) as train_t: + count = 0 + for s_line, t_line in zip(all_s, all_t): + s_line = s_line.split("\t")[1] + t_line = t_line.split("\t")[1] + if count >= args.valid_size: + train_s.write(s_line) + train_t.write(t_line) + else: + valid_s.write(s_line) + valid_t.write(t_line) + count += 1 diff --git a/examples/criss/save_encoder.py b/examples/criss/save_encoder.py index 8132bbf0fa..4d0f17f0f2 100644 --- a/examples/criss/save_encoder.py +++ b/examples/criss/save_encoder.py @@ -7,27 +7,29 @@ Translate pre-processed data with a trained model. """ +import numpy as np import torch - from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.sequence_generator import EnsembleModel -import numpy as np -def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False): +def get_avg_pool( + models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False +): model = EnsembleModel(models) # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { - k: v for k, v in sample['net_input'].items() - if k != 'prev_output_tokens' + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" } # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) - encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32) + encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype( + np.float32 + ) encoder_mask = np.expand_dims(encoder_mask.T, axis=2) if has_langtok: encoder_mask = encoder_mask[1:, :, :] @@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto def main(args): - assert args.path is not None, '--path required for generation!' - assert not args.sampling or args.nbest == args.beam, \ - '--sampling requires --nbest to be equal to --beam' - assert args.replace_unk is None or args.raw_text, \ - '--replace-unk requires a raw text dataset (--raw-text)' - - args.beam=1 + assert args.path is not None, "--path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + args.beam = 1 utils.import_user_module(args) if args.max_tokens is None: @@ -58,15 +62,15 @@ def main(args): # Set dictionaries try: - src_dict = getattr(task, 'source_dictionary', None) + src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble - print('| loading model(s) from {}'.format(args.path)) + print("| loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(':'), + args.path.split(":"), arg_overrides=eval(args.model_overrides), task=task, ) @@ -105,9 +109,9 @@ def main(args): shard_id = 0 all_avg_pool = None encoder_has_langtok = ( - hasattr(task.args, 'encoder_langtok') + hasattr(task.args, "encoder_langtok") and task.args.encoder_langtok is not None - and hasattr(task.args, 'lang_tok_replacing_bos_eos') + and hasattr(task.args, "lang_tok_replacing_bos_eos") and not task.args.lang_tok_replacing_bos_eos ) with progress_bar.build_progress_bar(args, itr) as t: @@ -116,34 +120,42 @@ def main(args): print("Skipping None") continue sample = utils.move_to_cuda(sample) if use_cuda else sample - if 'net_input' not in sample: + if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: - prefix_tokens = sample['target'][:, :args.prefix_size] + prefix_tokens = sample["target"][:, : args.prefix_size] with torch.no_grad(): avg_pool = get_avg_pool( - models, sample, prefix_tokens, src_dict, - args.remove_bpe, - has_langtok=encoder_has_langtok) + models, + sample, + prefix_tokens, + src_dict, + args.remove_bpe, + has_langtok=encoder_has_langtok, + ) if all_avg_pool is not None: all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) else: all_avg_pool = avg_pool - if not isinstance(sample['id'], list): - sample_ids = sample['id'].tolist() + if not isinstance(sample["id"], list): + sample_ids = sample["id"].tolist() else: - sample_ids = sample['id'] + sample_ids = sample["id"] for i, sample_id in enumerate(sample_ids): # Remove padding - src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + src_tokens = utils.strip_pad( + sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() + ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: - src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) + src_str = task.dataset(args.gen_subset).src.get_original_text( + sample_id + ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) @@ -152,37 +164,50 @@ def main(args): if not args.quiet: if src_dict is not None: - print('S-{}\t{}'.format(sample_id, src_str)) + print("S-{}\t{}".format(sample_id, src_str)) source_sentences.append(f"{sample_id}\t{src_str}") - num_sentences += sample['nsentences'] + num_sentences += sample["nsentences"] if all_avg_pool.shape[0] >= 1000000: - with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', - 'w') as avg_pool_file: + with open( + f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", + "w", + ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) - with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: - sentence_file.writelines(f'{line}\n' for line in source_sentences) + with open( + f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", + "w", + ) as sentence_file: + sentence_file.writelines(f"{line}\n" for line in source_sentences) all_avg_pool = None source_sentences = [] shard_id += 1 if all_avg_pool is not None: - with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}', - 'w') as avg_pool_file: + with open( + f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w" + ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) - with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file: - sentence_file.writelines(f'{line}\n' for line in source_sentences) + with open( + f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w" + ) as sentence_file: + sentence_file.writelines(f"{line}\n" for line in source_sentences) return None def cli_main(): parser = options.get_generation_parser() - parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N', - help='directory to save encoder outputs') + parser.add_argument( + "--encoder-save-dir", + default="", + type=str, + metavar="N", + help="directory to save encoder outputs", + ) args = options.parse_args_and_arch(parser) main(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/criss/sentence_retrieval/encoder_analysis.py b/examples/criss/sentence_retrieval/encoder_analysis.py index c0d74af23a..b41bfbe387 100644 --- a/examples/criss/sentence_retrieval/encoder_analysis.py +++ b/examples/criss/sentence_retrieval/encoder_analysis.py @@ -3,10 +3,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import numpy as np import argparse import glob +import numpy as np + DIM = 1024 @@ -14,9 +15,13 @@ def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): target_ids = [tid for tid in target_embs] source_mat = np.stack(source_embs.values(), axis=0) - normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True) + normalized_source_mat = source_mat / np.linalg.norm( + source_mat, axis=1, keepdims=True + ) target_mat = np.stack(target_embs.values(), axis=0) - normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True) + normalized_target_mat = target_mat / np.linalg.norm( + target_mat, axis=1, keepdims=True + ) sim_mat = normalized_source_mat.dot(normalized_target_mat.T) if return_sim_mat: return sim_mat @@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS): lang_dir = f"{directory}/{lang}" embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") for embed_file in embedding_files: - shard_id = embed_file.split('.')[-1] + shard_id = embed_file.split(".")[-1] embeddings = np.fromfile(embed_file, dtype=np.float32) num_rows = embeddings.shape[0] // DIM embeddings = embeddings.reshape((num_rows, DIM)) - with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file: + with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file: for idx, line in enumerate(sentence_file): - sentence_id, sentence = line.strip().split('\t') + sentence_id, sentence = line.strip().split("\t") sentence_texts[lang][sentence_id] = sentence sentence_embeddings[lang][sentence_id] = embeddings[idx, :] @@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS): top_1_accuracy = {} - top1_str = " ".join(LANGS) + '\n' + top1_str = " ".join(LANGS) + "\n" for source_lang in LANGS: top_1_accuracy[source_lang] = {} top1_str += f"{source_lang} " @@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS): top1 = 0 top5 = 0 neighbors_map = compute_dist( - sentence_embeddings[source_lang], - sentence_embeddings[target_lang]) + sentence_embeddings[source_lang], sentence_embeddings[target_lang] + ) for sentence_id, neighbors in neighbors_map.items(): if sentence_id == neighbors[0]: top1 += 1 @@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS): top1_str += "\n" print(top1_str) - print(top1_str, file=open(f"{directory}/accuracy", 'w')) + print(top1_str, file=open(f"{directory}/accuracy", "w")) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Analyze encoder outputs') - parser.add_argument('directory', - help='Source language corpus' - ) - parser.add_argument('--langs', - help='List of langs' - ) + parser = argparse.ArgumentParser(description="Analyze encoder outputs") + parser.add_argument("directory", help="Source language corpus") + parser.add_argument("--langs", help="List of langs") args = parser.parse_args() - langs = args.langs.split(',') + langs = args.langs.split(",") compute_accuracy(args.directory, langs) diff --git a/examples/latent_depth/src/__init__.py b/examples/latent_depth/src/__init__.py index 8a86fa5817..c5fa76039f 100644 --- a/examples/latent_depth/src/__init__.py +++ b/examples/latent_depth/src/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .models import latent_multilingual_transformer # noqa -from .modules import latent_layers # noqa -from .loss import latent_depth # noqa -from . import multilingual_translation_latent_depth # noqa +from . import multilingual_translation_latent_depth # noqa +from .loss import latent_depth # noqa +from .models import latent_multilingual_transformer # noqa +from .modules import latent_layers # noqa diff --git a/examples/latent_depth/src/loss/latent_depth.py b/examples/latent_depth/src/loss/latent_depth.py index f647c758ee..a3b9535eca 100644 --- a/examples/latent_depth/src/loss/latent_depth.py +++ b/examples/latent_depth/src/loss/latent_depth.py @@ -3,8 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import math + +import torch from torch.nn.modules.loss import _Loss @@ -19,17 +20,16 @@ def forward(self, layer_samples, lang_idx, update_num, sample_size): eps = 1e-7 if prior == "uniform": # uniform prior - kl_loss = (samples * ( - torch.log(samples + eps) - math.log(0.5) - )).sum(-1) + kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1) elif prior == "agged_posterior": # aggregated posterior y_t = torch.stack([x.detach() for x in layer_samples], dim=0) agged_q = torch.sum(y_t, dim=0) row_norm = agged_q.sum(-1) normed_agg_q = agged_q / row_norm - kl_loss = (samples * ( - torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1) + kl_loss = ( + samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps)) + ).sum(-1) else: raise NotImplementedError("The specified prior is not implemented.") @@ -37,7 +37,9 @@ def forward(self, layer_samples, lang_idx, update_num, sample_size): kl_loss /= layer_samples[0].size()[0] kl_weight = min( self.args.sparsity_weight, - (update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates + (update_num - self.args.soft_update) + * self.args.sparsity_weight + / self.args.anneal_updates, ) kl_loss *= kl_weight * sample_size return kl_loss @@ -58,15 +60,17 @@ def forward(self, layer_samples_list, update_num, sample_size): share_loss = 0 global_sparsity_loss = 0 layer_samples = torch.stack(layer_samples_list, dim=0) - if ((self.args.target_layers > 0 or self.args.share_weight > 0) and - update_num > (self.args.soft_update + self.args.anneal_updates)): + if ( + self.args.target_layers > 0 or self.args.share_weight > 0 + ) and update_num > (self.args.soft_update + self.args.anneal_updates): # anneal sparsity weight if update_num < (self.args.anneal_updates + self.args.soft_update): weight_anneal = 0 elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): weight_anneal = ( (update_num - self.args.soft_update - self.args.anneal_updates) - * self.args.share_weight / self.args.anneal_updates + * self.args.share_weight + / self.args.anneal_updates ) else: weight_anneal = 1 @@ -75,12 +79,21 @@ def forward(self, layer_samples_list, update_num, sample_size): layer_utilization /= layer_samples.size()[0] if self.args.share_weight > 0: # encouraging sharing across languages - share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0) - batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss + share_loss = sum( + -1.0 * v * math.log(v) for v in layer_utilization if v > 0 + ) + batch_loss += ( + weight_anneal * self.args.share_weight * sample_size * share_loss + ) if self.args.target_layers > 0: # computed expected number of layers selected expeted_layers = sum(layer_utilization) # compute l2 loss wrt target number of layers global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 - batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss + batch_loss += ( + weight_anneal + * self.args.share_weight + * sample_size + * global_sparsity_loss + ) return batch_loss diff --git a/examples/latent_depth/src/models/latent_multilingual_transformer.py b/examples/latent_depth/src/models/latent_multilingual_transformer.py index 97573cbd75..9e075fcc47 100644 --- a/examples/latent_depth/src/models/latent_multilingual_transformer.py +++ b/examples/latent_depth/src/models/latent_multilingual_transformer.py @@ -3,34 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.models import ( - register_model, - register_model_architecture, -) +from fairseq.models import register_model, register_model_architecture +from fairseq.models.multilingual_transformer import MultilingualTransformerModel from fairseq.models.transformer import ( - base_architecture, - TransformerEncoder, TransformerDecoder, + TransformerEncoder, + base_architecture, ) -from fairseq.models.multilingual_transformer import MultilingualTransformerModel -from .latent_transformer import ( - LatentTransformerEncoder, - LatentTransformerDecoder, -) +from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder -@register_model('latent_multilingual_transformer') +@register_model("latent_multilingual_transformer") class LatentMultilingualTransformerModel(MultilingualTransformerModel): """A variant of standard multilingual Transformer models which encoder and/or - decoders supports latent depth, as is in "Deep Transformer with Latent Depth" + decoders supports latent depth, as is in "Deep Transformer with Latent Depth" (https://arxiv.org/abs/2009.13102). """ + @classmethod def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): if is_encoder: if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: - return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) + return LatentTransformerEncoder( + args, lang_dict, embed_tokens, num_logits=len(langs) + ) else: return TransformerEncoder(args, lang_dict, embed_tokens) else: @@ -42,19 +39,21 @@ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): return TransformerDecoder(args, lang_dict, embed_tokens) -@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer') +@register_model_architecture( + "latent_multilingual_transformer", "latent_multilingual_transformer" +) def latent_multilingual_architecture(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) - args.encoder_layers = getattr(args, 'encoder_layers', 12) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) - args.decoder_layers = getattr(args, 'decoder_layers', 24) - args.share_encoders = getattr(args, 'share_encoders', True) - args.share_decoders = getattr(args, 'share_decoders', True) - args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True) - args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.share_encoders = getattr(args, "share_encoders", True) + args.share_decoders = getattr(args, "share_decoders", True) + args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True) + args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True) base_architecture(args) diff --git a/examples/latent_depth/src/models/latent_transformer.py b/examples/latent_depth/src/models/latent_transformer.py index 5d47340f58..db30239bff 100644 --- a/examples/latent_depth/src/models/latent_transformer.py +++ b/examples/latent_depth/src/models/latent_transformer.py @@ -7,26 +7,27 @@ import torch.nn as nn from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.models.transformer import TransformerEncoder, TransformerDecoder -from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer -from ..modules.latent_layers import LayerSelect +from fairseq.models.transformer import TransformerDecoder, TransformerEncoder +from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer from torch import Tensor +from ..modules.latent_layers import LayerSelect + class LatentTransformerEncoder(TransformerEncoder): """Latent depth (https://arxiv.org/abs/2009.13102) implemented in TransformerEncoder. """ + def __init__(self, args, dictionary, embed_tokens, num_logits=1): self.num_logits = num_logits self.num_layers = args.encoder_layers super().__init__(args, dictionary, embed_tokens) self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) self.lang_idx = None - self.layers = nn.ModuleList([ - self._build_encoder_layer(args, idx) - for idx in range(args.encoder_layers) - ]) + self.layers = nn.ModuleList( + [self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)] + ) def set_lang_idx(self, lang_idx): self.lang_idx = lang_idx @@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer): layer_select (LayerSelect, optional): instance of LayerSelect module with logits parameters and sampling method. """ + def __init__(self, args, idx, layer_select=None): super().__init__(args) self.idx = idx @@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder): """Latent depth (https://arxiv.org/abs/2009.13102) implemented in TransformerDecoder. """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1): + + def __init__( + self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1 + ): self.num_logits = num_logits self.num_layers = args.decoder_layers super().__init__( @@ -71,16 +76,20 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_lo ) self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) self.lang_idx = None - self.layers = nn.ModuleList([ - self._build_decoder_layer(args, no_encoder_attn, idx) - for idx in range(args.decoder_layers) - ]) + self.layers = nn.ModuleList( + [ + self._build_decoder_layer(args, no_encoder_attn, idx) + for idx in range(args.decoder_layers) + ] + ) def set_lang_idx(self, lang_idx): self.lang_idx = lang_idx def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None): - return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn) + return LatentTransformerDecoderLayer( + args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn + ) def forward( self, @@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer): (default: False). """ + def __init__( - self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + self, + args, + idx, + layer_select=None, + no_encoder_attn=False, + add_bias_kv=False, + add_zero_attn=False, ): super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) self.idx = idx diff --git a/examples/latent_depth/src/modules/latent_layers.py b/examples/latent_depth/src/modules/latent_layers.py index e772ac3237..a2b8ab4476 100644 --- a/examples/latent_depth/src/modules/latent_layers.py +++ b/examples/latent_depth/src/modules/latent_layers.py @@ -12,6 +12,7 @@ class LayerSelect(nn.Module): either (soft) weighting or (hard) selection of residual connection. https://arxiv.org/abs/2009.13102 """ + def __init__(self, num_layers, num_logits, args): super(LayerSelect, self).__init__() self.args = args @@ -27,14 +28,14 @@ def __init__(self, num_layers, num_logits, args): @staticmethod def add_args(parser): parser.add_argument( - '--soft-select', - action='store_true', - help='use soft samples in training an inference' + "--soft-select", + action="store_true", + help="use soft samples in training an inference", ) - parser.add_argument('--sampling-tau', type=float, help='sampling temperature') + parser.add_argument("--sampling-tau", type=float, help="sampling temperature") def sample(self, logit_idx): - """ To leverage the efficiency of distributed training, samples for all + """To leverage the efficiency of distributed training, samples for all layers are computed at once for each logit_idx. Logits are parameters learnt independent of each other. @@ -43,7 +44,9 @@ def sample(self, logit_idx): """ assert logit_idx is not None self.samples = self._gumbel_sigmoid( - self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :], + self.layer_logits[logit_idx, :].detach() + if self.detach_grad + else self.layer_logits[logit_idx, :], dim=-1, tau=self.tau, hard=self.hard_select, @@ -54,10 +57,20 @@ def forward(self, i): sample = self.samples[i] return sample - def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5): + def _gumbel_sigmoid( + self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5 + ): # ~Gumbel(0,1) - gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() - gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + gumbels1 = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) + gumbels2 = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) # Difference of two gumbels because we apply a sigmoid gumbels1 = (logits + gumbels1 - gumbels2) / tau y_soft = gumbels1.sigmoid() diff --git a/examples/latent_depth/src/multilingual_translation_latent_depth.py b/examples/latent_depth/src/multilingual_translation_latent_depth.py index 1a19f8f8f9..b5cd51a470 100644 --- a/examples/latent_depth/src/multilingual_translation_latent_depth.py +++ b/examples/latent_depth/src/multilingual_translation_latent_depth.py @@ -5,10 +5,11 @@ from fairseq.tasks import register_task from fairseq.tasks.multilingual_translation import MultilingualTranslationTask + from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss -@register_task('multilingual_translation_latent_depth') +@register_task("multilingual_translation_latent_depth") class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask): """A task for multiple translation with latent depth. @@ -39,7 +40,9 @@ def add_args(parser): def __init__(self, args, dicts, training): super().__init__(args, dicts, training) - self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]) + self.src_langs, self.tgt_langs = zip( + *[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs] + ) if self.training and self.encoder_latent_layer: assert self.args.share_encoders if self.training and self.decoder_latent_layer: @@ -47,46 +50,56 @@ def __init__(self, args, dicts, training): if training or self.encoder_latent_layer or self.decoder_latent_layer: self.lang_pairs = args.lang_pairs else: - self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] self.eval_lang_pairs = self.lang_pairs self.model_lang_pairs = self.lang_pairs if self.training and (self.encoder_latent_layer or self.decoder_latent_layer): self.kl_loss = LatentLayersKLLoss(self.args) self.sparsity_loss = LatentLayersSparsityLoss(self.args) - def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad): + def _per_lang_pair_train_loss( + self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad + ): src, tgt = lang_pair.split("-") if self.encoder_latent_layer: src_lang_idx = self.src_lang_idx_dict[src] model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) - model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update + model.models[lang_pair].encoder.layer_select.hard_select = ( + update_num > self.args.soft_update + ) if self.decoder_latent_layer: tgt_lang_idx = self.tgt_lang_idx_dict[tgt] model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) - model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update + model.models[lang_pair].decoder.layer_select.hard_select = ( + update_num > self.args.soft_update + ) - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) if self.encoder_latent_layer: none_samples = sum( - 1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples + 1 if x is None else 0 + for x in model.models[lang_pair].encoder.layer_select.layer_samples ) if none_samples == 0 or self.args.prior != "agged_posterior": loss += self.kl_loss( model.models[lang_pair].encoder.layer_select.layer_samples, src_lang_idx, update_num, - sample_size + sample_size, ) if self.decoder_latent_layer: none_samples = sum( - 1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples + 1 if x is None else 0 + for x in model.models[lang_pair].decoder.layer_select.layer_samples ) if none_samples == 0 or self.args.prior != "agged_posterior": loss += self.kl_loss( model.models[lang_pair].decoder.layer_select.layer_samples, tgt_lang_idx, update_num, - sample_size + sample_size, ) if ignore_grad: loss *= 0 @@ -99,18 +112,31 @@ def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sam return loss, sample_size, logging_output - def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): agg_loss, agg_sample_size, agg_logging_output = super().train_step( - sample, model, criterion, optimizer, update_num, ignore_grad) + sample, model, criterion, optimizer, update_num, ignore_grad + ) # compute auxiliary loss from layere sparsity, based on all samples from all languages if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): sparsity_loss = 0 if self.encoder_latent_layer: sparsity_loss += self.sparsity_loss( - next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size) + next( + iter(model.models.values()) + ).encoder.layer_select.layer_samples, + update_num, + agg_sample_size, + ) if self.decoder_latent_layer: sparsity_loss += self.sparsity_loss( - next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size) + next( + iter(model.models.values()) + ).decoder.layer_select.layer_samples, + update_num, + agg_sample_size, + ) if sparsity_loss > 0: optimizer.backward(sparsity_loss) return agg_loss, agg_sample_size, agg_logging_output @@ -123,10 +149,14 @@ def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): if self.decoder_latent_layer: tgt_lang_idx = self.tgt_lang_idx_dict[tgt] model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): if self.encoder_latent_layer or self.decoder_latent_layer: for model in models: if self.encoder_latent_layer: @@ -137,15 +167,23 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, constrai assert model.decoder.layer_select is not None tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang] model.decoder.set_lang_idx(tgt_lang_idx) - return super().inference_step(generator, models, sample, prefix_tokens, constraints) + return super().inference_step( + generator, models, sample, prefix_tokens, constraints + ) @property def encoder_latent_layer(self): - return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer + return ( + hasattr(self.args, "encoder_latent_layer") + and self.args.encoder_latent_layer + ) @property def decoder_latent_layer(self): - return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer + return ( + hasattr(self.args, "decoder_latent_layer") + and self.args.decoder_latent_layer + ) @property def src_lang_idx_dict(self): diff --git a/examples/linformer/src/models/linformer_roberta.py b/examples/linformer/src/models/linformer_roberta.py index 722f5a4b9e..913351f238 100644 --- a/examples/linformer/src/models/linformer_roberta.py +++ b/examples/linformer/src/models/linformer_roberta.py @@ -8,37 +8,40 @@ import logging -from fairseq.models import ( - register_model, - register_model_architecture, -) -from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.roberta import RobertaEncoder, RobertaModel -from fairseq.models.roberta import ( - RobertaModel, - RobertaEncoder, -) +from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder logger = logging.getLogger(__name__) -@register_model('linformer_roberta') +@register_model("linformer_roberta") class LinformerModel(RobertaModel): - @staticmethod def add_args(parser): RobertaModel.add_args(parser) # add args for Linformer - parser.add_argument('--compressed', type=int, - help='compressed ratio of sequence length') - parser.add_argument('--shared-kv-compressed', type=int, - help='share compressed matrix between k and v, in each layer') - parser.add_argument('--shared-layer-kv-compressed', type=int, - help='share compressed matrix between k and v and across all layers') - parser.add_argument('--freeze-compress', type=int, - help='freeze the parameters in compressed layer') + parser.add_argument( + "--compressed", type=int, help="compressed ratio of sequence length" + ) + parser.add_argument( + "--shared-kv-compressed", + type=int, + help="share compressed matrix between k and v, in each layer", + ) + parser.add_argument( + "--shared-layer-kv-compressed", + type=int, + help="share compressed matrix between k and v and across all layers", + ) + parser.add_argument( + "--freeze-compress", + type=int, + help="freeze the parameters in compressed layer", + ) @classmethod def build_model(cls, args, task): @@ -47,7 +50,7 @@ def build_model(cls, args, task): # make sure all arguments are present base_architecture(args) - if not hasattr(args, 'max_positions'): + if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample encoder = LinformerEncoder(args, task.source_dictionary) @@ -85,47 +88,47 @@ def __init__(self, args, dictionary): ) -@register_model_architecture('linformer_roberta', 'linformer_roberta') +@register_model_architecture("linformer_roberta", "linformer_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 12) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) - - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_dropout = getattr(args, 'activation_dropout', 0.0) - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) - args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) - args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) - args.compressed = getattr(args, 'compressed', 4) - args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) - args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) - args.freeze_compress = getattr(args, 'freeze_compress', 0) - - -@register_model_architecture('linformer_roberta', 'linformer_roberta_base') + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.compressed = getattr(args, "compressed", 4) + args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) + args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) + args.freeze_compress = getattr(args, "freeze_compress", 0) + + +@register_model_architecture("linformer_roberta", "linformer_roberta_base") def linformer_roberta_base_architecture(args): base_architecture(args) -@register_model_architecture('linformer_roberta', 'linformer_roberta_large') +@register_model_architecture("linformer_roberta", "linformer_roberta_large") def linformer_roberta_large_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 24) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) - - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_dropout = getattr(args, 'activation_dropout', 0.0) - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) - args.compressed = getattr(args, 'compressed', 4) - args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0) - args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.compressed = getattr(args, "compressed", 4) + args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) + args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) diff --git a/examples/linformer/src/modules/linformer_sentence_encoder.py b/examples/linformer/src/modules/linformer_sentence_encoder.py index e3d170023d..d6de9eeaae 100644 --- a/examples/linformer/src/modules/linformer_sentence_encoder.py +++ b/examples/linformer/src/modules/linformer_sentence_encoder.py @@ -6,8 +6,8 @@ import math import torch.nn as nn - from fairseq.modules import TransformerSentenceEncoder + from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer @@ -117,7 +117,9 @@ def build_transformer_sentence_encoder_layer( qn_block_size, ): if self.shared_layer_kv_compressed == 1: - compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed) + compress_layer = nn.Linear( + self.max_seq_len, self.max_seq_len // self.compressed + ) # intialize parameters for compressed layer nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) if self.freeze_compress == 1: @@ -139,8 +141,7 @@ def build_transformer_sentence_encoder_layer( max_seq_len=self.max_seq_len, shared_kv_compressed=self.shared_kv_compressed, shared_compress_layer=( - None if self.shared_layer_kv_compressed == 0 - else self.compress_layer + None if self.shared_layer_kv_compressed == 0 else self.compress_layer ), freeze_compress=self.freeze_compress, ) @@ -156,7 +157,8 @@ def upgrade_state_dict_named(self, state_dict, name): if self.shared_layer_kv_compressed: for layer_idx in range(len(self.layers)): new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( - layer_idx, k[len(prefix + 'compress_layer.'):], + layer_idx, + k[len(prefix + "compress_layer.") :], ) items_to_add[new_k] = state_dict[k] diff --git a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py index e0a6047ce8..d27c5afd09 100644 --- a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py +++ b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py @@ -6,6 +6,7 @@ from typing import Callable from fairseq.modules import TransformerSentenceEncoderLayer + from .multihead_linear_attention import MultiheadLinearAttention @@ -23,7 +24,7 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, - activation_fn: str = 'relu', + activation_fn: str = "relu", export: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, diff --git a/examples/linformer/src/modules/multihead_linear_attention.py b/examples/linformer/src/modules/multihead_linear_attention.py index 472cd4e3ea..ba2c36b1ef 100644 --- a/examples/linformer/src/modules/multihead_linear_attention.py +++ b/examples/linformer/src/modules/multihead_linear_attention.py @@ -9,10 +9,10 @@ import torch import torch.nn.functional as F from fairseq import utils -from torch import Tensor, nn -from torch.nn import Parameter from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter @with_incremental_state @@ -65,16 +65,24 @@ def __init__( "Self-attention requires query, key and " "value to be of the same size" ) - self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) - self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) - self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) # used for compress sequence to subsequence if shared_compress_layer is None: self.compress_seq_len = max_seq_len // compressed self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) if shared_kv_compressed == 0: - self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False) + self.compress_v = nn.Linear( + max_seq_len, self.compress_seq_len, bias=False + ) self.layerwise_sharing = False else: self.compress_k = shared_compress_layer @@ -83,7 +91,9 @@ def __init__( self.layerwise_sharing = True self.shared_kv_compressed = shared_kv_compressed - self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) @@ -116,22 +126,28 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) - if not self.layerwise_sharing: # otherwise, we already initialize the parameters - nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2)) + if ( + not self.layerwise_sharing + ): # otherwise, we already initialize the parameters + nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2)) if self.shared_kv_compressed == 0: - nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2)) + nn.init.xavier_uniform_( + self.compress_v.weight, gain=1 / math.sqrt(2) + ) else: nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) - if not self.layerwise_sharing: # otherwise, we already initialize the parameters + if ( + not self.layerwise_sharing + ): # otherwise, we already initialize the parameters nn.init.xavier_uniform_(self.compress_k.weight) if self.shared_kv_compressed == 0: nn.init.xavier_uniform_(self.compress_v.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: - nn.init.constant_(self.out_proj.bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: @@ -189,14 +205,26 @@ def forward( q = self.q_proj(query) k_input = query.permute(1, 2, 0).contiguous() # B * C * T - k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + k_input = ( + F.linear(k_input, self.compress_k.weight[:, 0:tgt_len]) + .permute(2, 0, 1) + .contiguous() + ) k = self.k_proj(k_input) v_input = query.permute(1, 2, 0).contiguous() # B * C * T if self.shared_kv_compressed == 0: - v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + v_input = ( + F.linear(v_input, self.compress_v.weight[:, 0:tgt_len]) + .permute(2, 0, 1) + .contiguous() + ) if self.shared_kv_compressed == 1: # use shared kv compressed linear layer - v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous() + v_input = ( + F.linear(v_input, self.compress_k.weight[:, 0:tgt_len]) + .permute(2, 0, 1) + .contiguous() + ) v = self.v_proj(v_input) elif self.encoder_decoder_attention: # encoder-decoder attention @@ -302,7 +330,9 @@ def forward( ) attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + attn_weights = MultiheadLinearAttention.apply_sparse_mask( + attn_weights, tgt_len, src_len, bsz + ) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] @@ -385,7 +415,9 @@ def _append_prev_key_padding_mask( @torch.jit.export def reorder_incremental_state( - self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, ): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) @@ -393,7 +425,9 @@ def reorder_incremental_state( for k in input_buffer.keys(): input_buffer_k = input_buffer[k] if input_buffer_k is not None: - if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0): + if self.encoder_decoder_attention and input_buffer_k.size( + 0 + ) == new_order.size(0): break input_buffer[k] = input_buffer_k.index_select(0, new_order) incremental_state = self._set_input_buffer(incremental_state, input_buffer) @@ -428,8 +462,8 @@ def upgrade_state_dict_named(self, state_dict, name): # in_proj_weight used to be q + k + v with same dimensions dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] - items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim] - items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] keys_to_remove.append(k) @@ -438,9 +472,9 @@ def upgrade_state_dict_named(self, state_dict, name): dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ - dim:2 * dim + dim : 2 * dim ] - items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] keys_to_remove.append(prefix + "in_proj_bias") diff --git a/examples/m2m_100/tokenizers/tokenize_indic.py b/examples/m2m_100/tokenizers/tokenize_indic.py index c1303b3d15..a44fad07f7 100644 --- a/examples/m2m_100/tokenizers/tokenize_indic.py +++ b/examples/m2m_100/tokenizers/tokenize_indic.py @@ -8,14 +8,16 @@ import sys -from indicnlp.tokenize.indic_tokenize import trivial_tokenize from indicnlp.normalize.indic_normalize import IndicNormalizerFactory +from indicnlp.tokenize.indic_tokenize import trivial_tokenize -factory=IndicNormalizerFactory() -normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing') + +factory = IndicNormalizerFactory() +normalizer = factory.get_normalizer( + sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing" +) for line in sys.stdin: - normalized_line=normalizer.normalize(line.strip()) - tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1])) + normalized_line = normalizer.normalize(line.strip()) + tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1])) print(tokenized_line) - diff --git a/examples/m2m_100/tokenizers/tokenize_thai.py b/examples/m2m_100/tokenizers/tokenize_thai.py index 7c7b7ebfaa..9c72cb8905 100644 --- a/examples/m2m_100/tokenizers/tokenize_thai.py +++ b/examples/m2m_100/tokenizers/tokenize_thai.py @@ -8,5 +8,6 @@ from pythainlp import word_tokenize + for line in sys.stdin: print(" ".join(word_tokenize(line.strip()))) diff --git a/examples/m2m_100/tokenizers/tokenize_zh.py b/examples/m2m_100/tokenizers/tokenize_zh.py index 531a7fb49b..674b5849cb 100644 --- a/examples/m2m_100/tokenizers/tokenize_zh.py +++ b/examples/m2m_100/tokenizers/tokenize_zh.py @@ -6,7 +6,9 @@ import fileinput + import sacrebleu + for line in fileinput.input(): print(sacrebleu.tokenize_zh(line)) diff --git a/examples/megatron_11b/detok.py b/examples/megatron_11b/detok.py index a77a0b4960..49921b28a1 100644 --- a/examples/megatron_11b/detok.py +++ b/examples/megatron_11b/detok.py @@ -6,19 +6,27 @@ import argparse import fileinput + import sacremoses def main(): - parser = argparse.ArgumentParser(description='') - parser.add_argument('files', nargs='*', help='input files') + parser = argparse.ArgumentParser(description="") + parser.add_argument("files", nargs="*", help="input files") args = parser.parse_args() detok = sacremoses.MosesDetokenizer() for line in fileinput.input(args.files, openhook=fileinput.hook_compressed): - print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' – ', '–')) + print( + detok.detokenize(line.strip().split(" ")) + .replace(" @", "") + .replace("@ ", "") + .replace(" =", "=") + .replace("= ", "=") + .replace(" – ", "–") + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index a5927a53b3..4df424e6b5 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -7,21 +7,22 @@ from multiprocessing import Pool import numpy as np - from fairseq import options from fairseq.data import dictionary from fairseq.scoring import bleu from . import ( rerank_generate, + rerank_options, rerank_score_bw, rerank_score_lm, - rerank_options, rerank_utils, ) -def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize): +def score_target_hypo( + args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize +): print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) @@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write bitext2_score = None bitext2_backwards = None - score = rerank_utils.get_score(a, b, c, target_len, - bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, - lenpen=lenpen, src_len=bitext1.source_lengths[i], - tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, - bitext2_backwards=bitext2_backwards, normalize=normalize) + score = rerank_utils.get_score( + a, + b, + c, + target_len, + bitext1.rescore_score[i], + bitext2_score, + lm_score=lm_score, + lenpen=lenpen, + src_len=bitext1.source_lengths[i], + tgt_len=bitext1.target_lengths[i], + bitext1_backwards=bitext1.backwards, + bitext2_backwards=bitext2_backwards, + normalize=normalize, + ) if score > best_score: best_score = score @@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write for key in range(len(gen_keys)): if args.prefix_len is None: assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( - "pred and rescore hypo mismatch: i: " + str(key) + ", " - + str(hypo_lst[key]) + str(gen_keys[key]) + "pred and rescore hypo mismatch: i: " + + str(key) + + ", " + + str(hypo_lst[key]) + + str(gen_keys[key]) + str(gen_output.no_bpe_hypo[key]) ) sys_tok = dict.encode_line(hypo_lst[key]) @@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write scorer.add(ref_tok, sys_tok) else: - full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) + full_hypo = rerank_utils.get_full_from_prefix( + hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] + ) sys_tok = dict.encode_line(full_hypo) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) scorer.add(ref_tok, sys_tok) @@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write # recover the orinal ids from n best list generation for key in range(len(gen_output.no_bpe_target)): if args.prefix_len is None: - assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ - "pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) + assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( + "pred and rescore hypo mismatch:" + + "i:" + + str(key) + + str(hypo_lst[key]) + + str(gen_output.no_bpe_hypo[key]) + ) ordered_hypos[gen_keys[key]] = hypo_lst[key] - ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ + gen_keys[key] + ] else: - full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) + full_hypo = rerank_utils.get_full_from_prefix( + hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] + ) ordered_hypos[gen_keys[key]] = full_hypo - ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ + gen_keys[key] + ] # write the hypos in the original order from nbest list generation if args.num_shards == (len(bitext1_lst)): - with open(target_outfile, 'w') as t: - with open(hypo_outfile, 'w') as h: + with open(target_outfile, "w") as t: + with open(hypo_outfile, "w") as h: for key in range(len(ordered_hypos)): t.write(ordered_targets[key]) h.write(ordered_hypos[key]) @@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write def match_target_hypo(args, target_outfile, hypo_outfile): """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" if len(args.weight1) == 1: - res = score_target_hypo(args, args.weight1[0], args.weight2[0], - args.weight3[0], args.lenpen[0], target_outfile, - hypo_outfile, True, args.normalize) + res = score_target_hypo( + args, + args.weight1[0], + args.weight2[0], + args.weight3[0], + args.lenpen[0], + target_outfile, + hypo_outfile, + True, + args.normalize, + ) rerank_scores = [res] else: print("launching pool") with Pool(32) as p: - rerank_scores = p.starmap(score_target_hypo, - [(args, args.weight1[i], args.weight2[i], args.weight3[i], - args.lenpen[i], target_outfile, hypo_outfile, - False, args.normalize) for i in range(len(args.weight1))]) + rerank_scores = p.starmap( + score_target_hypo, + [ + ( + args, + args.weight1[i], + args.weight2[i], + args.weight3[i], + args.lenpen[i], + target_outfile, + hypo_outfile, + False, + args.normalize, + ) + for i in range(len(args.weight1)) + ], + ) if len(rerank_scores) > 1: best_index = np.argmax(rerank_scores) @@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile): print("best weight1", args.weight1[best_index]) print("best weight2", args.weight2[best_index]) print("best weight3", args.weight3[best_index]) - return args.lenpen[best_index], args.weight1[best_index], \ - args.weight2[best_index], args.weight3[best_index], best_score + return ( + args.lenpen[best_index], + args.weight1[best_index], + args.weight2[best_index], + args.weight3[best_index], + best_score, + ) else: - return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0] + return ( + args.lenpen[0], + args.weight1[0], + args.weight2[0], + args.weight3[0], + rerank_scores[0], + ) def load_score_files(args): @@ -175,55 +234,100 @@ def load_score_files(args): for shard_id in shard_ids: using_nbest = args.nbest_list is not None - pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir = \ - rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, - args.gen_model_name, shard_id, args.num_shards, args.sampling, - args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) - - rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None - rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None - - score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards1) + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) + + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) if args.score_model2 is not None: - score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards2) + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) if args.language_model is not None: - lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) + lm_score_file = rerank_utils.rescore_file_name( + pre_gen, args.prefix_len, args.lm_name, lm_file=True + ) # get gen output - predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list - gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, - nbest=using_nbest, prefix_len=args.prefix_len, - target_prefix_frac=args.target_prefix_frac) + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, + bpe_symbol=args.remove_bpe, + nbest=using_nbest, + prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac, + ) if rerank1_is_gen: bitext1 = gen_output else: - bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1, - args.remove_bpe, args.prefix_len, args.target_prefix_frac, - args.source_prefix_frac) + bitext1 = rerank_utils.BitextOutput( + score1_file, + args.backwards1, + args.right_to_left1, + args.remove_bpe, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) if args.score_model2 is not None or args.nbest_list is not None: if rerank2_is_gen: bitext2 = gen_output else: - bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2, - args.remove_bpe, args.prefix_len, args.target_prefix_frac, - args.source_prefix_frac) - - assert bitext2.source_lengths == bitext1.source_lengths, \ - "source lengths for rescoring models do not match" - assert bitext2.target_lengths == bitext1.target_lengths, \ - "target lengths for rescoring models do not match" + bitext2 = rerank_utils.BitextOutput( + score2_file, + args.backwards2, + args.right_to_left2, + args.remove_bpe, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + assert ( + bitext2.source_lengths == bitext1.source_lengths + ), "source lengths for rescoring models do not match" + assert ( + bitext2.target_lengths == bitext1.target_lengths + ), "target lengths for rescoring models do not match" else: if args.diff_bpe: assert args.score_model2 is None @@ -232,8 +336,13 @@ def load_score_files(args): bitext2 = None if args.language_model is not None: - lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len, - args.remove_bpe, args.target_prefix_frac) + lm_res1 = rerank_utils.LMOutput( + lm_score_file, + args.lm_dict, + args.prefix_len, + args.remove_bpe, + args.target_prefix_frac, + ) else: lm_res1 = None @@ -259,28 +368,46 @@ def rerank(args): shard_ids = [args.shard_id] for shard_id in shard_ids: - pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir = \ - rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, - args.gen_model_name, shard_id, args.num_shards, args.sampling, - args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) rerank_generate.gen_and_reprocess_nbest(args) rerank_score_bw.score_bw(args) rerank_score_lm.score_lm(args) if args.write_hypos is None: - write_targets = pre_gen+"/matched_targets" - write_hypos = pre_gen+"/matched_hypos" + write_targets = pre_gen + "/matched_targets" + write_hypos = pre_gen + "/matched_hypos" else: - write_targets = args.write_hypos+"_targets" + args.gen_subset - write_hypos = args.write_hypos+"_hypos" + args.gen_subset + write_targets = args.write_hypos + "_targets" + args.gen_subset + write_hypos = args.write_hypos + "_hypos" + args.gen_subset if args.all_shards: write_targets += "_all_shards" write_hypos += "_all_shards" - best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \ - match_target_hypo(args, write_targets, write_hypos) + ( + best_lenpen, + best_weight1, + best_weight2, + best_weight3, + best_score, + ) = match_target_hypo(args, write_targets, write_hypos) return best_lenpen, best_weight1, best_weight2, best_weight3, best_score @@ -291,5 +418,5 @@ def cli_main(): rerank(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py index d2da6eacf9..4356b3387e 100644 --- a/examples/noisychannel/rerank_generate.py +++ b/examples/noisychannel/rerank_generate.py @@ -8,9 +8,9 @@ Generate n-best translations using a trained model. """ -from contextlib import redirect_stdout import os import subprocess +from contextlib import redirect_stdout from fairseq import options from fairseq_cli import generate, preprocess @@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args): if args.score_dict_dir is None: args.score_dict_dir = args.data if args.prefix_len is not None: - assert args.right_to_left1 is False, "prefix length not compatible with right to left models" - assert args.right_to_left2 is False, "prefix length not compatible with right to left models" + assert ( + args.right_to_left1 is False + ), "prefix length not compatible with right to left models" + assert ( + args.right_to_left2 is False + ), "prefix length not compatible with right to left models" if args.nbest_list is not None: assert args.score_model2 is None @@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args): scorer1_src = args.source_lang scorer1_tgt = args.target_lang - store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name + store_data = ( + os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name + ) if not os.path.exists(store_data): os.makedirs(store_data) - pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir = \ - rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, - args.gen_model_name, args.shard_id, args.num_shards, - args.sampling, args.prefix_len, args.target_prefix_frac, - args.source_prefix_frac) - assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported" - assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported" - assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \ - "target prefix frac and target prefix len incompatible" + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + assert not ( + args.right_to_left1 and args.backwards1 + ), "backwards right to left not supported" + assert not ( + args.right_to_left2 and args.backwards2 + ), "backwards right to left not supported" + assert not ( + args.prefix_len is not None and args.target_prefix_frac is not None + ), "target prefix frac and target prefix len incompatible" # make directory to store generation results if not os.path.exists(pre_gen): os.makedirs(pre_gen) - rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None - rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) if args.nbest_list is not None: rerank2_is_gen = True @@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args): if not os.path.exists(backwards_preprocessed_dir): os.makedirs(backwards_preprocessed_dir) - score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards1) + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) if args.score_model2 is not None: - score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards2) + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) - predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" using_nbest = args.nbest_list is not None @@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args): if not os.path.isfile(predictions_bpe_file): print("STEP 1: generate predictions using the p(T|S) model with bpe") print(args.data) - param1 = [args.data, - "--path", args.gen_model, - "--shard-id", str(args.shard_id), - "--num-shards", str(args.num_shards), - "--nbest", str(args.num_rescore), - "--batch-size", str(args.batch_size), - "--beam", str(args.num_rescore), - "--batch-size", str(args.num_rescore), - "--gen-subset", args.gen_subset, - "--source-lang", args.source_lang, - "--target-lang", args.target_lang] + param1 = [ + args.data, + "--path", + args.gen_model, + "--shard-id", + str(args.shard_id), + "--num-shards", + str(args.num_shards), + "--nbest", + str(args.num_rescore), + "--batch-size", + str(args.batch_size), + "--beam", + str(args.num_rescore), + "--batch-size", + str(args.num_rescore), + "--gen-subset", + args.gen_subset, + "--source-lang", + args.source_lang, + "--target-lang", + args.target_lang, + ] if args.sampling: param1 += ["--sampling"] @@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args): input_args = options.parse_args_and_arch(gen_parser, param1) print(input_args) - with open(predictions_bpe_file, 'w') as f: + with open(predictions_bpe_file, "w") as f: with redirect_stdout(f): generate.main(input_args) - gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, - nbest=using_nbest, prefix_len=args.prefix_len, - target_prefix_frac=args.target_prefix_frac) + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, + bpe_symbol=args.remove_bpe, + nbest=using_nbest, + prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac, + ) if args.diff_bpe: - rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, - gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang, - pre_gen+"/target_gen_bpe."+args.target_lang, - pre_gen+"/reference_gen_bpe."+args.target_lang) + rerank_utils.write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + pre_gen + "/source_gen_bpe." + args.source_lang, + pre_gen + "/target_gen_bpe." + args.target_lang, + pre_gen + "/reference_gen_bpe." + args.target_lang, + ) bitext_bpe = args.rescore_bpe_code - bpe_src_param = ["-c", bitext_bpe, - "--input", pre_gen+"/source_gen_bpe."+args.source_lang, - "--output", pre_gen+"/rescore_data."+args.source_lang] - bpe_tgt_param = ["-c", bitext_bpe, - "--input", pre_gen+"/target_gen_bpe."+args.target_lang, - "--output", pre_gen+"/rescore_data."+args.target_lang] - - subprocess.call(["python", - os.path.join(os.path.dirname(__file__), - "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, - shell=False) - - subprocess.call(["python", - os.path.join(os.path.dirname(__file__), - "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param, - shell=False) - - if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \ - (args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen): - print("STEP 2: process the output of generate.py so we have clean text files with the translations") + bpe_src_param = [ + "-c", + bitext_bpe, + "--input", + pre_gen + "/source_gen_bpe." + args.source_lang, + "--output", + pre_gen + "/rescore_data." + args.source_lang, + ] + bpe_tgt_param = [ + "-c", + bitext_bpe, + "--input", + pre_gen + "/target_gen_bpe." + args.target_lang, + "--output", + pre_gen + "/rescore_data." + args.target_lang, + ] + + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_src_param, + shell=False, + ) + + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_tgt_param, + shell=False, + ) + + if (not os.path.isfile(score1_file) and not rerank1_is_gen) or ( + args.score_model2 is not None + and not os.path.isfile(score2_file) + and not rerank2_is_gen + ): + print( + "STEP 2: process the output of generate.py so we have clean text files with the translations" + ) rescore_file = "/rescore_data" if args.prefix_len is not None: - prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len) + prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len) if args.target_prefix_frac is not None: - target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac) + target_prefix_frac_rescore_file = ( + rescore_file + "target_prefix_frac" + str(args.target_prefix_frac) + ) if args.source_prefix_frac is not None: - source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac) + source_prefix_frac_rescore_file = ( + rescore_file + "source_prefix_frac" + str(args.source_prefix_frac) + ) if not args.right_to_left1 or not args.right_to_left2: if not args.diff_bpe: - rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, - pre_gen+rescore_file+"."+args.source_lang, - pre_gen+rescore_file+"."+args.target_lang, - pre_gen+"/reference_file", bpe_symbol=args.remove_bpe) + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + rescore_file + "." + args.source_lang, + pre_gen + rescore_file + "." + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.remove_bpe, + ) if args.prefix_len is not None: bw_rescore_file = prefix_len_rescore_file - rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, - pre_gen+prefix_len_rescore_file+"."+args.source_lang, - pre_gen+prefix_len_rescore_file+"."+args.target_lang, - pre_gen+"/reference_file", prefix_len=args.prefix_len, - bpe_symbol=args.remove_bpe) + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + prefix_len_rescore_file + "." + args.source_lang, + pre_gen + prefix_len_rescore_file + "." + args.target_lang, + pre_gen + "/reference_file", + prefix_len=args.prefix_len, + bpe_symbol=args.remove_bpe, + ) elif args.target_prefix_frac is not None: bw_rescore_file = target_prefix_frac_rescore_file - rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, - pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang, - pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang, - pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, - target_prefix_frac=args.target_prefix_frac) + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + + target_prefix_frac_rescore_file + + "." + + args.source_lang, + pre_gen + + target_prefix_frac_rescore_file + + "." + + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.remove_bpe, + target_prefix_frac=args.target_prefix_frac, + ) else: bw_rescore_file = rescore_file if args.source_prefix_frac is not None: fw_rescore_file = source_prefix_frac_rescore_file - rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, - pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang, - pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang, - pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, - source_prefix_frac=args.source_prefix_frac) + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + + source_prefix_frac_rescore_file + + "." + + args.source_lang, + pre_gen + + source_prefix_frac_rescore_file + + "." + + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.remove_bpe, + source_prefix_frac=args.source_prefix_frac, + ) else: fw_rescore_file = rescore_file if args.right_to_left1 or args.right_to_left2: - rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, - pre_gen+"/right_to_left_rescore_data."+args.source_lang, - pre_gen+"/right_to_left_rescore_data."+args.target_lang, - pre_gen+"/right_to_left_reference_file", - right_to_left=True, bpe_symbol=args.remove_bpe) + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + "/right_to_left_rescore_data." + args.source_lang, + pre_gen + "/right_to_left_rescore_data." + args.target_lang, + pre_gen + "/right_to_left_reference_file", + right_to_left=True, + bpe_symbol=args.remove_bpe, + ) print("STEP 3: binarize the translations") - if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen: + if ( + not args.right_to_left1 + or args.score_model2 is not None + and not args.right_to_left2 + or not rerank1_is_gen + ): if args.backwards1 or args.backwards2: if args.backwards_score_dict_dir is not None: bw_dict = args.backwards_score_dict_dir else: bw_dict = args.score_dict_dir - bw_preprocess_param = ["--source-lang", scorer1_src, - "--target-lang", scorer1_tgt, - "--trainpref", pre_gen+bw_rescore_file, - "--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", - "--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", - "--destdir", backwards_preprocessed_dir] + bw_preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + bw_rescore_file, + "--srcdict", + bw_dict + "/dict." + scorer1_src + ".txt", + "--tgtdict", + bw_dict + "/dict." + scorer1_tgt + ".txt", + "--destdir", + backwards_preprocessed_dir, + ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(bw_preprocess_param) preprocess.main(input_args) - preprocess_param = ["--source-lang", scorer1_src, - "--target-lang", scorer1_tgt, - "--trainpref", pre_gen+fw_rescore_file, - "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", - "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", - "--destdir", left_to_right_preprocessed_dir] + preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + fw_rescore_file, + "--srcdict", + args.score_dict_dir + "/dict." + scorer1_src + ".txt", + "--tgtdict", + args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", + "--destdir", + left_to_right_preprocessed_dir, + ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) if args.right_to_left1 or args.right_to_left2: - preprocess_param = ["--source-lang", scorer1_src, - "--target-lang", scorer1_tgt, - "--trainpref", pre_gen+"/right_to_left_rescore_data", - "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", - "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", - "--destdir", right_to_left_preprocessed_dir] + preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + "/right_to_left_rescore_data", + "--srcdict", + args.score_dict_dir + "/dict." + scorer1_src + ".txt", + "--tgtdict", + args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", + "--destdir", + right_to_left_preprocessed_dir, + ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) @@ -241,5 +393,5 @@ def cli_main(): gen_and_reprocess_nbest(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py index a425fb295b..ca7a2e0a61 100644 --- a/examples/noisychannel/rerank_options.py +++ b/examples/noisychannel/rerank_options.py @@ -6,14 +6,14 @@ from fairseq import options -def get_reranking_parser(default_task='translation'): - parser = options.get_parser('Generation and reranking', default_task) +def get_reranking_parser(default_task="translation"): + parser = options.get_parser("Generation and reranking", default_task) add_reranking_args(parser) return parser -def get_tuning_parser(default_task='translation'): - parser = options.get_parser('Reranking tuning', default_task) +def get_tuning_parser(default_task="translation"): + parser = options.get_parser("Reranking tuning", default_task) add_reranking_args(parser) add_tuning_args(parser) return parser @@ -110,17 +110,40 @@ def add_reranking_args(parser): def add_tuning_args(parser): group = parser.add_argument_group("Tuning") - group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float, - help='lower bound of search space') - group.add_argument('--upper-bound', default=[3], nargs='+', type=float, - help='upper bound of search space') - group.add_argument('--tune-param', default=['lenpen'], nargs='+', - choices=['lenpen', 'weight1', 'weight2', 'weight3'], - help='the parameter(s) to tune') - group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'], - help='the subset to tune on ') - group.add_argument('--num-trials', default=1000, type=int, - help='number of trials to do for random search') - group.add_argument('--share-weights', action='store_true', - help='share weight2 and weight 3') + group.add_argument( + "--lower-bound", + default=[-0.7], + nargs="+", + type=float, + help="lower bound of search space", + ) + group.add_argument( + "--upper-bound", + default=[3], + nargs="+", + type=float, + help="upper bound of search space", + ) + group.add_argument( + "--tune-param", + default=["lenpen"], + nargs="+", + choices=["lenpen", "weight1", "weight2", "weight3"], + help="the parameter(s) to tune", + ) + group.add_argument( + "--tune-subset", + default="valid", + choices=["valid", "test", "train"], + help="the subset to tune on ", + ) + group.add_argument( + "--num-trials", + default=1000, + type=int, + help="number of trials to do for random search", + ) + group.add_argument( + "--share-weights", action="store_true", help="share weight2 and weight 3" + ) return group diff --git a/examples/noisychannel/rerank_score_bw.py b/examples/noisychannel/rerank_score_bw.py index 6a875e9fe3..895673b1cc 100644 --- a/examples/noisychannel/rerank_score_bw.py +++ b/examples/noisychannel/rerank_score_bw.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from contextlib import redirect_stdout import os +from contextlib import redirect_stdout from fairseq import options from fairseq_cli import generate @@ -13,82 +13,124 @@ def score_bw(args): - if args.backwards1: - scorer1_src = args.target_lang - scorer1_tgt = args.source_lang + if args.backwards1: + scorer1_src = args.target_lang + scorer1_tgt = args.source_lang + else: + scorer1_src = args.source_lang + scorer1_tgt = args.target_lang + + if args.score_model2 is not None: + if args.backwards2: + scorer2_src = args.target_lang + scorer2_tgt = args.source_lang else: - scorer1_src = args.source_lang - scorer1_tgt = args.target_lang - - if args.score_model2 is not None: - if args.backwards2: - scorer2_src = args.target_lang - scorer2_tgt = args.source_lang - else: - scorer2_src = args.source_lang - scorer2_tgt = args.target_lang - - rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None - rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None - - pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir = \ - rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, - args.gen_model_name, args.shard_id, args.num_shards, - args.sampling, args.prefix_len, args.target_prefix_frac, - args.source_prefix_frac) - - score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards1) - - if args.score_model2 is not None: - score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, - target_prefix_frac=args.target_prefix_frac, - source_prefix_frac=args.source_prefix_frac, - backwards=args.backwards2) - - if args.right_to_left1: - rerank_data1 = right_to_left_preprocessed_dir - elif args.backwards1: - rerank_data1 = backwards_preprocessed_dir + scorer2_src = args.source_lang + scorer2_tgt = args.target_lang + + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) + + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) + + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) + + if args.right_to_left1: + rerank_data1 = right_to_left_preprocessed_dir + elif args.backwards1: + rerank_data1 = backwards_preprocessed_dir + else: + rerank_data1 = left_to_right_preprocessed_dir + + gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] + if not rerank1_is_gen and not os.path.isfile(score1_file): + print("STEP 4: score the translations for model 1") + + model_param1 = [ + "--path", + args.score_model1, + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + ] + gen_model1_param = [rerank_data1] + gen_param + model_param1 + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) + + with open(score1_file, "w") as f: + with redirect_stdout(f): + generate.main(input_args) + + if ( + args.score_model2 is not None + and not os.path.isfile(score2_file) + and not rerank2_is_gen + ): + print("STEP 4: score the translations for model 2") + + if args.right_to_left2: + rerank_data2 = right_to_left_preprocessed_dir + elif args.backwards2: + rerank_data2 = backwards_preprocessed_dir else: - rerank_data1 = left_to_right_preprocessed_dir - - gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] - if not rerank1_is_gen and not os.path.isfile(score1_file): - print("STEP 4: score the translations for model 1") - - model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt] - gen_model1_param = [rerank_data1] + gen_param + model_param1 - - gen_parser = options.get_generation_parser() - input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) - - with open(score1_file, 'w') as f: - with redirect_stdout(f): - generate.main(input_args) - - if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen: - print("STEP 4: score the translations for model 2") - - if args.right_to_left2: - rerank_data2 = right_to_left_preprocessed_dir - elif args.backwards2: - rerank_data2 = backwards_preprocessed_dir - else: - rerank_data2 = left_to_right_preprocessed_dir + rerank_data2 = left_to_right_preprocessed_dir - model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt] - gen_model2_param = [rerank_data2] + gen_param + model_param2 + model_param2 = [ + "--path", + args.score_model2, + "--source-lang", + scorer2_src, + "--target-lang", + scorer2_tgt, + ] + gen_model2_param = [rerank_data2] + gen_param + model_param2 - gen_parser = options.get_generation_parser() - input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) - with open(score2_file, 'w') as f: - with redirect_stdout(f): - generate.main(input_args) + with open(score2_file, "w") as f: + with redirect_stdout(f): + generate.main(input_args) def cli_main(): @@ -97,5 +139,5 @@ def cli_main(): score_bw(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/noisychannel/rerank_score_lm.py b/examples/noisychannel/rerank_score_lm.py index 74b858e3c8..fa3aa64462 100644 --- a/examples/noisychannel/rerank_score_lm.py +++ b/examples/noisychannel/rerank_score_lm.py @@ -12,22 +12,38 @@ def score_lm(args): using_nbest = args.nbest_list is not None - pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir = \ - rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, - args.gen_model_name, args.shard_id, args.num_shards, - args.sampling, args.prefix_len, args.target_prefix_frac, - args.source_prefix_frac) - - predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list - gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest) + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest + ) if args.language_model is not None: - lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) + lm_score_file = rerank_utils.rescore_file_name( + pre_gen, args.prefix_len, args.lm_name, lm_file=True + ) if args.language_model is not None and not os.path.isfile(lm_score_file): print("STEP 4.5: language modeling for P(T)") @@ -38,10 +54,21 @@ def score_lm(args): else: bpe_status = "different" - rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen, - args.lm_dict, args.lm_name, args.language_model, - args.lm_bpe_code, 128, lm_score_file, args.target_lang, - args.source_lang, prefix_len=args.prefix_len) + rerank_utils.lm_scoring( + lm_preprocessed_dir, + bpe_status, + gen_output, + pre_gen, + args.lm_dict, + args.lm_name, + args.language_model, + args.lm_bpe_code, + 128, + lm_score_file, + args.target_lang, + args.source_lang, + prefix_len=args.prefix_len, + ) def cli_main(): @@ -50,5 +77,5 @@ def cli_main(): score_lm(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/noisychannel/rerank_tune.py b/examples/noisychannel/rerank_tune.py index 789096b3fa..1be71744a3 100644 --- a/examples/noisychannel/rerank_tune.py +++ b/examples/noisychannel/rerank_tune.py @@ -5,8 +5,8 @@ import argparse import random -import numpy as np +import numpy as np from fairseq import options from . import rerank, rerank_options @@ -14,7 +14,7 @@ def random_search(args): param_values = [] - tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3'] + tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"] initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3] for i, elem in enumerate(initial_params): if type(elem) is not list: @@ -33,51 +33,60 @@ def random_search(args): param_values += initial_params random.seed(args.seed) - random_params = np.array([ - [random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))] - for k in range(args.num_trials) - ]) - set_params = np.array([ - [initial_params[i][0] for i in range(len(tuneable_parameters))] - for k in range(args.num_trials) - ]) + random_params = np.array( + [ + [ + random.uniform(args.lower_bound[i], args.upper_bound[i]) + for i in range(len(args.tune_param)) + ] + for k in range(args.num_trials) + ] + ) + set_params = np.array( + [ + [initial_params[i][0] for i in range(len(tuneable_parameters))] + for k in range(args.num_trials) + ] + ) random_params = np.concatenate((random_params, set_params), 1) rerank_args = vars(args).copy() if args.nbest_list: - rerank_args['gen_subset'] = 'test' + rerank_args["gen_subset"] = "test" else: - rerank_args['gen_subset'] = args.tune_subset + rerank_args["gen_subset"] = args.tune_subset for k in range(len(tune_parameters)): rerank_args[tune_parameters[k]] = list(random_params[:, k]) if args.share_weights: - k = tune_parameters.index('weight2') - rerank_args['weight3'] = list(random_params[:, k]) + k = tune_parameters.index("weight2") + rerank_args["weight3"] = list(random_params[:, k]) rerank_args = argparse.Namespace(**rerank_args) - best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args) + best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank( + rerank_args + ) rerank_args = vars(args).copy() - rerank_args['lenpen'] = [best_lenpen] - rerank_args['weight1'] = [best_weight1] - rerank_args['weight2'] = [best_weight2] - rerank_args['weight3'] = [best_weight3] + rerank_args["lenpen"] = [best_lenpen] + rerank_args["weight1"] = [best_weight1] + rerank_args["weight2"] = [best_weight2] + rerank_args["weight3"] = [best_weight3] # write the hypothesis from the valid set from the best trial if args.gen_subset != "valid": - rerank_args['gen_subset'] = "valid" + rerank_args["gen_subset"] = "valid" rerank_args = argparse.Namespace(**rerank_args) rerank.rerank(rerank_args) # test with the best hyperparameters on gen subset rerank_args = vars(args).copy() - rerank_args['gen_subset'] = args.gen_subset - rerank_args['lenpen'] = [best_lenpen] - rerank_args['weight1'] = [best_weight1] - rerank_args['weight2'] = [best_weight2] - rerank_args['weight3'] = [best_weight3] + rerank_args["gen_subset"] = args.gen_subset + rerank_args["lenpen"] = [best_lenpen] + rerank_args["weight1"] = [best_weight1] + rerank_args["weight2"] = [best_weight2] + rerank_args["weight3"] = [best_weight3] rerank_args = argparse.Namespace(**rerank_args) rerank.rerank(rerank_args) @@ -89,5 +98,5 @@ def cli_main(): random_search(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/examples/noisychannel/rerank_utils.py b/examples/noisychannel/rerank_utils.py index e1fcf918c5..2c6bf1b1af 100644 --- a/examples/noisychannel/rerank_utils.py +++ b/examples/noisychannel/rerank_utils.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from contextlib import redirect_stdout import math import os import re import subprocess +from contextlib import redirect_stdout from fairseq import options from fairseq_cli import eval_lm, preprocess @@ -20,7 +20,7 @@ def reprocess(fle): # per source, so the values for hypothesis_dict are lists. # parses output of generate.py - with open(fle, 'r') as f: + with open(fle, "r") as f: txt = f.read() """reprocess generate.py output""" @@ -45,7 +45,9 @@ def reprocess(fle): if line_type == "H": h_txt = line[j:] hypo = re.search(hp, h_txt) - assert hypo is not None, ("regular expression failed to find the hypothesis scoring") + assert ( + hypo is not None + ), "regular expression failed to find the hypothesis scoring" _, i = hypo.span() score = hypo.group() if id_num in hypothesis_dict: @@ -56,9 +58,9 @@ def reprocess(fle): score_dict[id_num] = [float(score)] elif line_type == "S": - source_dict[id_num] = (line[j:]) + source_dict[id_num] = line[j:] elif line_type == "T": - target_dict[id_num] = (line[j:]) + target_dict[id_num] = line[j:] elif line_type == "P": pos_scores = (line[j:]).split() pos_scores = [float(x) for x in pos_scores] @@ -72,7 +74,7 @@ def reprocess(fle): def reprocess_nbest(fle): """reprocess interactive.py output""" - with open(fle, 'r') as f: + with open(fle, "r") as f: txt = f.read() source_dict = {} @@ -82,7 +84,7 @@ def reprocess_nbest(fle): pos_score_dict = {} lines = txt.split("\n") - hp = re.compile(r'[-]?\d+[.]?\d+') + hp = re.compile(r"[-]?\d+[.]?\d+") j = -1 for _i, line in enumerate(lines): @@ -119,59 +121,76 @@ def reprocess_nbest(fle): return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict -def write_reprocessed(sources, hypos, targets, source_outfile, - hypo_outfile, target_outfile, right_to_left=False, - prefix_len=None, bpe_symbol=None, - target_prefix_frac=None, source_prefix_frac=None): +def write_reprocessed( + sources, + hypos, + targets, + source_outfile, + hypo_outfile, + target_outfile, + right_to_left=False, + prefix_len=None, + bpe_symbol=None, + target_prefix_frac=None, + source_prefix_frac=None, +): """writes nbest hypothesis for rescoring""" - assert not (prefix_len is not None and target_prefix_frac is not None), \ - "in writing reprocessed, only one type of prefix may be used" - assert not (prefix_len is not None and source_prefix_frac is not None), \ - "in writing reprocessed, only one type of prefix may be used" - assert not (target_prefix_frac is not None and source_prefix_frac is not None), \ - "in writing reprocessed, only one type of prefix may be used" - - with open(source_outfile, 'w') as source_file, \ - open(hypo_outfile, 'w') as hypo_file, \ - open(target_outfile, 'w') as target_file: + assert not ( + prefix_len is not None and target_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + assert not ( + prefix_len is not None and source_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + assert not ( + target_prefix_frac is not None and source_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + + with open(source_outfile, "w") as source_file, open( + hypo_outfile, "w" + ) as hypo_file, open(target_outfile, "w") as target_file: assert len(sources) == len(hypos), "sources and hypos list length mismatch" if right_to_left: for i in range(len(sources)): - for j in range(len(hypos[i])): - if prefix_len is None: - hypo_file.write(make_right_to_left(hypos[i][j])+"\n") - else: - raise NotImplementedError() - source_file.write(make_right_to_left(sources[i])+"\n") - target_file.write(make_right_to_left(targets[i])+"\n") + for j in range(len(hypos[i])): + if prefix_len is None: + hypo_file.write(make_right_to_left(hypos[i][j]) + "\n") + else: + raise NotImplementedError() + source_file.write(make_right_to_left(sources[i]) + "\n") + target_file.write(make_right_to_left(targets[i]) + "\n") else: for i in sorted(sources.keys()): - for j in range(len(hypos[i])): - if prefix_len is not None: - shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n" - hypo_file.write(shortened) - source_file.write(sources[i]) - target_file.write(targets[i]) - elif target_prefix_frac is not None: - num_words, shortened, num_bpe_tokens = \ - calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol) - shortened += "\n" - hypo_file.write(shortened) - source_file.write(sources[i]) - target_file.write(targets[i]) - elif source_prefix_frac is not None: - num_words, shortened, num_bpe_tokensn = \ - calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol) - shortened += "\n" - hypo_file.write(hypos[i][j]) - source_file.write(shortened) - target_file.write(targets[i]) - else: - hypo_file.write(hypos[i][j]) - source_file.write(sources[i]) - target_file.write(targets[i]) + for j in range(len(hypos[i])): + if prefix_len is not None: + shortened = ( + get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len) + + "\n" + ) + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif target_prefix_frac is not None: + num_words, shortened, num_bpe_tokens = calc_length_from_frac( + hypos[i][j], target_prefix_frac, bpe_symbol + ) + shortened += "\n" + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif source_prefix_frac is not None: + num_words, shortened, num_bpe_tokensn = calc_length_from_frac( + sources[i], source_prefix_frac, bpe_symbol + ) + shortened += "\n" + hypo_file.write(hypos[i][j]) + source_file.write(shortened) + target_file.write(targets[i]) + else: + hypo_file.write(hypos[i][j]) + source_file.write(sources[i]) + target_file.write(targets[i]) def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol): @@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len): if bpe_count == 0: return sentence[:prefix_len] else: - return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count) + return sentence[:prefix_len] + get_prefix_from_len( + sentence[prefix_len:], bpe_symbol, bpe_count + ) def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len): @@ -225,9 +246,9 @@ def make_right_to_left(line): def remove_bpe(line, bpe_symbol): - line = line.replace("\n", '') - line = (line + ' ').replace(bpe_symbol, '').rstrip() - return line+("\n") + line = line.replace("\n", "") + line = (line + " ").replace(bpe_symbol, "").rstrip() + return line + ("\n") def remove_bpe_dict(pred_dict, bpe_symbol): @@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol): def parse_bleu_scoring(line): - p = re.compile(r'(BLEU4 = )\d+[.]\d+') + p = re.compile(r"(BLEU4 = )\d+[.]\d+") res = re.search(p, line) assert res is not None, line return float(res.group()[8:]) @@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos): raise Exception() -def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None, - lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False, - bitext2_backwards=False, normalize=False): +def get_score( + a, + b, + c, + target_len, + bitext_score1, + bitext_score2=None, + lm_score=None, + lenpen=None, + src_len=None, + tgt_len=None, + bitext1_backwards=False, + bitext2_backwards=False, + normalize=False, +): if bitext1_backwards: bitext1_norm = src_len else: @@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N bitext2_norm = 1 bitext_score2 = 0 if normalize: - score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len + score = ( + a * bitext_score1 / bitext1_norm + + b * bitext_score2 / bitext2_norm + + c * lm_score / src_len + ) else: - score = a*bitext_score1 + b*bitext_score2+c*lm_score + score = a * bitext_score1 + b * bitext_score2 + c * lm_score if lenpen is not None: score /= (target_len) ** float(lenpen) @@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N class BitextOutput(object): - def __init__(self, output_file, backwards, right_to_left, bpe_symbol, - prefix_len=None, target_prefix_frac=None, source_prefix_frac=None): + def __init__( + self, + output_file, + backwards, + right_to_left, + bpe_symbol, + prefix_len=None, + target_prefix_frac=None, + source_prefix_frac=None, + ): """process output from rescoring""" source, hypo, score, target, pos_score = reprocess(output_file) if backwards: @@ -296,7 +341,9 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol, self.hypo_fracs = target_prefix_frac # remove length penalty so we can use raw scores - score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards) + score, num_bpe_tokens = get_score_from_pos( + pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards + ) source_lengths = {} target_lengths = {} @@ -341,7 +388,9 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol, score[i] = float(score[i][0]) pos_score[i] = pos_score[i][0] else: - assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence" + assert ( + len(hypo[i]) == 1 + ), "expected only one hypothesis per source sentence" source[i] = remove_bpe(source[i], bpe_symbol) target[i] = remove_bpe(target[i], bpe_symbol) hypo[i] = remove_bpe(hypo[i][0], bpe_symbol) @@ -360,11 +409,26 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol, class BitextOutputFromGen(object): - def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None): + def __init__( + self, + predictions_bpe_file, + bpe_symbol=None, + nbest=False, + prefix_len=None, + target_prefix_frac=None, + ): if nbest: - pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file) + ( + pred_source, + pred_hypo, + pred_score, + pred_target, + pred_pos_score, + ) = reprocess_nbest(predictions_bpe_file) else: - pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file) + pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess( + predictions_bpe_file + ) assert len(pred_source) == len(pred_hypo) assert len(pred_source) == len(pred_score) @@ -372,8 +436,9 @@ def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_le assert len(pred_source) == len(pred_pos_score) # remove length penalty so we can use raw scores - pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo, - bpe_symbol, target_prefix_frac, False) + pred_score, num_bpe_tokens = get_score_from_pos( + pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False + ) self.source = pred_source self.target = pred_target @@ -414,7 +479,9 @@ def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_le index += 1 -def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards): +def get_score_from_pos( + pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards +): score_dict = {} num_bpe_tokens_dict = {} assert prefix_len is None or hypo_frac is None @@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f num_bpe_tokens_dict[key] = [] for i in range(len(pos_score_dict[key])): if prefix_len is not None and not backwards: - num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len) + num_bpe_tokens = get_num_bpe_tokens_from_len( + hypo_dict[key][i], bpe_symbol, prefix_len + ) score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens])) num_bpe_tokens_dict[key].append(num_bpe_tokens) elif hypo_frac is not None: - num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol) + num_words, shortened, hypo_prefix_len = calc_length_from_frac( + hypo_dict[key][i], hypo_frac, bpe_symbol + ) score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len])) num_bpe_tokens_dict[key].append(hypo_prefix_len) else: @@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f class LMOutput(object): - def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): - lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \ - parse_lm(lm_score_file, prefix_len=prefix_len, - bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac) + def __init__( + self, + lm_score_file, + lm_dict=None, + prefix_len=None, + bpe_symbol=None, + target_prefix_frac=None, + ): + ( + lm_sentences, + lm_sen_scores, + lm_sen_pos_scores, + lm_no_bpe_sentences, + lm_bpe_tokens, + ) = parse_lm( + lm_score_file, + prefix_len=prefix_len, + bpe_symbol=bpe_symbol, + target_prefix_frac=target_prefix_frac, + ) self.sentences = lm_sentences self.score = lm_sen_scores @@ -452,7 +539,7 @@ def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): """parse output of eval_lm""" - with open(input_file, 'r') as f: + with open(input_file, "r") as f: text = f.readlines() text = text[7:] cleaned_text = text[:-2] @@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No if tokens[0].isdigit(): line_id = int(tokens[0]) scores = [float(x[1:-1]) for x in tokens[2::2]] - sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n" + sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n" if bpe_symbol is not None: # exclude symbol to match output from generate.py - bpe_sen = " ".join(tokens[1::2][:-1])+"\n" + bpe_sen = " ".join(tokens[1::2][:-1]) + "\n" no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol) no_bpe_sentences[line_id] = no_bpe_sen if prefix_len is not None: - num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len) + num_bpe_tokens = get_num_bpe_tokens_from_len( + bpe_sen, bpe_symbol, prefix_len + ) sen_scores[line_id] = sum(scores[:num_bpe_tokens]) num_bpe_tokens_dict[line_id] = num_bpe_tokens elif target_prefix_frac is not None: - num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac, - bpe_symbol) + num_words, shortened, target_prefix_len = calc_length_from_frac( + bpe_sen, target_prefix_frac, bpe_symbol + ) sen_scores[line_id] = sum(scores[:target_prefix_len]) num_bpe_tokens_dict[line_id] = target_prefix_len else: @@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict -def get_directories(data_dir_name, num_rescore, gen_subset, - fw_name, shard_id, num_shards, - sampling=False, prefix_len=None, - target_prefix_frac=None, source_prefix_frac=None): - nbest_file_id = "nbest_" + str(num_rescore) + \ - "_subset_" + gen_subset + \ - "_fw_name_" + fw_name + \ - "_shard_" + str(shard_id) + \ - "_of_" + str(num_shards) +def get_directories( + data_dir_name, + num_rescore, + gen_subset, + fw_name, + shard_id, + num_shards, + sampling=False, + prefix_len=None, + target_prefix_frac=None, + source_prefix_frac=None, +): + nbest_file_id = ( + "nbest_" + + str(num_rescore) + + "_subset_" + + gen_subset + + "_fw_name_" + + fw_name + + "_shard_" + + str(shard_id) + + "_of_" + + str(num_shards) + ) if sampling: nbest_file_id += "_sampling" # the directory containing all information for this nbest list - pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id + pre_gen = ( + os.path.join(os.path.dirname(__file__)) + + "/rerank_data/" + + data_dir_name + + "/" + + nbest_file_id + ) # the directory to store the preprocessed nbest list, for left to right rescoring - left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed" + left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed" if source_prefix_frac is not None: - left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac) + left_to_right_preprocessed_dir = ( + left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac) + ) # the directory to store the preprocessed nbest list, for right to left rescoring - right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed" + right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed" # the directory to store the preprocessed nbest list, for backwards rescoring - backwards_preprocessed_dir = pre_gen+"/backwards" + backwards_preprocessed_dir = pre_gen + "/backwards" if target_prefix_frac is not None: - backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac) + backwards_preprocessed_dir = ( + backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac) + ) elif prefix_len is not None: - backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len) + backwards_preprocessed_dir = ( + backwards_preprocessed_dir + "/prefix_" + str(prefix_len) + ) # the directory to store the preprocessed nbest list, for rescoring with P(T) - lm_preprocessed_dir = pre_gen+"/lm_preprocessed" - - return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ - backwards_preprocessed_dir, lm_preprocessed_dir - - -def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen, - cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code, - batch_size, lm_score_file, target_lang, source_lang, prefix_len=None): + lm_preprocessed_dir = pre_gen + "/lm_preprocessed" + + return ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) + + +def lm_scoring( + preprocess_directory, + bpe_status, + gen_output, + pre_gen, + cur_lm_dict, + cur_lm_name, + cur_language_model, + cur_lm_bpe_code, + batch_size, + lm_score_file, + target_lang, + source_lang, + prefix_len=None, +): if prefix_len is not None: - assert bpe_status == "different", "bpe status must be different to use prefix len" + assert ( + bpe_status == "different" + ), "bpe status must be different to use prefix len" if bpe_status == "no bpe": # run lm on output without bpe - write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, - gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de", - pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe") - - preprocess_lm_param = ["--only-source", - "--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang, - "--srcdict", cur_lm_dict, - "--destdir", preprocess_directory] + write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + pre_gen + "/rescore_data_no_bpe.de", + pre_gen + "/rescore_data_no_bpe.en", + pre_gen + "/reference_file_no_bpe", + ) + + preprocess_lm_param = [ + "--only-source", + "--trainpref", + pre_gen + "/rescore_data_no_bpe." + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_directory, + ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_lm_param) preprocess.main(input_args) - eval_lm_param = [preprocess_directory, - "--path", cur_language_model, - "--output-word-probs", - "--batch-size", str(batch_size), - "--max-tokens", "1024", - "--sample-break-mode", "eos", - "--gen-subset", "train"] + eval_lm_param = [ + preprocess_directory, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--max-tokens", + "1024", + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] eval_lm_parser = options.get_eval_lm_parser() input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) - with open(lm_score_file, 'w') as f: + with open(lm_score_file, "w") as f: with redirect_stdout(f): eval_lm.main(input_args) elif bpe_status == "shared": - preprocess_lm_param = ["--only-source", - "--trainpref", pre_gen+"/rescore_data."+target_lang, - "--srcdict", cur_lm_dict, - "--destdir", preprocess_directory] - preprocess_parser = options.get_preprocessing_parser() - input_args = preprocess_parser.parse_args(preprocess_lm_param) - preprocess.main(input_args) - - eval_lm_param = [preprocess_directory, - "--path", cur_language_model, - "--output-word-probs", - "--batch-size", str(batch_size), - "--sample-break-mode", "eos", - "--gen-subset", "train"] - - eval_lm_parser = options.get_eval_lm_parser() - input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) - - with open(lm_score_file, 'w') as f: - with redirect_stdout(f): - eval_lm.main(input_args) + preprocess_lm_param = [ + "--only-source", + "--trainpref", + pre_gen + "/rescore_data." + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_directory, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [ + preprocess_directory, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, "w") as f: + with redirect_stdout(f): + eval_lm.main(input_args) elif bpe_status == "different": - rescore_file = pre_gen+"/rescore_data_no_bpe" - rescore_bpe = pre_gen+"/rescore_data_new_bpe" + rescore_file = pre_gen + "/rescore_data_no_bpe" + rescore_bpe = pre_gen + "/rescore_data_new_bpe" rescore_file += "." rescore_bpe += "." - write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, - gen_output.no_bpe_target, rescore_file+source_lang, - rescore_file+target_lang, pre_gen+"/reference_file_no_bpe", - bpe_symbol=None) + write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + rescore_file + source_lang, + rescore_file + target_lang, + pre_gen + "/reference_file_no_bpe", + bpe_symbol=None, + ) # apply LM bpe to nbest list - bpe_src_param = ["-c", cur_lm_bpe_code, - "--input", rescore_file+target_lang, - "--output", rescore_bpe+target_lang] - subprocess.call(["python", - os.path.join(os.path.dirname(__file__), - "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, - shell=False) + bpe_src_param = [ + "-c", + cur_lm_bpe_code, + "--input", + rescore_file + target_lang, + "--output", + rescore_bpe + target_lang, + ] + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_src_param, + shell=False, + ) # uncomment to use fastbpe instead of subword-nmt bpe # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code] # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False) preprocess_dir = preprocess_directory - preprocess_lm_param = ["--only-source", - "--trainpref", rescore_bpe+target_lang, - "--srcdict", cur_lm_dict, - "--destdir", preprocess_dir] + preprocess_lm_param = [ + "--only-source", + "--trainpref", + rescore_bpe + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_dir, + ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_lm_param) preprocess.main(input_args) - eval_lm_param = [preprocess_dir, - "--path", cur_language_model, - "--output-word-probs", - "--batch-size", str(batch_size), - "--max-tokens", "1024", - "--sample-break-mode", "eos", - "--gen-subset", "train"] + eval_lm_param = [ + preprocess_dir, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--max-tokens", + "1024", + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] eval_lm_parser = options.get_eval_lm_parser() input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) - with open(lm_score_file, 'w') as f: + with open(lm_score_file, "w") as f: with redirect_stdout(f): eval_lm.main(input_args) -def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False, - target_prefix_frac=None, source_prefix_frac=None, backwards=None): +def rescore_file_name( + nbest_dir, + prefix_len, + scorer_name, + lm_file=False, + target_prefix_frac=None, + source_prefix_frac=None, + backwards=None, +): if lm_file: - score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt" + score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt" else: - score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt" + score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt" if backwards: if prefix_len is not None: - score_file += "prefix_len"+str(prefix_len) + score_file += "prefix_len" + str(prefix_len) elif target_prefix_frac is not None: - score_file += "target_prefix_frac"+str(target_prefix_frac) + score_file += "target_prefix_frac" + str(target_prefix_frac) else: if source_prefix_frac is not None: - score_file += "source_prefix_frac"+str(source_prefix_frac) + score_file += "source_prefix_frac" + str(source_prefix_frac) return score_file diff --git a/examples/paraphraser/paraphrase.py b/examples/paraphraser/paraphrase.py index 405df296b3..d3422fb3db 100644 --- a/examples/paraphraser/paraphrase.py +++ b/examples/paraphraser/paraphrase.py @@ -13,57 +13,66 @@ def main(): - parser = argparse.ArgumentParser(description='') - parser.add_argument('--en2fr', required=True, - help='path to en2fr model') - parser.add_argument('--fr2en', required=True, - help='path to fr2en mixture of experts model') - parser.add_argument('--user-dir', - help='path to fairseq examples/translation_moe/src directory') - parser.add_argument('--num-experts', type=int, default=10, - help='(keep at 10 unless using a different model)') - parser.add_argument('files', nargs='*', default=['-'], - help='input files to paraphrase; "-" for stdin') + parser = argparse.ArgumentParser(description="") + parser.add_argument("--en2fr", required=True, help="path to en2fr model") + parser.add_argument( + "--fr2en", required=True, help="path to fr2en mixture of experts model" + ) + parser.add_argument( + "--user-dir", help="path to fairseq examples/translation_moe/src directory" + ) + parser.add_argument( + "--num-experts", + type=int, + default=10, + help="(keep at 10 unless using a different model)", + ) + parser.add_argument( + "files", + nargs="*", + default=["-"], + help='input files to paraphrase; "-" for stdin', + ) args = parser.parse_args() if args.user_dir is None: args.user_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ - 'translation_moe', - 'src', + "translation_moe", + "src", ) if os.path.exists(args.user_dir): - logging.info('found user_dir:' + args.user_dir) + logging.info("found user_dir:" + args.user_dir) else: raise RuntimeError( - 'cannot find fairseq examples/translation_moe/src ' - '(tried looking here: {})'.format(args.user_dir) + "cannot find fairseq examples/translation_moe/src " + "(tried looking here: {})".format(args.user_dir) ) - logging.info('loading en2fr model from:' + args.en2fr) + logging.info("loading en2fr model from:" + args.en2fr) en2fr = TransformerModel.from_pretrained( model_name_or_path=args.en2fr, - tokenizer='moses', - bpe='sentencepiece', + tokenizer="moses", + bpe="sentencepiece", ).eval() - logging.info('loading fr2en model from:' + args.fr2en) + logging.info("loading fr2en model from:" + args.fr2en) fr2en = TransformerModel.from_pretrained( model_name_or_path=args.fr2en, - tokenizer='moses', - bpe='sentencepiece', + tokenizer="moses", + bpe="sentencepiece", user_dir=args.user_dir, - task='translation_moe', + task="translation_moe", ).eval() def gen_paraphrases(en): fr = en2fr.translate(en) return [ - fr2en.translate(fr, inference_step_args={'expert': i}) + fr2en.translate(fr, inference_step_args={"expert": i}) for i in range(args.num_experts) ] - logging.info('Type the input sentence and press return:') + logging.info("Type the input sentence and press return:") for line in fileinput.input(args.files): line = line.strip() if len(line) == 0: @@ -72,5 +81,5 @@ def gen_paraphrases(en): print(paraphrase) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/pointer_generator/postprocess.py b/examples/pointer_generator/postprocess.py index a01434b5ce..b213aed80f 100755 --- a/examples/pointer_generator/postprocess.py +++ b/examples/pointer_generator/postprocess.py @@ -4,9 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import sys -import re import argparse +import re +import sys class OOVIndexError(IndexError): @@ -25,8 +25,8 @@ def __init__(self, pos, source_seq, target_seq): def replace_oovs(source_in, target_in, target_out): """Replaces tokens in the target text with the corresponding word in - the source text. - """ + the source text. + """ oov_re = re.compile("^$") diff --git a/examples/pointer_generator/preprocess.py b/examples/pointer_generator/preprocess.py index 4b7a5ab9c5..f72ca7d3d9 100755 --- a/examples/pointer_generator/preprocess.py +++ b/examples/pointer_generator/preprocess.py @@ -10,8 +10,8 @@ def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): """Replaces out-of-vocabulary words in source and target text with , - where N in is the position of the word in the source sequence. - """ + where N in is the position of the word in the source sequence. + """ def format_unk(pos): return "".format(pos) diff --git a/examples/pointer_generator/src/transformer_pg.py b/examples/pointer_generator/src/transformer_pg.py index af933b3495..079fdda581 100644 --- a/examples/pointer_generator/src/transformer_pg.py +++ b/examples/pointer_generator/src/transformer_pg.py @@ -8,19 +8,17 @@ import torch import torch.nn as nn - -from fairseq import utils, metrics +from fairseq import metrics, utils from fairseq.models import register_model, register_model_architecture from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( - TransformerModel, + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, TransformerDecoder, TransformerEncoder, + TransformerModel, base_architecture, - DEFAULT_MAX_SOURCE_POSITIONS, - DEFAULT_MAX_TARGET_POSITIONS, ) - from torch import Tensor diff --git a/examples/roberta/commonsense_qa/commonsense_qa_task.py b/examples/roberta/commonsense_qa/commonsense_qa_task.py index 7ed2bc36a4..216093f708 100644 --- a/examples/roberta/commonsense_qa/commonsense_qa_task.py +++ b/examples/roberta/commonsense_qa/commonsense_qa_task.py @@ -8,40 +8,44 @@ import numpy as np import torch - from fairseq.data import ( - data_utils, Dictionary, - encoders, IdDataset, ListDataset, NestedDictionaryDataset, - NumSamplesDataset, NumelDataset, + NumSamplesDataset, RawLabelDataset, RightPadDataset, SortDataset, + data_utils, + encoders, ) -from fairseq.tasks import register_task, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask, register_task -@register_task('commonsense_qa') +@register_task("commonsense_qa") class CommonsenseQATask(LegacyFairseqTask): """Task to finetune RoBERTa for Commonsense QA.""" @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', metavar='DIR', - help='path to data directory; we load .jsonl') - parser.add_argument('--init-token', type=int, default=None, - help='add token at the beginning of each batch item') - parser.add_argument('--num-classes', type=int, default=5) + parser.add_argument( + "data", metavar="DIR", help="path to data directory; we load .jsonl" + ) + parser.add_argument( + "--init-token", + type=int, + default=None, + help="add token at the beginning of each batch item", + ) + parser.add_argument("--num-classes", type=int, default=5) def __init__(self, args, vocab): super().__init__(args) self.vocab = vocab - self.mask = vocab.add_symbol('') + self.mask = vocab.add_symbol("") self.bpe = encoders.build_bpe(args) @@ -53,20 +57,24 @@ def load_dictionary(cls, filename): filename (str): the filename """ dictionary = Dictionary.load(filename) - dictionary.add_symbol('') + dictionary.add_symbol("") return dictionary @classmethod def setup_task(cls, args, **kwargs): - assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking' + assert ( + args.criterion == "sentence_ranking" + ), "Must set --criterion=sentence_ranking" # load data and label dictionaries - vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) - print('| dictionary: {} types'.format(len(vocab))) + vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt")) + print("| dictionary: {} types".format(len(vocab))) return cls(args, vocab) - def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset( + self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs + ): """Load a given dataset split. Args: @@ -77,16 +85,18 @@ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( - s, append_eos=True, add_if_not_exist=False, + s, + append_eos=True, + add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: - data_path = os.path.join(self.args.data, split + '.jsonl') + data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): - raise FileNotFoundError('Cannot find data: {}'.format(data_path)) + raise FileNotFoundError("Cannot find data: {}".format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] @@ -95,20 +105,23 @@ def binarize(s, append_bos=False): with open(data_path) as h: for line in h: example = json.loads(line.strip()) - if 'answerKey' in example: - label = ord(example['answerKey']) - ord('A') + if "answerKey" in example: + label = ord(example["answerKey"]) - ord("A") labels.append(label) - question = example['question']['stem'] - assert len(example['question']['choices']) == self.args.num_classes + question = example["question"]["stem"] + assert len(example["question"]["choices"]) == self.args.num_classes # format: ` Q: Where would I not want a fox? A: hen house ` - question = 'Q: ' + question + question = "Q: " + question question_toks = binarize(question, append_bos=True) - for i, choice in enumerate(example['question']['choices']): - src = 'A: ' + choice['text'] + for i, choice in enumerate(example["question"]["choices"]): + src = "A: " + choice["text"] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) - assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) + assert all( + len(src_tokens[0]) == len(src_tokens[i]) + for i in range(self.args.num_classes) + ) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) @@ -118,24 +131,26 @@ def binarize(s, append_bos=False): src_lengths[i] = ListDataset(src_lengths[i]) dataset = { - 'id': IdDataset(), - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(src_tokens[0], reduce=True), + "id": IdDataset(), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): - dataset.update({ - 'net_input{}'.format(i + 1): { - 'src_tokens': RightPadDataset( - src_tokens[i], - pad_idx=self.source_dictionary.pad(), - ), - 'src_lengths': src_lengths[i], + dataset.update( + { + "net_input{}".format(i + 1): { + "src_tokens": RightPadDataset( + src_tokens[i], + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": src_lengths[i], + } } - }) + ) if len(labels) > 0: - dataset.update({'target': RawLabelDataset(labels)}) + dataset.update({"target": RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, @@ -149,17 +164,18 @@ def binarize(s, append_bos=False): sort_order=[np.random.permutation(len(dataset))], ) - print('| Loaded {} with {} samples'.format(split, len(dataset))) + print("| Loaded {} with {} samples".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split] def build_model(self, args): from fairseq import models + model = models.build_model(args, self) model.register_classification_head( - 'sentence_classification_head', + "sentence_classification_head", num_classes=1, ) diff --git a/examples/roberta/multiprocessing_bpe_encoder.py b/examples/roberta/multiprocessing_bpe_encoder.py index f0240c210f..43fe0451bf 100644 --- a/examples/roberta/multiprocessing_bpe_encoder.py +++ b/examples/roberta/multiprocessing_bpe_encoder.py @@ -8,7 +8,6 @@ import argparse import contextlib import sys - from collections import Counter from multiprocessing import Pool @@ -26,23 +25,23 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( "--encoder-json", - help='path to encoder.json', + help="path to encoder.json", ) parser.add_argument( "--vocab-bpe", type=str, - help='path to vocab.bpe', + help="path to vocab.bpe", ) parser.add_argument( "--inputs", nargs="+", - default=['-'], + default=["-"], help="input files to filter/encode", ) parser.add_argument( "--outputs", nargs="+", - default=['-'], + default=["-"], help="path to save encoded outputs", ) parser.add_argument( @@ -53,18 +52,21 @@ def main(): parser.add_argument("--workers", type=int, default=20) args = parser.parse_args() - assert len(args.inputs) == len(args.outputs), \ - "number of input and output paths should match" + assert len(args.inputs) == len( + args.outputs + ), "number of input and output paths should match" with contextlib.ExitStack() as stack: inputs = [ stack.enter_context(open(input, "r", encoding="utf-8")) - if input != "-" else sys.stdin + if input != "-" + else sys.stdin for input in args.inputs ] outputs = [ stack.enter_context(open(output, "w", encoding="utf-8")) - if output != "-" else sys.stdout + if output != "-" + else sys.stdout for output in args.outputs ] @@ -87,7 +89,6 @@ def main(): class MultiprocessingEncoder(object): - def __init__(self, args): self.args = args diff --git a/examples/roberta/preprocess_RACE.py b/examples/roberta/preprocess_RACE.py index f6f606a389..cdd6607271 100644 --- a/examples/roberta/preprocess_RACE.py +++ b/examples/roberta/preprocess_RACE.py @@ -25,7 +25,7 @@ def get_examples(data_dir, set_type): examples = [] levels = ["middle", "high"] - set_type_c = set_type.split('-') + set_type_c = set_type.split("-") if len(set_type_c) == 2: levels = [set_type_c[1]] set_type = set_type_c[0] @@ -33,13 +33,13 @@ def get_examples(data_dir, set_type): cur_dir = os.path.join(data_dir, set_type, level) for filename in os.listdir(cur_dir): cur_path = os.path.join(cur_dir, filename) - with open(cur_path, 'r') as f: + with open(cur_path, "r") as f: cur_data = json.load(f) answers = cur_data["answers"] options = cur_data["options"] questions = cur_data["questions"] context = cur_data["article"].replace("\n", " ") - context = re.sub(r'\s+', ' ', context) + context = re.sub(r"\s+", " ", context) for i in range(len(answers)): label = ord(answers[i]) - ord("A") qa_list = [] @@ -50,7 +50,7 @@ def get_examples(data_dir, set_type): qa_cat = question.replace("_", option) else: qa_cat = " ".join([question, option]) - qa_cat = re.sub(r'\s+', ' ', qa_cat) + qa_cat = re.sub(r"\s+", " ", qa_cat) qa_list.append(qa_cat) examples.append(InputExample(context, qa_list, label)) @@ -64,11 +64,11 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input-dir", - help='input directory for downloaded RACE dataset', + help="input directory for downloaded RACE dataset", ) parser.add_argument( "--output-dir", - help='output directory for extracted data', + help="output directory for extracted data", ) args = parser.parse_args() @@ -77,17 +77,20 @@ def main(): for set_type in ["train", "dev", "test-middle", "test-high"]: examples = get_examples(args.input_dir, set_type) - qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)] - qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths] + qa_file_paths = [ + os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) + for i in range(4) + ] + qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths] outf_context_path = os.path.join(args.output_dir, set_type + ".input0") outf_label_path = os.path.join(args.output_dir, set_type + ".label") - outf_context = open(outf_context_path, 'w') - outf_label = open(outf_label_path, 'w') + outf_context = open(outf_context_path, "w") + outf_label = open(outf_label_path, "w") for example in examples: - outf_context.write(example.paragraph + '\n') + outf_context.write(example.paragraph + "\n") for i in range(4): - qa_files[i].write(example.qa_list[i] + '\n') - outf_label.write(str(example.label) + '\n') + qa_files[i].write(example.qa_list[i] + "\n") + outf_label.write(str(example.label) + "\n") for f in qa_files: f.close() @@ -95,5 +98,5 @@ def main(): outf_context.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py index dd909ab20c..1a5901234b 100644 --- a/examples/roberta/wsc/wsc_criterion.py +++ b/examples/roberta/wsc/wsc_criterion.py @@ -7,19 +7,17 @@ import torch import torch.nn.functional as F - from fairseq import utils -from fairseq.data import encoders from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from fairseq.data import encoders -@register_criterion('wsc') +@register_criterion("wsc") class WSCCriterion(LegacyFairseqCriterion): - def __init__(self, args, task): super().__init__(args, task) if self.args.save_predictions is not None: - self.prediction_h = open(self.args.save_predictions, 'w') + self.prediction_h = open(self.args.save_predictions, "w") else: self.prediction_h = None self.bpe = encoders.build_bpe(args) @@ -32,12 +30,16 @@ def __del__(self): @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" - parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0) - parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0) - parser.add_argument('--wsc-cross-entropy', action='store_true', - help='use cross entropy formulation instead of margin loss') - parser.add_argument('--save-predictions', metavar='FILE', - help='file to save predictions to') + parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0) + parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0) + parser.add_argument( + "--wsc-cross-entropy", + action="store_true", + help="use cross entropy formulation instead of margin loss", + ) + parser.add_argument( + "--save-predictions", metavar="FILE", help="file to save predictions to" + ) def get_masked_input(self, tokens, mask): masked_tokens = tokens.clone() @@ -60,27 +62,26 @@ def get_loss(self, query_lprobs, cand_lprobs): ) else: return ( - - query_lprobs - + self.args.wsc_margin_alpha * ( - cand_lprobs - query_lprobs + self.args.wsc_margin_beta - ).clamp(min=0) + -query_lprobs + + self.args.wsc_margin_alpha + * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0) ).sum() def forward(self, model, sample, reduce=True): # compute loss and accuracy - loss, nloss = 0., 0 + loss, nloss = 0.0, 0 ncorrect, nqueries = 0, 0 - for i, label in enumerate(sample['labels']): + for i, label in enumerate(sample["labels"]): query_lprobs = self.get_lprobs( model, - sample['query_tokens'][i].unsqueeze(0), - sample['query_masks'][i].unsqueeze(0), + sample["query_tokens"][i].unsqueeze(0), + sample["query_masks"][i].unsqueeze(0), ) cand_lprobs = self.get_lprobs( model, - sample['candidate_tokens'][i], - sample['candidate_masks'][i], + sample["candidate_tokens"][i], + sample["candidate_masks"][i], ) pred = (query_lprobs >= cand_lprobs).all().item() @@ -95,72 +96,72 @@ def forward(self, model, sample, reduce=True): nloss += 1 loss += self.get_loss(query_lprobs, cand_lprobs) - id = sample['id'][i].item() + id = sample["id"][i].item() if self.prediction_h is not None: - print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h) + print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) if nloss == 0: loss = torch.tensor(0.0, requires_grad=True) sample_size = nqueries if nqueries > 0 else 1 logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['nsentences'], - 'sample_size': sample_size, - 'ncorrect': ncorrect, - 'nqueries': nqueries, + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "ncorrect": ncorrect, + "nqueries": nqueries, } return loss, sample_size, logging_output @staticmethod def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) agg_output = { - 'loss': loss_sum / sample_size / math.log(2), - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, + "loss": loss_sum / sample_size / math.log(2), + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, } - ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) - nqueries = sum(log.get('nqueries', 0) for log in logging_outputs) + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + nqueries = sum(log.get("nqueries", 0) for log in logging_outputs) if nqueries > 0: - agg_output['accuracy'] = ncorrect / float(nqueries) + agg_output["accuracy"] = ncorrect / float(nqueries) return agg_output -@register_criterion('winogrande') +@register_criterion("winogrande") class WinograndeCriterion(WSCCriterion): def forward(self, model, sample, reduce=True): # compute loss and accuracy query_lprobs = self.get_lprobs( model, - sample['query_tokens'], - sample['query_masks'], + sample["query_tokens"], + sample["query_masks"], ) cand_lprobs = self.get_lprobs( model, - sample['candidate_tokens'], - sample['candidate_masks'], + sample["candidate_tokens"], + sample["candidate_masks"], ) pred = query_lprobs >= cand_lprobs loss = self.get_loss(query_lprobs, cand_lprobs) - sample_size = sample['query_tokens'].size(0) + sample_size = sample["query_tokens"].size(0) ncorrect = pred.sum().item() logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['nsentences'], - 'sample_size': sample_size, - 'ncorrect': ncorrect, - 'nqueries': sample_size, + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "ncorrect": ncorrect, + "nqueries": sample_size, } return loss, sample_size, logging_output diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py index 058e3eea23..602ea737ed 100644 --- a/examples/roberta/wsc/wsc_task.py +++ b/examples/roberta/wsc/wsc_task.py @@ -10,47 +10,51 @@ import numpy as np import torch import torch.nn.functional as F - from fairseq import utils from fairseq.data import ( - data_utils, Dictionary, - encoders, IdDataset, ListDataset, NestedDictionaryDataset, - NumSamplesDataset, NumelDataset, + NumSamplesDataset, PadDataset, SortDataset, + data_utils, + encoders, ) -from fairseq.tasks import register_task, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask, register_task from . import wsc_utils -@register_task('wsc') +@register_task("wsc") class WSCTask(LegacyFairseqTask): """Task to finetune RoBERTa for Winograd Schemas.""" @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', metavar='DIR', - help='path to data directory; we load .jsonl') - parser.add_argument('--init-token', type=int, default=None, - help='add token at the beginning of each batch item') + parser.add_argument( + "data", metavar="DIR", help="path to data directory; we load .jsonl" + ) + parser.add_argument( + "--init-token", + type=int, + default=None, + help="add token at the beginning of each batch item", + ) def __init__(self, args, vocab): super().__init__(args) self.vocab = vocab - self.mask = vocab.add_symbol('') + self.mask = vocab.add_symbol("") self.bpe = encoders.build_bpe(args) self.tokenizer = encoders.build_tokenizer(args) # hack to handle GPT-2 BPE, which includes leading spaces - if args.bpe == 'gpt2': + if args.bpe == "gpt2": self.leading_space = True self.trailing_space = False else: @@ -65,16 +69,16 @@ def load_dictionary(cls, filename): filename (str): the filename """ dictionary = Dictionary.load(filename) - dictionary.add_symbol('') + dictionary.add_symbol("") return dictionary @classmethod def setup_task(cls, args, **kwargs): - assert args.criterion == 'wsc', 'Must set --criterion=wsc' + assert args.criterion == "wsc", "Must set --criterion=wsc" # load data and label dictionaries - vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) - print('| dictionary: {} types'.format(len(vocab))) + vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt")) + print("| dictionary: {} types".format(len(vocab))) return cls(args, vocab) @@ -84,7 +88,9 @@ def binarize(self, s: str, append_eos: bool = False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( - s, append_eos=append_eos, add_if_not_exist=False, + s, + append_eos=append_eos, + add_if_not_exist=False, ).long() if self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) @@ -98,19 +104,21 @@ def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space) mask = torch.zeros_like(toks, dtype=torch.bool) mask_start = len(self.binarize(prefix)) mask_size = len(self.binarize(leading_space + txt)) - mask[mask_start:mask_start + mask_size] = 1 + mask[mask_start : mask_start + mask_size] = 1 return toks, mask - def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset( + self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: - data_path = os.path.join(self.args.data, split + '.jsonl') + data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): - raise FileNotFoundError('Cannot find data: {}'.format(data_path)) + raise FileNotFoundError("Cannot find data: {}".format(data_path)) query_tokens = [] query_masks = [] @@ -121,13 +129,15 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path): - prefix = sentence[:pronoun_span.start].text - suffix = sentence[pronoun_span.end:].text_with_ws + prefix = sentence[: pronoun_span.start].text + suffix = sentence[pronoun_span.end :].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE - leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else '' - trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else '' + leading_space = ( + " " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else "" + ) + trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else "" # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( @@ -152,7 +162,11 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( - cand_span.text, prefix, suffix, leading_space, trailing_space, + cand_span.text, + prefix, + suffix, + leading_space, + trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) @@ -176,17 +190,17 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) - labels = ListDataset(labels, [1]*len(labels)) + labels = ListDataset(labels, [1] * len(labels)) dataset = { - 'id': IdDataset(), - 'query_tokens': query_tokens, - 'query_masks': query_masks, - 'candidate_tokens': candidate_tokens, - 'candidate_masks': candidate_masks, - 'labels': labels, - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(query_tokens, reduce=True), + "id": IdDataset(), + "query_tokens": query_tokens, + "query_masks": query_masks, + "candidate_tokens": candidate_tokens, + "candidate_masks": candidate_masks, + "labels": labels, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( @@ -210,9 +224,9 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl def build_dataset_for_inference(self, sample_json): with tempfile.NamedTemporaryFile(buffering=0) as h: - h.write((json.dumps(sample_json) + '\n').encode('utf-8')) + h.write((json.dumps(sample_json) + "\n").encode("utf-8")) dataset = self.load_dataset( - 'disambiguate_pronoun', + "disambiguate_pronoun", data_path=h.name, return_only=True, ) @@ -239,19 +253,19 @@ def get_lprobs(tokens, mask): return scores cand_lprobs = get_lprobs( - sample['candidate_tokens'][0], - sample['candidate_masks'][0], + sample["candidate_tokens"][0], + sample["candidate_masks"][0], ) - if sample['query_tokens'][0] is not None: + if sample["query_tokens"][0] is not None: query_lprobs = get_lprobs( - sample['query_tokens'][0].unsqueeze(0), - sample['query_masks'][0].unsqueeze(0), + sample["query_tokens"][0].unsqueeze(0), + sample["query_masks"][0].unsqueeze(0), ) return (query_lprobs >= cand_lprobs).all().item() == 1 else: best_idx = cand_lprobs.argmax().item() - full_cand = sample['candidate_tokens'][0][best_idx] - mask = sample['candidate_masks'][0][best_idx] + full_cand = sample["candidate_tokens"][0][best_idx] + mask = sample["candidate_masks"][0][best_idx] toks = full_cand[mask.bool()] return self.bpe.decode(self.source_dictionary.string(toks)).strip() @@ -264,7 +278,7 @@ def target_dictionary(self): return self.vocab -@register_task('winogrande') +@register_task("winogrande") class WinograndeTask(WSCTask): """ Task for WinoGrande dataset. Efficient implementation for Winograd schema @@ -273,24 +287,26 @@ class WinograndeTask(WSCTask): @classmethod def setup_task(cls, args, **kwargs): - assert args.criterion == 'winogrande', 'Must set --criterion=winogrande' + assert args.criterion == "winogrande", "Must set --criterion=winogrande" # load data and label dictionaries - vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) - print('| dictionary: {} types'.format(len(vocab))) + vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt")) + print("| dictionary: {} types".format(len(vocab))) return cls(args, vocab) - def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset( + self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: - data_path = os.path.join(self.args.data, split + '.jsonl') + data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): - raise FileNotFoundError('Cannot find data: {}'.format(data_path)) + raise FileNotFoundError("Cannot find data: {}".format(data_path)) query_tokens = [] query_masks = [] @@ -299,19 +315,23 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl candidate_masks = [] candidate_lengths = [] - itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test')) + itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test")) for sample in itr: sentence, pronoun_span, query, cand_text = sample - prefix = sentence[:pronoun_span[0]].rstrip() - suffix = sentence[pronoun_span[1]:] + prefix = sentence[: pronoun_span[0]].rstrip() + suffix = sentence[pronoun_span[1] :] - leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else '' - trailing_space = '' + leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else "" + trailing_space = "" if query is not None: query_toks, query_mask = self.binarize_with_mask( - query, prefix, suffix, leading_space, trailing_space, + query, + prefix, + suffix, + leading_space, + trailing_space, ) query_len = len(query_toks) else: @@ -322,7 +342,11 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl query_lengths.append(query_len) cand_toks, cand_mask = self.binarize_with_mask( - cand_text, prefix, suffix, leading_space, trailing_space, + cand_text, + prefix, + suffix, + leading_space, + trailing_space, ) candidate_tokens.append(cand_toks) @@ -342,17 +366,19 @@ def get_pad_dataset_fn(tokens, length, pad_idx): query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) candidate_lengths = np.array(candidate_lengths) - candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) + candidate_tokens = get_pad_dataset_fn( + candidate_tokens, candidate_lengths, self.vocab.pad() + ) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) dataset = { - 'id': IdDataset(), - 'query_tokens': query_tokens, - 'query_masks': query_masks, - 'candidate_tokens': candidate_tokens, - 'candidate_masks': candidate_masks, - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(query_tokens, reduce=True), + "id": IdDataset(), + "query_tokens": query_tokens, + "query_masks": query_masks, + "candidate_tokens": candidate_tokens, + "candidate_masks": candidate_masks, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( diff --git a/examples/roberta/wsc/wsc_utils.py b/examples/roberta/wsc/wsc_utils.py index 2d4822479e..da6ba74383 100644 --- a/examples/roberta/wsc/wsc_utils.py +++ b/examples/roberta/wsc/wsc_utils.py @@ -3,48 +3,48 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from functools import lru_cache import json +from functools import lru_cache def convert_sentence_to_json(sentence): - if '_' in sentence: - prefix, rest = sentence.split('_', 1) - query, rest = rest.split('_', 1) - query_index = len(prefix.rstrip().split(' ')) + if "_" in sentence: + prefix, rest = sentence.split("_", 1) + query, rest = rest.split("_", 1) + query_index = len(prefix.rstrip().split(" ")) else: query, query_index = None, None - prefix, rest = sentence.split('[', 1) - pronoun, rest = rest.split(']', 1) - pronoun_index = len(prefix.rstrip().split(' ')) + prefix, rest = sentence.split("[", 1) + pronoun, rest = rest.split("]", 1) + pronoun_index = len(prefix.rstrip().split(" ")) - sentence = sentence.replace('_', '').replace('[', '').replace(']', '') + sentence = sentence.replace("_", "").replace("[", "").replace("]", "") return { - 'idx': 0, - 'text': sentence, - 'target': { - 'span1_index': query_index, - 'span1_text': query, - 'span2_index': pronoun_index, - 'span2_text': pronoun, + "idx": 0, + "text": sentence, + "target": { + "span1_index": query_index, + "span1_text": query, + "span2_index": pronoun_index, + "span2_text": pronoun, }, } def extended_noun_chunks(sentence): noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} - np_start, cur_np = 0, 'NONE' + np_start, cur_np = 0, "NONE" for i, token in enumerate(sentence): - np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE' + np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE" if np_type != cur_np: - if cur_np != 'NONE': + if cur_np != "NONE": noun_chunks.add((np_start, i)) - if np_type != 'NONE': + if np_type != "NONE": np_start = i cur_np = np_type - if cur_np != 'NONE': + if cur_np != "NONE": noun_chunks.add((np_start, len(sentence))) return [sentence[s:e] for (s, e) in sorted(noun_chunks)] @@ -61,14 +61,14 @@ def find_token(sentence, start_pos): def find_span(sentence, search_text, start=0): search_text = search_text.lower() for tok in sentence[start:]: - remainder = sentence[tok.i:].text.lower() + remainder = sentence[tok.i :].text.lower() if remainder.startswith(search_text): len_to_consume = len(search_text) start_idx = tok.idx - for next_tok in sentence[tok.i:]: + for next_tok in sentence[tok.i :]: end_idx = next_tok.idx + len(next_tok.text) if end_idx - start_idx == len_to_consume: - span = sentence[tok.i:next_tok.i + 1] + span = sentence[tok.i : next_tok.i + 1] return span return None @@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0): @lru_cache(maxsize=1) def get_detokenizer(): from sacremoses import MosesDetokenizer - detok = MosesDetokenizer(lang='en') + + detok = MosesDetokenizer(lang="en") return detok @lru_cache(maxsize=1) def get_spacy_nlp(): import en_core_web_lg + nlp = en_core_web_lg.load() return nlp @@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): for line in fin: sample = json.loads(line.strip()) - if positive_only and 'label' in sample and not sample['label']: + if positive_only and "label" in sample and not sample["label"]: # only consider examples where the query is correct continue - target = sample['target'] + target = sample["target"] # clean up the query - query = target['span1_text'] + query = target["span1_text"] if query is not None: - if '\n' in query: + if "\n" in query: continue - if query.endswith('.') or query.endswith(','): + if query.endswith(".") or query.endswith(","): query = query[:-1] # split tokens - tokens = sample['text'].split(' ') + tokens = sample["text"].split(" ") def strip_pronoun(x): return x.rstrip('.,"') # find the pronoun - pronoun_idx = target['span2_index'] - pronoun = strip_pronoun(target['span2_text']) + pronoun_idx = target["span2_index"] + pronoun = strip_pronoun(target["span2_text"]) if strip_pronoun(tokens[pronoun_idx]) != pronoun: # hack: sometimes the index is misaligned if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: pronoun_idx += 1 else: - raise Exception('Misaligned pronoun!') + raise Exception("Misaligned pronoun!") assert strip_pronoun(tokens[pronoun_idx]) == pronoun # split tokens before and after the pronoun before = tokens[:pronoun_idx] - after = tokens[pronoun_idx + 1:] + after = tokens[pronoun_idx + 1 :] # the GPT BPE attaches leading spaces to tokens, so we keep track # of whether we need spaces before or after the pronoun - leading_space = ' ' if pronoun_idx > 0 else '' - trailing_space = ' ' if len(after) > 0 else '' + leading_space = " " if pronoun_idx > 0 else "" + trailing_space = " " if len(after) > 0 else "" # detokenize before = detok.detokenize(before, return_str=True) @@ -142,14 +144,14 @@ def strip_pronoun(x): # hack: when the pronoun ends in a period (or comma), move the # punctuation to the "after" part - if pronoun.endswith('.') or pronoun.endswith(','): + if pronoun.endswith(".") or pronoun.endswith(","): after = pronoun[-1] + trailing_space + after pronoun = pronoun[:-1] # hack: when the "after" part begins with a comma or period, remove # the trailing space - if after.startswith('.') or after.startswith(','): - trailing_space = '' + if after.startswith(".") or after.startswith(","): + trailing_space = "" # parse sentence with spacy sentence = nlp(before + leading_space + pronoun + trailing_space + after) @@ -164,13 +166,13 @@ def strip_pronoun(x): # convert to format where pronoun is surrounded by "[]" and # query is surrounded by "_" query_span = find_span(sentence, query) - query_with_ws = '_{}_{}'.format( + query_with_ws = "_{}_{}".format( query_span.text, - (' ' if query_span.text_with_ws.endswith(' ') else '') + (" " if query_span.text_with_ws.endswith(" ") else ""), ) - pronoun_with_ws = '[{}]{}'.format( + pronoun_with_ws = "[{}]{}".format( pronoun_span.text, - (' ' if pronoun_span.text_with_ws.endswith(' ') else '') + (" " if pronoun_span.text_with_ws.endswith(" ") else ""), ) if query_span.start < pronoun_span.start: first = (query_span, query_with_ws) @@ -179,41 +181,45 @@ def strip_pronoun(x): first = (pronoun_span, pronoun_with_ws) second = (query_span, query_with_ws) sentence = ( - sentence[:first[0].start].text_with_ws + sentence[: first[0].start].text_with_ws + first[1] - + sentence[first[0].end:second[0].start].text_with_ws + + sentence[first[0].end : second[0].start].text_with_ws + second[1] - + sentence[second[0].end:].text + + sentence[second[0].end :].text ) - yield sentence, sample.get('label', None) + yield sentence, sample.get("label", None) else: - yield sentence, pronoun_span, query, sample.get('label', None) + yield sentence, pronoun_span, query, sample.get("label", None) def winogrande_jsonl_iterator(input_fname, eval=False): with open(input_fname) as fin: for line in fin: sample = json.loads(line.strip()) - sentence, option1, option2 = sample['sentence'], sample['option1'],\ - sample['option2'] + sentence, option1, option2 = ( + sample["sentence"], + sample["option1"], + sample["option2"], + ) - pronoun_span = (sentence.index('_'), sentence.index('_') + 1) + pronoun_span = (sentence.index("_"), sentence.index("_") + 1) if eval: query, cand = option1, option2 else: - query = option1 if sample['answer'] == '1' else option2 - cand = option2 if sample['answer'] == '1' else option1 + query = option1 if sample["answer"] == "1" else option2 + cand = option2 if sample["answer"] == "1" else option1 yield sentence, pronoun_span, query, cand -def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False): +def filter_noun_chunks( + chunks, exclude_pronouns=False, exclude_query=None, exact_match=False +): if exclude_pronouns: chunks = [ - np for np in chunks if ( - np.lemma_ != '-PRON-' - and not all(tok.pos_ == 'PRON' for tok in np) - ) + np + for np in chunks + if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np)) ] if exclude_query is not None: @@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact found = False for excl in excl_txt: if ( - (not exact_match and (lower_chunk in excl or excl in lower_chunk)) - or lower_chunk == excl - ): + not exact_match and (lower_chunk in excl or excl in lower_chunk) + ) or lower_chunk == excl: found = True break if not found: diff --git a/examples/simultaneous_translation/__init__.py b/examples/simultaneous_translation/__init__.py index e6963d6d1b..446fc86c8a 100644 --- a/examples/simultaneous_translation/__init__.py +++ b/examples/simultaneous_translation/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import criterions, models, eval # noqa +from . import criterions, eval, models # noqa diff --git a/examples/simultaneous_translation/criterions/__init__.py b/examples/simultaneous_translation/criterions/__init__.py index 84dc80ad95..08791bfff3 100644 --- a/examples/simultaneous_translation/criterions/__init__.py +++ b/examples/simultaneous_translation/criterions/__init__.py @@ -6,6 +6,7 @@ import importlib import os + for file in os.listdir(os.path.dirname(__file__)): if file.endswith(".py") and not file.startswith("_"): criterion_name = file[: file.find(".py")] diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py index d4d544ec5f..b3c8f6d53f 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -3,21 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from examples.simultaneous_translation.utils.latency import LatencyTraining from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( - LabelSmoothedCrossEntropyCriterion -) - -from examples.simultaneous_translation.utils.latency import ( - LatencyTraining + LabelSmoothedCrossEntropyCriterion, ) -@register_criterion('latency_augmented_label_smoothed_cross_entropy') +@register_criterion("latency_augmented_label_smoothed_cross_entropy") class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion ): - def __init__(self, args, task): super().__init__(args, task) self.eps = args.label_smoothing @@ -40,7 +36,7 @@ def __init__(self, args, task): def add_args(parser): super( LatencyAugmentedLabelSmoothedCrossEntropyCriterion, - LatencyAugmentedLabelSmoothedCrossEntropyCriterion + LatencyAugmentedLabelSmoothedCrossEntropyCriterion, ).add_args(parser) """Add criterion-specific arguments to the parser.""" # fmt: off @@ -69,7 +65,8 @@ def compute_loss(self, model, net_output, sample, reduce=True): # Get latency loss latency_loss = self.latency_train.loss( - attn_list, source_padding_mask, target_padding_mask) + attn_list, source_padding_mask, target_padding_mask + ) loss += latency_loss diff --git a/examples/simultaneous_translation/eval/agents/__init__.py b/examples/simultaneous_translation/eval/agents/__init__.py index 1c23fc1ad9..511e7b2474 100644 --- a/examples/simultaneous_translation/eval/agents/__init__.py +++ b/examples/simultaneous_translation/eval/agents/__init__.py @@ -5,16 +5,20 @@ import importlib import os + from fairseq import registry -build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type') + +build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry( + "--agent-type" +) -DEFAULT_EOS = '' +DEFAULT_EOS = "" GET = 0 SEND = 1 for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('agents.' + module) + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("agents." + module) diff --git a/examples/simultaneous_translation/eval/agents/agent.py b/examples/simultaneous_translation/eval/agents/agent.py index 1977a24dd9..997392cf9b 100644 --- a/examples/simultaneous_translation/eval/agents/agent.py +++ b/examples/simultaneous_translation/eval/agents/agent.py @@ -3,14 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import GET, SEND, DEFAULT_EOS import time -from multiprocessing.pool import ThreadPool as Pool from functools import partial +from multiprocessing.pool import ThreadPool as Pool + +from . import DEFAULT_EOS, GET, SEND class Agent(object): "an agent needs to follow this pattern" + def __init__(self, *args, **kwargs): pass @@ -40,26 +42,26 @@ def decode(self, session, low=0, high=100000, num_thread=10): with Pool(10) as p: p.map( partial(self._decode_one, session), - [sent_id for sent_id in range(low, high + 1)] + [sent_id for sent_id in range(low, high + 1)], ) else: for sent_id in range(low, high + 1): self._decode_one(session, sent_id) - print(f'Finished {low} to {high} in {time.time() - t0}s') + print(f"Finished {low} to {high} in {time.time() - t0}s") def _decode_one(self, session, sent_id): action = {} self.reset() states = self.init_states() - while action.get('value', None) != DEFAULT_EOS: + while action.get("value", None) != DEFAULT_EOS: # take an action action = self.policy(states) - if action['key'] == GET: + if action["key"] == GET: new_states = session.get_src(sent_id, action["value"]) states = self.update_states(states, new_states) - elif action['key'] == SEND: - session.send_hypo(sent_id, action['value']) + elif action["key"] == SEND: + session.send_hypo(sent_id, action["value"]) print(" ".join(states["tokens"]["tgt"])) diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py index 1b6960c5fa..071b9e89ce 100644 --- a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py +++ b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py @@ -3,11 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . agent import Agent -from . import DEFAULT_EOS, GET, SEND -from fairseq import checkpoint_utils, utils, tasks -import os import json +import os + +from fairseq import checkpoint_utils, tasks, utils + +from . import DEFAULT_EOS, GET, SEND +from .agent import Agent class SimulTransAgent(Agent): @@ -51,13 +53,15 @@ def load_dictionary(self, task): raise NotImplementedError def load_model(self, args): - args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..') + args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") utils.import_user_module(args) filename = args.model_path if not os.path.exists(filename): raise IOError("Model file not found: {}".format(filename)) - state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides)) + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, json.loads(args.model_overrides) + ) saved_args = state["args"] saved_args.data = args.data_bin @@ -79,7 +83,7 @@ def init_states(self): "steps": {"src": 0, "tgt": 0}, "finished": False, "finish_read": False, - "model_states": {} + "model_states": {}, } def update_states(self, states, new_state): @@ -115,38 +119,38 @@ def finish_read(self, states): def write_action(self, states): token, index = self.model.predict_from_states(states) - if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len: + if ( + index == self.dict["tgt"].eos() + or len(states["tokens"]["tgt"]) > self.max_len + ): # Finish this sentence is predict EOS states["finished"] = True end_idx_last_full_word = self._target_length(states) else: states["tokens"]["tgt"] += [token] - end_idx_last_full_word = ( - self.word_splitter["tgt"] - .end_idx_last_full_word(states["tokens"]["tgt"]) + end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( + states["tokens"]["tgt"] ) self._append_indices(states, [index], "tgt") if end_idx_last_full_word > states["steps"]["tgt"]: # Only sent detokenized full words to the server word = self.word_splitter["tgt"].merge( - states["tokens"]["tgt"][ - states["steps"]["tgt"]: end_idx_last_full_word - ] + states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] ) states["steps"]["tgt"] = end_idx_last_full_word states["segments"]["tgt"] += [word] - return {'key': SEND, 'value': word} + return {"key": SEND, "value": word} else: return None def read_action(self, states): - return {'key': GET, 'value': None} + return {"key": GET, "value": None} def finish_action(self): - return {'key': SEND, 'value': DEFAULT_EOS} + return {"key": SEND, "value": DEFAULT_EOS} def reset(self): pass @@ -160,4 +164,4 @@ def _append_indices(self, states, new_indices, key): states["indices"][key] += new_indices def _target_length(self, states): - return len(states["tokens"]['tgt']) + return len(states["tokens"]["tgt"]) diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py index 65f7cbd313..7c34817bf6 100644 --- a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py +++ b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py @@ -3,10 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . simul_trans_agent import SimulTransAgent -from . import DEFAULT_EOS, GET -from . import register_agent -from . word_splitter import SPLITTER_DICT +from . import DEFAULT_EOS, GET, register_agent +from .simul_trans_agent import SimulTransAgent +from .word_splitter import SPLITTER_DICT @register_agent("simul_trans_text") @@ -15,11 +14,11 @@ def build_word_splitter(self, args): self.word_splitter = {} self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type]( - getattr(args, f"src_splitter_path") - ) + getattr(args, f"src_splitter_path") + ) self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type]( - getattr(args, f"tgt_splitter_path") - ) + getattr(args, f"tgt_splitter_path") + ) def load_dictionary(self, task): self.dict = {} @@ -37,12 +36,16 @@ def update_states(self, states, new_state): tokens = self.word_splitter["src"].split(new_word) # Get indices from dictionary # You can change to you own dictionary - indices = self.dict["src"].encode_line( - tokens, - line_tokenizer=lambda x: x, - add_if_not_exist=False, - append_eos=False - ).tolist() + indices = ( + self.dict["src"] + .encode_line( + tokens, + line_tokenizer=lambda x: x, + add_if_not_exist=False, + append_eos=False, + ) + .tolist() + ) else: tokens = [new_word] indices = [self.dict["src"].eos()] @@ -61,11 +64,11 @@ def read_action(self, states): # At leat one word is read if len(states["tokens"]["src"]) == 0: - return {'key': GET, 'value': None} + return {"key": GET, "value": None} # Only request new word if there is no buffered tokens if len(states["tokens"]["src"]) <= states["steps"]["src"]: - return {'key': GET, 'value': None} + return {"key": GET, "value": None} return None diff --git a/examples/simultaneous_translation/eval/agents/word_splitter.py b/examples/simultaneous_translation/eval/agents/word_splitter.py index ea564f21ee..c3f71200a5 100644 --- a/examples/simultaneous_translation/eval/agents/word_splitter.py +++ b/examples/simultaneous_translation/eval/agents/word_splitter.py @@ -40,6 +40,7 @@ class BPEWordSplitter(object): def __init__(self, model_path): super().__init__() from subword_nmt.apply_bpe import BPE + with open(model_path) as f: self.model = BPE(f) @@ -48,7 +49,7 @@ def split(self, string): def end_idx_last_full_word(self, tokens): # Begin of word indices - bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@'] + bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"] if len(bow_indices) < 2: return 0 @@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object): def __init__(self, model_path): super().__init__() import sentencepiece as spm + self.model = spm.SentencePieceProcessor() self.model.Load(model_path) @@ -71,7 +73,7 @@ def split(self, string): def end_idx_last_full_word(self, tokens): # Begin of word indices - bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581'] + bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"] if len(bow_indices) < 2: return 0 diff --git a/examples/simultaneous_translation/eval/client.py b/examples/simultaneous_translation/eval/client.py index 5cbaa71d31..3ca4ea73b8 100644 --- a/examples/simultaneous_translation/eval/client.py +++ b/examples/simultaneous_translation/eval/client.py @@ -3,19 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import requests from typing import Optional + +import requests from scorers import build_scorer class SimulSTEvaluationService(object): - DEFAULT_HOSTNAME = 'localhost' + DEFAULT_HOSTNAME = "localhost" DEFAULT_PORT = 12321 def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT): self.hostname = hostname self.port = port - self.base_url = f'http://{self.hostname}:{self.port}' + self.base_url = f"http://{self.hostname}:{self.port}" def __enter__(self): self.new_session() @@ -25,56 +26,53 @@ def __exit__(self, exc_type, exc_val, exc_tb): def new_session(self): # start eval session - url = f'{self.base_url}' + url = f"{self.base_url}" try: _ = requests.post(url) except Exception as e: - print(f'Failed to start an evaluation session: {e}') + print(f"Failed to start an evaluation session: {e}") - print('Evaluation session started.') + print("Evaluation session started.") return self def get_scores(self): # end eval session - url = f'{self.base_url}/result' + url = f"{self.base_url}/result" try: r = requests.get(url) - print('Scores: {}'.format(r.json())) - print('Evaluation session finished.') + print("Scores: {}".format(r.json())) + print("Evaluation session finished.") except Exception as e: - print(f'Failed to end an evaluation session: {e}') + print(f"Failed to end an evaluation session: {e}") def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: - url = f'{self.base_url}/src' + url = f"{self.base_url}/src" params = {"sent_id": sent_id} if extra_params is not None: for key in extra_params.keys(): params[key] = extra_params[key] try: - r = requests.get( - url, - params=params - ) + r = requests.get(url, params=params) except Exception as e: - print(f'Failed to request a source segment: {e}') + print(f"Failed to request a source segment: {e}") return r.json() def send_hypo(self, sent_id: int, hypo: str) -> None: - url = f'{self.base_url}/hypo' + url = f"{self.base_url}/hypo" params = {"sent_id": sent_id} try: requests.put(url, params=params, data=hypo.encode("utf-8")) except Exception as e: - print(f'Failed to send a translated segment: {e}') + print(f"Failed to send a translated segment: {e}") def corpus_info(self): - url = f'{self.base_url}' + url = f"{self.base_url}" try: r = requests.get(url) except Exception as e: - print(f'Failed to request corpus information: {e}') + print(f"Failed to request corpus information: {e}") return r.json() diff --git a/examples/simultaneous_translation/eval/eval_latency.py b/examples/simultaneous_translation/eval/eval_latency.py index 12cfaa4ed1..50021de47c 100644 --- a/examples/simultaneous_translation/eval/eval_latency.py +++ b/examples/simultaneous_translation/eval/eval_latency.py @@ -3,20 +3,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from examples.simultaneous_translation.utils.latency import LatencyInference import argparse -import torch import json +import torch +from examples.simultaneous_translation.utils.latency import LatencyInference + LATENCY_METRICS = [ - 'differentiable_average_lagging', - 'average_lagging', - 'average_proportion', + "differentiable_average_lagging", + "average_lagging", + "average_proportion", ] -class LatencyScorer(): +class LatencyScorer: def __init__(self, start_from_zero=True): self.recorder = [] self.scores = {} @@ -26,10 +27,7 @@ def __init__(self, start_from_zero=True): def update_reorder(self, list_of_dict): self.recorder = [] for info in list_of_dict: - delays = [ - int(x) - int(not self.start_from_zero) - for x in info["delays"] - ] + delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]] delays = torch.LongTensor(delays).unsqueeze(0) src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) @@ -59,7 +57,7 @@ def score(cls, list_of_dict, start_from_zero=True): scorer = LatencyInference() recorder = [] - with open(args.input, 'r') as f: + with open(args.input, "r") as f: for line in f: info = json.loads(line) @@ -74,7 +72,7 @@ def score(cls, list_of_dict, start_from_zero=True): average_results = {} for metric in LATENCY_METRICS: - average_results[metric] = sum( - [x[metric][0, 0].item() for x in recorder] - ) / len(recorder) + average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len( + recorder + ) print(f"{metric}: {average_results[metric]}") diff --git a/examples/simultaneous_translation/eval/evaluate.py b/examples/simultaneous_translation/eval/evaluate.py index 07f93e7fb0..2f7474621a 100644 --- a/examples/simultaneous_translation/eval/evaluate.py +++ b/examples/simultaneous_translation/eval/evaluate.py @@ -5,37 +5,48 @@ import argparse +from agents import build_agent from client import SimulSTEvaluationService, SimulSTLocalEvaluationService from fairseq.registry import REGISTRIES -from agents import build_agent -DEFAULT_HOSTNAME = 'localhost' + +DEFAULT_HOSTNAME = "localhost" DEFAULT_PORT = 12321 def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME, - help='server hostname') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, - help='server port number') - parser.add_argument('--agent-type', default='simul_trans_text', - help='Agent type') - parser.add_argument('--scorer-type', default='text', - help='Scorer type') - parser.add_argument('--start-idx', type=int, default=0, - help='Start index of the sentence to evaluate') - parser.add_argument('--end-idx', type=int, default=float('inf'), - help='End index of the sentence to evaluate') - parser.add_argument('--scores', action="store_true", - help='Request scores from server') - parser.add_argument('--reset-server', action="store_true", - help='Reset the server') - parser.add_argument('--num-threads', type=int, default=10, - help='Number of threads used by agent') - parser.add_argument('--local', action="store_true", default=False, - help='Local evaluation') + parser.add_argument( + "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" + ) + parser.add_argument( + "--port", type=int, default=DEFAULT_PORT, help="server port number" + ) + parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") + parser.add_argument("--scorer-type", default="text", help="Scorer type") + parser.add_argument( + "--start-idx", + type=int, + default=0, + help="Start index of the sentence to evaluate", + ) + parser.add_argument( + "--end-idx", + type=int, + default=float("inf"), + help="End index of the sentence to evaluate", + ) + parser.add_argument( + "--scores", action="store_true", help="Request scores from server" + ) + parser.add_argument("--reset-server", action="store_true", help="Reset the server") + parser.add_argument( + "--num-threads", type=int, default=10, help="Number of threads used by agent" + ) + parser.add_argument( + "--local", action="store_true", default=False, help="Local evaluation" + ) args, _ = parser.parse_known_args() diff --git a/examples/simultaneous_translation/eval/scorers/__init__.py b/examples/simultaneous_translation/eval/scorers/__init__.py index c7fbb5495d..0a0e0a0518 100644 --- a/examples/simultaneous_translation/eval/scorers/__init__.py +++ b/examples/simultaneous_translation/eval/scorers/__init__.py @@ -5,15 +5,15 @@ import importlib import os + from fairseq import registry -( - build_scorer, - register_scorer, - SCORER_REGISTRIES, - _ -) = registry.setup_registry('--scorer-type') + + +(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry( + "--scorer-type" +) for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('scorers.' + module) + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("scorers." + module) diff --git a/examples/simultaneous_translation/eval/scorers/scorer.py b/examples/simultaneous_translation/eval/scorers/scorer.py index d16f130e75..d6d3e30aef 100644 --- a/examples/simultaneous_translation/eval/scorers/scorer.py +++ b/examples/simultaneous_translation/eval/scorers/scorer.py @@ -3,16 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from vizseq.scorers.bleu import BLEUScorer -from vizseq.scorers.ter import TERScorer -from vizseq.scorers.meteor import METEORScorer -from examples.simultaneous_translation.eval.eval_latency import LatencyScorer -from collections import defaultdict import json import os +from collections import defaultdict + +from examples.simultaneous_translation.eval.eval_latency import LatencyScorer +from vizseq.scorers.bleu import BLEUScorer +from vizseq.scorers.meteor import METEORScorer +from vizseq.scorers.ter import TERScorer -DEFAULT_EOS = '' +DEFAULT_EOS = "" class SimulScorer(object): @@ -23,7 +24,7 @@ def __init__(self, args): self.output_files = { "text": os.path.join(args.output, "text"), "delay": os.path.join(args.output, "delay"), - "scores": os.path.join(args.output, "scores") + "scores": os.path.join(args.output, "scores"), } else: self.output_files = None @@ -52,14 +53,7 @@ def send_src(self, sent_id, *args): def recv_hyp(self, sent_id, list_of_tokens): for token in list_of_tokens: - self.translations[ - sent_id - ].append( - ( - token, - self.steps[sent_id] - ) - ) + self.translations[sent_id].append((token, self.steps[sent_id])) def reset(self): self.steps = defaultdict(int) @@ -76,8 +70,9 @@ def score(self): delays += [[t[1] for t in self.translations[i]]] bleu_score = BLEUScorer( - sent_level=False, corpus_level=True, - extra_args={'bleu_tokenizer': self.tokenizer} + sent_level=False, + corpus_level=True, + extra_args={"bleu_tokenizer": self.tokenizer}, ).score(translations, [self.data["tgt"]]) ter_score = TERScorer(sent_level=False, corpus_level=True).score( @@ -92,16 +87,16 @@ def score(self): {"src_len": src_len, "delays": delay} for src_len, delay in zip(self.src_lengths(), delays) ], - start_from_zero=False + start_from_zero=False, ) scores = { - 'BLEU': bleu_score[0], - 'TER': ter_score[0], - 'METEOR': meteor_score[0], - 'DAL': latency_score['differentiable_average_lagging'], - 'AL': latency_score['average_lagging'], - 'AP': latency_score['average_proportion'], + "BLEU": bleu_score[0], + "TER": ter_score[0], + "METEOR": meteor_score[0], + "DAL": latency_score["differentiable_average_lagging"], + "AL": latency_score["average_lagging"], + "AP": latency_score["average_proportion"], } if self.output_files is not None: @@ -109,9 +104,9 @@ def score(self): os.makedirs(self.output_dir, exist_ok=True) self.write_results_to_file(translations, delays, scores) except BaseException as be: - print(f'Failed to write results to {self.output_dir}.') + print(f"Failed to write results to {self.output_dir}.") print(be) - print('Skip writing predictions') + print("Skip writing predictions") return scores @@ -125,12 +120,8 @@ def write_results_to_file(self, translations, delays, scores): with open(self.output_files["delay"], "w") as f: for i, delay in enumerate(delays): f.write( - json.dumps( - { - "src_len": self.src_lengths()[i], - "delays": delay - } - ) + "\n" + json.dumps({"src_len": self.src_lengths()[i], "delays": delay}) + + "\n" ) with open(self.output_files["scores"], "w") as f: @@ -163,7 +154,7 @@ def _load_wav_info_from_json(cls, file): list_to_return.append( { "path": item["input"]["path"].strip(), - "length": item["input"]["length_ms"] + "length": item["input"]["length_ms"], } ) return list_to_return diff --git a/examples/simultaneous_translation/eval/scorers/text_scorer.py b/examples/simultaneous_translation/eval/scorers/text_scorer.py index 4a5daaff21..649a2c7e5c 100644 --- a/examples/simultaneous_translation/eval/scorers/text_scorer.py +++ b/examples/simultaneous_translation/eval/scorers/text_scorer.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . scorer import SimulScorer from . import register_scorer +from .scorer import SimulScorer @register_scorer("text") @@ -13,7 +13,7 @@ def __init__(self, args): super().__init__(args) self.data = { "src": self._load_text_file(args.src_file, split=True), - "tgt": self._load_text_file(args.tgt_file, split=False) + "tgt": self._load_text_file(args.tgt_file, split=False), } def send_src(self, sent_id, *args): @@ -21,7 +21,7 @@ def send_src(self, sent_id, *args): dict_to_return = { "sent_id": sent_id, "segment_id": self.steps[sent_id], - "segment": self.eos + "segment": self.eos, } # Consider EOS self.steps[sent_id] = len(self.data["src"][sent_id]) + 1 @@ -29,7 +29,7 @@ def send_src(self, sent_id, *args): dict_to_return = { "sent_id": sent_id, "segment_id": self.steps[sent_id], - "segment": self.data["src"][sent_id][self.steps[sent_id]] + "segment": self.data["src"][sent_id][self.steps[sent_id]], } self.steps[sent_id] += 1 diff --git a/examples/simultaneous_translation/eval/server.py b/examples/simultaneous_translation/eval/server.py index a108881e38..e44ceaff85 100644 --- a/examples/simultaneous_translation/eval/server.py +++ b/examples/simultaneous_translation/eval/server.py @@ -3,12 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse -import sys import json -from tornado import web, ioloop +import sys + from scorers import build_scorer +from tornado import ioloop, web + -DEFAULT_HOSTNAME = 'localhost' +DEFAULT_HOSTNAME = "localhost" DEFAULT_PORT = 12321 @@ -34,10 +36,10 @@ def get(self): class SourceHandler(ScorerHandler): def get(self): - sent_id = int(self.get_argument('sent_id')) + sent_id = int(self.get_argument("sent_id")) segment_size = None if "segment_size" in self.request.arguments: - string = self.get_argument('segment_size') + string = self.get_argument("segment_size") if len(string) > 0: segment_size = int(string) @@ -48,8 +50,8 @@ def get(self): class HypothesisHandler(ScorerHandler): def put(self): - sent_id = int(self.get_argument('sent_id')) - list_of_tokens = self.request.body.decode('utf-8').strip().split() + sent_id = int(self.get_argument("sent_id")) + list_of_tokens = self.request.body.decode("utf-8").strip().split() self.scorer.recv_hyp(sent_id, list_of_tokens) @@ -67,18 +69,21 @@ def add_args(): def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False): - app = web.Application([ - (r'/result', ResultHandler, dict(scorer=scorer)), - (r'/src', SourceHandler, dict(scorer=scorer)), - (r'/hypo', HypothesisHandler, dict(scorer=scorer)), - (r'/', EvalSessionHandler, dict(scorer=scorer)), - ], debug=debug) + app = web.Application( + [ + (r"/result", ResultHandler, dict(scorer=scorer)), + (r"/src", SourceHandler, dict(scorer=scorer)), + (r"/hypo", HypothesisHandler, dict(scorer=scorer)), + (r"/", EvalSessionHandler, dict(scorer=scorer)), + ], + debug=debug, + ) app.listen(port, max_buffer_size=1024 ** 3) sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n") ioloop.IOLoop.current().start() -if __name__ == '__main__': +if __name__ == "__main__": args = add_args() scorer = build_scorer(args) start_server(scorer, args.hostname, args.port, args.debug) diff --git a/examples/simultaneous_translation/models/__init__.py b/examples/simultaneous_translation/models/__init__.py index 138006ed8c..083da43732 100644 --- a/examples/simultaneous_translation/models/__init__.py +++ b/examples/simultaneous_translation/models/__init__.py @@ -6,7 +6,10 @@ import importlib import os + for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - model_name = file[:file.find('.py')] - importlib.import_module('examples.simultaneous_translation.models.' + model_name) + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module( + "examples.simultaneous_translation.models." + model_name + ) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index 759f195386..ab8adf3aab 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -6,42 +6,34 @@ import torch import torch.nn as nn import torch.nn.functional as F - -from fairseq.models import ( - register_model, - register_model_architecture, +from examples.simultaneous_translation.modules.monotonic_transformer_layer import ( + TransformerMonotonicDecoderLayer, + TransformerMonotonicEncoderLayer, ) - - +from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import ( - TransformerModel, - TransformerEncoder, TransformerDecoder, + TransformerEncoder, + TransformerModel, base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, ) -from examples.simultaneous_translation.modules.monotonic_transformer_layer import ( - TransformerMonotonicDecoderLayer, - TransformerMonotonicEncoderLayer -) - DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 -@register_model('transformer_unidirectional') +@register_model("transformer_unidirectional") class TransformerUnidirectionalModel(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) -@register_model('transformer_monotonic') +@register_model("transformer_monotonic") class TransformerMonotonicModel(TransformerModel): - @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @@ -62,26 +54,17 @@ def _indices_from_states(self, states): ) tgt_indices = tensor( - [ - [self.decoder.dictionary.eos()] - + states["indices"]["tgt"] - ] + [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]] ) else: - src_indices = states["indices"]["src"][: 1 + - states["steps"]["src"]] + src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]] tgt_indices = states["indices"]["tgt"] return src_indices, None, tgt_indices def predict_from_states(self, states): - decoder_states = self.decoder.output_layer( - states["decoder_features"] - ) - lprobs = self.get_normalized_probs( - [decoder_states[:, -1:]], - log_probs=True - ) + decoder_states = self.decoder.output_layer(states["decoder_features"]) + lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True) index = lprobs.argmax(dim=-1) @@ -90,25 +73,24 @@ def predict_from_states(self, states): return token, index[0, 0].item() def decision_from_states(self, states): - ''' + """ This funcion take states dictionary as input, and gives the agent a decision of whether read a token from server. Moreover, the decoder states are also calculated here so we can directly generate a target token without recompute every thing - ''' + """ self.eval() if len(states["tokens"]["src"]) == 0: return 0 - src_indices, src_lengths, tgt_indices = self._indices_from_states( - states) + src_indices, src_lengths, tgt_indices = self._indices_from_states(states) # Update encoder states if needed if ( - "encoder_states" not in states or - states["encoder_states"][0].size(1) <= states["steps"]["src"] + "encoder_states" not in states + or states["encoder_states"][0].size(1) <= states["steps"]["src"] ): encoder_out_dict = self.encoder(src_indices, src_lengths) states["encoder_states"] = encoder_out_dict @@ -136,16 +118,14 @@ def decision_from_states(self, states): class TransformerMonotonicEncoder(TransformerEncoder): - def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) self.dictionary = dictionary self.layers = nn.ModuleList([]) - self.layers.extend([ - TransformerMonotonicEncoderLayer(args) - for i in range(args.encoder_layers) - ]) + self.layers.extend( + [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)] + ) class TransformerMonotonicDecoder(TransformerDecoder): @@ -166,19 +146,24 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.dictionary = dictionary self.layers = nn.ModuleList([]) - self.layers.extend([ - TransformerMonotonicDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ]) + self.layers.extend( + [ + TransformerMonotonicDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) def pre_attention( - self, prev_output_tokens, encoder_out_dict, - incremental_state=None + self, prev_output_tokens, encoder_out_dict, incremental_state=None ): - positions = self.embed_positions( - prev_output_tokens, - incremental_state=incremental_state, - ) if self.embed_positions is not None else None + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] @@ -216,8 +201,7 @@ def post_attention(self, x): return x def extract_features( - self, prev_output_tokens, encoder_out, - incremental_state=None, **unused + self, prev_output_tokens, encoder_out, incremental_state=None, **unused ): """ Similar to *forward* but only return features. @@ -228,14 +212,8 @@ def extract_features( - a dictionary with any model-specific outputs """ # incremental_state = None - ( - x, - encoder_outs, - encoder_padding_mask - ) = self.pre_attention( - prev_output_tokens, - encoder_out, - incremental_state + (x, encoder_outs, encoder_padding_mask) = self.pre_attention( + prev_output_tokens, encoder_out, incremental_state ) attn = None inner_states = [x] @@ -250,7 +228,8 @@ def extract_features( encoder_padding_mask=encoder_padding_mask, incremental_state=incremental_state, self_attn_mask=self.buffered_future_mask(x) - if incremental_state is None else None, + if incremental_state is None + else None, ) inner_states.append(x) @@ -261,38 +240,30 @@ def extract_features( step_list.append(curr_steps) if incremental_state.get("online", False): - p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) - - new_steps = ( - curr_steps - + (p_choose < 0.5).t().type_as(curr_steps) + p_choose = ( + attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) ) + new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) + if (new_steps >= incremental_state["steps"]["src"]).any(): # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state for j in range(i + 1): - self.layers[j].prune_incremental_state( - incremental_state) + self.layers[j].prune_incremental_state(incremental_state) return x, {"action": 0} - if ( - incremental_state is not None - and not incremental_state.get("online", False) - ): + if incremental_state is not None and not incremental_state.get("online", False): # Here is for fast evaluation - fastest_step = torch.max( - torch.cat(step_list, dim=1), - dim=1, - keepdim=True - )[0] + 1 + fastest_step = ( + torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1 + ) if "fastest_step" in incremental_state: incremental_state["fastest_step"] = torch.cat( - [incremental_state["fastest_step"], fastest_step], - dim=1 + [incremental_state["fastest_step"], fastest_step], dim=1 ) else: incremental_state["fastest_step"] = fastest_step @@ -310,25 +281,19 @@ def extract_features( def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) if "fastest_step" in incremental_state: - incremental_state["fastest_step"] = ( - incremental_state["fastest_step"] - .index_select(0, new_order) - ) + incremental_state["fastest_step"] = incremental_state[ + "fastest_step" + ].index_select(0, new_order) -@register_model_architecture( - 'transformer_monotonic', - 'transformer_monotonic' -) +@register_model_architecture("transformer_monotonic", "transformer_monotonic") def base_monotonic_rchitecture(args): base_architecture(args) - args.encoder_unidirectional = getattr( - args, 'encoder_unidirectional', False) + args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False) @register_model_architecture( - 'transformer_monotonic', - 'transformer_monotonic_iwslt_de_en' + "transformer_monotonic", "transformer_monotonic_iwslt_de_en" ) def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) @@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args): # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) @register_model_architecture( - 'transformer_monotonic', - 'transformer_monotonic_vaswani_wmt_en_de_big' + "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big" ) def transformer_monotonic_vaswani_wmt_en_de_big(args): transformer_vaswani_wmt_en_de_big(args) @register_model_architecture( - 'transformer_monotonic', - 'transformer_monotonic_vaswani_wmt_en_fr_big' + "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big" ) def transformer_monotonic_vaswani_wmt_en_fr_big(args): transformer_monotonic_vaswani_wmt_en_fr_big(args) @register_model_architecture( - 'transformer_unidirectional', - 'transformer_unidirectional_iwslt_de_en' + "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en" ) def transformer_unidirectional_iwslt_de_en(args): transformer_iwslt_de_en(args) diff --git a/examples/simultaneous_translation/modules/__init__.py b/examples/simultaneous_translation/modules/__init__.py index 8fd9d379a5..ad64774de4 100644 --- a/examples/simultaneous_translation/modules/__init__.py +++ b/examples/simultaneous_translation/modules/__init__.py @@ -7,14 +7,18 @@ import os from fairseq import registry + + ( build_monotonic_attention, register_monotonic_attention, MONOTONIC_ATTENTION_REGISTRY, - _ -) = registry.setup_registry('--simul-type') + _, +) = registry.setup_registry("--simul-type") for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - model_name = file[:file.find('.py')] - importlib.import_module('examples.simultaneous_translation.modules.' + model_name) + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module( + "examples.simultaneous_translation.modules." + model_name + ) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index d508b8cfba..c09725ac9a 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -4,22 +4,19 @@ # LICENSE file in the root directory of this source tree. import math + import torch -import torch.nn.functional as F import torch.nn as nn - -from fairseq import utils - -from fairseq.modules import MultiheadAttention - +import torch.nn.functional as F from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, - lengths_to_mask + lengths_to_mask, ) - - +from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules import MultiheadAttention from fairseq.utils import convert_padding_direction + from . import register_monotonic_attention @@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module): """ Abstract class of monotonic attentions """ + def __init__(self, args): self.eps = args.attention_eps self.mass_preservation = args.mass_preservation @@ -38,7 +36,8 @@ def __init__(self, args): self.energy_bias_init = args.energy_bias_init self.energy_bias = ( nn.Parameter(self.energy_bias_init * torch.ones([1])) - if args.energy_bias is True else 0 + if args.energy_bias is True + else 0 ) @staticmethod @@ -90,7 +89,7 @@ def attn_energy(self, q_proj, k_proj, key_padding_mask=None): if key_padding_mask is not None: attn_energy = attn_energy.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), - float('-inf'), + float("-inf"), ) return attn_energy @@ -131,10 +130,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask): alpha_i = ( p_choose[:, i] * cumprod_1mp[:, i] - * torch.cumsum( - previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], - dim=1 - ) + * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) ).clamp(0, 1.0) previous_attn.append(alpha_i.unsqueeze(1)) @@ -170,8 +166,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # prev_monotonic_step: bsz, num_heads bsz = bsz_num_heads // self.num_heads prev_monotonic_step = monotonic_cache.get( - "step", - p_choose.new_zeros([bsz, self.num_heads]).long() + "step", p_choose.new_zeros([bsz, self.num_heads]).long() ) bsz, num_heads = prev_monotonic_step.size() assert num_heads == self.num_heads @@ -181,8 +176,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state p_choose = p_choose.view(bsz, num_heads, src_len) if key_padding_mask is not None: - src_lengths = src_len - \ - key_padding_mask.sum(dim=1, keepdim=True).long() + src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long() else: src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len @@ -197,10 +191,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # left_pad_source = True: step_offset = key_padding_mask.sum(dim=-1, keepdim=True) - max_steps = ( - src_lengths - 1 if self.mass_preservation - else src_lengths - ) + max_steps = src_lengths - 1 if self.mass_preservation else src_lengths # finish_read: bsz, num_heads finish_read = new_monotonic_step.eq(max_steps) @@ -210,11 +201,11 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # only choose the p at monotonic steps # p_choose_i: bsz , self.num_heads p_choose_i = ( - p_choose - .gather( + p_choose.gather( 2, - (step_offset + new_monotonic_step).unsqueeze(2) - .clamp(0, src_len - 1) + (step_offset + new_monotonic_step) + .unsqueeze(2) + .clamp(0, src_len - 1), ) ).squeeze(2) @@ -239,21 +230,17 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # alpha: bsz * num_heads, 1, src_len # new_monotonic_step: bsz, num_heads - alpha = ( - p_choose - .new_zeros([bsz * self.num_heads, src_len]) - .scatter( - 1, - (step_offset + new_monotonic_step).view(bsz * - self.num_heads, 1).clamp(0, src_len - 1), - 1 - ) + alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter( + 1, + (step_offset + new_monotonic_step) + .view(bsz * self.num_heads, 1) + .clamp(0, src_len - 1), + 1, ) if not self.mass_preservation: alpha = alpha.masked_fill( - (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), - 0 + (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0 ) alpha = alpha.unsqueeze(1) @@ -266,8 +253,14 @@ def v_proj_output(self, value): raise NotImplementedError def forward( - self, query, key, value, - key_padding_mask=None, incremental_state=None, *args, **kwargs, + self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + *args, + **kwargs, ): tgt_len, bsz, embed_dim = query.size() @@ -280,25 +273,24 @@ def forward( # expected alignment alpha # bsz * self.num_heads, tgt_len, src_len if incremental_state is not None: - alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state) + alpha = self.expected_alignment_infer( + p_choose, key_padding_mask, incremental_state + ) else: alpha = self.expected_alignment_train(p_choose, key_padding_mask) # expected attention beta # bsz * self.num_heads, tgt_len, src_len - beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state) + beta = self.expected_attention( + alpha, query, key, value, key_padding_mask, incremental_state + ) attn_weights = beta v_proj = self.v_proj_output(value) attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) - attn = ( - attn - .transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) @@ -318,26 +310,32 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_monotonic_buffer(incremental_state, input_buffer) def _get_monotonic_buffer(self, incremental_state): - return utils.get_incremental_state( - self, - incremental_state, - 'monotonic', - ) or {} + return ( + utils.get_incremental_state( + self, + incremental_state, + "monotonic", + ) + or {} + ) def _set_monotonic_buffer(self, incremental_state, buffer): utils.set_incremental_state( self, incremental_state, - 'monotonic', + "monotonic", buffer, ) def get_pointer(self, incremental_state): - return utils.get_incremental_state( - self, - incremental_state, - 'monotonic', - ) or {} + return ( + utils.get_incremental_state( + self, + incremental_state, + "monotonic", + ) + or {} + ) def get_fastest_pointer(self, incremental_state): return self.get_pointer(incremental_state)["step"].max(0)[0] @@ -354,23 +352,22 @@ def set_pointer(self, incremental_state, p_choose): utils.set_incremental_state( self, incremental_state, - 'monotonic', + "monotonic", {"step": buffer}, ) @register_monotonic_attention("hard_aligned") class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): - def __init__(self, args): MultiheadAttention.__init__( self, embed_dim=args.decoder_embed_dim, num_heads=args.decoder_attention_heads, - kdim=getattr(args, 'encoder_embed_dim', None), - vdim=getattr(args, 'encoder_embed_dim', None), + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, - encoder_decoder_attention=True + encoder_decoder_attention=True, ) MonotonicAttention.__init__(self, args) @@ -395,21 +392,33 @@ def input_projections(self, query, key, value, name): bsz = query.size(1) q = self.q_in_proj[name](query) q *= self.scaling - q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + q = ( + q.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) else: q = None if key is not None: bsz = key.size(1) k = self.k_in_proj[name](key) - k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) else: k = None if value is not None: bsz = value.size(1) v = self.v_in_proj[name](value) - v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) else: v = None @@ -441,8 +450,7 @@ def p_choose(self, query, key, key_padding_mask=None): if self.training: # add noise here to encourage discretness noise = ( - torch - .normal(self.noise_mean, self.noise_var, attn_energy.size()) + torch.normal(self.noise_mean, self.noise_var, attn_energy.size()) .type_as(attn_energy) .to(attn_energy.device) ) @@ -454,9 +462,9 @@ def p_choose(self, query, key, key_padding_mask=None): return p_choose.view(-1, tgt_len, src_len) def expected_attention(self, alpha, *args): - ''' + """ For MMA-H, beta = alpha - ''' + """ return alpha def v_proj_output(self, value): @@ -479,13 +487,19 @@ def init_soft_attention(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization - nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_( + self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2) + ) + nn.init.xavier_uniform_( + self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2) + ) else: nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) - def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state): + def expected_attention( + self, alpha, query, key, value, key_padding_mask, incremental_state + ): # monotonic attention, we will calculate milk here bsz_x_num_heads, tgt_len, src_len = alpha.size() bsz = int(bsz_x_num_heads / self.num_heads) @@ -507,9 +521,10 @@ def expected_attention(self, alpha, query, key, value, key_padding_mask, increme step_offset = key_padding_mask.sum(dim=-1, keepdim=True) monotonic_step += step_offset mask = lengths_to_mask( - monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1) + monotonic_step.view(-1), soft_energy.size(2), 1 + ).unsqueeze(1) - soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf')) + soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] exp_soft_energy = torch.exp(soft_energy) exp_soft_energy_sum = exp_soft_energy.sum(dim=2) @@ -524,14 +539,20 @@ def expected_attention(self, alpha, query, key, value, key_padding_mask, increme if key_padding_mask is not None: if key_padding_mask.any(): exp_soft_energy_cumsum = ( - exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len) - .masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps) + exp_soft_energy_cumsum.view( + -1, self.num_heads, tgt_len, src_len + ) + .masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps + ) .view(-1, tgt_len, src_len) ) inner_items = alpha / exp_soft_energy_cumsum - beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2]) + beta = exp_soft_energy * torch.cumsum( + inner_items.flip(dims=[2]), dim=2 + ).flip(dims=[2]) beta = self.dropout_module(beta) @@ -547,7 +568,9 @@ def __init__(self, args): self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"] self.waitk_lagging = args.waitk_lagging - assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}." + assert ( + self.waitk_lagging > 0 + ), f"Lagging has to been larger than 0, get {self.waitk_lagging}." @staticmethod def add_args(parser): @@ -556,10 +579,13 @@ def add_args(parser): MonotonicMultiheadAttentionWaitk, ).add_args(parser) - parser.add_argument('--waitk-lagging', type=int, required=True, - help='Wait k lagging') + parser.add_argument( + "--waitk-lagging", type=int, required=True, help="Wait k lagging" + ) - def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None): + def p_choose( + self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None + ): """ query: bsz, tgt_len key: bsz, src_len @@ -574,16 +600,22 @@ def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incrementa if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): # Left pad source # add -1 to the end - p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1) - p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True) + p_choose = p_choose.masked_fill( + key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1 + ) + p_choose = convert_padding_direction( + p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True + ) p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) # remove -1 p_choose[p_choose.eq(-1)] = 0 # Extend to each head p_choose = ( - p_choose.contiguous().unsqueeze(1) - .expand(-1, self.num_heads, -1, -1).contiguous() + p_choose.contiguous() + .unsqueeze(1) + .expand(-1, self.num_heads, -1, -1) + .contiguous() .view(-1, tgt_len, src_len) ) diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index a9545b2540..442b7d487d 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -3,37 +3,32 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.modules import ( - LayerNorm, - TransformerEncoderLayer, - TransformerDecoderLayer -) +from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer from . import build_monotonic_attention class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): - def forward(self, x, encoder_padding_mask): seq_len, _, _ = x.size() attn_mask = x.new_ones([seq_len, seq_len]).triu(1) - attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf')) + attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf")) return super().forward(x, encoder_padding_mask, attn_mask) class TransformerMonotonicDecoderLayer(TransformerDecoderLayer): - - def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): super().__init__( args, no_encoder_attn=True, add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn + add_zero_attn=add_zero_attn, ) self.encoder_attn = build_monotonic_attention(args) self.encoder_attn_layer_norm = LayerNorm( - self.embed_dim, - export=getattr(args, 'char_inputs', False) + self.embed_dim, export=getattr(args, "char_inputs", False) ) def prune_incremental_state(self, incremental_state): @@ -46,12 +41,8 @@ def prune(module): input_buffer = {} break module._set_input_buffer(incremental_state, input_buffer) + prune(self.self_attn) def get_steps(self, incremental_state): - return ( - self.encoder_attn - ._get_monotonic_buffer( - incremental_state - ).get("step", 0) - ) + return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0) diff --git a/examples/simultaneous_translation/utils/__init__.py b/examples/simultaneous_translation/utils/__init__.py index 8e5886008f..be0ba4d99a 100644 --- a/examples/simultaneous_translation/utils/__init__.py +++ b/examples/simultaneous_translation/utils/__init__.py @@ -9,6 +9,6 @@ # automatically import any Python files in the criterions/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('examples.simultaneous_translation.utils.' + module) + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("examples.simultaneous_translation.utils." + module) diff --git a/examples/simultaneous_translation/utils/functions.py b/examples/simultaneous_translation/utils/functions.py index 620dd1d866..f795b5f31c 100644 --- a/examples/simultaneous_translation/utils/functions.py +++ b/examples/simultaneous_translation/utils/functions.py @@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): tensor_size = list(tensor.size()) tensor_size[dim] = 1 return_tensor = safe_cumprod( - torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps + torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), + dim=dim, + eps=eps, ) if dim == 0: @@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int): # batch_size, 1, src_len moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) - moving_sum = torch.nn.functional.conv1d( - x, - moving_sum_weight, - padding=start_idx + end_idx - 1 - ).squeeze(1).t() - moving_sum = moving_sum[end_idx: -start_idx] + moving_sum = ( + torch.nn.functional.conv1d( + x, moving_sum_weight, padding=start_idx + end_idx - 1 + ) + .squeeze(1) + .t() + ) + moving_sum = moving_sum[end_idx:-start_idx] assert src_len == moving_sum.size(0) assert batch_size == moving_sum.size(1) diff --git a/examples/simultaneous_translation/utils/latency.py b/examples/simultaneous_translation/utils/latency.py index 9d09584176..5d800a5d9e 100644 --- a/examples/simultaneous_translation/utils/latency.py +++ b/examples/simultaneous_translation/utils/latency.py @@ -18,7 +18,7 @@ def prepare_latency_metric( src_lens, target_padding_mask=None, batch_first: bool = False, - start_from_zero: bool = True + start_from_zero: bool = True, ): assert len(delays.size()) == 2 assert len(src_lens.size()) == 2 @@ -59,11 +59,7 @@ def __call__( start_from_zero: bool = True, ): delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric( - delays, - src_lens, - target_padding_mask, - batch_first, - start_from_zero + delays, src_lens, target_padding_mask, batch_first, start_from_zero ) return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask) @@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric): AP = 1 / (|x||y]) sum_i^|Y| deleys_i """ + @staticmethod def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): if target_padding_mask is not None: - AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True) + AP = torch.sum( + delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True + ) else: AP = torch.sum(delays, dim=0, keepdim=True) @@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric): gamma = |y| / |x| tau = argmin_i(delays_i = |x|) """ + @staticmethod def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): # tau = argmin_i(delays_i = |x|) tgt_len, bsz = delays.size() lagging_padding_mask = delays >= src_lens - lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :] + lagging_padding_mask = torch.nn.functional.pad( + lagging_padding_mask.t(), (1, 0) + ).t()[:-1, :] gamma = tgt_lens / src_lens - lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma + lagging = ( + delays + - torch.arange(delays.size(0)) + .unsqueeze(1) + .type_as(delays) + .expand_as(delays) + / gamma + ) lagging.masked_fill_(lagging_padding_mask, 0) tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True) AL = lagging.sum(dim=0, keepdim=True) / tau @@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric): 2. max(delays_i, delays'_{i-1} + 1 / gamma) """ + @staticmethod def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): tgt_len, bsz = delays.size() @@ -163,13 +173,18 @@ def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): new_delays[i] = torch.cat( [ new_delays[i - 1].unsqueeze(0) + 1 / gamma, - delays[i].unsqueeze(0) + delays[i].unsqueeze(0), ], - dim=0 + dim=0, ).max(dim=0)[0] DAL = ( - new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma + new_delays + - torch.arange(delays.size(0)) + .unsqueeze(1) + .type_as(delays) + .expand_as(delays) + / gamma ) if target_padding_mask is not None: DAL = DAL.masked_fill(target_padding_mask, 0) @@ -186,7 +201,7 @@ def prepare_latency_metric( src_lens, target_padding_mask=None, batch_first: bool = True, - start_from_zero: bool = True + start_from_zero: bool = True, ): assert batch_first assert len(delays.size()) == 3 @@ -256,25 +271,21 @@ def __call__(self, monotonic_step, src_lens): src_lens = src_lens - delays = ( - monotonic_step - .view(monotonic_step.size(0), -1, monotonic_step.size(-1)) - .max(dim=1)[0] - ) + delays = monotonic_step.view( + monotonic_step.size(0), -1, monotonic_step.size(-1) + ).max(dim=1)[0] - delays = ( - delays.masked_fill(delays >= src_lens, 0) - + (src_lens - 1) - .expand_as(delays) - .masked_fill(delays < src_lens, 0) - ) + delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as( + delays + ).masked_fill(delays < src_lens, 0) return_dict = {} for key, func in self.metric_calculator.items(): return_dict[key] = func( - delays.float(), src_lens.float(), + delays.float(), + src_lens.float(), target_padding_mask=None, batch_first=True, - start_from_zero=True + start_from_zero=True, ).t() return return_dict @@ -282,8 +293,13 @@ def __call__(self, monotonic_step, src_lens): class LatencyTraining(object): def __init__( - self, avg_weight, var_weight, avg_type, var_type, - stay_on_last_token, average_method, + self, + avg_weight, + var_weight, + avg_type, + var_type, + stay_on_last_token, + average_method, ): self.avg_weight = avg_weight self.var_weight = var_weight @@ -319,17 +335,12 @@ def expected_delays_from_attention( attention = attention.view(-1, tgt_len, src_len) if not self.stay_on_last_token: - residual_attention = \ - 1 - attention[:, :, :-1].sum(dim=2, keepdim=True) - attention = torch.cat( - [attention[:, :, :-1], residual_attention], - dim=2 - ) + residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True) + attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2) # bsz * num_heads_x_num_layers, tgt_len, src_len for MMA steps = ( - torch - .arange(1, 1 + src_len) + torch.arange(1, 1 + src_len) .unsqueeze(0) .unsqueeze(1) .expand_as(attention) @@ -355,15 +366,12 @@ def expected_delays_from_attention( src_lens = src_lens.view(-1, 1) # bsz * num_heads_num_layers, tgt_len, src_len - expected_delays = (steps * attention).sum(dim=2).view( - bsz, num_heads_x_layers, tgt_len + expected_delays = ( + (steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len) ) if target_padding_mask is not None: - expected_delays.masked_fill_( - target_padding_mask.unsqueeze(1), - 0 - ) + expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0) return expected_delays, src_lens @@ -371,8 +379,7 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask): bsz, num_heads_x_layers, tgt_len = expected_delays.size() target_padding_mask = ( - target_padding_mask - .unsqueeze(1) + target_padding_mask.unsqueeze(1) .expand_as(expected_delays) .contiguous() .view(-1, tgt_len) @@ -396,8 +403,11 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask): if self.avg_weight > 0.0: if self.avg_type in self.metric_calculator: average_delays = self.metric_calculator[self.avg_type]( - expected_delays, src_lens, target_padding_mask, - batch_first=True, start_from_zero=False + expected_delays, + src_lens, + target_padding_mask, + batch_first=True, + start_from_zero=False, ) else: raise RuntimeError(f"{self.avg_type} is not supported.") @@ -408,12 +418,17 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask): return 0.0 def var_loss(self, expected_delays, src_lens, target_padding_mask): - src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1] + src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[ + :, :1 + ] if self.var_weight > 0.0: if self.var_type in self.variance_calculator: variance_delays = self.variance_calculator[self.var_type]( - expected_delays, src_lens, target_padding_mask, - batch_first=True, start_from_zero=False + expected_delays, + src_lens, + target_padding_mask, + batch_first=True, + start_from_zero=False, ) else: raise RuntimeError(f"{self.var_type} is not supported.") diff --git a/examples/speech_recognition/__init__.py b/examples/speech_recognition/__init__.py index cd780902e3..0278f6a273 100644 --- a/examples/speech_recognition/__init__.py +++ b/examples/speech_recognition/__init__.py @@ -1 +1 @@ -from . import tasks, criterions, models # noqa +from . import criterions, models, tasks # noqa diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py index 8f932bcd5b..7493654afc 100644 --- a/examples/speech_recognition/criterions/ASG_loss.py +++ b/examples/speech_recognition/criterions/ASG_loss.py @@ -6,9 +6,9 @@ # LICENSE file in the root directory of this source tree. import torch +from examples.speech_recognition.data.replabels import pack_replabels from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion -from examples.speech_recognition.data.replabels import pack_replabels @register_criterion("asg_loss") diff --git a/examples/speech_recognition/data/__init__.py b/examples/speech_recognition/data/__init__.py index 737a22ec3a..47bb6e24dd 100644 --- a/examples/speech_recognition/data/__init__.py +++ b/examples/speech_recognition/data/__init__.py @@ -5,6 +5,7 @@ from .asr_dataset import AsrDataset + __all__ = [ - 'AsrDataset', + "AsrDataset", ] diff --git a/examples/speech_recognition/data/asr_dataset.py b/examples/speech_recognition/data/asr_dataset.py index 47969a2853..63a6fcac85 100644 --- a/examples/speech_recognition/data/asr_dataset.py +++ b/examples/speech_recognition/data/asr_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import os + import numpy as np from fairseq.data import FairseqDataset @@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset): """ def __init__( - self, aud_paths, aud_durations_ms, tgt, - tgt_dict, ids, speakers, - num_mel_bins=80, frame_length=25.0, frame_shift=10.0 + self, + aud_paths, + aud_durations_ms, + tgt, + tgt_dict, + ids, + speakers, + num_mel_bins=80, + frame_length=25.0, + frame_shift=10.0, ): assert frame_length > 0 assert frame_shift > 0 assert all(x > frame_length for x in aud_durations_ms) self.frame_sizes = [ - int(1 + (d - frame_length) / frame_shift) - for d in aud_durations_ms + int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms ] assert len(aud_paths) > 0 @@ -57,13 +64,17 @@ def __init__( self.frame_shift = frame_shift self.s2s_collater = Seq2SeqCollater( - 0, 1, pad_index=self.tgt_dict.pad(), - eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True + 0, + 1, + pad_index=self.tgt_dict.pad(), + eos_index=self.tgt_dict.eos(), + move_eos_to_beginning=True, ) def __getitem__(self, index): import torchaudio import torchaudio.compliance.kaldi as kaldi + tgt_item = self.tgt[index] if self.tgt is not None else None path = self.aud_paths[index] @@ -74,7 +85,7 @@ def __getitem__(self, index): sound, num_mel_bins=self.num_mel_bins, frame_length=self.frame_length, - frame_shift=self.frame_shift + frame_shift=self.frame_shift, ) output_cmvn = data_utils.apply_mv_norm(output) diff --git a/examples/speech_recognition/data/collaters.py b/examples/speech_recognition/data/collaters.py index 14740d48b7..6acfec876b 100644 --- a/examples/speech_recognition/data/collaters.py +++ b/examples/speech_recognition/data/collaters.py @@ -12,18 +12,18 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np +import numpy as np import torch from fairseq.data import data_utils as fairseq_data_utils class Seq2SeqCollater(object): """ - Implements collate function mainly for seq2seq tasks - This expects each sample to contain feature (src_tokens) and - targets. - This collator is also used for aligned training task. + Implements collate function mainly for seq2seq tasks + This expects each sample to contain feature (src_tokens) and + targets. + This collator is also used for aligned training task. """ def __init__( diff --git a/examples/speech_recognition/datasets/asr_prep_json.py b/examples/speech_recognition/datasets/asr_prep_json.py index 2bab825b89..b8db8ff166 100644 --- a/examples/speech_recognition/datasets/asr_prep_json.py +++ b/examples/speech_recognition/datasets/asr_prep_json.py @@ -6,52 +6,74 @@ from __future__ import absolute_import, division, print_function, unicode_literals -from collections import namedtuple -import concurrent.futures -from itertools import chain import argparse -import os +import concurrent.futures import json -import sentencepiece as spm import multiprocessing +import os +from collections import namedtuple +from itertools import chain +import sentencepiece as spm from fairseq.data import Dictionary + MILLISECONDS_TO_SECONDS = 0.001 def process_sample(aud_path, lable, utt_id, sp, tgt_dict): import torchaudio + input = {} output = {} si, ei = torchaudio.info(aud_path) - input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS) + input["length_ms"] = int( + si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS + ) input["path"] = aud_path token = " ".join(sp.EncodeAsPieces(lable)) ids = tgt_dict.encode_line(token, append_eos=False) output["text"] = lable output["token"] = token - output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids])) + output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) return {utt_id: {"input": input, "output": output}} def main(): parser = argparse.ArgumentParser() - parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True, - help="input directories with audio files") - parser.add_argument("--labels", required=True, - help="aggregated input labels with format per line", - type=argparse.FileType('r', encoding='UTF-8')) - parser.add_argument("--spm-model", required=True, - help="sentencepiece model to use for encoding", - type=argparse.FileType('r', encoding='UTF-8')) - parser.add_argument("--dictionary", required=True, - help="file to load fairseq dictionary from", - type=argparse.FileType('r', encoding='UTF-8')) + parser.add_argument( + "--audio-dirs", + nargs="+", + default=["-"], + required=True, + help="input directories with audio files", + ) + parser.add_argument( + "--labels", + required=True, + help="aggregated input labels with format per line", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--spm-model", + required=True, + help="sentencepiece model to use for encoding", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--dictionary", + required=True, + help="file to load fairseq dictionary from", + type=argparse.FileType("r", encoding="UTF-8"), + ) parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") - parser.add_argument("--output", required=True, type=argparse.FileType('w'), - help="path to save json output") + parser.add_argument( + "--output", + required=True, + type=argparse.FileType("w"), + help="path to save json output", + ) args = parser.parse_args() sp = spm.SentencePieceProcessor() @@ -64,15 +86,17 @@ def main(): (utt_id, label) = line.split(" ", 1) labels[utt_id] = label if len(labels) == 0: - raise Exception('No labels found in ', args.labels_path) + raise Exception("No labels found in ", args.labels_path) - Sample = namedtuple('Sample', 'aud_path utt_id') + Sample = namedtuple("Sample", "aud_path utt_id") samples = [] - for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs): + for path, _, files in chain.from_iterable( + os.walk(path) for path in args.audio_dirs + ): for f in files: if f.endswith(args.audio_format): if len(os.path.splitext(f)) != 2: - raise Exception('Expect file name. Got: ', f) + raise Exception("Expect file name. Got: ", f) utt_id = os.path.splitext(f)[0] if utt_id not in labels: continue @@ -81,12 +105,17 @@ def main(): utts = {} num_cpu = multiprocessing.cpu_count() with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: - future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples} + future_to_sample = { + executor.submit( + process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict + ): s + for s in samples + } for future in concurrent.futures.as_completed(future_to_sample): try: data = future.result() except Exception as exc: - print('generated an exception: ', exc) + print("generated an exception: ", exc) else: utts.update(data) json.dump({"utts": utts}, args.output, indent=4) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index fe5f607d1a..a197ab5a63 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -8,17 +8,17 @@ Run inference for pre-processed data with a trained model. """ -import editdistance import logging import math import os import sys +import editdistance import numpy as np import torch -from fairseq import checkpoint_utils, options, progress_bar, utils, tasks -from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data.data_utils import post_process +from fairseq.logging.meters import StopwatchMeter, TimeMeter logging.basicConfig() @@ -52,10 +52,12 @@ def add_asr_eval_argument(parser): "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" ) parser.add_argument( - "--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder" + "--w2l-decoder", + choices=["viterbi", "kenlm", "fairseqlm"], + help="use a w2l decoder", ) parser.add_argument("--lexicon", help="lexicon for w2l decoder") - parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm") + parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm") parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") parser.add_argument("--beam-threshold", type=float, default=25.0) parser.add_argument("--beam-size-token", type=float, default=100) @@ -87,10 +89,10 @@ def check_args(args): # assert args.path is not None, "--path required for generation!" # assert args.results_path is not None, "--results_path required for generation!" assert ( - not args.sampling or args.nbest == args.beam + not args.sampling or args.nbest == args.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - args.replace_unk is None or args.raw_text + args.replace_unk is None or args.raw_text ), "--replace-unk requires a raw text dataset (--raw-text)" @@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models): def process_predictions( - args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id + args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id ): for hypo in hypos[: min(len(hypos), args.nbest)]: hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) @@ -122,16 +124,25 @@ def process_predictions( if res_files is not None: print( - "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"] + "{} ({}-{})".format(hyp_pieces, speaker, id), + file=res_files["hypo.units"], + ) + print( + "{} ({}-{})".format(hyp_words, speaker, id), + file=res_files["hypo.words"], ) - print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) tgt_pieces = tgt_dict.string(target_tokens) tgt_words = post_process(tgt_pieces, args.remove_bpe) if res_files is not None: - print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) - print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) + print( + "{} ({}-{})".format(tgt_pieces, speaker, id), + file=res_files["ref.units"], + ) + print( + "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] + ) # only score top hypothesis if not args.quiet: logger.debug("HYPO:" + hyp_words) @@ -146,7 +157,7 @@ def process_predictions( def prepare_result_files(args): def get_res_file(file_prefix): if args.num_shards > 1: - file_prefix = f'{args.shard_id}_{file_prefix}' + file_prefix = f"{args.shard_id}_{file_prefix}" path = os.path.join( args.results_path, "{}-{}-{}.txt".format( @@ -166,15 +177,17 @@ def get_res_file(file_prefix): } -def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None): +def load_models_and_criterions( + filenames, data_path, arg_overrides=None, task=None, model_state=None +): models = [] criterions = [] if arg_overrides is None: arg_overrides = {} - arg_overrides['wer_args'] = None - arg_overrides['data'] = data_path + arg_overrides["wer_args"] = None + arg_overrides["data"] = data_path if filenames is None: assert model_state is not None @@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No def optimize_models(args, use_cuda, models): - """Optimize ensemble for generation - """ + """Optimize ensemble for generation""" for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, @@ -229,7 +241,7 @@ def generate(self, models, sample, **unused): emissions = np.stack(self.emissions[ids]) except: print([x.shape for x in self.emissions[ids]]) - raise Exception('invalid sizes') + raise Exception("invalid sizes") emissions = torch.from_numpy(emissions) return self.decoder.decode(emissions) @@ -300,7 +312,9 @@ def build_generator(args): return W2lFairseqLMDecoder(args, task.target_dictionary) else: - print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment') + print( + "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" + ) # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task generator = build_generator(args) @@ -361,7 +375,11 @@ def build_generator(args): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): - padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None + padding = ( + encoder_out["encoder_padding_mask"][i].cpu().numpy() + if encoder_out["encoder_padding_mask"] is not None + else None + ) features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) @@ -372,20 +390,31 @@ def build_generator(args): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id - toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :] - target_tokens = ( - utils.strip_pad(toks, tgt_dict.pad()).int().cpu() + toks = ( + sample["target"][i, :] + if "target_label" not in sample + else sample["target_label"][i, :] ) + target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() # Process top predictions errs, length = process_predictions( - args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id + args, + hypos[i], + None, + tgt_dict, + target_tokens, + res_files, + speaker, + id, ) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) - num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) wer = None if args.dump_emissions: @@ -413,7 +442,7 @@ def build_generator(args): gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, - ) + ) ) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) return task, wer @@ -424,6 +453,7 @@ def make_parser(): parser = add_asr_eval_argument(parser) return parser + def cli_main(): parser = make_parser() args = options.parse_args_and_arch(parser) diff --git a/examples/speech_recognition/models/__init__.py b/examples/speech_recognition/models/__init__.py index 66ad2b0a1f..0ad9663f11 100644 --- a/examples/speech_recognition/models/__init__.py +++ b/examples/speech_recognition/models/__init__.py @@ -1,7 +1,8 @@ import importlib import os + for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - model_name = file[:file.find('.py')] - importlib.import_module('examples.speech_recognition.models.' + model_name) + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.models." + model_name) diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py index e9a45ac73e..97974360a4 100644 --- a/examples/speech_recognition/models/vggtransformer.py +++ b/examples/speech_recognition/models/vggtransformer.py @@ -9,18 +9,22 @@ import torch import torch.nn as nn +from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask from fairseq import utils from fairseq.models import ( FairseqEncoder, + FairseqEncoderDecoderModel, FairseqEncoderModel, FairseqIncrementalDecoder, - FairseqEncoderDecoderModel, register_model, register_model_architecture, ) -from fairseq.modules import LinearizedConvolution -from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask -from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock +from fairseq.modules import ( + LinearizedConvolution, + TransformerDecoderLayer, + TransformerEncoderLayer, + VGGBlock, +) @register_model("asr_vggtransformer") @@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel): Transformers with convolutional context for ASR https://arxiv.org/abs/1904.11660 """ + def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @@ -602,18 +607,22 @@ def __init__( self.layers = nn.ModuleList() if conv_config[-1][0] != transformer_config[0][0]: self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) - self.layers.append(TransformerDecoderLayer( - prepare_transformer_decoder_params(*transformer_config[0]) - )) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[0]) + ) + ) for i in range(1, len(transformer_config)): if transformer_config[i - 1][0] != transformer_config[i][0]: self.layers.append( Linear(transformer_config[i - 1][0], transformer_config[i][0]) ) - self.layers.append(TransformerDecoderLayer( - prepare_transformer_decoder_params(*transformer_config[i]) - )) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[i]) + ) + ) self.fc_out = Linear(transformer_config[-1][0], vocab_size) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): @@ -713,6 +722,7 @@ def _transpose_if_inference(self, x, incremental_state): x = x.transpose(0, 1) return x + @register_model("asr_vggtransformer_encoder") class VGGTransformerEncoderModel(FairseqEncoderModel): def __init__(self, encoder): diff --git a/examples/speech_recognition/models/w2l_conv_glu_enc.py b/examples/speech_recognition/models/w2l_conv_glu_enc.py index 26f27553d4..655a9b0d19 100644 --- a/examples/speech_recognition/models/w2l_conv_glu_enc.py +++ b/examples/speech_recognition/models/w2l_conv_glu_enc.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq.models import ( FairseqEncoder, FairseqEncoderModel, diff --git a/examples/speech_recognition/tasks/__init__.py b/examples/speech_recognition/tasks/__init__.py index fb9e98372d..ffa5f3bd8c 100644 --- a/examples/speech_recognition/tasks/__init__.py +++ b/examples/speech_recognition/tasks/__init__.py @@ -1,7 +1,8 @@ import importlib import os + for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - task_name = file[:file.find('.py')] - importlib.import_module('examples.speech_recognition.tasks.' + task_name) + if file.endswith(".py") and not file.startswith("_"): + task_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.tasks." + task_name) diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index 769ce4ff54..d9f011d55f 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -9,10 +9,10 @@ import sys import torch -from fairseq.data import Dictionary -from fairseq.tasks import register_task, LegacyFairseqTask from examples.speech_recognition.data import AsrDataset from examples.speech_recognition.data.replabels import replabel_symbol +from fairseq.data import Dictionary +from fairseq.tasks import LegacyFairseqTask, register_task def get_asr_dataset_from_json(data_json_path, tgt_dict): @@ -78,10 +78,20 @@ def add_args(parser): parser.add_argument( "--silence-token", default="\u2581", help="token for silence (used by w2l)" ) - parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N', - help='max number of frames in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') + parser.add_argument( + "--max-source-positions", + default=sys.maxsize, + type=int, + metavar="N", + help="max number of frames in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) def __init__(self, args, tgt_dict): super().__init__(args) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 020aac5593..2a1d8a779d 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -9,16 +9,18 @@ Wav2letter decoders. """ -from collections import namedtuple, deque import gc import itertools as it -import numpy as np -import torch import os.path as osp import warnings +from collections import deque, namedtuple + +import numpy as np +import torch +from examples.speech_recognition.data.replabels import unpack_replabels from fairseq import tasks from fairseq.utils import apply_to_sample -from examples.speech_recognition.data.replabels import unpack_replabels + try: from wav2letter.common import create_word_dict, load_words diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 1983f70c10..1efeff4df1 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -4,66 +4,76 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from multiprocessing import cpu_count +import csv import os import os.path as op -from glob import glob import zipfile -import csv from functools import reduce -from typing import Dict, Any, List -from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank +from glob import glob +from multiprocessing import cpu_count +from typing import Any, Dict, List +import numpy as np import sentencepiece as sp +from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank +from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN from tqdm import tqdm -import numpy as np -from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN -UNK_TOKEN, UNK_TOKEN_ID = '', 3 -BOS_TOKEN, BOS_TOKEN_ID = '', 0 -EOS_TOKEN, EOS_TOKEN_ID = '', 2 -PAD_TOKEN, PAD_TOKEN_ID = '', 1 +UNK_TOKEN, UNK_TOKEN_ID = "", 3 +BOS_TOKEN, BOS_TOKEN_ID = "", 0 +EOS_TOKEN, EOS_TOKEN_ID = "", 2 +PAD_TOKEN, PAD_TOKEN_ID = "", 1 def gen_vocab( - input_path: str, output_path_prefix: str, model_type='bpe', - vocab_size=1000, + input_path: str, + output_path_prefix: str, + model_type="bpe", + vocab_size=1000, ): # Train SentencePiece Model arguments = [ - f'--input={input_path}', - f'--model_prefix={output_path_prefix}', - f'--model_type={model_type}', - f'--vocab_size={vocab_size}', - '--character_coverage=1.0', - f'--num_threads={cpu_count()}', - f'--unk_id={UNK_TOKEN_ID}', - f'--bos_id={BOS_TOKEN_ID}', - f'--eos_id={EOS_TOKEN_ID}', - f'--pad_id={PAD_TOKEN_ID}' + f"--input={input_path}", + f"--model_prefix={output_path_prefix}", + f"--model_type={model_type}", + f"--vocab_size={vocab_size}", + "--character_coverage=1.0", + f"--num_threads={cpu_count()}", + f"--unk_id={UNK_TOKEN_ID}", + f"--bos_id={BOS_TOKEN_ID}", + f"--eos_id={EOS_TOKEN_ID}", + f"--pad_id={PAD_TOKEN_ID}", ] - sp.SentencePieceTrainer.Train(' '.join(arguments)) + sp.SentencePieceTrainer.Train(" ".join(arguments)) # Export fairseq dictionary spm = sp.SentencePieceProcessor() - spm.Load(output_path_prefix + '.model') + spm.Load(output_path_prefix + ".model") vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} - assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \ - vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \ - vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \ - vocab.get(EOS_TOKEN_ID) == EOS_TOKEN + assert ( + vocab.get(UNK_TOKEN_ID) == UNK_TOKEN + and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN + and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN + and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN + ) vocab = { - i: s for i, s in vocab.items() + i: s + for i, s in vocab.items() if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} } - with open(output_path_prefix + '.txt', 'w') as f_out: + with open(output_path_prefix + ".txt", "w") as f_out: for _, s in sorted(vocab.items(), key=lambda x: x[0]): - f_out.write(f'{s} 1\n') + f_out.write(f"{s} 1\n") -def extract_fbank_features(waveform, sample_rate, output_path=None, - n_mel_bins=80, apply_utterance_cmvn=True, - overwrite=False): +def extract_fbank_features( + waveform, + sample_rate, + output_path=None, + n_mel_bins=80, + apply_utterance_cmvn=True, + overwrite=False, +): if output_path is not None and op.exists(output_path) and not overwrite: return @@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None, if features is None: features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) if features is None: - raise ImportError('Please install pyKaldi or torchaudio to enable ' - 'online filterbank feature extraction') + raise ImportError( + "Please install pyKaldi or torchaudio to enable " + "online filterbank feature extraction" + ) if apply_utterance_cmvn: cmvn = UtteranceCMVN(norm_means=True, norm_vars=True) @@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None, def create_zip(data_root, zip_path): cwd = os.path.abspath(os.curdir) os.chdir(data_root) - with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f: - for filename in tqdm(glob('*.npy')): + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f: + for filename in tqdm(glob("*.npy")): f.write(filename) os.chdir(cwd) @@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool: def get_zip_manifest(zip_root, zip_filename): zip_path = op.join(zip_root, zip_filename) - with zipfile.ZipFile(zip_path, mode='r') as f: + with zipfile.ZipFile(zip_path, mode="r") as f: info = f.infolist() manifest = {} for i in tqdm(info): utt_id = op.splitext(i.filename)[0] offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size - manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}' - with open(zip_path, 'rb') as f: + manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}" + with open(zip_path, "rb") as f: f.seek(offset) data = f.read(file_size) assert len(data) > 1 and is_npy_data(data) return manifest -def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml', - specaugment_policy='lb'): - assert specaugment_policy in {'lb', 'ld'} +def gen_config_yaml( + data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb" +): + assert specaugment_policy in {"lb", "ld"} data_root = op.abspath(data_root) writer = S2TDataConfigWriter(op.join(data_root, yaml_filename)) writer.set_audio_root(op.abspath(data_root)) writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) writer.set_input_channels(1) writer.set_input_feat_per_channel(80) - if specaugment_policy == 'lb': + if specaugment_policy == "lb": writer.set_specaugment_lb_policy() else: writer.set_specaugment_ld_policy() writer.set_bpe_tokenizer( - {'bpe': 'sentencepiece', - 'sentencepiece_model': op.join(data_root, spm_filename)} + { + "bpe": "sentencepiece", + "sentencepiece_model": op.join(data_root, spm_filename), + } ) - writer.set_feature_transforms('_train', ['specaugment']) + writer.set_feature_transforms("_train", ["specaugment"]) writer.flush() def save_df_to_tsv(dataframe, path): - dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8", - escapechar='\\', quoting=csv.QUOTE_NONE) + dataframe.to_csv( + path, + sep="\t", + header=True, + index=False, + encoding="utf-8", + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) -def filter_manifest_df(df, is_train_split=False, extra_filters=None, - min_n_frames=5, max_n_frames=3000): +def filter_manifest_df( + df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000 +): filters = { - 'no speech': df['audio'] == '', - f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames, - 'empty sentence': df['tgt_text'] == '', + "no speech": df["audio"] == "", + f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames, + "empty sentence": df["tgt_text"] == "", } if is_train_split: - filters[f'long speech (>{max_n_frames} frames)'] = \ - df['n_frames'] > max_n_frames + filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames if extra_filters is not None: filters.update(extra_filters) invalid = reduce(lambda x, y: x | y, filters.values()) valid = ~invalid print( - '| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) + - f', total {invalid.sum()} filtered, {valid.sum()} remained.' + "| " + + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items()) + + f", total {invalid.sum()} filtered, {valid.sum()} remained." ) return df[valid] class S2TDataConfigWriter(object): - DEFAULT_VOCAB_FILENAME = 'dict.txt' + DEFAULT_VOCAB_FILENAME = "dict.txt" DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 DEFAULT_INPUT_CHANNELS = 1 @@ -171,48 +194,69 @@ def __init__(self, yaml_path): try: import yaml except ImportError: - print('Please install PyYAML to load YAML files for S2T data config') + print("Please install PyYAML to load YAML files for S2T data config") self.yaml = yaml self.yaml_path = yaml_path self.config = {} def flush(self): - with open(self.yaml_path, 'w') as f: + with open(self.yaml_path, "w") as f: self.yaml.dump(self.config, f) - def set_audio_root(self, audio_root=''): - self.config['audio_root'] = audio_root - - def set_vocab_filename(self, vocab_filename='dict.txt'): - self.config['vocab_filename'] = vocab_filename - - def set_specaugment(self, time_wrap_w: int, freq_mask_n: int, - freq_mask_f: int, time_mask_n: int, time_mask_t: int, - time_mask_p: float): - self.config['specaugment'] = { - 'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n, - 'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n, - 'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p, + def set_audio_root(self, audio_root=""): + self.config["audio_root"] = audio_root + + def set_vocab_filename(self, vocab_filename="dict.txt"): + self.config["vocab_filename"] = vocab_filename + + def set_specaugment( + self, + time_wrap_w: int, + freq_mask_n: int, + freq_mask_f: int, + time_mask_n: int, + time_mask_t: int, + time_mask_p: float, + ): + self.config["specaugment"] = { + "time_wrap_W": time_wrap_w, + "freq_mask_N": freq_mask_n, + "freq_mask_F": freq_mask_f, + "time_mask_N": time_mask_n, + "time_mask_T": time_mask_t, + "time_mask_p": time_mask_p, } def set_specaugment_lb_policy(self): - self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27, - time_mask_n=1, time_mask_t=100, time_mask_p=1.0) + self.set_specaugment( + time_wrap_w=0, + freq_mask_n=1, + freq_mask_f=27, + time_mask_n=1, + time_mask_t=100, + time_mask_p=1.0, + ) def set_specaugment_ld_policy(self): - self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27, - time_mask_n=2, time_mask_t=100, time_mask_p=1.0) + self.set_specaugment( + time_wrap_w=0, + freq_mask_n=2, + freq_mask_f=27, + time_mask_n=2, + time_mask_t=100, + time_mask_p=1.0, + ) def set_input_channels(self, input_channels=1): - self.config['input_channels'] = input_channels + self.config["input_channels"] = input_channels def set_input_feat_per_channel(self, input_feat_per_channel=80): - self.config['input_feat_per_channel'] = input_feat_per_channel + self.config["input_feat_per_channel"] = input_feat_per_channel def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): - self.config['bpe_tokenizer'] = bpe_tokenizer + self.config["bpe_tokenizer"] = bpe_tokenizer def set_feature_transforms(self, split, transforms: List[str]): - if 'transforms' not in self.config: - self.config['transforms'] = {} - self.config['transforms'][split] = transforms + if "transforms" not in self.config: + self.config["transforms"] = {} + self.config["transforms"][split] = transforms diff --git a/examples/speech_to_text/prep_covost_data.py b/examples/speech_to_text/prep_covost_data.py index a70e24e04d..e8a028b446 100644 --- a/examples/speech_to_text/prep_covost_data.py +++ b/examples/speech_to_text/prep_covost_data.py @@ -5,30 +5,35 @@ # LICENSE file in the root directory of this source tree. import argparse +import csv import logging -from tempfile import NamedTemporaryFile import os import os.path as op import shutil -from typing import Tuple, Optional -import csv +from tempfile import NamedTemporaryFile +from typing import Optional, Tuple -from torchaudio.datasets.utils import download_url, extract_archive -from tqdm import tqdm import pandas as pd -from torch.utils.data import Dataset import torchaudio -from torch import Tensor - from examples.speech_to_text.data_utils import ( - gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, - extract_fbank_features, gen_config_yaml, filter_manifest_df + create_zip, + extract_fbank_features, + filter_manifest_df, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + save_df_to_tsv, ) +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import download_url, extract_archive +from tqdm import tqdm + log = logging.getLogger(__name__) -MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] class CoVoST(Dataset): @@ -44,40 +49,82 @@ class CoVoST(Dataset): found at root path. (default: ``False``). """ - CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \ - "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz" - COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \ - "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" + CV_URL_TEMPLATE = ( + "https://voice-prod-bundler-ee1969a6ce8178826482b88" + "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz" + ) + COVOST_URL_TEMPLATE = ( + "https://dl.fbaipublicfiles.com/covost/" + "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" + ) VERSIONS = {2} - SPLITS = ['train', 'dev', 'test'] + SPLITS = ["train", "dev", "test"] CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"} XX_EN_LANGUAGES = { - 1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn', - 'zh-CN'], - 2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn', - 'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy'] + 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"], + 2: [ + "fr", + "de", + "es", + "ca", + "it", + "ru", + "zh-CN", + "pt", + "fa", + "et", + "mn", + "nl", + "tr", + "ar", + "sv-SE", + "lv", + "sl", + "ta", + "ja", + "id", + "cy", + ], } EN_XX_LANGUAGES = { 1: [], - 2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et', - 'id', - 'ar', 'ta', 'lv', 'ja'] + 2: [ + "de", + "tr", + "fa", + "sv-SE", + "mn", + "zh-CN", + "cy", + "ca", + "sl", + "et", + "id", + "ar", + "ta", + "lv", + "ja", + ], } def __init__( - self, root: str, split: str, source_language: str, - target_language: Optional[str] = None, version: int = 2, - download: bool = False + self, + root: str, + split: str, + source_language: str, + target_language: Optional[str] = None, + version: int = 2, + download: bool = False, ) -> None: assert version in self.VERSIONS and split in self.SPLITS assert source_language is not None - self.no_translation = (target_language is None) + self.no_translation = target_language is None if not self.no_translation: - assert 'en' in {source_language, target_language} - if source_language == 'en': + assert "en" in {source_language, target_language} + if source_language == "en": assert target_language in self.EN_XX_LANGUAGES[version] else: assert source_language in self.XX_EN_LANGUAGES[version] @@ -85,51 +132,60 @@ def __init__( # Hack here so that we can get "split" column from CoVoST TSV. # Note that we use CoVoST train split for ASR which is an extension # to Common Voice train split. - target_language = 'de' if source_language == 'en' else 'en' + target_language = "de" if source_language == "en" else "en" - self.root = os.path.join(root, 'raw') + self.root = os.path.join(root, "raw") os.makedirs(self.root, exist_ok=True) - cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version], - lang=source_language) + cv_url = self.CV_URL_TEMPLATE.format( + ver=self.CV_VERSION_ID[version], lang=source_language + ) cv_archive = os.path.join(self.root, os.path.basename(cv_url)) if download: if not os.path.isfile(cv_archive): download_url(cv_url, self.root, hash_value=None) extract_archive(cv_archive) - covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language, - tgt_lang=target_language) + covost_url = self.COVOST_URL_TEMPLATE.format( + src_lang=source_language, tgt_lang=target_language + ) covost_archive = os.path.join(self.root, os.path.basename(covost_url)) if download: if not os.path.isfile(covost_archive): download_url(covost_url, self.root, hash_value=None) extract_archive(covost_archive) - cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv')) + cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv")) covost_tsv = self.load_from_tsv( - os.path.join(self.root, - os.path.basename(covost_url).replace('.tar.gz', '')) + os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", "")) + ) + df = pd.merge( + left=cv_tsv[["path", "sentence", "client_id"]], + right=covost_tsv[["path", "translation", "split"]], + how="inner", + on="path", ) - df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']], - right=covost_tsv[['path', 'translation', 'split']], - how='inner', on='path') - if split == 'train': - df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')] + if split == "train": + df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")] else: - df = df[df['split'] == split] - self.data = df.to_dict(orient='index').items() + df = df[df["split"] == split] + self.data = df.to_dict(orient="index").items() self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])] @classmethod def load_from_tsv(cls, path: str): return pd.read_csv( - path, sep='\t', header=0, encoding='utf-8', escapechar='\\', - quoting=csv.QUOTE_NONE, na_filter=False + path, + sep="\t", + header=0, + encoding="utf-8", + escapechar="\\", + quoting=csv.QUOTE_NONE, + na_filter=False, ) def __getitem__( - self, n: int + self, n: int ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]: """Load the n-th sample from the dataset. @@ -141,12 +197,12 @@ def __getitem__( sample_id)`` """ data = self.data[n] - path = os.path.join(self.root, 'clips', data['path']) + path = os.path.join(self.root, "clips", data["path"]) waveform, sample_rate = torchaudio.load(path) - sentence = data['sentence'] - translation = None if self.no_translation else data['translation'] - speaker_id = data['client_id'] - _id = data['path'].replace('.mp3', '') + sentence = data["sentence"] + translation = None if self.no_translation else data["translation"] + speaker_id = data["client_id"] + _id = data["path"].replace(".mp3", "") return waveform, sample_rate, sentence, translation, speaker_id, _id def __len__(self) -> int: @@ -157,76 +213,82 @@ def process(args): root = op.join(args.data_root, args.src_lang) os.makedirs(root, exist_ok=True) # Extract features - feature_root = op.join(root, 'fbank80') + feature_root = op.join(root, "fbank80") os.makedirs(feature_root, exist_ok=True) for split in CoVoST.SPLITS: - print(f'Fetching split {split}...') - dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, - download=True) - print('Extracting log mel filter bank features...') + print(f"Fetching split {split}...") + dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True) + print("Extracting log mel filter bank features...") for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features(waveform, sample_rate, - op.join(feature_root, f'{utt_id}.npy')) + extract_fbank_features( + waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy") + ) # Pack features into ZIP - zip_filename = 'fbank80.zip' + zip_filename = "fbank80.zip" zip_path = op.join(root, zip_filename) - print('ZIPing features...') + print("ZIPing features...") create_zip(feature_root, zip_path) - print('Fetching ZIP manifest...') - zip_manifest = get_zip_manifest(args.data_root, - f'{args.src_lang}/{zip_filename}') + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}") # Generate TSV manifest - print('Generating manifest...') + print("Generating manifest...") train_text = [] - task = f'asr_{args.src_lang}' + task = f"asr_{args.src_lang}" if args.tgt_lang is not None: - task = f'st_{args.src_lang}_{args.tgt_lang}' + task = f"st_{args.src_lang}_{args.tgt_lang}" for split in CoVoST.SPLITS: manifest = {c: [] for c in MANIFEST_COLUMNS} dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): - manifest['id'].append(utt_id) - manifest['audio'].append(zip_manifest[utt_id]) + manifest["id"].append(utt_id) + manifest["audio"].append(zip_manifest[utt_id]) duration_ms = int(wav.size(1) / sr * 1000) - manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) - manifest['tgt_text'].append( - src_utt if args.tgt_lang is None else tgt_utt - ) - manifest['speaker'].append(speaker_id) - is_train_split = split.startswith('train') + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt) + manifest["speaker"].append(speaker_id) + is_train_split = split.startswith("train") if is_train_split: - train_text.extend(manifest['tgt_text']) + train_text.extend(manifest["tgt_text"]) df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) - save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv')) + save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv")) # Generate vocab - vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size) - spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}' - with NamedTemporaryFile(mode='w') as f: + vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}" + with NamedTemporaryFile(mode="w") as f: for t in train_text: - f.write(t + '\n') - gen_vocab(f.name, op.join(root, spm_filename_prefix), - args.vocab_type, args.vocab_size) + f.write(t + "\n") + gen_vocab( + f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size + ) # Generate config YAML - gen_config_yaml(root, spm_filename_prefix + '.model', - yaml_filename=f'config_{task}.yaml', - specaugment_policy='lb') + gen_config_yaml( + root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{task}.yaml", + specaugment_policy="lb", + ) # Clean up shutil.rmtree(feature_root) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--data-root', '-d', required=True, type=str) - parser.add_argument('--vocab-type', default='unigram', required=True, - type=str, choices=['bpe', 'unigram', 'char']), - parser.add_argument('--vocab-size', default=1000, type=int) - parser.add_argument('--src-lang', '-s', required=True, type=str) - parser.add_argument('--tgt-lang', '-t', type=str) + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--vocab-size", default=1000, type=int) + parser.add_argument("--src-lang", "-s", required=True, type=str) + parser.add_argument("--tgt-lang", "-t", type=str) args = parser.parse_args() process(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 4f003ec505..95fcec8fe3 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -6,91 +6,114 @@ import argparse import logging -from tempfile import NamedTemporaryFile import os -import shutil import os.path as op +import shutil +from tempfile import NamedTemporaryFile -from tqdm import tqdm -from torchaudio.datasets import LIBRISPEECH import pandas as pd - from examples.speech_to_text.data_utils import ( - gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, - extract_fbank_features, gen_config_yaml + create_zip, + extract_fbank_features, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + save_df_to_tsv, ) +from torchaudio.datasets import LIBRISPEECH +from tqdm import tqdm + log = logging.getLogger(__name__) -SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean', - 'dev-other', 'test-clean', 'test-other'] +SPLITS = [ + "train-clean-100", + "train-clean-360", + "train-other-500", + "dev-clean", + "dev-other", + "test-clean", + "test-other", +] -MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] def process(args): os.makedirs(args.output_root, exist_ok=True) # Extract features - feature_root = op.join(args.output_root, 'fbank80') + feature_root = op.join(args.output_root, "fbank80") os.makedirs(feature_root, exist_ok=True) for split in SPLITS: - print(f'Fetching split {split}...') + print(f"Fetching split {split}...") dataset = LIBRISPEECH(args.output_root, url=split, download=True) - print('Extracting log mel filter bank features...') + print("Extracting log mel filter bank features...") for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset): - sample_id = f'{spk_id}-{chapter_id}-{utt_id}' - extract_fbank_features(wav, sample_rate, - op.join(feature_root, f'{sample_id}.npy')) + sample_id = f"{spk_id}-{chapter_id}-{utt_id}" + extract_fbank_features( + wav, sample_rate, op.join(feature_root, f"{sample_id}.npy") + ) # Pack features into ZIP - zip_filename = 'fbank80.zip' + zip_filename = "fbank80.zip" zip_path = op.join(args.output_root, zip_filename) - print('ZIPing features...') + print("ZIPing features...") create_zip(feature_root, zip_path) - print('Fetching ZIP manifest...') + print("Fetching ZIP manifest...") zip_manifest = get_zip_manifest(args.output_root, zip_filename) # Generate TSV manifest - print('Generating manifest...') + print("Generating manifest...") train_text = [] for split in SPLITS: manifest = {c: [] for c in MANIFEST_COLUMNS} dataset = LIBRISPEECH(args.output_root, url=split) for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset): - sample_id = f'{spk_id}-{chapter_id}-{utt_id}' - manifest['id'].append(sample_id) - manifest['audio'].append(zip_manifest[sample_id]) + sample_id = f"{spk_id}-{chapter_id}-{utt_id}" + manifest["id"].append(sample_id) + manifest["audio"].append(zip_manifest[sample_id]) duration_ms = int(wav.size(1) / sample_rate * 1000) - manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) - manifest['tgt_text'].append(utt) - manifest['speaker'].append(spk_id) - save_df_to_tsv(pd.DataFrame.from_dict(manifest), - op.join(args.output_root, f'{split}.tsv')) - if split.startswith('train'): - train_text.extend(manifest['tgt_text']) + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["tgt_text"].append(utt) + manifest["speaker"].append(spk_id) + save_df_to_tsv( + pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv") + ) + if split.startswith("train"): + train_text.extend(manifest["tgt_text"]) # Generate vocab - vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size) - spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}' - with NamedTemporaryFile(mode='w') as f: + vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}" + with NamedTemporaryFile(mode="w") as f: for t in train_text: - f.write(t + '\n') - gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix), - args.vocab_type, args.vocab_size) + f.write(t + "\n") + gen_vocab( + f.name, + op.join(args.output_root, spm_filename_prefix), + args.vocab_type, + args.vocab_size, + ) # Generate config YAML - gen_config_yaml(args.output_root, spm_filename_prefix + '.model', - specaugment_policy='ld') + gen_config_yaml( + args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld" + ) # Clean up shutil.rmtree(feature_root) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--output-root', '-o', required=True, type=str) - parser.add_argument('--vocab-type', default='unigram', required=True, - type=str, choices=['bpe', 'unigram', 'char']), - parser.add_argument('--vocab-size', default=10000, type=int) + parser.add_argument("--output-root", "-o", required=True, type=str) + parser.add_argument( + "--vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--vocab-size", default=10000, type=int) args = parser.parse_args() process(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 6c0a9b7132..5593d2e7e2 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -6,29 +6,34 @@ import argparse import logging -from tempfile import NamedTemporaryFile import os import os.path as op import shutil -from typing import Tuple from itertools import groupby +from tempfile import NamedTemporaryFile +from typing import Tuple -from tqdm import tqdm import pandas as pd -from torch.utils.data import Dataset import torchaudio -from torch import Tensor - from examples.speech_to_text.data_utils import ( - gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv, - extract_fbank_features, gen_config_yaml, filter_manifest_df + create_zip, + extract_fbank_features, + filter_manifest_df, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + save_df_to_tsv, ) +from torch import Tensor +from torch.utils.data import Dataset +from tqdm import tqdm + log = logging.getLogger(__name__) -MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker'] -TASKS = ['asr', 'st'] +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] +TASKS = ["asr", "st"] class MUSTC(Dataset): @@ -37,49 +42,55 @@ class MUSTC(Dataset): waveform, sample_rate, source utterance, target utterance, speaker_id, utterance_id """ - SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE'] - LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru'] + + SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"] + LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] def __init__(self, root: str, lang: str, split: str) -> None: assert split in self.SPLITS and lang in self.LANGUAGES - _root = op.join(root, f'en-{lang}', 'data', split) - wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt') + _root = op.join(root, f"en-{lang}", "data", split) + wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt") assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root) # Load audio segments try: import yaml except ImportError: - print('Please install PyYAML to load YAML files for ' - 'the MuST-C dataset') - with open(op.join(txt_root, f'{split}.yaml')) as f: + print("Please install PyYAML to load YAML files for " "the MuST-C dataset") + with open(op.join(txt_root, f"{split}.yaml")) as f: segments = yaml.load(f, Loader=yaml.BaseLoader) # Load source and target utterances - for _lang in ['en', lang]: - with open(op.join(txt_root, f'{split}.{_lang}')) as f: + for _lang in ["en", lang]: + with open(op.join(txt_root, f"{split}.{_lang}")) as f: utterances = [r.strip() for r in f] assert len(segments) == len(utterances) for i, u in enumerate(utterances): segments[i][_lang] = u # Gather info self.data = [] - for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']): + for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = op.join(wav_root, wav_filename) sample_rate = torchaudio.info(wav_path)[0].rate - seg_group = sorted(_seg_group, key=lambda x: x['offset']) + seg_group = sorted(_seg_group, key=lambda x: x["offset"]) for i, segment in enumerate(seg_group): - offset = int(float(segment['offset']) * sample_rate) - n_frames = int(float(segment['duration']) * sample_rate) - _id = f'{op.splitext(wav_filename)[0]}_{i}' + offset = int(float(segment["offset"]) * sample_rate) + n_frames = int(float(segment["duration"]) * sample_rate) + _id = f"{op.splitext(wav_filename)[0]}_{i}" self.data.append( - (wav_path, offset, n_frames, sample_rate, segment['en'], - segment[lang], segment['speaker_id'], _id) + ( + wav_path, + offset, + n_frames, + sample_rate, + segment["en"], + segment[lang], + segment["speaker_id"], + _id, + ) ) def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: - wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \ - self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, - num_frames=n_frames) + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] + waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) return waveform, sr, src_utt, tgt_utt, spk_id, utt_id def __len__(self) -> int: @@ -88,85 +99,102 @@ def __len__(self) -> int: def process(args): for lang in MUSTC.LANGUAGES: - cur_root = op.join(args.data_root, f'en-{lang}') + cur_root = op.join(args.data_root, f"en-{lang}") if not op.isdir(cur_root): - print(f'{cur_root} does not exist. Skipped.') + print(f"{cur_root} does not exist. Skipped.") continue # Extract features - feature_root = op.join(cur_root, 'fbank80') + feature_root = op.join(cur_root, "fbank80") os.makedirs(feature_root, exist_ok=True) for split in MUSTC.SPLITS: - print(f'Fetching split {split}...') + print(f"Fetching split {split}...") dataset = MUSTC(args.data_root, lang, split) - print('Extracting log mel filter bank features...') + print("Extracting log mel filter bank features...") for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features(waveform, sample_rate, - op.join(feature_root, f'{utt_id}.npy')) + extract_fbank_features( + waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy") + ) # Pack features into ZIP - zip_filename = 'fbank80.zip' + zip_filename = "fbank80.zip" zip_path = op.join(cur_root, zip_filename) - print('ZIPing features...') + print("ZIPing features...") create_zip(feature_root, zip_path) - print('Fetching ZIP manifest...') - zip_manifest = get_zip_manifest(args.data_root, - f'en-{lang}/{zip_filename}') + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}") # Generate TSV manifest - print('Generating manifest...') + print("Generating manifest...") train_text = {task: [] for task in TASKS} for split in MUSTC.SPLITS: - is_train_split = split.startswith('train') + is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} text = {task: [] for task in TASKS} dataset = MUSTC(args.data_root, lang, split) for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): - manifest['id'].append(utt_id) - manifest['audio'].append(zip_manifest[utt_id]) + manifest["id"].append(utt_id) + manifest["audio"].append(zip_manifest[utt_id]) duration_ms = int(wav.size(1) / sr * 1000) - manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10)) - text['asr'].append(src_utt) - text['st'].append(tgt_utt) - manifest['speaker'].append(speaker_id) + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + text["asr"].append(src_utt) + text["st"].append(tgt_utt) + manifest["speaker"].append(speaker_id) if is_train_split: for task in TASKS: train_text[task].extend(text[task]) for task in TASKS: - manifest['tgt_text'] = text[task] + manifest["tgt_text"] = text[task] df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) - save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv')) + save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv")) # Generate vocab for task in TASKS: vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size - if task == 'st': + if task == "st": vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size - vocab_size_str = '' if vocab_type == 'char' else str(vocab_size) - spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}' - with NamedTemporaryFile(mode='w') as f: + vocab_size_str = "" if vocab_type == "char" else str(vocab_size) + spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}" + with NamedTemporaryFile(mode="w") as f: for t in train_text[task]: - f.write(t + '\n') - gen_vocab(f.name, op.join(cur_root, spm_filename_prefix), - vocab_type, vocab_size) + f.write(t + "\n") + gen_vocab( + f.name, + op.join(cur_root, spm_filename_prefix), + vocab_type, + vocab_size, + ) # Generate config YAML - gen_config_yaml(cur_root, spm_filename_prefix + '.model', - yaml_filename=f'config_{task}.yaml', - specaugment_policy='lb') + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{task}.yaml", + specaugment_policy="lb", + ) # Clean up shutil.rmtree(feature_root) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--data-root', '-d', required=True, type=str) - parser.add_argument('--asr-vocab-type', default='unigram', required=True, - type=str, choices=['bpe', 'unigram', 'char']), - parser.add_argument('--st-vocab-type', default='unigram', required=True, - type=str, choices=['bpe', 'unigram', 'char']), - parser.add_argument('--asr-vocab-size', default=5000, type=int) - parser.add_argument('--st-vocab-size', default=8000, type=int) + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--asr-vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument( + "--st-vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--asr-vocab-size", default=5000, type=int) + parser.add_argument("--st-vocab-size", default=8000, type=int) args = parser.parse_args() process(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/translation_moe/score.py b/examples/translation_moe/score.py index b68cc828a7..9a529a9850 100644 --- a/examples/translation_moe/score.py +++ b/examples/translation_moe/score.py @@ -12,9 +12,9 @@ """ import argparse -from itertools import chain -import sys import random +import sys +from itertools import chain import numpy as np from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu @@ -22,17 +22,21 @@ def main(): parser = argparse.ArgumentParser(sys.argv[0]) - parser.add_argument('--sys', nargs='*', default='', metavar='FILE', - help='path to system output') - parser.add_argument('--ref', default='', metavar='FILE', - help='path to references') - parser.add_argument('--output', default='', metavar='FILE', - help='print outputs into a pretty format') + parser.add_argument( + "--sys", nargs="*", default="", metavar="FILE", help="path to system output" + ) + parser.add_argument("--ref", default="", metavar="FILE", help="path to references") + parser.add_argument( + "--output", + default="", + metavar="FILE", + help="print outputs into a pretty format", + ) args = parser.parse_args() if args.sys: src, tgt, hypos, log_probs = load_sys(args.sys) - print('pairwise BLEU: %.2f' % pairwise(hypos)) + print("pairwise BLEU: %.2f" % pairwise(hypos)) if args.output: merge(src, tgt, hypos, log_probs, args.output) @@ -58,18 +62,18 @@ def load_sys(paths): # S: source # T: target # D: detokenized system output - if line.startswith(('S-', 'T-', 'D-')): - i = int(line[line.find('-')+1:line.find('\t')]) - if line.startswith('S-'): - src[i] = line.split('\t')[1] - if line.startswith('T-'): - tgt[i] = line.split('\t')[1] - if line.startswith('D-'): + if line.startswith(("S-", "T-", "D-")): + i = int(line[line.find("-") + 1 : line.find("\t")]) + if line.startswith("S-"): + src[i] = line.split("\t")[1] + if line.startswith("T-"): + tgt[i] = line.split("\t")[1] + if line.startswith("D-"): if i not in hypos: hypos[i] = [] log_probs[i] = [] - hypos[i].append(line.split('\t')[2]) - log_probs[i].append(float(line.split('\t')[1])) + hypos[i].append(line.split("\t")[2]) + log_probs[i].append(float(line.split("\t")[1])) return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) @@ -79,34 +83,34 @@ def load_ref(path): src, tgt, refs = [], [], [] i = 0 while i < len(lines): - if lines[i].startswith('S-'): - src.append(lines[i].split('\t')[1].rstrip()) + if lines[i].startswith("S-"): + src.append(lines[i].split("\t")[1].rstrip()) i += 1 - elif lines[i].startswith('T-'): - tgt.append(lines[i].split('\t')[1].rstrip()) + elif lines[i].startswith("T-"): + tgt.append(lines[i].split("\t")[1].rstrip()) i += 1 else: a = [] - while i < len(lines) and lines[i].startswith('R'): - a.append(lines[i].split('\t')[1].rstrip()) + while i < len(lines) and lines[i].startswith("R"): + a.append(lines[i].split("\t")[1].rstrip()) i += 1 refs.append(a) return src, tgt, refs def merge(src, tgt, hypos, log_probs, path): - with open(path, 'w') as f: + with open(path, "w") as f: for s, t, hs, lps in zip(src, tgt, hypos, log_probs): - f.write(s + '\n') - f.write(t + '\n') - f.write('\n') + f.write(s + "\n") + f.write(t + "\n") + f.write("\n") for h, lp in zip(hs, lps): - f.write('\t%f\t%s\n' % (lp, h.strip())) - f.write('------------------------------------------------------\n') + f.write("\t%f\t%s\n" % (lp, h.strip())) + f.write("------------------------------------------------------\n") def corpus_bleu(sys_stream, ref_streams): - bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none') + bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none") return bleu.score @@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference): bleu.counts[i] += 1 bleu.totals[i] += 1 bleu = compute_bleu( - bleu.counts, bleu.totals, - bleu.sys_len, bleu.ref_len, - smooth_method='exp', + bleu.counts, + bleu.totals, + bleu.sys_len, + bleu.ref_len, + smooth_method="exp", ) return bleu.score @@ -150,7 +156,7 @@ def multi_ref(refs, hypos): best = [k for k in range(len(rs)) if s[k] == s[j]] a.add(random.choice(best)) ref_cnt += len(a) - print('#refs covered: %.2f' % (ref_cnt / len(refs))) + print("#refs covered: %.2f" % (ref_cnt / len(refs))) # transpose refs and hypos refs = list(zip(*refs)) @@ -160,33 +166,32 @@ def multi_ref(refs, hypos): k = len(hypos) m = len(refs) flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)] - duplicated_refs = [ - [ref for ref in refs_i for _ in range(k)] - for refs_i in refs - ] + duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs] loo_bleus = [] for held_out_ref in range(m): - remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:] + remaining_refs = ( + duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :] + ) assert len(remaining_refs) == m - 1 loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs)) - print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus)) + print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus)) def intra_ref(refs): - print('ref pairwise BLEU: %.2f' % pairwise(refs)) + print("ref pairwise BLEU: %.2f" % pairwise(refs)) refs = list(zip(*refs)) m = len(refs) concat_h = [] concat_rest = [[] for j in range(m - 1)] for i, h in enumerate(refs): - rest = refs[:i] + refs[i+1:] + rest = refs[:i] + refs[i + 1 :] concat_h.append(h) for j in range(m - 1): concat_rest[j].extend(rest[j]) concat_h = list(chain.from_iterable(concat_h)) bleu = corpus_bleu(concat_h, concat_rest) - print('multi-reference BLEU (leave-one-out): %.2f' % bleu) + print("multi-reference BLEU (leave-one-out): %.2f" % bleu) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/translation_moe/src/logsumexp_moe.py b/examples/translation_moe/src/logsumexp_moe.py index 0379f226b0..fb299daecb 100644 --- a/examples/translation_moe/src/logsumexp_moe.py +++ b/examples/translation_moe/src/logsumexp_moe.py @@ -21,6 +21,6 @@ def forward(ctx, logp, posterior, dim=-1): @staticmethod def backward(ctx, grad_output): - posterior, = ctx.saved_tensors + (posterior,) = ctx.saved_tensors grad_logp = grad_output.unsqueeze(ctx.dim) * posterior return grad_logp, None, None diff --git a/examples/translation_moe/src/mean_pool_gating_network.py b/examples/translation_moe/src/mean_pool_gating_network.py index 25743b4e98..484b6ac912 100644 --- a/examples/translation_moe/src/mean_pool_gating_network.py +++ b/examples/translation_moe/src/mean_pool_gating_network.py @@ -26,15 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None): def forward(self, encoder_out): if not ( - hasattr(encoder_out, 'encoder_out') - and hasattr(encoder_out, 'encoder_padding_mask') + hasattr(encoder_out, "encoder_out") + and hasattr(encoder_out, "encoder_padding_mask") and encoder_out.encoder_out.size(2) == self.embed_dim ): - raise ValueError('Unexpected format for encoder_out') + raise ValueError("Unexpected format for encoder_out") # mean pooling over time encoder_padding_mask = encoder_out.encoder_padding_mask # B x T - encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C + encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C if encoder_padding_mask is not None: encoder_out = encoder_out.clone() # required because of transpose above encoder_out[encoder_padding_mask] = 0 diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/src/translation_moe.py index 5455dd6681..ae458aaad3 100644 --- a/examples/translation_moe/src/translation_moe.py +++ b/examples/translation_moe/src/translation_moe.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import torch - from fairseq import metrics, utils from fairseq.tasks import register_task from fairseq.tasks.translation import TranslationTask @@ -13,7 +12,7 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork -@register_task('translation_moe') +@register_task("translation_moe") class TranslationMoETask(TranslationTask): """ Translation task for Mixture of Experts (MoE) models. @@ -58,19 +57,19 @@ def add_args(parser): # fmt: on def __init__(self, args, src_dict, tgt_dict): - if args.method == 'sMoElp': + if args.method == "sMoElp": # soft MoE with learned prior self.uniform_prior = False self.hard_selection = False - elif args.method == 'sMoEup': + elif args.method == "sMoEup": # soft MoE with uniform prior self.uniform_prior = True self.hard_selection = False - elif args.method == 'hMoElp': + elif args.method == "hMoElp": # hard MoE with learned prior self.uniform_prior = False self.hard_selection = True - elif args.method == 'hMoEup': + elif args.method == "hMoEup": # hard MoE with uniform prior self.uniform_prior = True self.hard_selection = True @@ -78,50 +77,56 @@ def __init__(self, args, src_dict, tgt_dict): # add indicator tokens for each expert for i in range(args.num_experts): # add to both dictionaries in case we're sharing embeddings - src_dict.add_symbol(''.format(i)) - tgt_dict.add_symbol(''.format(i)) + src_dict.add_symbol("".format(i)) + tgt_dict.add_symbol("".format(i)) super().__init__(args, src_dict, tgt_dict) def build_model(self, args): from fairseq import models + model = models.build_model(args, self) - if not self.uniform_prior and not hasattr(model, 'gating_network'): + if not self.uniform_prior and not hasattr(model, "gating_network"): if self.args.mean_pool_gating_network: - if getattr(args, 'mean_pool_gating_network_encoder_dim', None): + if getattr(args, "mean_pool_gating_network_encoder_dim", None): encoder_dim = args.mean_pool_gating_network_encoder_dim - elif getattr(args, 'encoder_embed_dim', None): + elif getattr(args, "encoder_embed_dim", None): # assume that encoder_embed_dim is the encoder's output dimension encoder_dim = args.encoder_embed_dim else: - raise ValueError('Must specify --mean-pool-gating-network-encoder-dim') + raise ValueError( + "Must specify --mean-pool-gating-network-encoder-dim" + ) - if getattr(args, 'mean_pool_gating_network_dropout', None): + if getattr(args, "mean_pool_gating_network_dropout", None): dropout = args.mean_pool_gating_network_dropout - elif getattr(args, 'dropout', None): + elif getattr(args, "dropout", None): dropout = args.dropout else: - raise ValueError('Must specify --mean-pool-gating-network-dropout') + raise ValueError("Must specify --mean-pool-gating-network-dropout") model.gating_network = MeanPoolGatingNetwork( - encoder_dim, args.num_experts, dropout, + encoder_dim, + args.num_experts, + dropout, ) else: raise ValueError( - 'translation_moe task with learned prior requires the model to ' - 'have a gating network; try using --mean-pool-gating-network' + "translation_moe task with learned prior requires the model to " + "have a gating network; try using --mean-pool-gating-network" ) return model def expert_index(self, i): - return i + self.tgt_dict.index('') + return i + self.tgt_dict.index("") def _get_loss(self, sample, model, criterion): - assert hasattr(criterion, 'compute_loss'), \ - 'translation_moe task requires the criterion to implement the compute_loss() method' + assert hasattr( + criterion, "compute_loss" + ), "translation_moe task requires the criterion to implement the compute_loss() method" k = self.args.num_experts - bsz = sample['target'].size(0) + bsz = sample["target"].size(0) def get_lprob_y(encoder_out, prev_output_tokens_k): net_output = model.decoder( @@ -134,20 +139,22 @@ def get_lprob_y(encoder_out, prev_output_tokens_k): def get_lprob_yz(winners=None): encoder_out = model.encoder( - src_tokens=sample['net_input']['src_tokens'], - src_lengths=sample['net_input']['src_lengths'], + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], ) if winners is None: lprob_y = [] for i in range(k): - prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone() + prev_output_tokens_k = sample["net_input"][ + "prev_output_tokens" + ].clone() assert not prev_output_tokens_k.requires_grad prev_output_tokens_k[:, 0] = self.expert_index(i) lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k)) lprob_y = torch.cat(lprob_y, dim=1) # -> B x K else: - prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone() + prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone() prev_output_tokens_k[:, 0] = self.expert_index(winners) lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B @@ -177,17 +184,21 @@ def get_lprob_yz(winners=None): loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) loss = loss.sum() - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = ( + sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + ) logging_output = { - 'loss': utils.item(loss.data), - 'ntokens': sample['ntokens'], - 'nsentences': bsz, - 'sample_size': sample_size, - 'posterior': prob_z_xy.float().sum(dim=0).cpu(), + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": bsz, + "sample_size": sample_size, + "posterior": prob_z_xy.float().sum(dim=0).cpu(), } return loss, sample_size, logging_output - def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): model.train() loss, sample_size, logging_output = self._get_loss(sample, model, criterion) if ignore_grad: @@ -201,7 +212,15 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = self._get_loss(sample, model, criterion) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None): + def inference_step( + self, + generator, + models, + sample, + prefix_tokens=None, + expert=None, + constraints=None, + ): expert = expert or self.args.gen_expert with torch.no_grad(): return generator.generate( @@ -215,6 +234,6 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, expert=N def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) metrics.log_scalar( - 'posterior', - sum(log['posterior'] for log in logging_outputs if 'posterior' in log) + "posterior", + sum(log["posterior"] for log in logging_outputs if "posterior" in log), ) diff --git a/examples/unsupervised_quality_estimation/aggregate_scores.py b/examples/unsupervised_quality_estimation/aggregate_scores.py index 35a6baf67d..66d50d07ff 100644 --- a/examples/unsupervised_quality_estimation/aggregate_scores.py +++ b/examples/unsupervised_quality_estimation/aggregate_scores.py @@ -4,37 +4,38 @@ # LICENSE file in the root directory of this source tree. import argparse -import numpy as np import sys +import numpy as np + aggregate_funcs = { - 'std': np.std, - 'var': np.var, - 'median': np.median, - 'mean': np.mean, - 'min': np.min, - 'max': np.max, + "std": np.std, + "var": np.var, + "median": np.median, + "mean": np.mean, + "min": np.min, + "max": np.max, } def main(): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_file', required=True, type=str) - parser.add_argument('-n', '--repeat_times', required=True, type=int) - parser.add_argument('-o', '--output_file', required=False) - parser.add_argument('-f', '--func', required=False, default='mean') + parser.add_argument("-i", "--input_file", required=True, type=str) + parser.add_argument("-n", "--repeat_times", required=True, type=int) + parser.add_argument("-o", "--output_file", required=False) + parser.add_argument("-f", "--func", required=False, default="mean") args = parser.parse_args() - stream = open(args.output_file, 'w') if args.output_file else sys.stdout + stream = open(args.output_file, "w") if args.output_file else sys.stdout segment_scores = [] for line in open(args.input_file): segment_scores.append(float(line.strip())) if len(segment_scores) == args.repeat_times: - stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores))) + stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores))) segment_scores = [] -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py index ed4ba4ec34..4a214e794d 100644 --- a/examples/unsupervised_quality_estimation/meteor.py +++ b/examples/unsupervised_quality_estimation/meteor.py @@ -4,14 +4,13 @@ # LICENSE file in the root directory of this source tree. import argparse +import math import os -import sys import subprocess +import sys import tempfile -import math - -from itertools import combinations from collections import defaultdict +from itertools import combinations def read_translations(path, n_repeats): @@ -19,7 +18,7 @@ def read_translations(path, n_repeats): segment_translations = [] translations = defaultdict(list) for line in open(path): - segment_translations.append(' '.join(line.split())) + segment_translations.append(" ".join(line.split())) if len(segment_translations) == n_repeats: translations[segment_counter] = segment_translations segment_translations = [] @@ -30,42 +29,55 @@ def read_translations(path, n_repeats): def generate_input(translations, n_repeats): _, ref_path = tempfile.mkstemp() _, mt_path = tempfile.mkstemp() - ref_fh = open(ref_path, 'w') - mt_fh = open(mt_path, 'w') + ref_fh = open(ref_path, "w") + mt_fh = open(mt_path, "w") for segid in sorted(translations.keys()): assert len(translations[segid]) == n_repeats indexes = combinations(range(n_repeats), 2) for idx1, idx2 in indexes: - mt_fh.write(translations[segid][idx1].strip() + '\n') - ref_fh.write(translations[segid][idx2].strip() + '\n') - sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path)) + mt_fh.write(translations[segid][idx1].strip() + "\n") + ref_fh.write(translations[segid][idx2].strip() + "\n") + sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path)) return ref_path, mt_path -def run_meteor(ref_path, mt_path, metric_path, lang='en'): +def run_meteor(ref_path, mt_path, metric_path, lang="en"): _, out_path = tempfile.mkstemp() - subprocess.call([ - 'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path, - '-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R - '-norm', - '-l', lang], stdout=open(out_path, 'w')) + subprocess.call( + [ + "java", + "-Xmx2G", + "-jar", + metric_path, + mt_path, + ref_path, + "-p", + "0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R + "-norm", + "-l", + lang, + ], + stdout=open(out_path, "w"), + ) os.remove(ref_path) os.remove(mt_path) - sys.stderr.write('\nSaved Meteor output to %s' % out_path) + sys.stderr.write("\nSaved Meteor output to %s" % out_path) return out_path def read_output(meteor_output_path, n_repeats): - n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2)) + n_combinations = math.factorial(n_repeats) / ( + math.factorial(2) * math.factorial(n_repeats - 2) + ) raw_scores = [] average_scores = [] for line in open(meteor_output_path): - if not line.startswith('Segment '): + if not line.startswith("Segment "): continue - score = float(line.strip().split('\t')[1]) + score = float(line.strip().split("\t")[1]) raw_scores.append(score) if len(raw_scores) == n_combinations: - average_scores.append(sum(raw_scores)/n_combinations) + average_scores.append(sum(raw_scores) / n_combinations) raw_scores = [] os.remove(meteor_output_path) return average_scores @@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats): def main(): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input') - parser.add_argument('-n', '--repeat_times', type=int) - parser.add_argument('-m', '--meteor') - parser.add_argument('-o', '--output') + parser.add_argument("-i", "--input") + parser.add_argument("-n", "--repeat_times", type=int) + parser.add_argument("-m", "--meteor") + parser.add_argument("-o", "--output") args = parser.parse_args() translations = read_translations(args.infile, args.repetitions) - sys.stderr.write('\nGenerating input for Meteor...') + sys.stderr.write("\nGenerating input for Meteor...") ref_path, mt_path = generate_input(translations, args.repetitions) - sys.stderr.write('\nRunning Meteor...') + sys.stderr.write("\nRunning Meteor...") out_path = run_meteor(ref_path, mt_path, args.meteor) - sys.stderr.write('\nReading output...') + sys.stderr.write("\nReading output...") scores = read_output(out_path, args.repetitions) - sys.stderr.write('\nWriting results...') - with open(args.output, 'w') as o: + sys.stderr.write("\nWriting results...") + with open(args.output, "w") as o: for scr in scores: - o.write('{}\n'.format(scr)) + o.write("{}\n".format(scr)) o.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/unsupervised_quality_estimation/repeat_lines.py b/examples/unsupervised_quality_estimation/repeat_lines.py index 661ca17c1b..5a04851a74 100644 --- a/examples/unsupervised_quality_estimation/repeat_lines.py +++ b/examples/unsupervised_quality_estimation/repeat_lines.py @@ -8,21 +8,21 @@ def _normalize_spaces(line): - return ' '.join(line.split()) + return " ".join(line.split()) def main(): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_file', required=True, type=str) - parser.add_argument('-n', '--repeat_times', required=True, type=int) - parser.add_argument('-o', '--output_file', required=False, type=str) + parser.add_argument("-i", "--input_file", required=True, type=str) + parser.add_argument("-n", "--repeat_times", required=True, type=int) + parser.add_argument("-o", "--output_file", required=False, type=str) args = parser.parse_args() - stream = open(args.output_file, 'w') if args.output_file else sys.stdout + stream = open(args.output_file, "w") if args.output_file else sys.stdout for line in open(args.input_file): for _ in range(args.repeat_times): - stream.write(_normalize_spaces(line) + '\n') + stream.write(_normalize_spaces(line) + "\n") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py index 0d658c07ca..baabc1d365 100644 --- a/examples/wav2vec/vq-wav2vec_featurize.py +++ b/examples/wav2vec/vq-wav2vec_featurize.py @@ -8,30 +8,31 @@ Helper script to pre-compute embeddings for a wav2letter++ dataset """ +import argparse +import glob +import os +import os.path as osp import pprint -import glob, os, argparse +import soundfile as sf import torch +import tqdm +from fairseq.models.wav2vec.wav2vec import Wav2VecModel from torch import nn +from torch.utils.data import DataLoader + try: import tqdm except: print("Install tqdm to use --log-format=tqdm") -from fairseq.models.wav2vec.wav2vec import Wav2VecModel - -import tqdm -import soundfile as sf -from torch.utils.data import DataLoader -import os.path as osp - class FilesDataset: def __init__(self, files, labels): self.files = files if labels and osp.exists(labels): - with open(labels, 'r') as lbl_f: + with open(labels, "r") as lbl_f: self.labels = [line.rstrip() for line in lbl_f] else: self.labels = labels @@ -50,7 +51,7 @@ def __getitem__(self, index): if self.labels: if isinstance(self.labels, str): lbl_file = osp.splitext(fname)[0] + "." + self.labels - with open(lbl_file, 'r') as lblf: + with open(lbl_file, "r") as lblf: lbls = lblf.readline() assert lbls is not None else: @@ -116,24 +117,24 @@ def process_splits(self): assert len(files) > 0 if self.args.shard is not None: - files = files[self.args.shard::self.args.num_shards] + files = files[self.args.shard :: self.args.num_shards] lbls = [] - with open(self.data_file(split), 'w') as srcf: + with open(self.data_file(split), "w") as srcf: for line, lbl in self.iterate(files): print(line, file=srcf) if self.args.labels: - lbls.append(lbl + '\n') + lbls.append(lbl + "\n") if self.args.labels: assert all(a is not None for a in lbls) - with open(self.lbl_file(split), 'w') as lblf: + with open(self.lbl_file(split), "w") as lblf: lblf.writelines(lbls) def iterate(self, files): data = self.load_data(files) - for samples in tqdm.tqdm(data, total=len(files)//32): + for samples in tqdm.tqdm(data, total=len(files) // 32): for wav, lbl in samples: x = wav.unsqueeze(0).float().cuda() @@ -162,7 +163,6 @@ def iterate(self, files): idx = torch.cat(result, dim=0) yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl - def lbl_file(self, name): shard_part = "" if self.args.shard is None else f".{self.args.shard}" return osp.join(self.output_dir, f"{name}.lbl{shard_part}") @@ -230,7 +230,9 @@ def __call__(self): self.process_splits() - if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0): + if hasattr(self.model.feature_extractor, "vars") and ( + self.args.shard is None or self.args.shard == 0 + ): vars = ( self.model.feature_extractor.vars.view( self.model.feature_extractor.banks, @@ -248,4 +250,4 @@ def __call__(self): write_data = DatasetWriter() write_data() - print("Done.") \ No newline at end of file + print("Done.") diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py index 445a5d0213..9283930587 100644 --- a/examples/wav2vec/wav2vec_featurize.py +++ b/examples/wav2vec/wav2vec_featurize.py @@ -14,13 +14,12 @@ from shutil import copy import h5py -import soundfile as sf import numpy as np +import soundfile as sf import torch -from torch import nn import tqdm - from fairseq.models.wav2vec.wav2vec import Wav2VecModel +from torch import nn def read_audio(fname): @@ -33,7 +32,6 @@ def read_audio(fname): class PretrainedWav2VecModel(nn.Module): - def __init__(self, fname): super().__init__() @@ -55,32 +53,33 @@ def forward(self, x): class EmbeddingWriterConfig(argparse.ArgumentParser): - def __init__(self): super().__init__("Pre-compute embeddings for wav2letter++ datasets") kwargs = {"action": "store", "type": str, "required": True} - self.add_argument("--input", "-i", - help="Input Directory", **kwargs) - self.add_argument("--output", "-o", - help="Output Directory", **kwargs) - self.add_argument("--model", - help="Path to model checkpoint", **kwargs) - self.add_argument("--split", - help="Dataset Splits", nargs='+', **kwargs) - self.add_argument("--ext", default="wav", required=False, - help="Audio file extension") - - self.add_argument("--no-copy-labels", action="store_true", - help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.") - self.add_argument("--use-feat", action="store_true", - help="Use the feature vector ('z') instead of context vector ('c') for features") - self.add_argument("--gpu", - help="GPU to use", default=0, type=int) - - -class Prediction(): + self.add_argument("--input", "-i", help="Input Directory", **kwargs) + self.add_argument("--output", "-o", help="Output Directory", **kwargs) + self.add_argument("--model", help="Path to model checkpoint", **kwargs) + self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs) + self.add_argument( + "--ext", default="wav", required=False, help="Audio file extension" + ) + + self.add_argument( + "--no-copy-labels", + action="store_true", + help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.", + ) + self.add_argument( + "--use-feat", + action="store_true", + help="Use the feature vector ('z') instead of context vector ('c') for features", + ) + self.add_argument("--gpu", help="GPU to use", default=0, type=int) + + +class Prediction: """ Lightweight wrapper around a fairspeech embedding model """ def __init__(self, fname, gpu=0): @@ -95,7 +94,7 @@ def __call__(self, x): return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() -class H5Writer(): +class H5Writer: """ Write features as hdf5 file in wav2letter++ compatible format """ def __init__(self, fname): @@ -112,7 +111,7 @@ def write(self, data): class EmbeddingDatasetWriter(object): - """ Given a model and a wav2letter++ dataset, pre-compute and store embeddings + """Given a model and a wav2letter++ dataset, pre-compute and store embeddings Args: input_root, str : @@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object): Dataset split """ - def __init__(self, input_root, output_root, split, - model_fname, - extension="wav", - gpu=0, - verbose=False, - use_feat=False, - ): + def __init__( + self, + input_root, + output_root, + split, + model_fname, + extension="wav", + gpu=0, + verbose=False, + use_feat=False, + ): assert os.path.exists(model_fname) @@ -143,8 +146,9 @@ def __init__(self, input_root, output_root, split, self.extension = extension self.use_feat = use_feat - assert os.path.exists(self.input_path), \ - "Input path '{}' does not exist".format(self.input_path) + assert os.path.exists(self.input_path), "Input path '{}' does not exist".format( + self.input_path + ) def _progress(self, iterable, **kwargs): if self.verbose: @@ -176,7 +180,11 @@ def get_output_path(self, fname=None): def copy_labels(self): self.require_output_path() - labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*")))) + labels = list( + filter( + lambda x: self.extension not in x, glob.glob(self.get_input_path("*")) + ) + ) for fname in tqdm.tqdm(labels): copy(fname, self.output_path) @@ -191,10 +199,16 @@ def write_features(self): paths = self.input_fnames - fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \ - map(os.path.basename, paths)) + fnames_context = map( + lambda x: os.path.join( + self.output_path, x.replace("." + self.extension, ".h5context") + ), + map(os.path.basename, paths), + ) - for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)): + for name, target_fname in self._progress( + zip(paths, fnames_context), total=len(self) + ): wav, sr = read_audio(name) z, c = self.model(wav) feat = z if self.use_feat else c @@ -204,7 +218,8 @@ def write_features(self): def __repr__(self): return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( - n_files=len(self), **self.__dict__) + n_files=len(self), **self.__dict__ + ) if __name__ == "__main__": diff --git a/examples/wav2vec/wav2vec_manifest.py b/examples/wav2vec/wav2vec_manifest.py index c80f9883df..1d27f58afc 100644 --- a/examples/wav2vec/wav2vec_manifest.py +++ b/examples/wav2vec/wav2vec_manifest.py @@ -10,32 +10,50 @@ import argparse import glob import os -import soundfile import random +import soundfile + def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index') - parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D', - help='percentage of data to use as validation set (between 0 and 1)') - parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory') - parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for') - parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed') - parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG', - help='if set, path must contain this substring for a file to be included in the manifest') + parser.add_argument( + "root", metavar="DIR", help="root directory containing flac files to index" + ) + parser.add_argument( + "--valid-percent", + default=0.01, + type=float, + metavar="D", + help="percentage of data to use as validation set (between 0 and 1)", + ) + parser.add_argument( + "--dest", default=".", type=str, metavar="DIR", help="output directory" + ) + parser.add_argument( + "--ext", default="flac", type=str, metavar="EXT", help="extension to look for" + ) + parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed") + parser.add_argument( + "--path-must-contain", + default=None, + type=str, + metavar="FRAG", + help="if set, path must contain this substring for a file to be included in the manifest", + ) return parser def main(args): - assert args.valid_percent >= 0 and args.valid_percent <= 1. + assert args.valid_percent >= 0 and args.valid_percent <= 1.0 dir_path = os.path.realpath(args.root) - search_path = os.path.join(dir_path, '**/*.' + args.ext) + search_path = os.path.join(dir_path, "**/*." + args.ext) rand = random.Random(args.seed) - with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open( - os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f: + with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open( + os.path.join(args.dest, "valid.tsv"), "w" + ) as valid_f: print(dir_path, file=train_f) print(dir_path, file=valid_f) @@ -47,10 +65,12 @@ def main(args): frames = soundfile.info(fname).frames dest = train_f if rand.random() > args.valid_percent else valid_f - print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest) + print( + "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/fairseq/__init__.py b/fairseq/__init__.py index a4244c8a3a..cac3d0e43b 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -4,16 +4,17 @@ # LICENSE file in the root directory of this source tree. """isort:skip_file""" -__all__ = ['pdb'] -__version__ = '1.0.0a0' +__all__ = ["pdb"] +__version__ = "1.0.0a0" import sys # backwards compatibility to support `from fairseq.meters import AverageMeter` from fairseq.logging import meters, metrics, progress_bar # noqa -sys.modules['fairseq.meters'] = meters -sys.modules['fairseq.metrics'] = metrics -sys.modules['fairseq.progress_bar'] = progress_bar + +sys.modules["fairseq.meters"] = meters +sys.modules["fairseq.metrics"] = metrics +sys.modules["fairseq.progress_bar"] = progress_bar import fairseq.criterions # noqa import fairseq.models # noqa diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py index 926f3ce739..f6584661bd 100644 --- a/fairseq/benchmark/__init__.py +++ b/fairseq/benchmark/__init__.py @@ -4,9 +4,4 @@ # LICENSE file in the root directory of this source tree. # import models/tasks to register them -from . import ( # noqa - dummy_lm, - dummy_masked_lm, - dummy_model, - dummy_mt, -) +from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 3c400e9d7f..6429d04de3 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -7,25 +7,27 @@ import numpy as np import torch - from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import register_task, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('dummy_lm') +@register_task("dummy_lm") class DummyLMTask(LegacyFairseqTask): - @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--dict-size', default=49996, type=int) - parser.add_argument('--dataset-size', default=100000, type=int) - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments ' - 'per sample for BERT dataset') + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) def __init__(self, args, dictionary): super().__init__(args) @@ -44,8 +46,8 @@ def setup_task(cls, args, **kwargs): """Setup the task. """ dictionary = Dictionary() for i in range(args.dict_size): - dictionary.add_symbol('word{}'.format(i)) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): @@ -59,16 +61,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( { - 'id': 1, - 'net_input': { - 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), - 'src_lengths': torch.full( - (bsz, ), self.args.tokens_per_sample, dtype=torch.long + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.tokens_per_sample, dtype=torch.long ), }, - 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), - 'nsentences': bsz, - 'ntokens': bsz * self.args.tokens_per_sample, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.args.tokens_per_sample, }, num_items=self.args.dataset_size, item_size=self.args.tokens_per_sample, @@ -84,7 +86,6 @@ def target_dictionary(self): class DummyDataset(FairseqDataset): - def __init__(self, batch, num_items, item_size): super().__init__() self.batch = batch diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index 621265d452..ab506fe1d5 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -7,32 +7,34 @@ import numpy as np import torch - from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import register_task, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('dummy_masked_lm') +@register_task("dummy_masked_lm") class DummyMaskedLMTask(LegacyFairseqTask): - @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--dict-size', default=49995, type=int) - parser.add_argument('--dataset-size', default=100000, type=int) - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments ' - 'per sample for BERT dataset') + parser.add_argument("--dict-size", default=49995, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary # add mask token - self.mask_idx = dictionary.add_symbol('') + self.mask_idx = dictionary.add_symbol("") dictionary.pad_to_multiple_(8) # often faster if divisible by 8 mask_idx = 0 @@ -52,8 +54,8 @@ def setup_task(cls, args, **kwargs): """Setup the task. """ dictionary = Dictionary() for i in range(args.dict_size): - dictionary.add_symbol('word{}'.format(i)) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): @@ -67,16 +69,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) self.datasets[split] = DummyDataset( { - 'id': 1, - 'net_input': { - 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), - 'src_lengths': torch.full( - (bsz, ), self.args.tokens_per_sample, dtype=torch.long + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.tokens_per_sample, dtype=torch.long ), }, - 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), - 'nsentences': bsz, - 'ntokens': bsz * self.args.tokens_per_sample, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.args.tokens_per_sample, }, num_items=self.args.dataset_size, item_size=self.args.tokens_per_sample, @@ -92,7 +94,6 @@ def target_dictionary(self): class DummyDataset(FairseqDataset): - def __init__(self, batch, num_items, item_size): super().__init__() self.batch = batch diff --git a/fairseq/benchmark/dummy_model.py b/fairseq/benchmark/dummy_model.py index 817cdb34bb..ff26e4fe65 100644 --- a/fairseq/benchmark/dummy_model.py +++ b/fairseq/benchmark/dummy_model.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.nn.functional as F - from fairseq.data import Dictionary from fairseq.models import ( FairseqDecoder, @@ -15,17 +14,16 @@ ) -@register_model('dummy_model') +@register_model("dummy_model") class DummyModel(FairseqLanguageModel): - def __init__(self, args, encoder): super().__init__(encoder) self.args = args @staticmethod def add_args(parser): - parser.add_argument('--num-layers', type=int, default=24) - parser.add_argument('--embed-dim', type=int, default=1024) + parser.add_argument("--num-layers", type=int, default=24) + parser.add_argument("--embed-dim", type=int, default=1024) @classmethod def build_model(cls, args, task): @@ -41,32 +39,35 @@ def forward(self, src_tokens, masked_tokens=None, **kwargs): class DummyEncoder(FairseqDecoder): - def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): super().__init__(Dictionary()) self.embed = nn.Embedding( num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 ) - self.layers_a = nn.ModuleList([ - nn.Sequential( - nn.LayerNorm(embed_dim), - nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection - nn.Linear(3*embed_dim, embed_dim), # skip self-attention - nn.Linear(embed_dim, embed_dim), # output projection - nn.Dropout(), - ) - for i in range(num_layers) - ]) - self.layers_b = nn.ModuleList([ - nn.Sequential( - nn.LayerNorm(embed_dim), - nn.Linear(embed_dim, 4*embed_dim), # FFN - nn.ReLU(), - nn.Linear(4*embed_dim, embed_dim), # FFN - nn.Dropout(0.1), - ) - for i in range(num_layers) - ]) + self.layers_a = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection + nn.Linear(3 * embed_dim, embed_dim), # skip self-attention + nn.Linear(embed_dim, embed_dim), # output projection + nn.Dropout(), + ) + for i in range(num_layers) + ] + ) + self.layers_b = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * embed_dim), # FFN + nn.ReLU(), + nn.Linear(4 * embed_dim, embed_dim), # FFN + nn.Dropout(0.1), + ) + for i in range(num_layers) + ] + ) self.out_proj = nn.Linear(embed_dim, num_embed) def forward(self, tokens, masked_tokens=None): @@ -90,6 +91,6 @@ def get_normalized_probs(self, net_output, log_probs, sample=None): return F.softmax(logits, dim=-1) -@register_model_architecture('dummy_model', 'dummy_model') +@register_model_architecture("dummy_model", "dummy_model") def base_architecture(args): pass diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 2f8d65d5be..4ca7be93a3 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -7,24 +7,22 @@ import numpy as np import torch - from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import register_task, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('dummy_mt') +@register_task("dummy_mt") class DummyMTTask(LegacyFairseqTask): - @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--dict-size', default=49996, type=int) - parser.add_argument('--dataset-size', default=100000, type=int) - parser.add_argument('--src-len', default=30, type=int) - parser.add_argument('--tgt-len', default=30, type=int) + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument("--src-len", default=30, type=int) + parser.add_argument("--tgt-len", default=30, type=int) def __init__(self, args, dictionary): super().__init__(args) @@ -41,8 +39,8 @@ def setup_task(cls, args, **kwargs): """Setup the task. """ dictionary = Dictionary() for i in range(args.dict_size): - dictionary.add_symbol('word{}'.format(i)) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) args.max_source_positions = args.src_len + dictionary.pad() + 2 args.max_target_positions = args.tgt_len + dictionary.pad() + 2 @@ -62,17 +60,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) self.datasets[split] = DummyDataset( { - 'id': 1, - 'net_input': { - 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), - 'src_lengths': torch.full( - (bsz, ), self.args.src_len, dtype=torch.long + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.src_len, dtype=torch.long ), - 'prev_output_tokens': tgt.clone(), + "prev_output_tokens": tgt.clone(), }, - 'target': tgt, - 'nsentences': bsz, - 'ntokens': bsz * self.args.tgt_len, + "target": tgt, + "nsentences": bsz, + "ntokens": bsz * self.args.tgt_len, }, num_items=self.args.dataset_size, item_size=item_size, @@ -88,7 +86,6 @@ def target_dictionary(self): class DummyDataset(FairseqDataset): - def __init__(self, batch, num_items, item_size): super().__init__() self.batch = batch diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index ec3b90f211..0255c084b5 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -6,9 +6,10 @@ import os from collections import Counter -from fairseq.tokenizer import tokenize_line import torch from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + def safe_readline(f): pos = f.tell() diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 60ab3190c7..75e2c68ca3 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -67,12 +67,14 @@ def is_better(a, b): or is_better(val_loss, save_checkpoint.best) ) if val_loss is not None and args.keep_best_checkpoints > 0: - checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( - args.best_checkpoint_metric, val_loss)] = ( - not hasattr(save_checkpoint, "best") - or is_better(val_loss, save_checkpoint.best) + checkpoint_conds[ + "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) + ] = not hasattr(save_checkpoint, "best") or is_better( + val_loss, save_checkpoint.best ) - checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not args.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): @@ -112,10 +114,14 @@ def is_better(a, b): if args.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( - args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric)) + args.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( + args.best_checkpoint_metric + ), + ) if not args.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] - for old_chk in checkpoints[args.keep_best_checkpoints:]: + for old_chk in checkpoints[args.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args): reset_meters = args.reset_meters reset_dataloader = args.reset_dataloader - if getattr(args, 'finetune_from_model', None) is not None \ - and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): - raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer" - " or reset_lr_scheduler or reset_meters or reset_dataloader") + if getattr(args, "finetune_from_model", None) is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) suffix = getattr(args, "checkpoint_suffix", "") - if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt' - checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) + if ( + args.restore_file == "checkpoint_last.pt" + ): # default value of restore_file is 'checkpoint_last.pt' + checkpoint_path = os.path.join( + args.save_dir, "checkpoint_last{}.pt".format(suffix) + ) first_launch = not PathManager.exists(checkpoint_path) - if getattr(args, 'finetune_from_model', None) is not None and first_launch: + if getattr(args, "finetune_from_model", None) is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(args.finetune_from_model): @@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args): reset_lr_scheduler = True reset_meters = True reset_dataloader = True - logger.info(f'loading pretrained model from {checkpoint_path}: ' - 'optimizer, lr scheduler, meters, dataloader will be reset') + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) else: - raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist') + raise ValueError( + f"--funetune-from-model {args.finetune_from_model} does not exist" + ) elif getattr(args, "model_parallel_size", 1) > 1: checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = args.restore_file - if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None): + if args.restore_file != "checkpoint_last.pt" and getattr( + args, "finetune_from_model", None + ): raise ValueError( - '--finetune-from-model and --restore-file (non-default value) ' - 'can not be specified together: ' + str(args)) + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(args) + ) extra_state = trainer.load_checkpoint( checkpoint_path, @@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): return state -def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): +def load_model_ensemble( + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 +): """Loads an ensemble of models. Args: @@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s were used during model training task (fairseq.tasks.FairseqTask, optional): task to use for loading """ - assert not (strict and num_shards > 1), \ - "Cannot load state dict with strict=True and checkpoint shards > 1" + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble, args, _task = load_model_ensemble_and_task( - filenames, arg_overrides, task, strict, suffix, num_shards, + filenames, + arg_overrides, + task, + strict, + suffix, + num_shards, ) return ensemble, args -def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1): +def load_model_ensemble_and_task( + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 +): from fairseq import tasks - assert not (strict and num_shards > 1), \ - "Cannot load state dict with strict=True and checkpoint shards > 1" + + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] for filename in filenames: orig_filename = filename @@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None: with open(temp_file_path, "w"): pass except OSError as e: - logger.warning("Unable to access checkpoint save directory: {}".format(save_dir)) + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) raise e else: os.remove(temp_file_path) diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py index 6671c696e9..65341c2d3b 100644 --- a/fairseq/criterions/composite_loss.py +++ b/fairseq/criterions/composite_loss.py @@ -3,13 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torch import nn - from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion +from torch import nn -@register_criterion('composite_loss') +@register_criterion("composite_loss") class CompositeLoss(FairseqCriterion): """This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair""" @@ -40,7 +39,6 @@ def build_criterion(cls, args, task): underlying_criterion = CompositeLoss.build_underlying_criterion(args, task) class FakeModel(nn.Module): - def __init__(self, model, net_out, target): super().__init__() self.model = model @@ -51,7 +49,9 @@ def forward(self, **unused): return self.net_out def get_normalized_probs(self, net_output, log_probs, sample=None): - return self.model.get_normalized_probs(net_output, log_probs, sample=sample) + return self.model.get_normalized_probs( + net_output, log_probs, sample=sample + ) def get_targets(self, *unused): return self.target @@ -61,14 +61,13 @@ def decoder(self): return self.model.decoder class _CompositeLoss(FairseqCriterion): - def __init__(self, task, underlying_criterion): super().__init__(task) self.underlying_criterion = underlying_criterion def forward(self, model, sample, reduce=True): - net_outputs = model(**sample['net_input']) - targets = sample['target'] + net_outputs = model(**sample["net_input"]) + targets = sample["target"] bsz = targets[0].size(0) loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_() @@ -77,7 +76,7 @@ def forward(self, model, sample, reduce=True): logging_output = {} for o, t in zip(net_outputs[0], targets): m = FakeModel(model, (o, net_outputs[1]), t) - sample['target'] = t + sample["target"] = t l, ss, logging_output = self.underlying_criterion(m, sample, reduce) loss += l sample_size += ss @@ -85,12 +84,14 @@ def forward(self, model, sample, reduce=True): loss.div_(len(targets)) sample_size /= len(targets) - logging_output['loss'] = utils.item(loss.data) if reduce else loss.data + logging_output["loss"] = utils.item(loss.data) if reduce else loss.data return loss, sample_size, logging_output @staticmethod def aggregate_logging_outputs(logging_outputs): - return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs) + return underlying_criterion.__class__.aggregate_logging_outputs( + logging_outputs + ) @staticmethod def reduce_metrics(logging_outputs) -> None: diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 0239a548a9..ef94a86327 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -6,25 +6,23 @@ import inspect from typing import Any, Dict, List -from torch.nn.modules.loss import _Loss - from fairseq import metrics, utils from fairseq.dataclass.utils import gen_parser_from_dataclass +from torch.nn.modules.loss import _Loss class FairseqCriterion(_Loss): - def __init__(self, task): super().__init__() self.task = task - if hasattr(task, 'target_dictionary'): + if hasattr(task, "target_dictionary"): tgt_dict = task.target_dictionary self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 @classmethod def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" - dc = getattr(cls, '__dataclass', None) + dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc()) @@ -43,20 +41,20 @@ def build_criterion(cls, args, task): ): # we haven't implemented inference for these argument types, # but PRs welcome :) - raise NotImplementedError('{} not supported'.format(p.kind)) + raise NotImplementedError("{} not supported".format(p.kind)) assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} - if p.name == 'task': - init_args['task'] = task + if p.name == "task": + init_args["task"] = task elif hasattr(args, p.name): init_args[p.name] = getattr(args, p.name) elif p.default != p.empty: pass # we'll use the default value else: raise NotImplementedError( - 'Unable to infer Criterion arguments, please implement ' - '{}.build_criterion'.format(cls.__name__) + "Unable to infer Criterion arguments, please implement " + "{}.build_criterion".format(cls.__name__) ) return cls(**init_args) @@ -76,8 +74,8 @@ def aggregate_logging_outputs( ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( - 'The aggregate_logging_outputs API is deprecated. ' - 'Please use the reduce_metrics API instead.' + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." ) raise NotImplementedError @@ -85,12 +83,12 @@ def aggregate_logging_outputs( def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( - 'Criterions should implement the reduce_metrics API. ' - 'Falling back to deprecated aggregate_logging_outputs API.' + "Criterions should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." ) agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) for k, v in agg_logging_outputs.items(): - if k in {'nsentences', 'ntokens', 'sample_size'}: + if k in {"nsentences", "ntokens", "sample_size"}: continue metrics.log_scalar(k, v) @@ -105,15 +103,14 @@ def logging_outputs_can_be_summed() -> bool: class LegacyFairseqCriterion(FairseqCriterion): - def __init__(self, args, task): super().__init__(task=task) self.args = args utils.deprecation_warning( - 'Criterions should take explicit arguments instead of an ' - 'argparse.Namespace object, please update your criterion by ' - 'extending FairseqCriterion instead of LegacyFairseqCriterion.' + "Criterions should take explicit arguments instead of an " + "argparse.Namespace object, please update your criterion by " + "extending FairseqCriterion instead of LegacyFairseqCriterion." ) @classmethod diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 931a8f76d5..2dc7f7a47d 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -6,7 +6,6 @@ import math import torch - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -18,8 +17,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T smooth_loss = -lprobs.sum(dim=-1, keepdim=True) if ignore_index is not None: pad_mask = target.eq(ignore_index) - nll_loss.masked_fill_(pad_mask, 0.) - smooth_loss.masked_fill_(pad_mask, 0.) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) @@ -27,14 +26,20 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) - loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss, nll_loss -@register_criterion('label_smoothed_cross_entropy') +@register_criterion("label_smoothed_cross_entropy") class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): - def __init__(self, task, sentence_avg, label_smoothing, - ignore_prefix_size=0, report_accuracy=False): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + ): super().__init__(task) self.sentence_avg = sentence_avg self.eps = label_smoothing @@ -61,20 +66,22 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) + net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { - 'loss': loss.data, - 'nll_loss': nll_loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } if self.report_accuracy: n_correct, total = self.compute_accuracy(model, net_output, sample) - logging_output['n_correct'] = utils.item(n_correct.data) - logging_output['total'] = utils.item(total.data) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) return loss, sample_size, logging_output def get_lprobs_and_target(self, model, net_output, sample): @@ -82,17 +89,21 @@ def get_lprobs_and_target(self, model, net_output, sample): target = model.get_targets(sample, net_output) if self.ignore_prefix_size > 0: if getattr(lprobs, "batch_first", False): - lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() - target = target[:, self.ignore_prefix_size:].contiguous() + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() else: - lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous() - target = target[self.ignore_prefix_size:, :].contiguous() + lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() + target = target[self.ignore_prefix_size :, :].contiguous() return lprobs.view(-1, lprobs.size(-1)), target.view(-1) def compute_loss(self, model, net_output, sample, reduce=True): lprobs, target = self.get_lprobs_and_target(model, net_output, sample) loss, nll_loss = label_smoothed_nll_loss( - lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, ) return loss, nll_loss @@ -100,34 +111,43 @@ def compute_accuracy(self, model, net_output, sample): lprobs, target = self.get_lprobs_and_target(model, net_output, sample) mask = target.ne(self.padding_idx) n_correct = torch.sum( - lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))) + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) + ) total = torch.sum(mask) return n_correct, total @classmethod def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) - metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) - total = utils.item(sum(log.get('total', 0) for log in logging_outputs)) + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) if total > 0: - metrics.log_scalar('total', total) + metrics.log_scalar("total", total) n_correct = utils.item( - sum(log.get('n_correct', 0) for log in logging_outputs) + sum(log.get("n_correct", 0) for log in logging_outputs) ) - metrics.log_scalar('n_correct', n_correct) + metrics.log_scalar("n_correct", n_correct) metrics.log_derived( - 'accuracy', + "accuracy", lambda meters: round( - meters['n_correct'].sum * 100.0 / meters['total'].sum, 3 - ) if meters['total'].sum > 0 else float('nan'), + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), ) @staticmethod diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py index cfc7e008cd..73cfa05310 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -11,9 +11,10 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion -@register_criterion('label_smoothed_cross_entropy_with_alignment') -class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion): - +@register_criterion("label_smoothed_cross_entropy_with_alignment") +class LabelSmoothedCrossEntropyCriterionWithAlignment( + LabelSmoothedCrossEntropyCriterion +): def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): super().__init__(task, sentence_avg, label_smoothing) self.alignment_lambda = alignment_lambda @@ -22,8 +23,13 @@ def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): def add_args(parser): """Add criterion-specific arguments to the parser.""" LabelSmoothedCrossEntropyCriterion.add_args(parser) - parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D', - help='weight for the alignment loss') + parser.add_argument( + "--alignment-lambda", + default=0.05, + type=float, + metavar="D", + help="weight for the alignment loss", + ) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -33,41 +39,46 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) + net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": utils.item(loss.data) if reduce else loss.data, + "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } alignment_loss = None # Compute alignment loss only for training set and non dummy batches. - if 'alignments' in sample and sample['alignments'] is not None: + if "alignments" in sample and sample["alignments"] is not None: alignment_loss = self.compute_alignment_loss(sample, net_output) if alignment_loss is not None: - logging_output['alignment_loss'] = utils.item(alignment_loss.data) + logging_output["alignment_loss"] = utils.item(alignment_loss.data) loss += self.alignment_lambda * alignment_loss return loss, sample_size, logging_output def compute_alignment_loss(self, sample, net_output): - attn_prob = net_output[1]['attn'][0] + attn_prob = net_output[1]["attn"][0] bsz, tgt_sz, src_sz = attn_prob.shape attn = attn_prob.view(bsz * tgt_sz, src_sz) - align = sample['alignments'] - align_weights = sample['align_weights'].float() + align = sample["alignments"] + align_weights = sample["align_weights"].float() if len(align) > 0: # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. - loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum() + loss = -( + (attn[align[:, 1][:, None], align[:, 0][:, None]]).log() + * align_weights[:, None] + ).sum() else: return None @@ -76,16 +87,33 @@ def compute_alignment_loss(self, sample, net_output): @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) - nll_loss_sum = utils.item(sum(log.get('nll_loss', 0) for log in logging_outputs)) - alignment_loss_sum = utils.item(sum(log.get('alignment_loss', 0) for log in logging_outputs)) - ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) - sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) - - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) - metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_scalar('alignment_loss', alignment_loss_sum / sample_size / math.log(2), sample_size, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss_sum = utils.item( + sum(log.get("nll_loss", 0) for log in logging_outputs) + ) + alignment_loss_sum = utils.item( + sum(log.get("alignment_loss", 0) for log in logging_outputs) + ) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_scalar( + "alignment_loss", + alignment_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py index 3dbfdfbe46..c70608c5a1 100644 --- a/fairseq/criterions/legacy_masked_lm.py +++ b/fairseq/criterions/legacy_masked_lm.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -18,8 +17,9 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100): ignore_index is the same as the default value for F.cross_entropy in pytorch. """ - assert logits.size(0) == targets.size(-1), \ - "Logits and Targets tensor shapes don't match up" + assert logits.size(0) == targets.size( + -1 + ), "Logits and Targets tensor shapes don't match up" loss = F.nll_loss( F.log_softmax(logits, -1, dtype=torch.float32), @@ -30,7 +30,7 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100): return loss -@register_criterion('legacy_masked_lm_loss') +@register_criterion("legacy_masked_lm_loss") class LegacyMaskedLmLoss(FairseqCriterion): """ Implementation for the loss used in masked language model (MLM) training. @@ -57,11 +57,18 @@ def __init__(self, task, masked_lm_only, nsp_loss_weight): def add_args(parser): """Args for MaskedLM Loss""" # Default for masked_lm_only is False so as to not break BERT training - parser.add_argument('--masked-lm-only', default=False, - action='store_true', help='compute MLM loss only') - parser.add_argument('--nsp-loss-weight', default=1.0, type=float, - help='weight for next sentence prediction' - ' loss (default 1)') + parser.add_argument( + "--masked-lm-only", + default=False, + action="store_true", + help="compute MLM loss only", + ) + parser.add_argument( + "--nsp-loss-weight", + default=1.0, + type=float, + help="weight for next sentence prediction" " loss (default 1)", + ) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -74,22 +81,21 @@ def forward(self, model, sample, reduce=True): # reshape lm_logits from (N,T,C) to (N*T,C) lm_logits = lm_logits.view(-1, lm_logits.size(-1)) - lm_targets = sample['lm_target'].view(-1) - lm_loss = compute_cross_entropy_loss( - lm_logits, lm_targets, self.padding_idx) + lm_targets = sample["lm_target"].view(-1) + lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx) # compute the number of tokens for which loss is computed. This is used # to normalize the loss ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel() loss = lm_loss / ntokens - nsentences = sample['nsentences'] + nsentences = sample["nsentences"] # nsentences = 0 # Compute sentence loss if masked_lm_only is False sentence_loss = None if not self.masked_lm_only: - sentence_logits = output_metadata['sentence_logits'] - sentence_targets = sample['sentence_target'].view(-1) + sentence_logits = output_metadata["sentence_logits"] + sentence_targets = sample["sentence_target"].view(-1) # This needs to be recomputed due to some differences between # TokenBlock and BlockPair dataset. This can be resolved with a # refactor of BERTModel which we will do in the future. @@ -102,7 +108,8 @@ def forward(self, model, sample, reduce=True): # refactor in the BERT model. if sentence_logits is not None: sentence_loss = compute_cross_entropy_loss( - sentence_logits, sentence_targets) + sentence_logits, sentence_targets + ) loss += self.nsp_loss_weight * (sentence_loss / nsentences) @@ -111,36 +118,54 @@ def forward(self, model, sample, reduce=True): # here sample_size is just used for logging sample_size = 1 logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data, + "loss": utils.item(loss.data) if reduce else loss.data, + "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data, # sentence loss is not always computed - 'sentence_loss': ( - ( - utils.item(sentence_loss.data) if reduce - else sentence_loss.data - ) if sentence_loss is not None else 0.0 + "sentence_loss": ( + (utils.item(sentence_loss.data) if reduce else sentence_loss.data) + if sentence_loss is not None + else 0.0 ), - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs) - sentence_loss_sum = sum( - log.get('sentence_loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) - agg_loss = sum(log.get('loss', 0) for log in logging_outputs) - - metrics.log_scalar('loss', agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., sample_size, round=3) - metrics.log_scalar('lm_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) - metrics.log_scalar('sentence_loss', sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., nsentences, round=3) - metrics.log_scalar('nll_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) + lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs) + sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + agg_loss = sum(log.get("loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", + agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0, + sample_size, + round=3, + ) + metrics.log_scalar( + "lm_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) + metrics.log_scalar( + "sentence_loss", + sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0, + nsentences, + round=3, + ) + metrics.log_scalar( + "nll_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index f62ed805f2..b04cfbff6d 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -7,12 +7,11 @@ import torch import torch.nn.functional as F - from fairseq import metrics, modules, utils from fairseq.criterions import FairseqCriterion, register_criterion -@register_criterion('masked_lm') +@register_criterion("masked_lm") class MaskedLmLoss(FairseqCriterion): """ Implementation for the loss used in masked language model (MLM) training. @@ -30,7 +29,7 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - masked_tokens = sample['target'].ne(self.padding_idx) + masked_tokens = sample["target"].ne(self.padding_idx) sample_size = masked_tokens.int().sum() # Rare: when all tokens are masked, project all tokens. @@ -39,7 +38,7 @@ def forward(self, model, sample, reduce=True): # (see github.com/pytorch/pytorch/issues/26247). if self.tpu: masked_tokens = None # always project all tokens on TPU - elif masked_tokens.device == torch.device('cpu'): + elif masked_tokens.device == torch.device("cpu"): if not masked_tokens.any(): masked_tokens = None else: @@ -49,7 +48,7 @@ def forward(self, model, sample, reduce=True): masked_tokens.new([True]), ) - logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0] + logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) if masked_tokens is not None: targets = targets[masked_tokens] @@ -57,26 +56,30 @@ def forward(self, model, sample, reduce=True): loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), - reduction='sum', + reduction="sum", ignore_index=self.padding_idx, ) logging_output = { - 'loss': loss if self.tpu else loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['nsentences'], - 'sample_size': sample_size, + "loss": loss if self.tpu else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg)) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py index 3326734d55..cdc7da861d 100644 --- a/fairseq/criterions/nat_loss.py +++ b/fairseq/criterions/nat_loss.py @@ -5,17 +5,15 @@ import math -import torch.nn.functional as F import torch -from torch import Tensor - +import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from torch import Tensor @register_criterion("nat_loss") class LabelSmoothedDualImitationCriterion(FairseqCriterion): - def __init__(self, task, label_smoothing): super().__init__(task) self.label_smoothing = label_smoothing @@ -24,23 +22,23 @@ def __init__(self, task, label_smoothing): def add_args(parser): """Add criterion-specific arguments to the parser.""" parser.add_argument( - '--label-smoothing', - default=0., + "--label-smoothing", + default=0.0, type=float, - metavar='D', - help='epsilon for label smoothing, 0 means no label smoothing', + metavar="D", + help="epsilon for label smoothing, 0 means no label smoothing", ) def _compute_loss( self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 ): """ - outputs: batch x len x d_model - targets: batch x len - masks: batch x len + outputs: batch x len x d_model + targets: batch x len + masks: batch x len - policy_logprob: if there is some policy - depends on the likelihood score as rewards. + policy_logprob: if there is some policy + depends on the likelihood score as rewards. """ def mean_ds(x: Tensor, dim=None) -> Tensor: @@ -49,6 +47,7 @@ def mean_ds(x: Tensor, dim=None) -> Tensor: if dim is None else x.float().mean(dim).type_as(x) ) + if masks is not None: outputs, targets = outputs[masks], targets[masks] @@ -58,16 +57,17 @@ def mean_ds(x: Tensor, dim=None) -> Tensor: else: logits = F.log_softmax(outputs, dim=-1) if targets.dim() == 1: - losses = F.nll_loss(logits, targets.to(logits.device), reduction='none') + losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") else: # soft-labels - losses = F.kl_div(logits, targets.to(logits.device), reduction='none') + losses = F.kl_div(logits, targets.to(logits.device), reduction="none") losses = losses.sum(-1) nll_loss = mean_ds(losses) if label_smoothing > 0: - loss = nll_loss * ( - 1 - label_smoothing) - mean_ds(logits) * label_smoothing + loss = ( + nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing + ) else: loss = nll_loss @@ -103,14 +103,14 @@ def forward(self, model, sample, reduce=True): outputs[obj].get("tgt"), outputs[obj].get("mask", None), outputs[obj].get("ls", 0.0), - name=obj + '-loss', - factor=outputs[obj].get("factor", 1.0) + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), ) else: _losses = self._custom_loss( outputs[obj].get("loss"), - name=obj + '-loss', - factor=outputs[obj].get("factor", 1.0) + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), ) losses += [_losses] @@ -118,8 +118,7 @@ def forward(self, model, sample, reduce=True): nll_loss += [_losses.get("nll_loss", 0.0)] loss = sum(l["loss"] for l in losses) - nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 \ - else loss.new_tensor(0) + nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0) # NOTE: # we don't need to use sample_size as denominator for the gradient @@ -145,13 +144,21 @@ def forward(self, model, sample, reduce=True): @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs)) - metrics.log_scalar('loss', loss / sample_size / math.log(2), sample_size, round=3) - metrics.log_scalar('nll_loss', nll_loss / sample_size / math.log(2), sample_size, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg)) + metrics.log_scalar( + "loss", loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) for key in logging_outputs[0]: if key[-5:] == "-loss": diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py index 4ba1317856..9519fdc56d 100644 --- a/fairseq/criterions/sentence_prediction.py +++ b/fairseq/criterions/sentence_prediction.py @@ -7,14 +7,12 @@ import torch import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -@register_criterion('sentence_prediction') +@register_criterion("sentence_prediction") class SentencePredictionCriterion(FairseqCriterion): - def __init__(self, task, classification_head_name, regression_target): super().__init__(task) self.classification_head_name = classification_head_name @@ -37,12 +35,12 @@ def forward(self, model, sample, reduce=True): 3) logging outputs to display while training """ assert ( - hasattr(model, 'classification_heads') + hasattr(model, "classification_heads") and self.classification_head_name in model.classification_heads - ), 'model must provide sentence classification head for --criterion=sentence_prediction' + ), "model must provide sentence classification head for --criterion=sentence_prediction" logits, _ = model( - **sample['net_input'], + **sample["net_input"], features_only=True, classification_head_name=self.classification_head_name, ) @@ -51,39 +49,45 @@ def forward(self, model, sample, reduce=True): if not self.regression_target: lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - loss = F.nll_loss(lprobs, targets, reduction='sum') + loss = F.nll_loss(lprobs, targets, reduction="sum") else: logits = logits.view(-1).float() targets = targets.float() - loss = F.mse_loss(logits, targets, reduction='sum') + loss = F.mse_loss(logits, targets, reduction="sum") logging_output = { - 'loss': loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample_size, - 'sample_size': sample_size, + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, } if not self.regression_target: preds = logits.argmax(dim=1) - logging_output['ncorrect'] = (preds == targets).sum() + logging_output["ncorrect"] = (preds == targets).sum() return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3) - - if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]: - ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) - metrics.log_scalar('accuracy', 100.0 * ncorrect / nsentences, nsentences, round=1) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/sentence_ranking.py b/fairseq/criterions/sentence_ranking.py index 52a0a177d8..d4c76341d4 100644 --- a/fairseq/criterions/sentence_ranking.py +++ b/fairseq/criterions/sentence_ranking.py @@ -7,19 +7,17 @@ import torch import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -@register_criterion('sentence_ranking') +@register_criterion("sentence_ranking") class SentenceRankingCriterion(FairseqCriterion): - def __init__(self, task, ranking_head_name, save_predictions, num_classes): super().__init__(task) self.ranking_head_name = ranking_head_name if save_predictions is not None: - self.prediction_h = open(save_predictions, 'w') + self.prediction_h = open(save_predictions, "w") else: self.prediction_h = None self.num_classes = num_classes @@ -47,14 +45,14 @@ def forward(self, model, sample, reduce=True): 3) logging outputs to display while training """ assert ( - hasattr(model, 'classification_heads') + hasattr(model, "classification_heads") and self.ranking_head_name in model.classification_heads - ), 'model must provide sentence ranking head for --criterion=sentence_ranking' + ), "model must provide sentence ranking head for --criterion=sentence_ranking" scores = [] for idx in range(self.num_classes): score, _ = model( - **sample['net_input{idx}'.format(idx=idx+1)], + **sample["net_input{idx}".format(idx=idx + 1)], classification_head_name=self.ranking_head_name, ) scores.append(score) @@ -62,49 +60,55 @@ def forward(self, model, sample, reduce=True): logits = torch.cat(scores, dim=1) sample_size = logits.size(0) - if 'target' in sample: + if "target" in sample: targets = model.get_targets(sample, [logits]).view(-1) lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - loss = F.nll_loss(lprobs, targets, reduction='sum') + loss = F.nll_loss(lprobs, targets, reduction="sum") else: targets = None loss = torch.tensor(0.0, requires_grad=True) if self.prediction_h is not None: preds = logits.argmax(dim=1) - for i, (id, pred) in enumerate(zip(sample['id'].tolist(), preds.tolist())): + for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())): if targets is not None: label = targets[i].item() - print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h) + print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) else: - print('{}\t{}'.format(id, pred), file=self.prediction_h) + print("{}\t{}".format(id, pred), file=self.prediction_h) logging_output = { - 'loss': loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample_size, - 'sample_size': sample_size, + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, } if targets is not None: - logging_output['ncorrect'] = (logits.argmax(dim=1) == targets).sum() + logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum() return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) - - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) - if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]: - ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) - metrics.log_scalar('accuracy', 100.0 * ncorrect / nsentences, nsentences, round=1) + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index cc743524d2..6ac7557dcc 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -7,15 +7,13 @@ import torch import torch.nn.functional as F - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.logging.meters import safe_round -@register_criterion('wav2vec') +@register_criterion("wav2vec") class Wav2vecCriterion(FairseqCriterion): - def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): super().__init__(task) self.infonce = infonce @@ -42,12 +40,12 @@ def forward(self, model, sample, reduce=True, log_pred=False): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) + net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) weights = None - if hasattr(model, 'get_target_weights') and not self.infonce: + if hasattr(model, "get_target_weights") and not self.infonce: weights = model.get_target_weights(target, net_output) if torch.is_tensor(weights): weights = weights.float() @@ -55,9 +53,18 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses = [] if self.infonce: - loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",) + loss = F.cross_entropy( + logits, + target, + reduction="sum" if reduce else "none", + ) else: - loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) + loss = F.binary_cross_entropy_with_logits( + logits, + target.float(), + weights, + reduction="sum" if reduce else "none", + ) sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss.detach().clone()) @@ -69,7 +76,9 @@ def forward(self, model, sample, reduce=True, log_pred=False): extra_losses = [extra_losses] if len(self.loss_weights) == 1 and len(extra_losses) != 1: self.loss_weights = [self.loss_weights[0]] * len(extra_losses) - assert len(extra_losses) == len(self.loss_weights), f'{len(extra_losses)}, {len(self.loss_weights)}' + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" for p, coef in zip(extra_losses, self.loss_weights): if coef != 0 and p is not None: p = coef * p.float() * sample_size @@ -77,10 +86,10 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses.append(p) logging_output = { - 'loss': loss.item() if reduce else loss, - 'ntokens': sample_size, - 'nsentences': sample['id'].numel(), - 'sample_size': sample_size, + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, } for lk in self.log_keys: @@ -89,7 +98,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f'loss_{i}'] = l.item() + logging_output[f"loss_{i}"] = l.item() if self.infonce: with torch.no_grad(): @@ -108,21 +117,27 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output["count"] = count if log_pred: - logging_output['logits'] = logits.cpu().numpy() - logging_output['target'] = target.cpu().numpy() + logging_output["logits"] = logits.cpu().numpy() + logging_output["target"] = target.cpu().numpy() return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) - ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) - nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) - sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) - - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) - metrics.log_scalar('ntokens', ntokens) - metrics.log_scalar('nsentences', nsentences) + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) correct = sum(log.get("correct", 0) for log in logging_outputs) metrics.log_scalar("_correct", correct) @@ -130,21 +145,31 @@ def reduce_metrics(logging_outputs) -> None: total = sum(log.get("count", 0) for log in logging_outputs) metrics.log_scalar("_total", total) - if total > 0: metrics.log_derived( "accuracy", - lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5) + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) if meters["_total"].sum > 0 else float("nan"), ) - builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'} + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "correct", + "count", + } for k in logging_outputs[0]: if k not in builtin_keys: - val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs) - if k.startswith('loss'): + val = sum(log.get(k, 0) for log in logging_outputs) / len( + logging_outputs + ) + if k.startswith("loss"): metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) else: metrics.log_scalar(k, val, round=3) diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 785a0aa643..9b30813955 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -20,7 +20,12 @@ from .concat_sentences_dataset import ConcatSentencesDataset from .denoising_dataset import DenoisingDataset from .id_dataset import IdDataset -from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset +from .indexed_dataset import ( + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + MMapIndexedDataset, +) from .language_pair_dataset import LanguagePairDataset from .list_dataset import ListDataset from .lm_context_window_dataset import LMContextWindowDataset @@ -60,60 +65,60 @@ ) __all__ = [ - 'AddTargetDataset', - 'AppendTokenDataset', - 'BacktranslationDataset', - 'BaseWrapperDataset', - 'BucketPadLengthDataset', - 'ColorizeDataset', - 'ConcatDataset', - 'ConcatSentencesDataset', - 'CountingIterator', - 'DenoisingDataset', - 'Dictionary', - 'EncodedFastaDataset', - 'EpochBatchIterator', - 'FairseqDataset', - 'FairseqIterableDataset', - 'FastaDataset', - 'GroupedIterator', - 'IdDataset', - 'IndexedCachedDataset', - 'IndexedDataset', - 'IndexedRawTextDataset', - 'LanguagePairDataset', - 'LeftPadDataset', - 'ListDataset', - 'LMContextWindowDataset', - 'LRUCacheDataset', - 'MaskTokensDataset', - 'MMapIndexedDataset', - 'MonolingualDataset', - 'MultiCorpusSampledDataset', - 'NestedDictionaryDataset', - 'NoisingDataset', - 'NumelDataset', - 'NumSamplesDataset', - 'OffsetTokensDataset', - 'PadDataset', - 'PrependDataset', - 'PrependTokenDataset', - 'ReplaceDataset', - 'RollDataset', - 'FileAudioDataset', - 'RawLabelDataset', - 'ResamplingDataset', - 'RightPadDataset', - 'RoundRobinZipDatasets', - 'SampledMultiDataset', - 'SampledMultiEpochDataset', - 'ShardedIterator', - 'SortDataset', - 'StripTokenDataset', - 'SubsampleDataset', - 'TokenBlockDataset', - 'TransformEosDataset', - 'TransformEosLangPairDataset', - 'TruncateDataset', - 'TruncatedDictionary', + "AddTargetDataset", + "AppendTokenDataset", + "BacktranslationDataset", + "BaseWrapperDataset", + "BucketPadLengthDataset", + "ColorizeDataset", + "ConcatDataset", + "ConcatSentencesDataset", + "CountingIterator", + "DenoisingDataset", + "Dictionary", + "EncodedFastaDataset", + "EpochBatchIterator", + "FairseqDataset", + "FairseqIterableDataset", + "FastaDataset", + "GroupedIterator", + "IdDataset", + "IndexedCachedDataset", + "IndexedDataset", + "IndexedRawTextDataset", + "LanguagePairDataset", + "LeftPadDataset", + "ListDataset", + "LMContextWindowDataset", + "LRUCacheDataset", + "MaskTokensDataset", + "MMapIndexedDataset", + "MonolingualDataset", + "MultiCorpusSampledDataset", + "NestedDictionaryDataset", + "NoisingDataset", + "NumelDataset", + "NumSamplesDataset", + "OffsetTokensDataset", + "PadDataset", + "PrependDataset", + "PrependTokenDataset", + "ReplaceDataset", + "RollDataset", + "FileAudioDataset", + "RawLabelDataset", + "ResamplingDataset", + "RightPadDataset", + "RoundRobinZipDatasets", + "SampledMultiDataset", + "SampledMultiEpochDataset", + "ShardedIterator", + "SortDataset", + "StripTokenDataset", + "SubsampleDataset", + "TokenBlockDataset", + "TransformEosDataset", + "TransformEosLangPairDataset", + "TruncateDataset", + "TruncatedDictionary", ] diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py index 3a42dd7a2e..9ef467058b 100644 --- a/fairseq/data/add_target_dataset.py +++ b/fairseq/data/add_target_dataset.py @@ -5,12 +5,20 @@ import torch -from . import BaseWrapperDataset -from . import data_utils +from . import BaseWrapperDataset, data_utils class AddTargetDataset(BaseWrapperDataset): - def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None, add_to_input=False): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + add_to_input=False, + ): super().__init__(dataset) self.labels = labels self.batch_targets = batch_targets @@ -20,7 +28,11 @@ def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None, self.add_to_input = add_to_input def get_label(self, index): - return self.labels[index] if self.process_label is None else self.process_label(self.labels[index]) + return ( + self.labels[index] + if self.process_label is None + else self.process_label(self.labels[index]) + ) def __getitem__(self, index): item = self.dataset[index] @@ -51,6 +63,8 @@ def collater(self, samples): if self.add_to_input: eos = target.new_full((target.size(0), 1), self.eos) collated["target"] = torch.cat([target, eos], dim=-1).long() - collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat( + [eos, target], dim=-1 + ).long() collated["ntokens"] += target.size(0) - return collated \ No newline at end of file + return collated diff --git a/fairseq/data/append_token_dataset.py b/fairseq/data/append_token_dataset.py index 7298129f62..87695bd0f5 100644 --- a/fairseq/data/append_token_dataset.py +++ b/fairseq/data/append_token_dataset.py @@ -10,7 +10,6 @@ class AppendTokenDataset(BaseWrapperDataset): - def __init__(self, dataset, token=None): super().__init__(dataset) self.token = token diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index 3731721953..de08669851 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,11 +1,11 @@ import os.path as op -from typing import Union, BinaryIO, Optional, Tuple +from typing import BinaryIO, Optional, Tuple, Union import numpy as np def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization=True + path_or_fp: Union[str, BinaryIO], normalization=True ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. @@ -15,15 +15,15 @@ def get_waveform( """ if isinstance(path_or_fp, str): ext = op.splitext(op.basename(path_or_fp))[1] - if ext not in {'.flac', '.wav'}: - raise ValueError(f'Unsupported audio format: {ext}') + if ext not in {".flac", ".wav"}: + raise ValueError(f"Unsupported audio format: {ext}") try: import soundfile as sf except ImportError: - raise ImportError('Please install soundfile to load WAV/FLAC file') + raise ImportError("Please install soundfile to load WAV/FLAC file") - waveform, sample_rate = sf.read(path_or_fp, dtype='float32') + waveform, sample_rate = sf.read(path_or_fp, dtype="float32") if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers return waveform, sample_rate @@ -56,9 +56,11 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr try: import torch import torchaudio.compliance.kaldi as ta_kaldi + waveform = torch.from_numpy(waveform).unsqueeze(0) - features = ta_kaldi.fbank(waveform, num_mel_bins=n_bins, - sample_frequency=sample_rate) + features = ta_kaldi.fbank( + waveform, num_mel_bins=n_bins, sample_frequency=sample_rate + ) return features.numpy() except ImportError: return None @@ -75,7 +77,9 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: if features is None: features = _get_torchaudio_fbank(sound, sample_rate, n_bins) if features is None: - raise ImportError('Please install pyKaldi or torchaudio to enable ' - 'online filterbank feature extraction') + raise ImportError( + "Please install pyKaldi or torchaudio to enable " + "online filterbank feature extraction" + ) return features diff --git a/fairseq/data/audio/feature_transforms/__init__.py b/fairseq/data/audio/feature_transforms/__init__.py index 399956a33b..359fa06971 100644 --- a/fairseq/data/audio/feature_transforms/__init__.py +++ b/fairseq/data/audio/feature_transforms/__init__.py @@ -1,7 +1,7 @@ import importlib import os -from typing import Optional, Dict from abc import ABC, abstractmethod +from typing import Dict, Optional class AudioFeatureTransform(ABC): @@ -18,14 +18,16 @@ def from_config_dict(cls, config: Optional[Dict] = None): def register_audio_feature_transform(name): def register_audio_feature_transform_cls(cls): if name in AUDIO_FEATURE_TRANSFORM_REGISTRY: - raise ValueError(f'Cannot register duplicate transform ({name})') + raise ValueError(f"Cannot register duplicate transform ({name})") if not issubclass(cls, AudioFeatureTransform): - raise ValueError(f'Transform ({name}: {cls.__name__}) must extend ' - 'AudioFeatureTransform') + raise ValueError( + f"Transform ({name}: {cls.__name__}) must extend " + "AudioFeatureTransform" + ) if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES: raise ValueError( - f'Cannot register audio feature transform with duplicate ' - f'class name ({cls.__name__})' + f"Cannot register audio feature transform with duplicate " + f"class name ({cls.__name__})" ) AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__) @@ -42,19 +44,19 @@ def get_audio_feature_transform(name): for file in os.listdir(transforms_dir): path = os.path.join(transforms_dir, file) if ( - not file.startswith('_') - and not file.startswith('.') - and (file.endswith('.py') or os.path.isdir(path)) + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) ): - name = file[:file.find('.py')] if file.endswith('.py') else file - importlib.import_module('fairseq.data.audio.feature_transforms.' + name) + name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module("fairseq.data.audio.feature_transforms." + name) class CompositeAudioFeatureTransform(AudioFeatureTransform): @classmethod def from_config_dict(cls, config=None): _config = {} if config is None else config - _transforms = _config.get('transforms') + _transforms = _config.get("transforms") if _transforms is None: return None transforms = [ @@ -72,6 +74,9 @@ def __call__(self, x): return x def __repr__(self): - format_string = [self.__class__.__name__ + '('] + \ - [f" {t.__repr__()}" for t in self.transforms] + [')'] - return '\n'.join(format_string) + format_string = ( + [self.__class__.__name__ + "("] + + [f" {t.__repr__()}" for t in self.transforms] + + [")"] + ) + return "\n".join(format_string) diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py index f9c92a66b1..d512fed300 100644 --- a/fairseq/data/audio/feature_transforms/global_cmvn.py +++ b/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -1,10 +1,11 @@ import numpy as np from fairseq.data.audio.feature_transforms import ( - AudioFeatureTransform, register_audio_feature_transform + AudioFeatureTransform, + register_audio_feature_transform, ) -@register_audio_feature_transform('global_cmvn') +@register_audio_feature_transform("global_cmvn") class GlobalCMVN(AudioFeatureTransform): """Global CMVN (cepstral mean and variance normalization). The global mean and variance need to be pre-computed and stored in NumPy format (.npz).""" @@ -12,11 +13,11 @@ class GlobalCMVN(AudioFeatureTransform): @classmethod def from_config_dict(cls, config=None): _config = {} if config is None else config - return GlobalCMVN(_config.get('stats_npz_path')) + return GlobalCMVN(_config.get("stats_npz_path")) def __init__(self, stats_npz_path): stats = np.load(stats_npz_path) - self.mean, self.std = stats['mean'], stats['std'] + self.mean, self.std = stats["mean"], stats["std"] def __call__(self, x): x = np.subtract(x, self.mean) diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py index e4c36bde3c..2ef4778b85 100644 --- a/fairseq/data/audio/feature_transforms/specaugment.py +++ b/fairseq/data/audio/feature_transforms/specaugment.py @@ -3,13 +3,13 @@ from typing import Optional import numpy as np - from fairseq.data.audio.feature_transforms import ( - AudioFeatureTransform, register_audio_feature_transform + AudioFeatureTransform, + register_audio_feature_transform, ) -@register_audio_feature_transform('specaugment') +@register_audio_feature_transform("specaugment") class SpecAugmentTransform(AudioFeatureTransform): """SpecAugment (https://arxiv.org/abs/1904.08779)""" @@ -17,13 +17,13 @@ class SpecAugmentTransform(AudioFeatureTransform): def from_config_dict(cls, config=None): _config = {} if config is None else config return SpecAugmentTransform( - _config.get('time_warp_W', 0), - _config.get('freq_mask_N', 0), - _config.get('freq_mask_F', 0), - _config.get('time_mask_N', 0), - _config.get('time_mask_T', 0), - _config.get('time_mask_p', 0.0), - _config.get('mask_value', None), + _config.get("time_warp_W", 0), + _config.get("freq_mask_N", 0), + _config.get("freq_mask_F", 0), + _config.get("time_mask_N", 0), + _config.get("time_mask_T", 0), + _config.get("time_mask_p", 0.0), + _config.get("mask_value", None), ) def __init__( @@ -41,15 +41,15 @@ def __init__( mask_value, numbers.Number ), f"mask_value (type: {type(mask_value)}) must be None or a number" if freq_mask_n > 0: - assert ( - freq_mask_f > 0 - ), f"freq_mask_F ({freq_mask_f}) " \ - f"must be larger than 0 when doing freq masking." + assert freq_mask_f > 0, ( + f"freq_mask_F ({freq_mask_f}) " + f"must be larger than 0 when doing freq masking." + ) if time_mask_n > 0: - assert ( - time_mask_t > 0 - ), f"time_mask_T ({time_mask_t}) must be larger than 0 when " \ - f"doing time masking." + assert time_mask_t > 0, ( + f"time_mask_T ({time_mask_t}) must be larger than 0 when " + f"doing time masking." + ) self.time_warp_w = time_warp_w self.freq_mask_n = freq_mask_n @@ -60,14 +60,21 @@ def __init__( self.mask_value = mask_value def __repr__(self): - return self.__class__.__name__ + '(' + ', '.join( - [f'time_warp_w={self.time_warp_w}', - f'freq_mask_n={self.freq_mask_n}', - f'freq_mask_f={self.freq_mask_f}', - f'time_mask_n={self.time_mask_n}', - f'time_mask_t={self.time_mask_t}', - f'time_mask_p={self.time_mask_p}'] - ) + ')' + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"time_warp_w={self.time_warp_w}", + f"freq_mask_n={self.freq_mask_n}", + f"freq_mask_f={self.freq_mask_f}", + f"time_mask_n={self.time_mask_n}", + f"time_mask_t={self.time_mask_t}", + f"time_mask_p={self.time_mask_p}", + ] + ) + + ")" + ) def __call__(self, spectrogram): assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." @@ -89,14 +96,12 @@ def __call__(self, spectrogram): if self.time_warp_w > 0: if 2 * self.time_warp_w < num_frames: import cv2 - w0 = np.random.randint( - self.time_warp_w, num_frames - self.time_warp_w - ) + + w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) w = np.random.randint(0, self.time_warp_w) upper, lower = distorted[:w0, :], distorted[w0:, :] upper = cv2.resize( - upper, dsize=(num_freqs, w0 + w), - interpolation=cv2.INTER_LINEAR + upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR ) lower = cv2.resize( lower, @@ -109,7 +114,7 @@ def __call__(self, spectrogram): f = np.random.randint(0, self.freq_mask_f) f0 = np.random.randint(0, num_freqs - f) if f != 0: - distorted[:, f0: f0 + f] = mask_value + distorted[:, f0 : f0 + f] = mask_value max_time_mask_t = min( self.time_mask_t, math.floor(num_frames * self.time_mask_p) @@ -121,6 +126,6 @@ def __call__(self, spectrogram): t = np.random.randint(0, max_time_mask_t) t0 = np.random.randint(0, num_frames - t) if t != 0: - distorted[t0: t0 + t, :] = mask_value + distorted[t0 : t0 + t, :] = mask_value return distorted diff --git a/fairseq/data/audio/feature_transforms/utterance_cmvn.py b/fairseq/data/audio/feature_transforms/utterance_cmvn.py index cbedd360d0..6bbd0ae821 100644 --- a/fairseq/data/audio/feature_transforms/utterance_cmvn.py +++ b/fairseq/data/audio/feature_transforms/utterance_cmvn.py @@ -1,11 +1,11 @@ import numpy as np - from fairseq.data.audio.feature_transforms import ( - AudioFeatureTransform, register_audio_feature_transform + AudioFeatureTransform, + register_audio_feature_transform, ) -@register_audio_feature_transform('utterance_cmvn') +@register_audio_feature_transform("utterance_cmvn") class UtteranceCMVN(AudioFeatureTransform): """Utterance-level CMVN (cepstral mean and variance normalization)""" @@ -13,16 +13,18 @@ class UtteranceCMVN(AudioFeatureTransform): def from_config_dict(cls, config=None): _config = {} if config is None else config return UtteranceCMVN( - _config.get('norm_means', True), - _config.get('norm_vars', True), + _config.get("norm_means", True), + _config.get("norm_vars", True), ) def __init__(self, norm_means=True, norm_vars=True): self.norm_means, self.norm_vars = norm_means, norm_vars def __repr__(self): - return self.__class__.__name__ + \ - f'(norm_means={self.norm_means}, norm_vars={self.norm_vars})' + return ( + self.__class__.__name__ + + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})" + ) def __call__(self, x): mean = x.mean(axis=0) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 09838a54e0..8d6ce85ecc 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -4,16 +4,17 @@ # LICENSE file in the root directory of this source tree. -import os import logging -import numpy as np +import os import sys +import numpy as np import torch import torch.nn.functional as F from .. import FairseqDataset + logger = logging.getLogger(__name__) @@ -72,11 +73,7 @@ def crop_to_max_size(self, wav, target_size): return wav[start:end] def collater(self, samples): - samples = [ - s - for s in samples - if s["source"] is not None - ] + samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: return {} diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index df360b2c74..aefe95658d 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -3,54 +3,61 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging -import re -from typing import List, Tuple, Optional, Dict -import os.path as op import csv import io +import logging +import os.path as op +import re +from typing import Dict, List, Optional, Tuple import numpy as np import torch -from fairseq.data import (FairseqDataset, Dictionary, ResamplingDataset, - ConcatDataset, data_utils as fairseq_data_utils) +from fairseq.data import ( + ConcatDataset, + Dictionary, + FairseqDataset, + ResamplingDataset, + data_utils as fairseq_data_utils, +) from fairseq.data.audio.audio_utils import get_fbank, get_waveform from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform + logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, ) logger = logging.getLogger(__name__) class S2TDataConfig(object): """Wrapper class for data config YAML""" + def __init__(self, yaml_path): try: import yaml except ImportError: - print('Please install PyYAML to load YAML files for ' - 'S2T data config') + print("Please install PyYAML to load YAML files for " "S2T data config") self.config = {} if op.isfile(yaml_path): try: with open(yaml_path) as f: self.config = yaml.load(f, Loader=yaml.FullLoader) except Exception as e: - logger.info(f'Failed to load config from {yaml_path}: {e}') + logger.info(f"Failed to load config from {yaml_path}: {e}") else: - logger.info(f'Cannot find {yaml_path}') + logger.info(f"Cannot find {yaml_path}") @property def vocab_filename(self): """fairseq vocabulary file under data root""" - return self.config.get('vocab_filename', 'dict.txt') + return self.config.get("vocab_filename", "dict.txt") @property def shuffle(self) -> bool: """Shuffle dataset samples before batching""" - return self.config.get('shuffle', False) + return self.config.get("shuffle", False) @property def pre_tokenizer(self) -> Dict: @@ -58,7 +65,7 @@ def pre_tokenizer(self) -> Dict: a dictionary with `tokenizer` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" - return self.config.get('pre_tokenizer', {'tokenizer': None}) + return self.config.get("pre_tokenizer", {"tokenizer": None}) @property def bpe_tokenizer(self) -> Dict: @@ -66,54 +73,55 @@ def bpe_tokenizer(self) -> Dict: a dictionary with `bpe` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" - return self.config.get('bpe_tokenizer', None) + return self.config.get("bpe_tokenizer", None) @property def prepend_tgt_lang_tag(self) -> bool: """Prepend target lang ID token as the target BOS (e.g. for to-many multilingual setting). During inference, this requires `--prefix-size 1` to force BOS to be lang ID token.""" - return self.config.get('prepend_tgt_lang_tag', False) + return self.config.get("prepend_tgt_lang_tag", False) @property def input_feat_per_channel(self): """The dimension of input features (per audio channel)""" - return self.config.get('input_feat_per_channel', 80) + return self.config.get("input_feat_per_channel", 80) @property def input_channels(self): """The number of channels in the input audio""" - return self.config.get('input_channels', 1) + return self.config.get("input_channels", 1) @property def sampling_alpha(self): """Hyper-parameter alpha = 1/T for temperature-based resampling. (alpha = 1 for no resampling)""" - return self.config.get('sampling_alpha', 1.) + return self.config.get("sampling_alpha", 1.0) @property def use_audio_input(self): """Needed by the dataset loader to see if the model requires raw audio as inputs.""" - return self.config.get('use_audio_input', False) + return self.config.get("use_audio_input", False) @property def audio_root(self): """Audio paths in the manifest TSV can be relative and this provides the root path. Set this to empty string when using absolute paths.""" - return self.config.get('audio_root', '') + return self.config.get("audio_root", "") def get_feature_transforms(self, split, is_train): """Split-specific feature transforms. Allowing train set wildcard `_train`, evaluation set wildcard `_eval` and general wildcard `*` for matching.""" from copy import deepcopy + cfg = deepcopy(self.config) - _cur = cfg.get('transforms', {}) + _cur = cfg.get("transforms", {}) cur = _cur.get(split) - cur = _cur.get('_train') if cur is None and is_train else cur - cur = _cur.get('_eval') if cur is None and not is_train else cur - cur = _cur.get('*') if cur is None else cur - cfg['transforms'] = cur + cur = _cur.get("_train") if cur is None and is_train else cur + cur = _cur.get("_eval") if cur is None and not is_train else cur + cur = _cur.get("*") if cur is None else cur + cfg["transforms"] = cur return cfg @@ -122,13 +130,13 @@ def is_npy_data(data: bytes) -> bool: def is_flac_or_wav_data(data: bytes) -> bool: - is_flac = (data[0] == 102 and data[1] == 76) - is_wav = (data[0] == 82 and data[1] == 73) + is_flac = data[0] == 102 and data[1] == 76 + is_wav = data[0] == 82 and data[1] == 73 return is_flac or is_wav def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: f.seek(offset) data = f.read(file_size) return data @@ -136,15 +144,15 @@ def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes: def get_features_from_npy_or_audio(path): ext = op.splitext(op.basename(path))[1] - if ext not in {'.npy', '.flac', '.wav'}: + if ext not in {".npy", ".flac", ".wav"}: raise ValueError(f'Unsupported file format for "{path}"') - return np.load(path) if ext == '.npy' else get_fbank(path) + return np.load(path) if ext == ".npy" else get_fbank(path) def get_features_or_waveform_from_uncompressed_zip( - path, byte_offset, byte_size, need_waveform=False + path, byte_offset, byte_size, need_waveform=False ): - assert path.endswith('.zip') + assert path.endswith(".zip") data = read_from_uncompressed_zip(path, byte_offset, byte_size) f = io.BytesIO(data) if is_npy_data(data): @@ -169,9 +177,9 @@ def get_features_or_waveform(path: str, need_waveform=False): Returns: features_or_waveform (numpy.ndarray): speech features or waveform. """ - _path, *extra = path.split(':') + _path, *extra = path.split(":") if not op.exists(_path): - raise FileNotFoundError(f'File not found: {_path}') + raise FileNotFoundError(f"File not found: {_path}") if len(extra) == 0: if need_waveform: @@ -183,13 +191,14 @@ def get_features_or_waveform(path: str, need_waveform=False): _path, extra[0], extra[1], need_waveform=need_waveform ) else: - raise ValueError(f'Invalid path: {path}') + raise ValueError(f"Invalid path: {path}") return features_or_waveform -def _collate_frames(frames: List[torch.Tensor], - is_audio_input: bool = False) -> torch.Tensor: +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +) -> torch.Tensor: """ Convert a list of 2D frames into a padded 3D tensor Args: @@ -209,24 +218,24 @@ def _collate_frames(frames: List[torch.Tensor], class SpeechToTextDataset(FairseqDataset): - LANG_TAG_TEMPLATE = '' + LANG_TAG_TEMPLATE = "" def __init__( - self, - split: str, - is_train_split: bool, - data_cfg: S2TDataConfig, - audio_paths: List[str], - n_frames: List[int], - src_texts: Optional[List[str]] = None, - tgt_texts: Optional[List[str]] = None, - speakers: Optional[List[str]] = None, - src_langs: Optional[List[str]] = None, - tgt_langs: Optional[List[str]] = None, - ids: Optional[List[str]] = None, - tgt_dict: Optional[Dictionary] = None, - pre_tokenizer=None, - bpe_tokenizer=None, + self, + split: str, + is_train_split: bool, + data_cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, ): self.split, self.is_train_split = split, is_train_split self.data_cfg = data_cfg @@ -239,8 +248,9 @@ def __init__( assert src_langs is None or len(src_langs) == self.n_samples assert tgt_langs is None or len(tgt_langs) == self.n_samples assert ids is None or len(ids) == self.n_samples - assert (tgt_dict is None and tgt_texts is None) or \ - (tgt_dict is not None and tgt_texts is not None) + assert (tgt_dict is None and tgt_texts is None) or ( + tgt_dict is not None and tgt_texts is not None + ) self.tgt_dict = tgt_dict self.check_tgt_lang_tag() self.src_texts, self.tgt_texts = src_texts, tgt_texts @@ -258,21 +268,24 @@ def __init__( logger.info(self.__repr__()) def __repr__(self): - return self.__class__.__name__ + \ - f'(split="{self.split}", n_samples={self.n_samples}, ' \ - f'prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, ' \ - f'shuffle={self.shuffle}, transforms={self.feature_transforms})' + return ( + self.__class__.__name__ + + f'(split="{self.split}", n_samples={self.n_samples}, ' + f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, " + f"shuffle={self.shuffle}, transforms={self.feature_transforms})" + ) @classmethod def is_lang_tag(cls, token): - pattern = cls.LANG_TAG_TEMPLATE.replace('{}', '(.*)') + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") return re.match(pattern, token) def check_tgt_lang_tag(self): if self.data_cfg.prepend_tgt_lang_tag: assert self.tgt_langs is not None and self.tgt_dict is not None - tgt_lang_tags = [self.LANG_TAG_TEMPLATE.format(t) - for t in set(self.tgt_langs)] + tgt_lang_tags = [ + self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) + ] assert all(t in self.tgt_dict for t in tgt_lang_tags) def tokenize_text(self, text: str): @@ -283,7 +296,7 @@ def tokenize_text(self, text: str): return text def __getitem__( - self, index: int + self, index: int ) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]: source = get_features_or_waveform( self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input @@ -308,18 +321,15 @@ def __getitem__( def __len__(self): return self.n_samples - def collater( - self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]] - ) -> Dict: + def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict: if len(samples) == 0: return {} indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long) - frames = _collate_frames([s for _, s, _ in samples], - self.data_cfg.use_audio_input) - # sort samples by descending number of frames - n_frames = torch.tensor( - [s.size(0) for _, s, _ in samples], dtype=torch.long + frames = _collate_frames( + [s for _, s, _ in samples], self.data_cfg.use_audio_input ) + # sort samples by descending number of frames + n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) @@ -329,16 +339,22 @@ def collater( ntokens = None if self.tgt_texts is not None: target = fairseq_data_utils.collate_tokens( - [t for _, _, t in samples], self.tgt_dict.pad(), - self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False + [t for _, _, t in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, ) target = target.index_select(0, order) target_lengths = torch.tensor( [t.size(0) for _, _, t in samples], dtype=torch.long ).index_select(0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( - [t for _, _, t in samples], self.tgt_dict.pad(), - self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True + [t for _, _, t in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) ntokens = sum(t.size(0) for _, _, t in samples) @@ -364,7 +380,7 @@ def size(self, index): t_len = 0 if self.tgt_texts is not None: tokenized = self.tokenize_text(self.tgt_texts[index]) - t_len = len(tokenized.split(' ')) + t_len = len(tokenized.split(" ")) return self.n_frames[index], t_len @property @@ -390,43 +406,59 @@ def prefetch(self, indices): class SpeechToTextDatasetCreator(object): # mandatory columns - KEY_ID, KEY_AUDIO, KEY_N_FRAMES = 'id', 'audio', 'n_frames' - KEY_TGT_TEXT = 'tgt_text' + KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" + KEY_TGT_TEXT = "tgt_text" # optional columns - KEY_SPEAKER, KEY_SRC_TEXT = 'speaker', 'src_text' - KEY_SRC_LANG, KEY_TGT_LANG = 'src_lang', 'tgt_lang' + KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" + KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" # default values - DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = '' + DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" @classmethod - def _from_list(cls, split_name: str, is_train_split, - samples: List[List[Dict]], data_cfg: S2TDataConfig, tgt_dict, - pre_tokenizer, bpe_tokenizer) -> SpeechToTextDataset: + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[List[Dict]], + data_cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + ) -> SpeechToTextDataset: audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], [] speakers, src_langs, tgt_langs = [], [], [] for s in samples: ids.extend([ss[cls.KEY_ID] for ss in s]) - audio_paths.extend([op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) - for ss in s]) + audio_paths.extend( + [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s] + ) n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s]) tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s]) - src_texts.extend([ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) - for ss in s]) - speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) - for ss in s]) - src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) - for ss in s]) - tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) - for ss in s]) + src_texts.extend( + [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] + ) + speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]) + src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]) + tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]) return SpeechToTextDataset( - split_name, is_train_split, data_cfg, audio_paths, n_frames, - src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, - pre_tokenizer, bpe_tokenizer + split_name, + is_train_split, + data_cfg, + audio_paths, + n_frames, + src_texts, + tgt_texts, + speakers, + src_langs, + tgt_langs, + ids, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, ) @classmethod - def _get_size_ratios(cls, ids: List[str], sizes: List[int], - alpha: float = 1.): + def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0): """Size ratios for temperature-based sampling (https://arxiv.org/abs/1907.05019)""" _sizes = np.array(sizes) @@ -444,35 +476,58 @@ def _get_size_ratios(cls, ids: List[str], sizes: List[int], return size_ratio.tolist() @classmethod - def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict, - pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, - seed: int) -> SpeechToTextDataset: + def from_tsv( + cls, + root: str, + data_cfg: S2TDataConfig, + splits: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + ) -> SpeechToTextDataset: samples = [] - _splits = splits.split(',') + _splits = splits.split(",") for split in _splits: - tsv_path = op.join(root, f'{split}.tsv') + tsv_path = op.join(root, f"{split}.tsv") if not op.isfile(tsv_path): raise FileNotFoundError(f"Dataset not found: {tsv_path}") with open(tsv_path) as f: reader = csv.DictReader( - f, delimiter='\t', quotechar=None, doublequote=False, - lineterminator='\n', quoting=csv.QUOTE_NONE + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, ) samples.append([dict(e) for e in reader]) assert len(samples) > 0 - datasets = [cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict, - pre_tokenizer, bpe_tokenizer) - for name, s in zip(_splits, samples)] + datasets = [ + cls._from_list( + name, + is_train_split, + [s], + data_cfg, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + ) + for name, s in zip(_splits, samples) + ] - if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.: + if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls._get_size_ratios( _splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha ) datasets = [ - ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, - replace=(r >= 1.)) + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) for d, r in zip(datasets, size_ratios) ] return ConcatDataset(datasets) diff --git a/fairseq/data/backtranslation_dataset.py b/fairseq/data/backtranslation_dataset.py index 0007a01506..8f70c90df3 100644 --- a/fairseq/data/backtranslation_dataset.py +++ b/fairseq/data/backtranslation_dataset.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import torch - from fairseq import utils from . import FairseqDataset @@ -36,16 +35,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): s = utils.move_to_cuda(collated_samples) if cuda else collated_samples generated_sources = generate_fn(s) - id_to_src = { - sample['id']: sample['source'] for sample in samples - } + id_to_src = {sample["id"]: sample["source"] for sample in samples} # Go through each tgt sentence in batch and its corresponding best # generated hypothesis and create a backtranslation data pair # {id: id, source: generated backtranslation, target: original tgt} return [ - {'id': id.item(), 'target': id_to_src[id.item()], 'source': hypos[0]['tokens'].cpu()} - for id, hypos in zip(collated_samples['id'], generated_sources) + { + "id": id.item(), + "target": id_to_src[id.item()], + "source": hypos[0]["tokens"].cpu(), + } + for id, hypos in zip(collated_samples["id"], generated_sources) ] @@ -87,8 +88,9 @@ def __init__( ): self.tgt_dataset = tgt_dataset self.backtranslation_fn = backtranslation_fn - self.output_collater = output_collater if output_collater is not None \ - else tgt_dataset.collater + self.output_collater = ( + output_collater if output_collater is not None else tgt_dataset.collater + ) self.cuda = cuda if torch.cuda.is_available() else False self.src_dict = src_dict self.tgt_dict = tgt_dict @@ -126,14 +128,12 @@ def collater(self, samples): Returns: dict: a mini-batch with keys coming from *output_collater* """ - if samples[0].get('is_dummy', False): + if samples[0].get("is_dummy", False): return samples samples = backtranslate_samples( samples=samples, collate_fn=self.tgt_dataset.collater, - generate_fn=( - lambda net_input: self.backtranslation_fn(net_input) - ), + generate_fn=(lambda net_input: self.backtranslation_fn(net_input)), cuda=self.cuda, ) return self.output_collater(samples) @@ -159,7 +159,7 @@ def size(self, index): @property def supports_prefetch(self): - return getattr(self.tgt_dataset, 'supports_prefetch', False) + return getattr(self.tgt_dataset, "supports_prefetch", False) def prefetch(self, indices): return self.tgt_dataset.prefetch(indices) diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py index 680dcce9ae..134d398b47 100644 --- a/fairseq/data/base_wrapper_dataset.py +++ b/fairseq/data/base_wrapper_dataset.py @@ -9,7 +9,6 @@ class BaseWrapperDataset(FairseqDataset): - def __init__(self, dataset): super().__init__() self.dataset = dataset @@ -21,7 +20,7 @@ def __len__(self): return len(self.dataset) def collater(self, samples): - if hasattr(self.dataset, 'collater'): + if hasattr(self.dataset, "collater"): return self.dataset.collater(samples) else: return default_collate(samples) @@ -41,7 +40,7 @@ def ordered_indices(self): @property def supports_prefetch(self): - return getattr(self.dataset, 'supports_prefetch', False) + return getattr(self.dataset, "supports_prefetch", False) def attr(self, attr: str, index: int): return self.dataset.attr(attr, index) @@ -75,5 +74,5 @@ def can_reuse_epoch_itr_across_epochs(self): def set_epoch(self, epoch): super().set_epoch(epoch) - if hasattr(self.dataset, 'set_epoch'): + if hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(epoch) diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 6f53d01188..cda8834ac8 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -5,7 +5,6 @@ import numpy as np import torch.nn.functional as F - from fairseq.data import BaseWrapperDataset @@ -40,7 +39,7 @@ def __init__( np.percentile( sizes, np.linspace(0, 100, num_buckets + 1), - interpolation='lower', + interpolation="lower", )[1:] ) diff --git a/fairseq/data/colorize_dataset.py b/fairseq/data/colorize_dataset.py index 89e0e04142..6ef097bff1 100644 --- a/fairseq/data/colorize_dataset.py +++ b/fairseq/data/colorize_dataset.py @@ -10,6 +10,7 @@ class ColorizeDataset(BaseWrapperDataset): """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ + def __init__(self, dataset, color_getter): super().__init__(dataset) self.color_getter = color_getter diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py index 0091a28e47..01a4078bb1 100644 --- a/fairseq/data/concat_dataset.py +++ b/fairseq/data/concat_dataset.py @@ -49,7 +49,7 @@ def _get_dataset_and_sample_index(self, idx: int): def collater(self, samples, **extra_args): # For now only supports datasets with same underlying collater implementations - if hasattr(self.datasets[0], 'collater'): + if hasattr(self.datasets[0], "collater"): return self.datasets[0].collater(samples, **extra_args) else: return default_collate(samples, **extra_args) @@ -92,14 +92,16 @@ def ordered_indices(self): # special handling for concatenating lang_pair_datasets indices = np.arange(len(self)) sizes = self.sizes - tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None - src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + tgt_sizes = ( + sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + ) + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) # sort by target length, then source length if tgt_sizes is not None: - indices = indices[ - np.argsort(tgt_sizes[indices], kind='mergesort') - ] - return indices[np.argsort(src_sizes[indices], kind='mergesort')] + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(src_sizes[indices], kind="mergesort")] else: return np.argsort(self.sizes) @@ -107,7 +109,7 @@ def prefetch(self, indices): frm = 0 for to, ds in zip(self.cumulative_sizes, self.datasets): real_size = len(ds) - if getattr(ds, 'supports_prefetch', False): + if getattr(ds, "supports_prefetch", False): ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) frm = to @@ -118,5 +120,5 @@ def can_reuse_epoch_itr_across_epochs(self): def set_epoch(self, epoch): super().set_epoch(epoch) for ds in self.datasets: - if hasattr(ds, 'set_epoch'): + if hasattr(ds, "set_epoch"): ds.set_epoch(epoch) diff --git a/fairseq/data/concat_sentences_dataset.py b/fairseq/data/concat_sentences_dataset.py index 55445ee1c7..625a29370e 100644 --- a/fairseq/data/concat_sentences_dataset.py +++ b/fairseq/data/concat_sentences_dataset.py @@ -9,12 +9,12 @@ class ConcatSentencesDataset(FairseqDataset): - def __init__(self, *datasets): super().__init__() self.datasets = datasets - assert all(len(ds) == len(datasets[0]) for ds in datasets), \ - 'datasets must have the same length' + assert all( + len(ds) == len(datasets[0]) for ds in datasets + ), "datasets must have the same length" def __getitem__(self, index): return torch.cat([ds[index] for ds in self.datasets]) @@ -40,17 +40,15 @@ def ordered_indices(self): @property def supports_prefetch(self): - return any( - getattr(ds, 'supports_prefetch', False) for ds in self.datasets - ) + return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) def prefetch(self, indices): for ds in self.datasets: - if getattr(ds, 'supports_prefetch', False): + if getattr(ds, "supports_prefetch", False): ds.prefetch(indices) def set_epoch(self, epoch): super().set_epoch(epoch) for ds in self.datasets: - if hasattr(ds, 'set_epoch'): + if hasattr(ds, "set_epoch"): ds.set_epoch(epoch) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index a8c480c5b1..81f457365a 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -12,8 +12,7 @@ import logging import os import warnings - -from typing import Tuple, Optional +from typing import Optional, Tuple import numpy as np import torch @@ -26,19 +25,26 @@ def infer_language_pair(path): """Infer language pair from filename: .-.(...).idx""" src, dst = None, None for filename in os.listdir(path): - parts = filename.split('.') - if len(parts) >= 3 and len(parts[1].split('-')) == 2: - return parts[1].split('-') + parts = filename.split(".") + if len(parts) >= 3 and len(parts[1].split("-")) == 2: + return parts[1].split("-") return src, dst -def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, - pad_to_length=None, pad_to_multiple=1): +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, +): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: - size = int(((size-0.1)//pad_to_multiple + 1) * pad_to_multiple) + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) res = values[0].new(len(values), size).fill_(pad_idx) def copy_tensor(src, dst): @@ -54,11 +60,13 @@ def copy_tensor(src, dst): dst.copy_(src) for i, v in enumerate(values): - copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) return res -def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False, default='cached'): +def load_indexed_dataset( + path, dictionary=None, dataset_impl=None, combine=False, default="cached" +): """A helper function for loading indexed datasets. Args: @@ -74,9 +82,10 @@ def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False """ from fairseq.data.concat_dataset import ConcatDataset import fairseq.data.indexed_dataset as indexed_dataset + datasets = [] for k in itertools.count(): - path_k = path + (str(k) if k > 0 else '') + path_k = path + (str(k) if k > 0 else "") path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) dataset_impl_k = dataset_impl @@ -90,7 +99,7 @@ def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False ) if dataset is None: break - logger.info('loaded {} examples from: {}'.format(len(dataset), path_k)) + logger.info("loaded {} examples from: {}".format(len(dataset), path_k)) datasets.append(dataset) if not combine: break @@ -148,8 +157,10 @@ def check_size(idx): assert isinstance(idx_size, dict) intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) return all( - all(a is None or b is None or a <= b - for a, b in zip(idx_size[key], max_positions[key])) + all( + a is None or b is None or a <= b + for a, b in zip(idx_size[key], max_positions[key]) + ) for key in intersect_keys ) else: @@ -166,6 +177,7 @@ def check_size(idx): a is None or b is None or a <= b for a, b in zip(size_fn(idx), max_positions) ) + ignored = [] itr = collect_filtered(check_size, indices, ignored) indices = np.fromiter(itr, dtype=np.int64, count=-1) @@ -186,37 +198,47 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): any elements are filtered (default: False). """ warnings.warn( - 'data_utils.filter_by_size is deprecated. ' - 'Use `FairseqDataset::filter_indices_by_size` instead.', - stacklevel=2 + "data_utils.filter_by_size is deprecated. " + "Use `FairseqDataset::filter_indices_by_size` instead.", + stacklevel=2, ) if isinstance(max_positions, float) or isinstance(max_positions, int): - if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray): + if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): ignored = indices[dataset.sizes[indices] > max_positions].tolist() indices = indices[dataset.sizes[indices] <= max_positions] - elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1: + elif ( + hasattr(dataset, "sizes") + and isinstance(dataset.sizes, list) + and len(dataset.sizes) == 1 + ): ignored = indices[dataset.sizes[0][indices] > max_positions].tolist() indices = indices[dataset.sizes[0][indices] <= max_positions] else: - indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) + indices, ignored = _filter_by_size_dynamic( + indices, dataset.size, max_positions + ) else: indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) if len(ignored) > 0 and raise_exception: - raise Exception(( - 'Size of sample #{} is invalid (={}) since max_positions={}, ' - 'skip this example with --skip-invalid-size-inputs-valid-test' - ).format(ignored[0], dataset.size(ignored[0]), max_positions)) + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) if len(ignored) > 0: - logger.warning(( - '{} samples have invalid sizes and will be skipped, ' - 'max_positions={}, first few sample ids={}' - ).format(len(ignored), max_positions, ignored[:10])) + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) return indices def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: @@ -238,21 +260,26 @@ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_siz ignored = indices[src_sizes[indices] > max_src_size] else: ignored = indices[ - (src_sizes[indices] > max_src_size) | - (tgt_sizes[indices] > max_tgt_size)] + (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size) + ] if len(ignored) > 0: if tgt_sizes is None: indices = indices[src_sizes[indices] <= max_src_size] else: indices = indices[ - (src_sizes[indices] <= max_src_size) & - (tgt_sizes[indices] <= max_tgt_size)] + (src_sizes[indices] <= max_src_size) + & (tgt_sizes[indices] <= max_tgt_size) + ] return indices, ignored.tolist() def batch_by_size( - indices, num_tokens_fn, max_tokens=None, max_sentences=None, - required_batch_size_multiple=1, fixed_shapes=None, + indices, + num_tokens_fn, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain @@ -274,12 +301,13 @@ def batch_by_size( """ try: from fairseq.data.data_utils_fast import ( - batch_by_size_fast, batch_fixed_shapes_fast, + batch_by_size_fast, + batch_fixed_shapes_fast, ) except ImportError: raise ImportError( - 'Please build Cython components with: `pip install --editable .` ' - 'or `python setup.py build_ext --inplace`' + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" ) max_tokens = max_tokens if max_tokens is not None else -1 @@ -291,14 +319,20 @@ def batch_by_size( if fixed_shapes is None: return batch_by_size_fast( - indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, + indices, + num_tokens_fn, + max_tokens, + max_sentences, + bsz_mult, ) else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) - sort_order = np.lexsort([ - fixed_shapes[:, 1].argsort(), # length - fixed_shapes[:, 0].argsort(), # bsz - ]) + sort_order = np.lexsort( + [ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ] + ) fixed_shapes_sorted = fixed_shapes[sort_order] return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) @@ -306,26 +340,27 @@ def batch_by_size( def post_process(sentence: str, symbol: str): if symbol == "sentencepiece": sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() - elif symbol == 'wordpiece': + elif symbol == "wordpiece": sentence = sentence.replace(" ", "").replace("_", " ").strip() - elif symbol == 'letter': + elif symbol == "letter": sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() - elif symbol is not None and symbol != 'none': + elif symbol is not None and symbol != "none": sentence = (sentence + " ").replace(symbol, "").rstrip() return sentence + def compute_mask_indices( - shape: Tuple[int, int], - padding_mask: Optional[torch.Tensor], - mask_prob: float, - mask_length: int, - mask_type: str = "static", - mask_other: float = 0.0, - min_masks: int = 0, - no_overlap: bool = False, - min_space: int = 0, + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape @@ -390,13 +425,14 @@ def compute_mask_indices( if no_overlap: mask_idc = [] + def arrange(s, e, length, keep_length): - span_start = np.random.randint(s, e-length) + span_start = np.random.randint(s, e - length) mask_idc.extend(span_start + i for i in range(length)) new_parts = [] if span_start - s - min_space >= keep_length: - new_parts.append((s, span_start-min_space+1)) + new_parts.append((s, span_start - min_space + 1)) if e - span_start - keep_length - min_space > keep_length: new_parts.append((span_start + length + min_space, e)) return new_parts @@ -404,7 +440,10 @@ def arrange(s, e, length, keep_length): parts = [(0, sz)] min_length = min(lengths) for length in sorted(lengths, reverse=True): - lens = np.fromiter((e - s if e-s >= length+min_space else 0 for s, e in parts), np.int) + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) l_sum = np.sum(lens) if l_sum == 0: break @@ -416,7 +455,7 @@ def arrange(s, e, length, keep_length): else: min_len = min(lengths) if sz - min_len <= num_mask: - min_len = sz - num_mask - 1 + min_len = sz - num_mask - 1 mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) @@ -442,10 +481,11 @@ def arrange(s, e, length, keep_length): def get_mem_usage(): try: import psutil + mb = 1024 * 1024 - return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb' + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" except ImportError: - return 'N/A' + return "N/A" def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index 4fe560b0a7..bdb62c8d5d 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import math + import numpy as np import torch -import math -from . import data_utils, FairseqDataset +from . import FairseqDataset, data_utils def collate( @@ -34,53 +35,59 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): pad_to_length=pad_to_length, ) - id = torch.LongTensor([s['id'] for s in samples]) + id = torch.LongTensor([s["id"] for s in samples]) src_tokens = merge( - 'source', left_pad=left_pad_source, - pad_to_length=pad_to_length['source'] if pad_to_length is not None else None, + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length - src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) src_tokens = src_tokens.index_select(0, sort_order) prev_output_tokens = None target = None - if samples[0].get('target', None) is not None: + if samples[0].get("target", None) is not None: target = merge( - 'target', left_pad=left_pad_target, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) target = target.index_select(0, sort_order) - ntokens = sum(len(s['target']) for s in samples) + ntokens = sum(len(s["target"]) for s in samples) if input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( - 'target', + "target", left_pad=left_pad_target, move_eos_to_beginning=True, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: - ntokens = sum(len(s['source']) for s in samples) + ntokens = sum(len(s["source"]) for s in samples) batch = { - 'id': id, - 'ntokens': ntokens, - 'net_input': { - 'src_tokens': src_tokens, - 'src_lengths': src_lengths, + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, }, - 'target': target, - 'nsentences': samples[0]['source'].size(0), - 'sort_order': sort_order, + "target": target, + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, } if prev_output_tokens is not None: - batch['net_input']['prev_output_tokens'] = prev_output_tokens + batch["net_input"]["prev_output_tokens"] = prev_output_tokens return batch @@ -130,25 +137,25 @@ def __init__( self.insert_ratio = args.insert self.rotate_ratio = args.rotate self.permute_sentence_ratio = args.permute_sentences - self.eos = (eos if eos is not None else vocab.eos()) + self.eos = eos if eos is not None else vocab.eos() self.item_transform_func = item_transform_func - if args.bpe != 'gpt2': + if args.bpe != "gpt2": self.full_stop_index = self.vocab.eos() else: - assert args.bpe == 'gpt2' - self.full_stop_index = self.vocab.index('13') + assert args.bpe == "gpt2" + self.full_stop_index = self.vocab.index("13") self.replace_length = args.replace_length if self.replace_length not in [-1, 0, 1]: - raise ValueError(f'invalid arg: replace_length={self.replace_length}') - if args.mask_length not in ['subword', 'word', 'span-poisson']: - raise ValueError(f'invalid arg: mask-length={args.mask_length}') - if args.mask_length == 'subword' and args.replace_length not in [0, 1]: - raise ValueError(f'if using subwords, use replace-length=1 or 0') + raise ValueError(f"invalid arg: replace_length={self.replace_length}") + if args.mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={args.mask_length}") + if args.mask_length == "subword" and args.replace_length not in [0, 1]: + raise ValueError(f"if using subwords, use replace-length=1 or 0") self.mask_span_distribution = None - if args.mask_length == 'span-poisson': + if args.mask_length == "span-poisson": _lambda = args.poisson_lambda lambda_to_the_k = 1 @@ -158,7 +165,7 @@ def __init__( for k in range(0, 128): ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) lambda_to_the_k *= _lambda - k_factorial *= (k + 1) + k_factorial *= k + 1 if ps[-1] < 0.0000001: break ps = torch.FloatTensor(ps) @@ -200,16 +207,16 @@ def __getitem__(self, index): assert source[0] == self.vocab.bos() assert source[-1] == self.eos return { - 'id': index, - 'source': source, - 'target': target, + "id": index, + "source": source, + "target": target, } def __len__(self): return len(self.dataset) def permute_sentences(self, source, p=1.0): - full_stops = (source == self.full_stop_index) + full_stops = source == self.full_stop_index # Pretend it ends with a full stop so last span is a sentence full_stops[-2] = 1 @@ -226,8 +233,8 @@ def permute_sentences(self, source, p=1.0): # Ignore at start index = 1 for i in ordering: - sentence = source[(sentence_ends[i - 1] if i > 0 else 1):sentence_ends[i]] - result[index:index + sentence.size(0)] = sentence + sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] + result[index : index + sentence.size(0)] = sentence index += sentence.size(0) return result @@ -253,7 +260,13 @@ def add_whole_word_mask(self, source, p): # Make sure we have enough to mask cum_length = torch.cumsum(lengths, 0) while cum_length[-1] < num_to_mask: - lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0) + lengths = torch.cat( + [ + lengths, + self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), + ], + dim=0, + ) cum_length = torch.cumsum(lengths, 0) # Trim to masking budget @@ -276,19 +289,25 @@ def add_whole_word_mask(self, source, p): lengths = torch.ones((num_to_mask,)).long() assert is_word_start[-1] == 0 word_starts = is_word_start.nonzero(as_tuple=False) - indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1) + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio source_length = source.size(0) assert source_length - 1 not in indices to_keep = torch.ones(source_length, dtype=torch.bool) - is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc if self.replace_length == 0: to_keep[indices] = 0 else: # keep index, but replace it with [MASK] source[indices] = self.mask_idx - source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) if self.mask_span_distribution is not None: assert len(lengths.size()) == 1 @@ -307,7 +326,9 @@ def add_whole_word_mask(self, source, p): else: # keep index, but replace it with [MASK] source[indices] = self.mask_idx - source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) else: # A bit faster when all lengths are 1 while indices.size(0) > 0: @@ -320,7 +341,9 @@ def add_whole_word_mask(self, source, p): else: # keep index, but replace it with [MASK] source[indices] = self.mask_idx - source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) assert source_length - 1 not in indices @@ -360,7 +383,9 @@ def add_insertion_noise(self, tokens, p): num_random = int(math.ceil(n * self.random_ratio)) result[noise_indices[num_random:]] = self.mask_idx - result[noise_indices[:num_random]] = torch.randint(low=1, high=len(self.vocab), size=(num_random,)) + result[noise_indices[:num_random]] = torch.randint( + low=1, high=len(self.vocab), size=(num_random,) + ) result[~noise_mask] = tokens @@ -375,8 +400,8 @@ def collater(self, samples, pad_to_length=None): dict: a mini-batch of data """ return collate( - samples, self.vocab.pad(), self.eos, self.vocab, - pad_to_length=pad_to_length) + samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length + ) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to @@ -395,7 +420,7 @@ def ordered_indices(self): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - return indices[np.argsort(self.sizes[indices], kind='mergesort')] + return indices[np.argsort(self.sizes[indices], kind="mergesort")] def prefetch(self, indices): self.src.prefetch(indices) @@ -404,8 +429,8 @@ def prefetch(self, indices): @property def supports_prefetch(self): return ( - hasattr(self.src, 'supports_prefetch') + hasattr(self.src, "supports_prefetch") and self.src.supports_prefetch - and hasattr(self.tgt, 'supports_prefetch') + and hasattr(self.tgt, "supports_prefetch") and self.tgt.supports_prefetch ) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 3d11f93137..e2df08e092 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -251,8 +251,7 @@ def add_from_file(self, f): "Duplicate words can overwrite earlier ones by adding the " "#fairseq:overwrite flag at the end of the corresponding row " "in the dictionary file. If using the Camembert model, please " - "download an updated copy of the model file." - .format(word) + "download an updated copy of the model file.".format(word) ) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py index d796496b86..2e807d8ae7 100644 --- a/fairseq/data/encoders/__init__.py +++ b/fairseq/data/encoders/__init__.py @@ -11,19 +11,19 @@ build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( - '--tokenizer', + "--tokenizer", default=None, ) build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( - '--bpe', + "--bpe", default=None, ) # automatically import any Python files in the encoders/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('fairseq.data.encoders.' + module) + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.data.encoders." + module) diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py index 1d78ff9150..0d2da3ea1a 100644 --- a/fairseq/data/encoders/byte_bpe.py +++ b/fairseq/data/encoders/byte_bpe.py @@ -6,11 +6,15 @@ from fairseq import file_utils from fairseq.data.encoders import register_bpe -from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode, - SPACE, SPACE_ESCAPE) +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) -@register_bpe('byte_bpe') +@register_bpe("byte_bpe") class ByteBPE(object): @staticmethod def add_args(parser): @@ -23,10 +27,13 @@ def __init__(self, args): vocab = file_utils.cached_path(args.sentencepiece_model_path) try: import sentencepiece as spm + self.sp = spm.SentencePieceProcessor() self.sp.Load(vocab) except ImportError: - raise ImportError('Please install sentencepiece with: pip install sentencepiece') + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) def encode(self, x: str) -> str: byte_encoded = byte_encode(x) @@ -34,5 +41,5 @@ def encode(self, x: str) -> str: @staticmethod def decode(x: str) -> str: - unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/byte_utils.py b/fairseq/data/encoders/byte_utils.py index 7c4bb74713..a305c08092 100644 --- a/fairseq/data/encoders/byte_utils.py +++ b/fairseq/data/encoders/byte_utils.py @@ -5,13 +5,13 @@ import re -WHITESPACE_NORMALIZER = re.compile(r'\s+') + +WHITESPACE_NORMALIZER = re.compile(r"\s+") SPACE = chr(32) SPACE_ESCAPE = chr(9601) # excluding non-breaking space (160) here PRINTABLE_LATIN = set( - list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + - list(range(174, 255 + 1)) + list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) ) BYTE_TO_BCHAR = { b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) @@ -21,19 +21,19 @@ def byte_encode(x: str) -> str: normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) - return ''.join([BYTE_TO_BCHAR[b] for b in normalized.encode('utf-8')]) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) def byte_decode(x: str) -> str: try: - return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode('utf-8') + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") except ValueError: - return '' + return "" def smart_byte_decode(x: str) -> str: output = byte_decode(x) - if output == '': + if output == "": # DP the best recovery (max valid chars) if it's broken n_bytes = len(x) f = [0 for _ in range(n_bytes + 1)] @@ -41,11 +41,11 @@ def smart_byte_decode(x: str) -> str: for i in range(1, n_bytes + 1): f[i], pt[i] = f[i - 1], i - 1 for j in range(1, min(4, i) + 1): - if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j: i])) > 0: + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: f[i], pt[i] = f[i - j] + 1, i - j cur_pt = n_bytes while cur_pt > 0: if f[cur_pt] == f[pt[cur_pt]] + 1: - output = byte_decode(x[pt[cur_pt]: cur_pt]) + output + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output cur_pt = pt[cur_pt] return output diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py index 8bace19c53..bb9554ed53 100644 --- a/fairseq/data/encoders/bytes.py +++ b/fairseq/data/encoders/bytes.py @@ -5,11 +5,15 @@ from fairseq.data.encoders import register_bpe -from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode, - SPACE, SPACE_ESCAPE) +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) -@register_bpe('bytes') +@register_bpe("bytes") class Bytes(object): def __init__(self, args): pass @@ -26,5 +30,5 @@ def encode(x: str) -> str: @staticmethod def decode(x: str) -> str: - unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py index db6a58a650..cffc57511c 100644 --- a/fairseq/data/encoders/characters.py +++ b/fairseq/data/encoders/characters.py @@ -6,11 +6,12 @@ from fairseq.data.encoders import register_bpe + SPACE = chr(32) SPACE_ESCAPE = chr(9601) -@register_bpe('characters') +@register_bpe("characters") class Characters(object): def __init__(self, args): pass @@ -26,4 +27,4 @@ def encode(x: str) -> str: @staticmethod def decode(x: str) -> str: - return x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) + return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py index ea0badd544..74d4ad8504 100644 --- a/fairseq/data/encoders/fastbpe.py +++ b/fairseq/data/encoders/fastbpe.py @@ -7,9 +7,8 @@ from fairseq.data.encoders import register_bpe -@register_bpe('fastbpe') +@register_bpe("fastbpe") class fastBPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -19,17 +18,18 @@ def add_args(parser): def __init__(self, args): if args.bpe_codes is None: - raise ValueError('--bpe-codes is required for --bpe=fastbpe') + raise ValueError("--bpe-codes is required for --bpe=fastbpe") codes = file_utils.cached_path(args.bpe_codes) try: import fastBPE + self.bpe = fastBPE.fastBPE(codes) self.bpe_symbol = "@@ " except ImportError: - raise ImportError('Please install fastBPE with: pip install fastBPE') + raise ImportError("Please install fastBPE with: pip install fastBPE") def encode(self, x: str) -> str: return self.bpe.apply([x])[0] def decode(self, x: str) -> str: - return (x + ' ').replace(self.bpe_symbol, '').rstrip() + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py index 54e0593d00..8ac099a688 100644 --- a/fairseq/data/encoders/gpt2_bpe.py +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -9,13 +9,12 @@ from .gpt2_bpe_utils import get_encoder -DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' -DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' +DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" +DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" -@register_bpe('gpt2') +@register_bpe("gpt2") class GPT2BPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -29,21 +28,20 @@ def add_args(parser): def __init__(self, args): encoder_json = file_utils.cached_path( - getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON) + getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) ) vocab_bpe = file_utils.cached_path( - getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE) + getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) ) self.bpe = get_encoder(encoder_json, vocab_bpe) def encode(self, x: str) -> str: - return ' '.join(map(str, self.bpe.encode(x))) + return " ".join(map(str, self.bpe.encode(x))) def decode(self, x: str) -> str: - return self.bpe.decode([ - int(tok) if tok not in {'', ''} else tok - for tok in x.split() - ]) + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) def is_beginning_of_word(self, x: str) -> bool: - return self.decode(x).startswith(' ') + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq/data/encoders/gpt2_bpe_utils.py index 1917f82314..688d4e36e3 100644 --- a/fairseq/data/encoders/gpt2_bpe_utils.py +++ b/fairseq/data/encoders/gpt2_bpe_utils.py @@ -5,8 +5,8 @@ Original license: MIT """ -from functools import lru_cache import json +from functools import lru_cache @lru_cache() @@ -20,17 +20,22 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). @@ -42,25 +47,28 @@ def get_pairs(word): prev_char = char return pairs -class Encoder: - def __init__(self, encoder, bpe_merges, errors='replace'): +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): self.encoder = encoder - self.decoder = {v:k for k,v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.cache = {} try: import regex as re + self.re = re except ImportError: - raise ImportError('Please install regex with: pip install regex') + raise ImportError("Please install regex with: pip install regex") # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions - self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.pat = self.re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) def bpe(self, token): if token in self.cache: @@ -72,7 +80,7 @@ def bpe(self, token): return token while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -87,8 +95,8 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) @@ -99,28 +107,33 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = ' '.join(word) + word = " ".join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] for token in self.re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) return bpe_tokens def decode(self, tokens): - text = ''.join([self.decoder.get(token, token) for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + text = "".join([self.decoder.get(token, token) for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) return text + def get_encoder(encoder_json_path, vocab_bpe_path): - with open(encoder_json_path, 'r') as f: + with open(encoder_json_path, "r") as f: encoder = json.load(f) - with open(vocab_bpe_path, 'r', encoding="utf-8") as f: + with open(vocab_bpe_path, "r", encoding="utf-8") as f: bpe_data = f.read() - bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] return Encoder( encoder=encoder, bpe_merges=bpe_merges, diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py index 16adc45aee..a968fe8857 100644 --- a/fairseq/data/encoders/hf_bert_bpe.py +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -6,9 +6,8 @@ from fairseq.data.encoders import register_bpe -@register_bpe('bert') +@register_bpe("bert") class BertBPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -24,25 +23,26 @@ def __init__(self, args): from transformers import BertTokenizer except ImportError: raise ImportError( - 'Please install transformers with: pip install transformers' + "Please install transformers with: pip install transformers" ) - if 'bpe_vocab_file' in args: + if "bpe_vocab_file" in args: self.bert_tokenizer = BertTokenizer( - args.bpe_vocab_file, - do_lower_case=not args.bpe_cased + args.bpe_vocab_file, do_lower_case=not args.bpe_cased ) else: - vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' + vocab_file_name = ( + "bert-base-cased" if args.bpe_cased else "bert-base-uncased" + ) self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) def encode(self, x: str) -> str: - return ' '.join(self.bert_tokenizer.tokenize(x)) + return " ".join(self.bert_tokenizer.tokenize(x)) def decode(self, x: str) -> str: return self.bert_tokenizer.clean_up_tokenization( - self.bert_tokenizer.convert_tokens_to_string(x.split(' ')) + self.bert_tokenizer.convert_tokens_to_string(x.split(" ")) ) def is_beginning_of_word(self, x: str) -> bool: - return not x.startswith('##') + return not x.startswith("##") diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 2767df044e..544d408273 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -6,9 +6,8 @@ from fairseq.data.encoders import register_bpe -@register_bpe('hf_byte_bpe') +@register_bpe("hf_byte_bpe") class HuggingFaceByteLevelBPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -23,24 +22,22 @@ def __init__(self, args): from tokenizers import ByteLevelBPETokenizer except ImportError: raise ImportError( - 'Please install huggingface/tokenizers with: ' - 'pip install tokenizers' + "Please install huggingface/tokenizers with: " "pip install tokenizers" ) self.bpe = ByteLevelBPETokenizer( args.bpe_vocab, args.bpe_merges, - add_prefix_space=getattr(args, 'bpe_add_prefix_space', False), + add_prefix_space=getattr(args, "bpe_add_prefix_space", False), ) def encode(self, x: str) -> str: - return ' '.join(map(str, self.bpe.encode(x).ids)) + return " ".join(map(str, self.bpe.encode(x).ids)) def decode(self, x: str) -> str: - return self.bpe.decode([ - int(tok) if tok not in {'', ''} else tok - for tok in x.split() - ]) + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) def is_beginning_of_word(self, x: str) -> bool: - return self.decode(x).startswith(' ') + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index b1e7478b9d..8c24844263 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -6,9 +6,8 @@ from fairseq.data.encoders import register_tokenizer -@register_tokenizer('moses') +@register_tokenizer("moses") class MosesTokenizer(object): - @staticmethod def add_args(parser): # fmt: off @@ -25,17 +24,20 @@ def add_args(parser): def __init__(self, args): self.args = args - if getattr(args, 'moses_source_lang', None) is None: - args.moses_source_lang = getattr(args, 'source_lang', 'en') - if getattr(args, 'moses_target_lang', None) is None: - args.moses_target_lang = getattr(args, 'target_lang', 'en') + if getattr(args, "moses_source_lang", None) is None: + args.moses_source_lang = getattr(args, "source_lang", "en") + if getattr(args, "moses_target_lang", None) is None: + args.moses_target_lang = getattr(args, "target_lang", "en") try: from sacremoses import MosesTokenizer, MosesDetokenizer + self.tok = MosesTokenizer(args.moses_source_lang) self.detok = MosesDetokenizer(args.moses_target_lang) except ImportError: - raise ImportError('Please install Moses tokenizer with: pip install sacremoses') + raise ImportError( + "Please install Moses tokenizer with: pip install sacremoses" + ) def encode(self, x: str) -> str: return self.tok.tokenize( diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index 3db8ee5652..3b617e7314 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -6,18 +6,18 @@ from fairseq.data.encoders import register_tokenizer -@register_tokenizer('nltk') +@register_tokenizer("nltk") class NLTKTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): try: from nltk.tokenize import word_tokenize + self.word_tokenize = word_tokenize except ImportError: - raise ImportError('Please install nltk with: pip install nltk') + raise ImportError("Please install nltk with: pip install nltk") def encode(self, x: str) -> str: - return ' '.join(self.word_tokenize(x)) + return " ".join(self.word_tokenize(x)) def decode(self, x: str) -> str: return x diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index e5ff5db389..b25c6caebe 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -7,9 +7,8 @@ from fairseq.data.encoders import register_bpe -@register_bpe('sentencepiece') +@register_bpe("sentencepiece") class SentencepieceBPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -21,23 +20,26 @@ def __init__(self, args): sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) try: import sentencepiece as spm + self.sp = spm.SentencePieceProcessor() self.sp.Load(sentencepiece_model) except ImportError: - raise ImportError('Please install sentencepiece with: pip install sentencepiece') + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) def encode(self, x: str) -> str: - return ' '.join(self.sp.EncodeAsPieces(x)) + return " ".join(self.sp.EncodeAsPieces(x)) def decode(self, x: str) -> str: - return x.replace(' ', '').replace('\u2581', ' ').strip() + return x.replace(" ", "").replace("\u2581", " ").strip() def is_beginning_of_word(self, x: str) -> bool: - if x in ['', '', '', '']: + if x in ["", "", "", ""]: # special elements are always considered beginnings # HACK: this logic is already present in fairseq/tasks/masked_lm.py # but these special tokens are also contained in the sentencepiece # vocabulary which causes duplicate special tokens. This hack makes # sure that they are all taken into account. return True - return x.startswith('\u2581') + return x.startswith("\u2581") diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 670001a8e8..3bc7ce4958 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -8,14 +8,13 @@ from fairseq.data.encoders import register_tokenizer -@register_tokenizer('space') +@register_tokenizer("space") class SpaceTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): self.space_tok = re.compile(r"\s+") def encode(self, x: str) -> str: - return self.space_tok.sub(' ', x) + return self.space_tok.sub(" ", x) def decode(self, x: str) -> str: return x diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py index 78f19b43ea..e85f99af39 100644 --- a/fairseq/data/encoders/subword_nmt_bpe.py +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -7,9 +7,8 @@ from fairseq.data.encoders import register_bpe -@register_bpe('subword_nmt') +@register_bpe("subword_nmt") class SubwordNMTBPE(object): - @staticmethod def add_args(parser): # fmt: off @@ -21,15 +20,20 @@ def add_args(parser): def __init__(self, args): if args.bpe_codes is None: - raise ValueError('--bpe-codes is required for --bpe=subword_nmt') + raise ValueError("--bpe-codes is required for --bpe=subword_nmt") codes = file_utils.cached_path(args.bpe_codes) try: from subword_nmt import apply_bpe + bpe_parser = apply_bpe.create_parser() - bpe_args = bpe_parser.parse_args([ - '--codes', codes, - '--separator', args.bpe_separator, - ]) + bpe_args = bpe_parser.parse_args( + [ + "--codes", + codes, + "--separator", + args.bpe_separator, + ] + ) self.bpe = apply_bpe.BPE( bpe_args.codes, bpe_args.merges, @@ -37,12 +41,14 @@ def __init__(self, args): None, bpe_args.glossaries, ) - self.bpe_symbol = bpe_args.separator + ' ' + self.bpe_symbol = bpe_args.separator + " " except ImportError: - raise ImportError('Please install subword_nmt with: pip install subword-nmt') + raise ImportError( + "Please install subword_nmt with: pip install subword-nmt" + ) def encode(self, x: str) -> str: return self.bpe.process_line(x) def decode(self, x: str) -> str: - return (x + ' ').replace(self.bpe_symbol, '').rstrip() + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/utils.py b/fairseq/data/encoders/utils.py index a0e491c143..d93eb532ef 100644 --- a/fairseq/data/encoders/utils.py +++ b/fairseq/data/encoders/utils.py @@ -10,19 +10,21 @@ def get_whole_word_mask(args, dictionary): bpe = encoders.build_bpe(args) if bpe is not None: + def is_beginning_of_word(i): if i < dictionary.nspecial: # special elements are always considered beginnings return True tok = dictionary[i] - if tok.startswith('madeupword'): + if tok.startswith("madeupword"): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True - mask_whole_words = torch.ByteTensor(list( - map(is_beginning_of_word, range(len(dictionary))) - )) + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(dictionary)))) + ) return mask_whole_words return None diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index caaef8f713..ed08c1ba20 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -5,7 +5,6 @@ import numpy as np import torch.utils.data - from fairseq.data import data_utils @@ -112,7 +111,7 @@ def batch_by_size( def adjust_bsz(bsz, num_tokens): if bsz is None: - assert max_tokens is not None, 'Must specify --max-tokens' + assert max_tokens is not None, "Must specify --max-tokens" bsz = max_tokens // num_tokens if max_sentences is not None: bsz = min(bsz, max_sentences) @@ -120,13 +119,15 @@ def adjust_bsz(bsz, num_tokens): bsz >= required_batch_size_multiple and bsz % required_batch_size_multiple != 0 ): - bsz -= (bsz % required_batch_size_multiple) + bsz -= bsz % required_batch_size_multiple return bsz - fixed_shapes = np.array([ - [adjust_bsz(bsz, num_tokens), num_tokens] - for (bsz, num_tokens) in fixed_shapes - ]) + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) return data_utils.batch_by_size( indices, @@ -154,16 +155,24 @@ def filter_indices_by_size(self, indices, max_sizes): list: list of removed indices """ if isinstance(max_sizes, float) or isinstance(max_sizes, int): - if hasattr(self, 'sizes') and isinstance(self.sizes, np.ndarray): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): ignored = indices[self.sizes[indices] > max_sizes].tolist() indices = indices[self.sizes[indices] <= max_sizes] - elif hasattr(self, 'sizes') and isinstance(self.sizes, list) and len(self.sizes) == 1: + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): ignored = indices[self.sizes[0][indices] > max_sizes].tolist() indices = indices[self.sizes[0][indices] <= max_sizes] else: - indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) else: - indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes) + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) return indices, ignored @property diff --git a/fairseq/data/id_dataset.py b/fairseq/data/id_dataset.py index 6a73ba1ff7..3e4d7969cf 100644 --- a/fairseq/data/id_dataset.py +++ b/fairseq/data/id_dataset.py @@ -9,7 +9,6 @@ class IdDataset(FairseqDataset): - def __getitem__(self, index): return index diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 55bf0ca585..3efecab3a6 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -3,18 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from functools import lru_cache import os import shutil import struct +from functools import lru_cache import numpy as np import torch - -from . import FairseqDataset from fairseq.data.fasta_dataset import FastaDataset from fairseq.file_io import PathManager +from . import FairseqDataset + def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: @@ -24,56 +24,59 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['raw', 'lazy', 'cached', 'mmap', 'fasta'] + return ["raw", "lazy", "cached", "mmap", "fasta"] def infer_dataset_impl(path): if IndexedRawTextDataset.exists(path): - return 'raw' + return "raw" elif IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None elif FastaDataset.exists(path): - return 'fasta' + return "fasta" else: return None def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': - return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) - elif impl == 'fasta': + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + elif impl == "fasta": raise NotImplementedError else: return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): - if impl == 'raw' and IndexedRawTextDataset.exists(path): + if impl == "raw" and IndexedRawTextDataset.exists(path): assert dictionary is not None return IndexedRawTextDataset(path, dictionary) - elif impl == 'lazy' and IndexedDataset.exists(path): + elif impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path) - elif impl == 'fasta' and FastaDataset.exists(path): + elif impl == "fasta" and FastaDataset.exists(path): from fairseq.data.fasta_dataset import EncodedFastaDataset + return EncodedFastaDataset(path, dictionary) return None def dataset_exists(path, impl): - if impl == 'raw': + if impl == "raw": return IndexedRawTextDataset.exists(path) - elif impl == 'mmap': + elif impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -97,7 +100,7 @@ def write_longs(f, a): 5: np.int64, 6: np.float, 7: np.double, - 8: np.uint16 + 8: np.uint16, } @@ -109,16 +112,17 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" class IndexedDataset(FairseqDataset): """Loader for TorchNet IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path, fix_lua_indexing=False): super().__init__() @@ -128,27 +132,27 @@ def __init__(self, path, fix_lua_indexing=False): self.read_index(path) def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -159,7 +163,7 @@ def __getitem__(self, i): if not self.data_file: self.read_data(self.path) self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -179,8 +183,8 @@ def size(self, index): @staticmethod def exists(path): - return ( - PathManager.exists(index_file_path(path)) and PathManager.exists(data_file_path(path)) + return PathManager.exists(index_file_path(path)) and PathManager.exists( + data_file_path(path) ) @property @@ -189,7 +193,6 @@ def supports_prefetch(self): class IndexedCachedDataset(IndexedDataset): - def __init__(self, path, fix_lua_indexing=False): super().__init__(path, fix_lua_indexing=fix_lua_indexing) self.cache = None @@ -214,7 +217,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx: ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -226,10 +229,10 @@ def prefetch(self, indices): @lru_cache(maxsize=8) def __getitem__(self, i): self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx: ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing @@ -250,12 +253,14 @@ def __init__(self, path, dictionary, append_eos=True, reverse_order=False): self.size = len(self.tokens_list) def read_data(self, path, dictionary): - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: for line in f: - self.lines.append(line.strip('\n')) + self.lines.append(line.strip("\n")) tokens = dictionary.encode_line( - line, add_if_not_exist=False, - append_eos=self.append_eos, reverse_order=self.reverse_order, + line, + add_if_not_exist=False, + append_eos=self.append_eos, + reverse_order=self.reverse_order, ).long() self.tokens_list.append(tokens) self.sizes.append(len(tokens)) @@ -263,7 +268,7 @@ def read_data(self, path, dictionary): def check_index(self, i): if i < 0 or i >= self.size: - raise IndexError('index out of range') + raise IndexError("index out of range") @lru_cache(maxsize=8) def __getitem__(self, i): @@ -299,11 +304,11 @@ class IndexedDatasetBuilder(object): np.int32: 4, np.int64: 8, np.float: 4, - np.double: 8 + np.double: 8, } def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] @@ -330,7 +335,7 @@ def merge_file_(self, another_file): for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -340,11 +345,11 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= self.total: raise RuntimeError( - 'Mismatch between actual and expected iterable length. ' - 'Please report this to the fairseq developers.' + "Mismatch between actual and expected iterable length. " + "Please report this to the fairseq developers." ) self.n += 1 yield x @@ -138,7 +137,11 @@ def load_state_dict(self, state_dict): class StreamingEpochBatchIterator(EpochBatchIterating): def __init__( - self, dataset, epoch=1, num_shards=1, shard_id=0, + self, + dataset, + epoch=1, + num_shards=1, + shard_id=0, ): assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset @@ -178,11 +181,11 @@ def iterations_in_epoch(self) -> int: def state_dict(self): return { - 'epoch': self.epoch, + "epoch": self.epoch, } def load_state_dict(self, state_dict): - self.epoch = state_dict['epoch'] + self.epoch = state_dict["epoch"] class EpochBatchIterator(EpochBatchIterating): @@ -222,14 +225,25 @@ class EpochBatchIterator(EpochBatchIterating): """ def __init__( - self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, epoch=1, buffer_size=0, timeout=0, + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset self.collate_fn = collate_fn self.batch_sampler = batch_sampler - self._frozen_batches = tuple(batch_sampler) if not callable(batch_sampler) else None + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) self.seed = seed self.num_shards = num_shards self.shard_id = shard_id @@ -243,7 +257,7 @@ def __init__( self.shuffle = True self._cur_epoch_itr = None self._next_epoch_itr = None - self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) @property def frozen_batches(self): @@ -303,7 +317,9 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): # reset _frozen_batches to refresh the next epoch self._frozen_batches = None self._cur_epoch_itr = self._get_iterator_for_epoch( - self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, ) self.shuffle = shuffle return self._cur_epoch_itr @@ -330,22 +346,22 @@ def state_dict(self): epoch = self.epoch iter_in_epoch = self.iterations_in_epoch return { - 'version': 2, - 'epoch': epoch, - 'iterations_in_epoch': iter_in_epoch, - 'shuffle': self.shuffle, + "version": 2, + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, } def load_state_dict(self, state_dict): """Copies the state of the iterator from the given *state_dict*.""" - self.epoch = state_dict['epoch'] - itr_pos = state_dict.get('iterations_in_epoch', 0) - version = state_dict.get('version', 1) + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + version = state_dict.get("version", 1) if itr_pos > 0: # fast-forward epoch iterator self._next_epoch_itr = self._get_iterator_for_epoch( self.epoch, - shuffle=state_dict.get('shuffle', True), + shuffle=state_dict.get("shuffle", True), offset=itr_pos, ) if self._next_epoch_itr is None: @@ -354,15 +370,16 @@ def load_state_dict(self, state_dict): self.epoch += 1 else: raise RuntimeError( - 'Cannot resume training due to dataloader mismatch, please ' - 'report this to the fairseq developers. You can relaunch ' - 'training with `--reset-dataloader` and it should work.' + "Cannot resume training due to dataloader mismatch, please " + "report this to the fairseq developers. You can relaunch " + "training with `--reset-dataloader` and it should work." ) else: self._next_epoch_itr = None - def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0): - + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): def shuffle_batches(batches, seed): with data_utils.numpy_seed(seed): np.random.shuffle(batches) @@ -374,9 +391,9 @@ def shuffle_batches(batches, seed): if shuffle and not fix_batches_to_gpus: batches = shuffle_batches(list(batches), self.seed + epoch) - batches = list(ShardedIterator( - batches, self.num_shards, self.shard_id, fill_value=[] - )) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) self.dataset.prefetch([i for s in batches for i in s]) if shuffle and fix_batches_to_gpus: @@ -386,15 +403,15 @@ def shuffle_batches(batches, seed): batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) else: batches = self.frozen_batches - batches = list(ShardedIterator( - batches, self.num_shards, self.shard_id, fill_value=[] - )) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) if offset > 0 and offset >= len(batches): return None if self.num_workers > 0: - os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning' + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" # Create data loader itr = torch.utils.data.DataLoader( @@ -429,7 +446,7 @@ def __init__(self, iterable, chunk_size): itr = _chunk_iterator(iterable, chunk_size) super().__init__( itr, - start=int(math.ceil(getattr(iterable, 'n', 0) / float(chunk_size))), + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), total=int(math.ceil(len(iterable) / float(chunk_size))), ) self.chunk_size = chunk_size @@ -462,7 +479,7 @@ class ShardedIterator(CountingIterator): def __init__(self, iterable, num_shards, shard_id, fill_value=None): if shard_id < 0 or shard_id >= num_shards: - raise ValueError('shard_id must be between 0 and num_shards') + raise ValueError("shard_id must be between 0 and num_shards") sharded_len = int(math.ceil(len(iterable) / float(num_shards))) itr = map( operator.itemgetter(1), @@ -474,7 +491,7 @@ def __init__(self, iterable, num_shards, shard_id, fill_value=None): ) super().__init__( itr, - start=int(math.ceil(getattr(iterable, 'n', 0) / float(num_shards))), + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), total=sharded_len, ) @@ -545,7 +562,10 @@ def __next__(self): # Notify the user if there is a data loading bottleneck if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): if time.time() - self.start_time > 5 * 60: - if self.warning_time is None or time.time() - self.warning_time > 15 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): logger.debug( "Data loading buffer is empty or nearly empty. This may " "indicate a data loading bottleneck, and increasing the " diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 3014354e7c..62e7109b33 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -7,8 +7,7 @@ import numpy as np import torch - -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils logger = logging.getLogger(__name__) @@ -30,7 +29,10 @@ def collate( def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): return data_utils.collate_tokens( [s[key] for s in samples], - pad_idx, eos_idx, left_pad, move_eos_to_beginning, + pad_idx, + eos_idx, + left_pad, + move_eos_to_beginning, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) @@ -38,7 +40,10 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): def check_alignment(alignment, src_len, tgt_len): if alignment is None or len(alignment) == 0: return False - if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: + if ( + alignment[:, 0].max().item() >= src_len - 1 + or alignment[:, 1].max().item() >= tgt_len - 1 + ): logger.warning("alignment size mismatch found, skipping alignment!") return False return True @@ -53,78 +58,90 @@ def compute_alignment_weights(alignments): index 3 is repeated twice) """ align_tgt = alignments[:, 1] - _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) + _, align_tgt_i, align_tgt_c = torch.unique( + align_tgt, return_inverse=True, return_counts=True + ) align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] - return 1. / align_weights.float() + return 1.0 / align_weights.float() - id = torch.LongTensor([s['id'] for s in samples]) + id = torch.LongTensor([s["id"] for s in samples]) src_tokens = merge( - 'source', left_pad=left_pad_source, - pad_to_length=pad_to_length['source'] if pad_to_length is not None else None + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length - src_lengths = torch.LongTensor([ - s['source'].ne(pad_idx).long().sum() for s in samples - ]) + src_lengths = torch.LongTensor( + [s["source"].ne(pad_idx).long().sum() for s in samples] + ) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) src_tokens = src_tokens.index_select(0, sort_order) prev_output_tokens = None target = None - if samples[0].get('target', None) is not None: + if samples[0].get("target", None) is not None: target = merge( - 'target', left_pad=left_pad_target, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) target = target.index_select(0, sort_order) - tgt_lengths = torch.LongTensor([ - s['target'].ne(pad_idx).long().sum() for s in samples - ]).index_select(0, sort_order) + tgt_lengths = torch.LongTensor( + [s["target"].ne(pad_idx).long().sum() for s in samples] + ).index_select(0, sort_order) ntokens = tgt_lengths.sum().item() - if samples[0].get('prev_output_tokens', None) is not None: - prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target) + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( - 'target', + "target", left_pad=left_pad_target, move_eos_to_beginning=True, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) else: ntokens = src_lengths.sum().item() batch = { - 'id': id, - 'nsentences': len(samples), - 'ntokens': ntokens, - 'net_input': { - 'src_tokens': src_tokens, - 'src_lengths': src_lengths, + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, }, - 'target': target, + "target": target, } if prev_output_tokens is not None: - batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order) + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) - if samples[0].get('alignment', None) is not None: - bsz, tgt_sz = batch['target'].shape - src_sz = batch['net_input']['src_tokens'].shape[1] + if samples[0].get("alignment", None) is not None: + bsz, tgt_sz = batch["target"].shape + src_sz = batch["net_input"]["src_tokens"].shape[1] offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) - offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) + offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz if left_pad_source: - offsets[:, 0] += (src_sz - src_lengths) + offsets[:, 0] += src_sz - src_lengths if left_pad_target: - offsets[:, 1] += (tgt_sz - tgt_lengths) + offsets[:, 1] += tgt_sz - tgt_lengths alignments = [ alignment + offset - for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) - for alignment in [samples[align_idx]['alignment'].view(-1, 2)] + for align_idx, offset, src_len, tgt_len in zip( + sort_order, offsets, src_lengths, tgt_lengths + ) + for alignment in [samples[align_idx]["alignment"].view(-1, 2)] if check_alignment(alignment, src_len, tgt_len) ] @@ -132,8 +149,8 @@ def compute_alignment_weights(alignments): alignments = torch.cat(alignments, dim=0) align_weights = compute_alignment_weights(alignments) - batch['alignments'] = alignments - batch['align_weights'] = align_weights + batch["alignments"] = alignments + batch["align_weights"] = align_weights if samples[0].get("constraints", None) is not None: # Collate the packed constraints across the samples, padding to @@ -142,7 +159,7 @@ def compute_alignment_weights(alignments): max_len = max(lens) constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): - constraints[i, 0:lens[i]] = samples[i].get("constraints") + constraints[i, 0 : lens[i]] = samples[i].get("constraints") batch["constraints"] = constraints return batch @@ -188,14 +205,23 @@ class LanguagePairDataset(FairseqDataset): """ def __init__( - self, src, src_sizes, src_dict, - tgt=None, tgt_sizes=None, tgt_dict=None, - left_pad_source=True, left_pad_target=False, - shuffle=True, input_feeding=True, - remove_eos_from_source=False, append_eos_to_target=False, + self, + src, + src_sizes, + src_dict, + tgt=None, + tgt_sizes=None, + tgt_dict=None, + left_pad_source=True, + left_pad_target=False, + shuffle=True, + input_feeding=True, + remove_eos_from_source=False, + append_eos_to_target=False, align_dataset=None, constraints=None, - append_bos=False, eos=None, + append_bos=False, + eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, @@ -206,12 +232,18 @@ def __init__( assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: - assert len(src) == len(tgt), "Source and target must contain the same number of examples" + assert len(src) == len( + tgt + ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source @@ -222,14 +254,17 @@ def __init__( self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: - assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" + assert ( + self.tgt_sizes is not None + ), "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos - self.eos = (eos if eos is not None else src_dict.eos()) + self.eos = eos if eos is not None else src_dict.eos() self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset + self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -238,7 +273,7 @@ def __init__( left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes - logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, @@ -248,15 +283,16 @@ def __init__( left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes - logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) + logger.info( + "bucketing target lengths: {}".format(list(self.tgt.buckets)) + ) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -292,12 +328,12 @@ def __getitem__(self, index): src_item = self.src[index][:-1] example = { - 'id': index, - 'source': src_item, - 'target': tgt_item, + "id": index, + "source": src_item, + "target": tgt_item, } if self.align_dataset is not None: - example['alignment'] = self.align_dataset[index] + example["alignment"] = self.align_dataset[index] if self.constraints is not None: example["constraints"] = self.constraints[index] return example @@ -352,27 +388,33 @@ def collater(self, samples, pad_to_length=None): pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: - src_tokens = res['net_input']['src_tokens'] + src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: - res['net_input']['src_lang_id'] = torch.LongTensor( - [[self.src_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) if self.tgt_lang_id is not None: - res['tgt_lang_id'] = torch.LongTensor( - [[self.tgt_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) return res def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" - return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return max( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based @@ -384,22 +426,19 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind='mergesort') - ] - return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[ - np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") ] @property def supports_prefetch(self): - return ( - getattr(self.src, 'supports_prefetch', False) - and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) + return getattr(self.src, "supports_prefetch", False) and ( + getattr(self.tgt, "supports_prefetch", False) or self.tgt is None ) def prefetch(self, indices): @@ -410,7 +449,7 @@ def prefetch(self, indices): self.align_dataset.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/fairseq/data/legacy/__init__.py b/fairseq/data/legacy/__init__.py index 1acaafeb09..9bd5c72b5e 100644 --- a/fairseq/data/legacy/__init__.py +++ b/fairseq/data/legacy/__init__.py @@ -3,13 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary from .block_pair_dataset import BlockPairDataset from .masked_lm_dataset import MaskedLMDataset +from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary + __all__ = [ - 'BertDictionary', - 'BlockPairDataset', - 'MaskedLMDataset', - 'MaskedLMDictionary', + "BertDictionary", + "BlockPairDataset", + "MaskedLMDataset", + "MaskedLMDictionary", ] diff --git a/fairseq/data/legacy/block_pair_dataset.py b/fairseq/data/legacy/block_pair_dataset.py index b9fc814147..ba069b4605 100644 --- a/fairseq/data/legacy/block_pair_dataset.py +++ b/fairseq/data/legacy/block_pair_dataset.py @@ -7,7 +7,6 @@ import numpy as np import torch - from fairseq.data import FairseqDataset diff --git a/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/data/legacy/masked_lm_dataset.py index 953aa85dd4..dd8ea2c60a 100644 --- a/fairseq/data/legacy/masked_lm_dataset.py +++ b/fairseq/data/legacy/masked_lm_dataset.py @@ -4,18 +4,14 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Dict, List, Tuple import numpy as np import torch - -from typing import Dict, List, Tuple - -from fairseq.data import FairseqDataset, data_utils - -from fairseq.data import Dictionary +from fairseq.data import Dictionary, FairseqDataset, data_utils +from fairseq.data.concat_dataset import ConcatDataset from fairseq.data.legacy.block_pair_dataset import BlockPairDataset from fairseq.data.token_block_dataset import TokenBlockDataset -from fairseq.data.concat_dataset import ConcatDataset class MaskedLMDataset(FairseqDataset): @@ -55,29 +51,31 @@ class MaskedLMDataset(FairseqDataset): """ def __init__( - self, - dataset: FairseqDataset, - sizes: np.ndarray, - vocab: Dictionary, - pad_idx: int, - mask_idx: int, - classif_token_idx: int, - sep_token_idx: int, - seed: int = 1, - shuffle: bool = True, - has_pairs: bool = True, - segment_id: int = 0, - masking_ratio: float = 0.15, - masking_prob: float = 0.8, - random_token_prob: float = 0.1 + self, + dataset: FairseqDataset, + sizes: np.ndarray, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + classif_token_idx: int, + sep_token_idx: int, + seed: int = 1, + shuffle: bool = True, + has_pairs: bool = True, + segment_id: int = 0, + masking_ratio: float = 0.15, + masking_prob: float = 0.8, + random_token_prob: float = 0.1, ): # Make sure the input datasets are the ones supported assert ( - isinstance(dataset, TokenBlockDataset) or - isinstance(dataset, BlockPairDataset) or - isinstance(dataset, ConcatDataset) - ), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " \ - "ConcatDataset" + isinstance(dataset, TokenBlockDataset) + or isinstance(dataset, BlockPairDataset) + or isinstance(dataset, ConcatDataset) + ), ( + "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " + "ConcatDataset" + ) self.dataset = dataset self.sizes = np.array(sizes) @@ -99,10 +97,7 @@ def __init__( if not has_pairs: self.sizes = self.sizes + 1 - def __getitem__( - self, - index: int - ): + def __getitem__(self, index: int): # if has_pairs, then expect 2 blocks and a sentence target if self.has_pairs: (block_one, block_two, sentence_target) = self.dataset[index] @@ -120,11 +115,11 @@ def __len__(self): return len(self.dataset) def _mask_block( - self, - sentence: np.ndarray, - mask_idx: int, - pad_idx: int, - dictionary_token_range: Tuple, + self, + sentence: np.ndarray, + mask_idx: int, + pad_idx: int, + dictionary_token_range: Tuple, ): """ Mask tokens for Masked Language Model training @@ -166,22 +161,15 @@ def _mask_block( # masking_prob + random_token_prob (Eg: 0.9) elif rand < (self.masking_prob + self.random_token_prob): # sample random token from dictionary - masked_sent[i] = ( - np.random.randint( - dictionary_token_range[0], dictionary_token_range[1] - ) + masked_sent[i] = np.random.randint( + dictionary_token_range[0], dictionary_token_range[1] ) else: target[i] = pad_idx return masked_sent, target - def _collate( - self, - samples: List[Dict], - pad_idx: int, - eos_idx: int - ): + def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int): """ Does the heavy lifting for creating a batch from the input list of examples. The logic is as follows: @@ -215,12 +203,13 @@ def _collate( # mask according to specified probabilities. masked_blk_one, masked_tgt_one = self._mask_block( - s["block_one"], self.mask_idx, self.pad_idx, token_range, + s["block_one"], + self.mask_idx, + self.pad_idx, + token_range, ) - tokens = np.concatenate([ - [self.classif_token_idx], masked_blk_one - ]) + tokens = np.concatenate([[self.classif_token_idx], masked_blk_one]) targets = np.concatenate([[self.pad_idx], masked_tgt_one]) segments = np.ones(len(tokens)) * self.segment_id @@ -232,9 +221,9 @@ def _collate( targets_one = np.concatenate([targets, [self.pad_idx]]) masked_blk_two, masked_tgt_two = self._mask_block( - s["block_two"], self.mask_idx, self.pad_idx, token_range) - tokens_two = np.concatenate( - [masked_blk_two, [self.sep_token_idx]]) + s["block_two"], self.mask_idx, self.pad_idx, token_range + ) + tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]]) targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) # block + 1 sep + 1 special (CLS) @@ -254,6 +243,7 @@ def merge(key): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad=False ) + return { "id": torch.LongTensor([s["id"] for s in samples]), "ntokens": sum(len(s["source"]) for s in samples), @@ -262,16 +252,13 @@ def merge(key): "segment_labels": merge("segment_labels"), }, "lm_target": merge("lm_target"), - "sentence_target": torch.LongTensor( - [s["sentence_target"] for s in samples] - ) if self.has_pairs else None, + "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples]) + if self.has_pairs + else None, "nsentences": len(samples), } - def collater( - self, - samples: List[Dict] - ): + def collater(self, samples: List[Dict]): """Merge a list of samples to form a mini-batch. Args: @@ -282,20 +269,14 @@ def collater( """ return self._collate(samples, self.vocab.pad(), self.vocab.eos()) - def num_tokens( - self, - index: int - ): + def num_tokens(self, index: int): """ Return the number of tokens in a sample. This value is used to enforce max-tokens during batching. """ return self.sizes[index] - def size( - self, - index: int - ): + def size(self, index: int): """ Return an example's size as a float or tuple. This value is used when filtering a dataset with max-positions. diff --git a/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/data/legacy/masked_lm_dictionary.py index bff4bcb5ec..dee88f7a3e 100644 --- a/fairseq/data/legacy/masked_lm_dictionary.py +++ b/fairseq/data/legacy/masked_lm_dictionary.py @@ -11,12 +11,13 @@ class MaskedLMDictionary(Dictionary): Dictionary for Masked Language Modelling tasks. This extends Dictionary by adding the mask symbol. """ + def __init__( self, - pad='', - eos='', - unk='', - mask='', + pad="", + eos="", + unk="", + mask="", ): super().__init__(pad=pad, eos=eos, unk=unk) self.mask_word = mask @@ -33,14 +34,15 @@ class BertDictionary(MaskedLMDictionary): Dictionary for BERT task. This extends MaskedLMDictionary by adding support for cls and sep symbols. """ + def __init__( self, - pad='', - eos='', - unk='', - mask='', - cls='', - sep='' + pad="", + eos="", + unk="", + mask="", + cls="", + sep="", ): super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) self.cls_word = cls diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py index b96bba3437..12f00aa436 100644 --- a/fairseq/data/list_dataset.py +++ b/fairseq/data/list_dataset.py @@ -7,7 +7,6 @@ class ListDataset(BaseWrapperDataset): - def __init__(self, dataset, sizes=None): super().__init__(dataset) self._sizes = sizes diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py index 17ba08bc7f..29ad887b7d 100644 --- a/fairseq/data/lm_context_window_dataset.py +++ b/fairseq/data/lm_context_window_dataset.py @@ -5,7 +5,6 @@ import numpy as np import torch - from fairseq.data.monolingual_dataset import MonolingualDataset from . import FairseqDataset @@ -35,11 +34,11 @@ def collater(self, samples): pad = self.pad_idx max_sample_len = self.tokens_per_sample + self.context_window - bsz, tsz = sample['net_input']['src_tokens'].shape + bsz, tsz = sample["net_input"]["src_tokens"].shape start_idxs = [0] * bsz - toks = sample['net_input']['src_tokens'] - lengths = sample['net_input']['src_lengths'] - tgt = sample['target'] + toks = sample["net_input"]["src_tokens"] + lengths = sample["net_input"]["src_lengths"] + tgt = sample["target"] new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64) new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64) sample_lens = toks.ne(pad).long().sum(dim=1).cpu() @@ -50,13 +49,15 @@ def collater(self, samples): self.prev_tokens = self.prev_tokens[extra:] pads = np.full(self.context_window - len(self.prev_tokens), pad) new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) - new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i] + new_tgt[ + i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i]) + ] = tgt[i] start_idxs[i] = len(self.prev_tokens) lengths[i] += len(self.prev_tokens) - self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:] - sample['net_input']['src_tokens'] = torch.from_numpy(new_toks) - sample['target'] = torch.from_numpy(new_tgt) - sample['start_indices'] = start_idxs + self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :] + sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks) + sample["target"] = torch.from_numpy(new_tgt) + sample["start_indices"] = start_idxs return sample @@ -72,7 +73,7 @@ def ordered_indices(self): @property def supports_prefetch(self): - return getattr(self.dataset, 'supports_prefetch', False) + return getattr(self.dataset, "supports_prefetch", False) def prefetch(self, indices): return self.dataset.prefetch(indices) diff --git a/fairseq/data/lru_cache_dataset.py b/fairseq/data/lru_cache_dataset.py index 833a2c75cb..a7854ac170 100644 --- a/fairseq/data/lru_cache_dataset.py +++ b/fairseq/data/lru_cache_dataset.py @@ -9,7 +9,6 @@ class LRUCacheDataset(BaseWrapperDataset): - def __init__(self, dataset, token=None): super().__init__(dataset) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 31f5459307..8ea86245f7 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -7,8 +7,7 @@ import numpy as np import torch - -from fairseq.data import data_utils, Dictionary +from fairseq.data import Dictionary, data_utils from . import BaseWrapperDataset, LRUCacheDataset @@ -86,7 +85,7 @@ def __init__( weights = np.array(self.vocab.count) else: weights = np.ones(len(self.vocab)) - weights[:self.vocab.nspecial] = 0 + weights[: self.vocab.nspecial] = 0 self.weights = weights / weights.sum() self.epoch = 0 @@ -105,10 +104,11 @@ def __getitem__(self, index: int): item = self.dataset[index] sz = len(item) - assert self.mask_idx not in item, \ - 'Dataset contains mask_idx (={}), this is not expected!'.format( - self.mask_idx, - ) + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) @@ -122,7 +122,8 @@ def __getitem__(self, index: int): mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding - self.mask_prob * sz + np.random.rand() + self.mask_prob * sz + + np.random.rand() ) mask[np.random.choice(sz, num_mask, replace=False)] = True diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index 76c3772374..ec73f1fda8 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -6,7 +6,7 @@ import numpy as np import torch -from . import data_utils, FairseqDataset +from . import FairseqDataset, data_utils def collate(samples, pad_idx, eos_idx): @@ -17,33 +17,39 @@ def merge(key, is_list=False): if is_list: res = [] for i in range(len(samples[0][key])): - res.append(data_utils.collate_tokens( - [s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False, - )) + res.append( + data_utils.collate_tokens( + [s[key][i] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + ) + ) return res else: return data_utils.collate_tokens( - [s[key] for s in samples], pad_idx, eos_idx, left_pad=False, + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad=False, ) - src_tokens = merge('source') - if samples[0]['target'] is not None: - is_target_list = isinstance(samples[0]['target'], list) - target = merge('target', is_target_list) + src_tokens = merge("source") + if samples[0]["target"] is not None: + is_target_list = isinstance(samples[0]["target"], list) + target = merge("target", is_target_list) else: target = src_tokens return { - 'id': torch.LongTensor([s['id'] for s in samples]), - 'nsentences': len(samples), - 'ntokens': sum(len(s['source']) for s in samples), - 'net_input': { - 'src_tokens': src_tokens, - 'src_lengths': torch.LongTensor([ - s['source'].numel() for s in samples - ]), + "id": torch.LongTensor([s["id"] for s in samples]), + "nsentences": len(samples), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": src_tokens, + "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]), }, - 'target': target, + "target": target, } @@ -59,8 +65,17 @@ class MonolingualDataset(FairseqDataset): (default: True). """ - def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, - targets=None, add_bos_token=False): + def __init__( + self, + dataset, + sizes, + src_vocab, + tgt_vocab, + add_eos_for_other_targets, + shuffle, + targets=None, + add_bos_token=False, + ): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = src_vocab @@ -69,8 +84,9 @@ def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targe self.shuffle = shuffle self.add_bos_token = add_bos_token - assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \ - "targets must be none or one of 'self', 'future', 'past'" + assert targets is None or all( + t in {"self", "future", "past"} for t in targets + ), "targets must be none or one of 'self', 'future', 'past'" if targets is not None and len(targets) == 0: targets = None self.targets = targets @@ -86,12 +102,14 @@ def __getitem__(self, index): # Right-to-left language models should condition on *source* and # predict *past_target*. source, future_target, past_target = self.dataset[index] - source, target = self._make_source_target(source, future_target, past_target) + source, target = self._make_source_target( + source, future_target, past_target + ) else: source = self.dataset[index] target = None source, target = self._maybe_add_bos(source, target) - return {'id': index, 'source': source, 'target': target} + return {"id": index, "source": source, "target": target} def __len__(self): return len(self.dataset) @@ -100,27 +118,38 @@ def _make_source_target(self, source, future_target, past_target): if self.targets is not None: target = [] - if self.add_eos_for_other_targets and (('self' in self.targets) or ('past' in self.targets)) \ - and source[-1] != self.vocab.eos(): + if ( + self.add_eos_for_other_targets + and (("self" in self.targets) or ("past" in self.targets)) + and source[-1] != self.vocab.eos() + ): # append eos at the end of source source = torch.cat([source, source.new([self.vocab.eos()])]) - if 'future' in self.targets: - future_target = torch.cat([future_target, future_target.new([self.vocab.pad()])]) - if 'past' in self.targets: + if "future" in self.targets: + future_target = torch.cat( + [future_target, future_target.new([self.vocab.pad()])] + ) + if "past" in self.targets: # first token is before the start of sentence which is only used in "none" break mode when # add_eos_for_other_targets is False - past_target = torch.cat([past_target.new([self.vocab.pad()]), past_target[1:], source[-2, None]]) + past_target = torch.cat( + [ + past_target.new([self.vocab.pad()]), + past_target[1:], + source[-2, None], + ] + ) for t in self.targets: - if t == 'self': + if t == "self": target.append(source) - elif t == 'future': + elif t == "future": target.append(future_target) - elif t == 'past': + elif t == "past": target.append(past_target) else: - raise Exception('invalid target ' + t) + raise Exception("invalid target " + t) if len(target) == 1: target = target[0] @@ -138,6 +167,7 @@ def _maybe_add_bos(self, source, target): def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): + def _filter(target): mask = target.ge(len(self.tgt_vocab)) if mask.any(): @@ -194,7 +224,7 @@ def ordered_indices(self): @property def supports_prefetch(self): - return getattr(self.dataset, 'supports_prefetch', False) + return getattr(self.dataset, "supports_prefetch", False) def prefetch(self, indices): self.dataset.prefetch(indices) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 7ce269a4df..8c14f4e3ad 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -6,9 +6,9 @@ import itertools import json import logging +import math import os from collections import OrderedDict, defaultdict -import math from fairseq import utils from fairseq.data import ( @@ -197,7 +197,7 @@ def add_args(parser): ) parser.add_argument( "--fixed-dictionary", - help='Fixed dictionary to use with model path', + help="Fixed dictionary to use with model path", default=None, type=str, ) @@ -266,7 +266,9 @@ def load_langs(cls, args, **kwargs): langs = sorted(langs) logger.info(f"inferred language list: {langs}") elif args.lang_dict: - with open(PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8") as f: + with open( + PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8" + ) as f: langs = [lang.strip() for lang in f.readlines() if lang.strip()] logger.info( f"loaded language list from {args.lang_dict} as they are ordered in file" @@ -292,7 +294,9 @@ def estimate_global_pass_epoch(self, epoch): if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None: return None # one epoch more for remaining data in each shard - virtual_epochs_per_shard = math.ceil(self.args.virtual_data_size / self.args.virtual_epoch_size) + virtual_epochs_per_shard = math.ceil( + self.args.virtual_data_size / self.args.virtual_epoch_size + ) # note that fairseq epoch / shard_epoch starts from 1 shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1 return shard_epoch @@ -809,7 +813,7 @@ def _get_shard_num_dict(cls, split, paths): for f in files: if f.startswith(split) and f.endswith(".idx"): # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" - direction = f.split('.')[-3] + direction = f.split(".")[-3] directions.add(direction) for direction in directions: shards[direction] += 1 diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 5270675124..3f544b099f 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -3,25 +3,25 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List -from enum import Enum -from collections import OrderedDict -from collections import defaultdict -from bisect import bisect_right +import datetime import hashlib import logging -import datetime import time +from bisect import bisect_right +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List import numpy as np import torch - from fairseq import distributed_utils from fairseq.data import FairseqDataset, data_utils def get_time_gap(s, e): - return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__() + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ def __init__( eval_key=None, collate_format=CollateFormat.single, virtual_size=default_virtual_size_func, - split='', + split="", shared_collater=False, shuffle=True, ): @@ -126,9 +126,7 @@ def _clean_if_not_none(self, var_list): del v def _reset_cached_properties(self): - self._clean_if_not_none([ - self._sizes, self._cur_indices - ]) + self._clean_if_not_none([self._sizes, self._cur_indices]) self._sizes = None self._cur_indices = None @@ -142,10 +140,14 @@ def setup_sampling(self, sample_ratios, virtual_size): if not isinstance(sample_ratios, np.ndarray): sample_ratios = np.array(sample_ratios) self.sample_ratios = sample_ratios - virtual_size = default_virtual_size_func if virtual_size is None else virtual_size + virtual_size = ( + default_virtual_size_func if virtual_size is None else virtual_size + ) self.virtual_size = ( - virtual_size(self.datasets, self.sample_ratios) if callable(virtual_size) - else virtual_size) + virtual_size(self.datasets, self.sample_ratios) + if callable(virtual_size) + else virtual_size + ) def adjust_sampling(self, epoch, sampling_ratios, virtual_size): if sampling_ratios is not None: @@ -166,10 +168,12 @@ def _sync_sample_ratios(self, ratios): return ret def random_choice_in_dataset(self, rng, dataset, choice_size): - if hasattr(dataset, 'random_choice_in_dataset'): + if hasattr(dataset, "random_choice_in_dataset"): return dataset.random_choice_in_dataset(rng, choice_size) dataset_size = len(dataset) - return rng.choice(dataset_size, choice_size, replace=(choice_size > dataset_size)) + return rng.choice( + dataset_size, choice_size, replace=(choice_size > dataset_size) + ) def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size): def get_counts(sample_ratios): @@ -178,7 +182,9 @@ def get_counts(sample_ratios): assert diff >= 0 # due to round-offs, the size might not match the desired sizes if diff > 0: - dataset_indices = rng.choice(len(sample_ratios), size=diff, p=sample_ratios) + dataset_indices = rng.choice( + len(sample_ratios), size=diff, p=sample_ratios + ) for i in dataset_indices: counts[i] += 1 return counts @@ -189,7 +195,8 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios): # if the desired counts are large, sample with replacement: indices = [ self.random_choice_in_dataset(rng, d, c) - for c, d in zip(counts, datasets)] + for c, d in zip(counts, datasets) + ] return indices sizes = [len(d) for d in datasets] @@ -207,8 +214,8 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios): assert cumulative_sizes[-1] == virtual_size if virtual_size < sum(sizes): logger.warning( - f'virtual data size ({virtual_size}) is less than real data size ({sum(sizes)}).' - ' If virtual size << real data size, there could be data coverage issue.' + f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})." + " If virtual size << real data size, there could be data coverage issue." ) in_dataset_indices = np.hstack(in_dataset_indices) return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset @@ -237,26 +244,34 @@ def collater(self, samples, **extra_args): """Merge a list of samples to form a mini-batch.""" if len(samples) == 0: return None - if self.collate_format == 'ordered_dict': + if self.collate_format == "ordered_dict": collect_samples = [[] for _ in range(len(self.datasets))] for (i, sample) in samples: collect_samples[i].append(sample) - batch = OrderedDict([ - (self.keys[i], dataset.collater(collect_samples[i])) - for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) - if len(collect_samples[i]) > 0 - ]) - elif self.shared_collater: - batch = self.datasets[0].collater( - [s for _, s in samples] + batch = OrderedDict( + [ + (self.keys[i], dataset.collater(collect_samples[i])) + for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) + if len(collect_samples[i]) > 0 + ] ) + elif self.shared_collater: + batch = self.datasets[0].collater([s for _, s in samples]) else: samples_dict = defaultdict(list) - pad_to_length = defaultdict(int) if 'pad_to_length' not in extra_args else extra_args['pad_to_length'] + pad_to_length = ( + defaultdict(int) + if "pad_to_length" not in extra_args + else extra_args["pad_to_length"] + ) for ds_idx, s in samples: - pad_to_length['source'] = max(pad_to_length['source'], s['source'].size(0)) - if s['target'] is not None: - pad_to_length['target'] = max(pad_to_length['target'], s['target'].size(0)) + pad_to_length["source"] = max( + pad_to_length["source"], s["source"].size(0) + ) + if s["target"] is not None: + pad_to_length["target"] = max( + pad_to_length["target"], s["target"].size(0) + ) samples_dict[ds_idx].append(s) batches = [ self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length) @@ -268,7 +283,9 @@ def straight_data(tensors): batch = torch.cat(tensors, dim=0) return batch - src_lengths = straight_data([b['net_input']['src_lengths'] for b in batches]) + src_lengths = straight_data( + [b["net_input"]["src_lengths"] for b in batches] + ) src_lengths, sort_order = src_lengths.sort(descending=True) def straight_order(tensors): @@ -276,22 +293,31 @@ def straight_order(tensors): return batch.index_select(0, sort_order) batch = { - 'id': straight_order([b['id'] for b in batches]), - 'nsentences': sum(b['nsentences'] for b in batches), - 'ntokens': sum(b['ntokens'] for b in batches), - 'net_input': { - 'src_tokens': straight_order([b['net_input']['src_tokens'] for b in batches]), - 'src_lengths': src_lengths, + "id": straight_order([b["id"] for b in batches]), + "nsentences": sum(b["nsentences"] for b in batches), + "ntokens": sum(b["ntokens"] for b in batches), + "net_input": { + "src_tokens": straight_order( + [b["net_input"]["src_tokens"] for b in batches] + ), + "src_lengths": src_lengths, }, - 'target': straight_order([b['target'] for b in batches]) if batches[0]['target'] is not None else None, + "target": straight_order([b["target"] for b in batches]) + if batches[0]["target"] is not None + else None, } - if 'prev_output_tokens' in batches[0]['net_input']: - batch['net_input']['prev_output_tokens'] = straight_order( - [b['net_input']['prev_output_tokens'] for b in batches]) - if 'src_lang_id' in batches[0]['net_input']: - batch['net_input']['src_lang_id'] = straight_order([b['net_input']['src_lang_id'] for b in batches]) - if 'tgt_lang_id' in batches[0]: - batch['tgt_lang_id'] = straight_order([b['tgt_lang_id'] for b in batches]) + if "prev_output_tokens" in batches[0]["net_input"]: + batch["net_input"]["prev_output_tokens"] = straight_order( + [b["net_input"]["prev_output_tokens"] for b in batches] + ) + if "src_lang_id" in batches[0]["net_input"]: + batch["net_input"]["src_lang_id"] = straight_order( + [b["net_input"]["src_lang_id"] for b in batches] + ) + if "tgt_lang_id" in batches[0]: + batch["tgt_lang_id"] = straight_order( + [b["tgt_lang_id"] for b in batches] + ) return batch @property @@ -300,7 +326,9 @@ def sizes(self): return self._sizes start_time = time.time() in_sub_dataset_indices = [ - self._cur_indices[0 if i == 0 else self.cumulated_sizes[i-1]:self.cumulated_sizes[i]] + self._cur_indices[ + 0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i] + ] for i in range(len(self.datasets)) ] sub_dataset_sizes = [ @@ -308,7 +336,7 @@ def sizes(self): for d, indices in zip(self.datasets, in_sub_dataset_indices) ] self._sizes = np.vstack(sub_dataset_sizes) - logger.info(f'sizes() calling time: {get_time_gap(start_time, time.time())}') + logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}") return self._sizes def ordered_indices(self): @@ -319,14 +347,14 @@ def ordered_indices(self): sizes = self.sizes tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None - src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) # sort by target length, then source length if tgt_sizes is not None: - indices = indices[ - np.argsort(tgt_sizes[indices], kind='mergesort') - ] - sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')] + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")] return sort_indices def prefetch(self, indices): @@ -347,7 +375,7 @@ def set_epoch(self, epoch): # re-enter so return return for d in self.datasets: - if hasattr(d, 'set_epoch'): + if hasattr(d, "set_epoch"): d.set_epoch(epoch) self._cur_epoch = epoch self._establish_virtual_datasets() @@ -362,37 +390,52 @@ def _establish_virtual_datasets(self): # Generate a weighted sample of indices as a function of the # random seed and the current epoch. rng = np.random.RandomState( - [ - int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32), - self.seed % (2 ** 32), # global seed - self._cur_epoch, # epoch index, - ] + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2 ** 32), + self.seed % (2 ** 32), # global seed + self._cur_epoch, # epoch index, + ] + ) + self._clean_if_not_none( + [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes] ) - self._clean_if_not_none([ - self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes - ]) self._sizes = None indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( - rng, self.datasets, self.sample_ratios, self.virtual_size) + rng, self.datasets, self.sample_ratios, self.virtual_size + ) self._cur_indices = indices self.cumulated_sizes = cumulated_sizes self.virtual_size_per_dataset = virtual_size_per_dataset raw_sizes = [len(d) for d in self.datasets] sampled_sizes = self.virtual_size_per_dataset - logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; ' - f'raw total size: {sum(raw_sizes)}') - logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; ' - f'resampled total size: {sum(sampled_sizes)}') + logger.info( + f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; " + f"raw total size: {sum(raw_sizes)}" + ) + logger.info( + f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; " + f"resampled total size: {sum(sampled_sizes)}" + ) if self.sample_ratios is not None: - logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}') + logger.info( + f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}" + ) else: - logger.info(f'[{self.split}] A concat dataset') - logger.info(f'[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}') + logger.info(f"[{self.split}] A concat dataset") + logger.info( + f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}" + ) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: @@ -406,6 +449,10 @@ def filter_indices_by_size(self, indices, max_sizes): """ sizes = self.sizes tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None - src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) - return data_utils.filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes) + return data_utils.filter_paired_dataset_indices_by_size( + src_sizes, tgt_sizes, indices, max_sizes + ) diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py index 81ff78f705..17387b2f85 100644 --- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -4,12 +4,13 @@ # LICENSE file in the root directory of this source tree. import hashlib -import math import logging +import math import numpy as np from fairseq.data import SampledMultiDataset -from .sampled_multi_dataset import default_virtual_size_func, CollateFormat + +from .sampled_multi_dataset import CollateFormat, default_virtual_size_func logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ class SampledMultiEpochDataset(SampledMultiDataset): shard_epoch (int): the real epoch number for shard selection. shuffle (bool): whether or not to shuffle data (default: True). """ + def __init__( self, datasets, @@ -53,7 +55,7 @@ def __init__( eval_key=None, collate_format=CollateFormat.single, virtual_size=default_virtual_size_func, - split='', + split="", virtual_epoch_size=None, shared_collater=False, shard_epoch=1, @@ -79,14 +81,22 @@ def __init__( ) def _setup(self, epoch): - self.virtual_epoch_size = self.virtual_epoch_size if self.virtual_epoch_size is not None else self.virtual_size + self.virtual_epoch_size = ( + self.virtual_epoch_size + if self.virtual_epoch_size is not None + else self.virtual_size + ) if self.virtual_epoch_size > self.virtual_size: - logger.warning(f'virtual epoch size {self.virtual_epoch_size} ' - f'is greater than virtual dataset size {self.virtual_size}') + logger.warning( + f"virtual epoch size {self.virtual_epoch_size} " + f"is greater than virtual dataset size {self.virtual_size}" + ) self.virtual_epoch_size = self.virtual_size self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) self._current_epoch_start_index = self._get_epoch_start_index(epoch) - logger.info(f'virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}') + logger.info( + f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}" + ) def _map_epoch_index_to_global(self, index): index = self._current_epoch_start_index + index @@ -99,7 +109,8 @@ def sizes(self): return self._epoch_sizes _sizes = super().sizes indices = self._random_global_indices[ - self._current_epoch_start_index:self._current_epoch_start_index + len(self) + self._current_epoch_start_index : self._current_epoch_start_index + + len(self) ] self._epoch_sizes = _sizes[indices] # del super()._sizes to save memory @@ -114,7 +125,8 @@ def _get_dataset_and_index(self, index): def __len__(self): return ( self.virtual_epoch_size - if self._current_epoch_start_index + self.virtual_epoch_size < self.virtual_size + if self._current_epoch_start_index + self.virtual_epoch_size + < self.virtual_size else self.virtual_size - self._current_epoch_start_index ) @@ -136,38 +148,52 @@ def _get_epoch_start_index(self, epoch): def _next_global_indices(self, epoch): rng = np.random.RandomState( - [ - int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32), - self.seed % (2 ** 32), # global seed - epoch, # epoch index, - ] + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2 ** 32), + self.seed % (2 ** 32), # global seed + epoch, # epoch index, + ] ) del self._random_global_indices - self._random_global_indices = rng.choice(self.virtual_size, self.virtual_size, replace=False) + self._random_global_indices = rng.choice( + self.virtual_size, self.virtual_size, replace=False + ) if self.load_next_shard is None: self.load_next_shard = False else: # increase shard epoch for next loading self.shard_epoch += 1 self.load_next_shard = True - logger.info('to load next epoch/shard in next load_dataset: ' - f'epoch={epoch}/shard_epoch={self.shard_epoch}') + logger.info( + "to load next epoch/shard in next load_dataset: " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) def _next_virtual_epoch(self, epoch): index = self._get_epoch_start_index(epoch) if index == 0 or self._random_global_indices is None: # need to start from the beginning, # so call super().set_epoch(epoch) to establish the global virtual indices - logger.info('establishing a new set of global virtual indices for ' - f'epoch={epoch}/shard_epoch={self.shard_epoch}') + logger.info( + "establishing a new set of global virtual indices for " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) super().set_epoch(epoch) self._next_global_indices(epoch) else: self._cur_epoch = epoch # reset cache sizes and ordered_indices for the epoch after moving to a new epoch - self._clean_if_not_none([ - self._epoch_sizes, - ]) + self._clean_if_not_none( + [ + self._epoch_sizes, + ] + ) self._epoch_sizes = None self._current_epoch_start_index = index diff --git a/fairseq/data/multilingual/sampling_method.py b/fairseq/data/multilingual/sampling_method.py index 6a9d39f7a6..140c68f01d 100644 --- a/fairseq/data/multilingual/sampling_method.py +++ b/fairseq/data/multilingual/sampling_method.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List import logging +from typing import List logger = logging.getLogger(__name__) @@ -16,18 +16,20 @@ def uniform(dataset_sizes: List[int]): def temperature_sampling(dataset_sizes, temp): total_size = sum(dataset_sizes) - return [(size / total_size) ** (1.0/temp) for size in dataset_sizes] + return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] def make_temperature_sampling(temp=1.0): def sampling_func(dataset_sizes): return temperature_sampling(dataset_sizes, temp) + return sampling_func def make_ratio_sampling(ratios): def sampling_func(dataset_sizes): return ratios + return sampling_func @@ -35,13 +37,23 @@ class SamplingMethod: @staticmethod def add_arguments(parser): parser.add_argument( - '--sampling-method', - choices=['uniform', 'temperature', 'concat', 'RoundRobin', ], + "--sampling-method", + choices=[ + "uniform", + "temperature", + "concat", + "RoundRobin", + ], type=str, - default='concat', - help='The method to sample data per language pairs') - parser.add_argument('--sampling-temperature', default=1.5, type=float, - help='only work with --sampling-method temperature') + default="concat", + help="The method to sample data per language pairs", + ) + parser.add_argument( + "--sampling-temperature", + default=1.5, + type=float, + help="only work with --sampling-method temperature", + ) @staticmethod def build_sampler(args, task): @@ -56,10 +68,10 @@ def is_adaptive(self): def sampling_method_selector(self): args = self.args - logger.info(f'selected sampler: {args.sampling_method}') - if args.sampling_method == 'uniform': + logger.info(f"selected sampler: {args.sampling_method}") + if args.sampling_method == "uniform": return uniform - elif args.sampling_method == 'temperature' or self.is_adaptive(): + elif args.sampling_method == "temperature" or self.is_adaptive(): return make_temperature_sampling(float(args.sampling_temperature)) else: # default to concating all data set together diff --git a/fairseq/data/nested_dictionary_dataset.py b/fairseq/data/nested_dictionary_dataset.py index ebc56303b9..52e74abdda 100644 --- a/fairseq/data/nested_dictionary_dataset.py +++ b/fairseq/data/nested_dictionary_dataset.py @@ -15,14 +15,14 @@ def _flatten(dico, prefix=None): """Flatten a nested dictionary.""" new_dico = OrderedDict() if isinstance(dico, dict): - prefix = prefix + '.' if prefix is not None else '' + prefix = prefix + "." if prefix is not None else "" for k, v in dico.items(): if v is None: continue new_dico.update(_flatten(v, prefix + k)) elif isinstance(dico, list): for i, v in enumerate(dico): - new_dico.update(_flatten(v, prefix + '.[' + str(i) + ']')) + new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) else: new_dico = OrderedDict({prefix: dico}) return new_dico @@ -32,10 +32,10 @@ def _unflatten(dico): """Unflatten a flattened dictionary into a nested dictionary.""" new_dico = OrderedDict() for full_k, v in dico.items(): - full_k = full_k.split('.') + full_k = full_k.split(".") node = new_dico for k in full_k[:-1]: - if k.startswith('[') and k.endswith(']'): + if k.startswith("[") and k.endswith("]"): k = int(k[1:-1]) if k not in node: node[k] = OrderedDict() @@ -45,7 +45,6 @@ def _unflatten(dico): class NestedDictionaryDataset(FairseqDataset): - def __init__(self, defn, sizes=None): super().__init__() self.defn = _flatten(defn) @@ -53,11 +52,17 @@ def __init__(self, defn, sizes=None): first = None for v in self.defn.values(): - if not isinstance(v, (FairseqDataset, torch.utils.data.Dataset, )): - raise ValueError('Expected Dataset but found: {}'.format(v.__class__)) + if not isinstance( + v, + ( + FairseqDataset, + torch.utils.data.Dataset, + ), + ): + raise ValueError("Expected Dataset but found: {}".format(v.__class__)) first = first or v if len(v) > 0: - assert len(v) == len(first), 'dataset lengths must match' + assert len(v) == len(first), "dataset lengths must match" self._len = len(first) @@ -107,7 +112,7 @@ def supports_prefetch(self): def prefetch(self, indices): """Prefetch the data required for this epoch.""" for ds in self.defn.values(): - if getattr(ds, 'supports_prefetch', False): + if getattr(ds, "supports_prefetch", False): ds.prefetch(indices) @property diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py index 5801ae6eac..9643d1aa6a 100644 --- a/fairseq/data/noising.py +++ b/fairseq/data/noising.py @@ -3,32 +3,34 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import numpy as np - +import torch from fairseq.data import data_utils class WordNoising(object): """Generate a noisy version of a sentence, without changing words themselves.""" + def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None): self.dictionary = dictionary self.bpe_end = None if bpe_cont_marker: - self.bpe_end = np.array([ - not self.dictionary[i].endswith(bpe_cont_marker) - for i in range(len(self.dictionary)) - ]) + self.bpe_end = np.array( + [ + not self.dictionary[i].endswith(bpe_cont_marker) + for i in range(len(self.dictionary)) + ] + ) elif bpe_end_marker: - self.bpe_end = np.array([ - self.dictionary[i].endswith(bpe_end_marker) - for i in range(len(self.dictionary)) - ]) + self.bpe_end = np.array( + [ + self.dictionary[i].endswith(bpe_end_marker) + for i in range(len(self.dictionary)) + ] + ) self.get_word_idx = ( - self._get_bpe_word_idx - if self.bpe_end is not None - else self._get_token_idx + self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx ) def noising(self, x, lengths, noising_prob=0.0): @@ -44,7 +46,7 @@ def _get_bpe_word_idx(self, x): # x: (T x B) bpe_end = self.bpe_end[x] - if (x.size(0) == 1 and x.size(1) == 1): + if x.size(0) == 1 and x.size(1) == 1: # Special case when we only have one word in x. If x = [[N]], # bpe_end is a scalar (bool) instead of a 2-dim array of bools, # which makes the sum operation below fail. @@ -70,7 +72,13 @@ class WordDropout(WordNoising): then dropped words will be removed. Otherwise, it will be replaced by the blank_idx.""" - def __init__(self, dictionary, default_dropout_prob=0.1, bpe_cont_marker="@@", bpe_end_marker=None): + def __init__( + self, + dictionary, + default_dropout_prob=0.1, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) self.default_dropout_prob = default_dropout_prob @@ -108,13 +116,12 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None): else: keep = np.random.rand(num_words) >= dropout_prob - words = x[:lengths[i], i].tolist() + words = x[: lengths[i], i].tolist() # TODO: speed up the following loop # drop words from the input according to keep new_s = [ - w if keep[word_idx[j, i]] else blank_idx - for j, w in enumerate(words) + w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words) ] new_s = [w for w in new_s if w is not None] # we need to have at least one word in the sentence (more than the @@ -132,11 +139,10 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None): # re-construct input modified_lengths = torch.LongTensor(modified_lengths) modified_x = torch.LongTensor( - modified_lengths.max(), - modified_lengths.size(0) + modified_lengths.max(), modified_lengths.size(0) ).fill_(self.dictionary.pad()) for i in range(modified_lengths.size(0)): - modified_x[:modified_lengths[i], i].copy_(torch.LongTensor(sentences[i])) + modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i])) return modified_x, modified_lengths @@ -144,7 +150,13 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None): class WordShuffle(WordNoising): """Shuffle words by no more than k positions.""" - def __init__(self, dictionary, default_max_shuffle_distance=3, bpe_cont_marker="@@", bpe_end_marker=None): + def __init__( + self, + dictionary, + default_max_shuffle_distance=3, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) self.default_max_shuffle_distance = 3 @@ -189,6 +201,7 @@ class UnsupervisedMTNoising(WordNoising): Implements the default configuration for noising in UnsupervisedMT (github.com/facebookresearch/UnsupervisedMT) """ + def __init__( self, dictionary, @@ -275,8 +288,13 @@ def __init__( self.src_dataset = src_dataset self.src_dict = src_dict self.seed = seed - self.noiser = noiser if noiser is not None else noising_class( - dictionary=src_dict, **kwargs, + self.noiser = ( + noiser + if noiser is not None + else noising_class( + dictionary=src_dict, + **kwargs, + ) ) def __getitem__(self, index): diff --git a/fairseq/data/num_samples_dataset.py b/fairseq/data/num_samples_dataset.py index 9d7ea44019..99a17495c7 100644 --- a/fairseq/data/num_samples_dataset.py +++ b/fairseq/data/num_samples_dataset.py @@ -7,7 +7,6 @@ class NumSamplesDataset(FairseqDataset): - def __getitem__(self, index): return 1 diff --git a/fairseq/data/numel_dataset.py b/fairseq/data/numel_dataset.py index 50087e5857..ac86dfd2f1 100644 --- a/fairseq/data/numel_dataset.py +++ b/fairseq/data/numel_dataset.py @@ -10,7 +10,6 @@ class NumelDataset(BaseWrapperDataset): - def __init__(self, dataset, reduce=False): super().__init__(dataset) self.reduce = reduce diff --git a/fairseq/data/offset_tokens_dataset.py b/fairseq/data/offset_tokens_dataset.py index a6fd559a30..6fabbdcdaa 100644 --- a/fairseq/data/offset_tokens_dataset.py +++ b/fairseq/data/offset_tokens_dataset.py @@ -7,7 +7,6 @@ class OffsetTokensDataset(BaseWrapperDataset): - def __init__(self, dataset, offset): super().__init__(dataset) self.offset = offset diff --git a/fairseq/data/pad_dataset.py b/fairseq/data/pad_dataset.py index 4c13b549aa..8075bba6a9 100644 --- a/fairseq/data/pad_dataset.py +++ b/fairseq/data/pad_dataset.py @@ -9,7 +9,6 @@ class PadDataset(BaseWrapperDataset): - def __init__(self, dataset, pad_idx, left_pad): super().__init__(dataset) self.pad_idx = pad_idx @@ -20,12 +19,10 @@ def collater(self, samples): class LeftPadDataset(PadDataset): - def __init__(self, dataset, pad_idx): super().__init__(dataset, pad_idx, left_pad=True) class RightPadDataset(PadDataset): - def __init__(self, dataset, pad_idx): super().__init__(dataset, pad_idx, left_pad=False) diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index 33f250eea9..2b12646783 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -33,6 +33,7 @@ def plasma(self): if self._plasma is None and not self.disable: try: import pyarrow.plasma as plasma + self._plasma = plasma except ImportError: self._plasma = None @@ -45,11 +46,15 @@ def start_server(self): assert self.path is None self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name - self._server = subprocess.Popen([ - 'plasma_store', - '-m', str(int(1.05 * self.array.nbytes)), - '-s', self.path, - ]) + self._server = subprocess.Popen( + [ + "plasma_store", + "-m", + str(int(1.05 * self.array.nbytes)), + "-s", + self.path, + ] + ) @property def client(self): @@ -65,11 +70,11 @@ def __getstate__(self): self.start_server() self.object_id = self.client.put(self.array) state = self.__dict__.copy() - del state['array'] - state['_client'] = None - state['_server'] = None - state['_server_tmp'] = None - state['_plasma'] = None + del state["array"] + state["_client"] = None + state["_server"] = None + state["_server_tmp"] = None + state["_plasma"] = None return state def __setstate__(self, state): diff --git a/fairseq/data/prepend_token_dataset.py b/fairseq/data/prepend_token_dataset.py index 9dac71badf..fd1331f4c4 100644 --- a/fairseq/data/prepend_token_dataset.py +++ b/fairseq/data/prepend_token_dataset.py @@ -10,7 +10,6 @@ class PrependTokenDataset(BaseWrapperDataset): - def __init__(self, dataset, token=None): super().__init__(dataset) self.token = token diff --git a/fairseq/data/raw_label_dataset.py b/fairseq/data/raw_label_dataset.py index e67170f1a5..d054904f41 100644 --- a/fairseq/data/raw_label_dataset.py +++ b/fairseq/data/raw_label_dataset.py @@ -9,7 +9,6 @@ class RawLabelDataset(FairseqDataset): - def __init__(self, labels): super().__init__() self.labels = labels diff --git a/fairseq/data/replace_dataset.py b/fairseq/data/replace_dataset.py index 3bc52f0fb5..5aac2ba96b 100644 --- a/fairseq/data/replace_dataset.py +++ b/fairseq/data/replace_dataset.py @@ -9,12 +9,12 @@ class ReplaceDataset(BaseWrapperDataset): """Replaces tokens found in the dataset by a specified replacement token - Args: - dataset (~torch.utils.data.Dataset): dataset to replace tokens in - replace_map(Dictionary[int,int]): map of token to replace -> replacement token - offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be - as many as the number of objects returned by the underlying dataset __getitem__ method. - """ + Args: + dataset (~torch.utils.data.Dataset): dataset to replace tokens in + replace_map(Dictionary[int,int]): map of token to replace -> replacement token + offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be + as many as the number of objects returned by the underlying dataset __getitem__ method. + """ def __init__(self, dataset, replace_map, offsets): super().__init__(dataset) diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py index ffb25ac668..3d3b993164 100644 --- a/fairseq/data/resampling_dataset.py +++ b/fairseq/data/resampling_dataset.py @@ -6,7 +6,6 @@ import logging import numpy as np - from fairseq.data import BaseWrapperDataset, plasma_utils @@ -112,7 +111,7 @@ def can_reuse_epoch_itr_across_epochs(self): return False def set_epoch(self, epoch): - logger.debug('ResamplingDataset.set_epoch: {}'.format(epoch)) + logger.debug("ResamplingDataset.set_epoch: {}".format(epoch)) super().set_epoch(epoch) if epoch == self._cur_epoch: diff --git a/fairseq/data/roll_dataset.py b/fairseq/data/roll_dataset.py index d07800d0f6..a2915eeb3e 100644 --- a/fairseq/data/roll_dataset.py +++ b/fairseq/data/roll_dataset.py @@ -9,7 +9,6 @@ class RollDataset(BaseWrapperDataset): - def __init__(self, dataset, shifts): super().__init__(dataset) self.shifts = shifts diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py index 5bfc966ce8..690823fc86 100644 --- a/fairseq/data/round_robin_zip_datasets.py +++ b/fairseq/data/round_robin_zip_datasets.py @@ -40,16 +40,19 @@ def __init__(self, datasets, eval_key=None): self._ordered_indices = None def _map_index(self, key, index): - assert self._ordered_indices is not None, \ - 'Must call RoundRobinZipDatasets.ordered_indices() first' + assert ( + self._ordered_indices is not None + ), "Must call RoundRobinZipDatasets.ordered_indices() first" return self._ordered_indices[key][index % len(self.datasets[key])] def __getitem__(self, index): if self.eval_key is None: - return OrderedDict([ - (key, dataset[self._map_index(key, index)]) - for key, dataset in self.datasets.items() - ]) + return OrderedDict( + [ + (key, dataset[self._map_index(key, index)]) + for key, dataset in self.datasets.items() + ] + ) else: # at evaluation time it's useful to pass-through batches from a single key return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] @@ -62,10 +65,12 @@ def collater(self, samples): if len(samples) == 0: return None if self.eval_key is None: - return OrderedDict([ - (key, dataset.collater([sample[key] for sample in samples])) - for key, dataset in self.datasets.items() - ]) + return OrderedDict( + [ + (key, dataset.collater([sample[key] for sample in samples])) + for key, dataset in self.datasets.items() + ] + ) else: # at evaluation time it's useful to pass-through batches from a single key return self.datasets[self.eval_key].collater(samples) @@ -92,16 +97,18 @@ def ordered_indices(self): # Call the underlying dataset's ordered_indices() here, so that we # get the same random ordering as we would have from using the # underlying dataset directly. - self._ordered_indices = OrderedDict([ - (key, dataset.ordered_indices()) - for key, dataset in self.datasets.items() - ]) + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) return np.arange(len(self)) @property def supports_prefetch(self): return all( - getattr(dataset, 'supports_prefetch', False) + getattr(dataset, "supports_prefetch", False) for dataset in self.datasets.values() ) diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py index 85659d101e..6ebb5d88fe 100644 --- a/fairseq/data/shorten_dataset.py +++ b/fairseq/data/shorten_dataset.py @@ -10,8 +10,7 @@ class TruncateDataset(BaseWrapperDataset): - """Truncate a sequence by returning the first truncation_length tokens - """ + """Truncate a sequence by returning the first truncation_length tokens""" def __init__(self, dataset, truncation_length): super().__init__(dataset) @@ -23,7 +22,7 @@ def __getitem__(self, index): item = self.dataset[index] item_len = item.size(0) if item_len > self.truncation_length: - item = item[:self.truncation_length] + item = item[: self.truncation_length] return item @property @@ -35,8 +34,7 @@ def __len__(self): class RandomCropDataset(TruncateDataset): - """Truncate a sequence by returning a random crop of truncation_length tokens - """ + """Truncate a sequence by returning a random crop of truncation_length tokens""" def __init__(self, dataset, truncation_length, seed=1): super().__init__(dataset, truncation_length) @@ -58,9 +56,10 @@ def __getitem__(self, index): excess = item_len - self.truncation_length if excess > 0: start_idx = np.random.randint(0, excess) - item = item[start_idx:start_idx+self.truncation_length] + item = item[start_idx : start_idx + self.truncation_length] return item + def maybe_shorten_dataset( dataset, split, @@ -69,10 +68,11 @@ def maybe_shorten_dataset( tokens_per_sample, seed, ): - truncate_split = split in shorten_data_split_list.split(',') \ - or len(shorten_data_split_list) == 0 - if shorten_method == 'truncate' and truncate_split: + truncate_split = ( + split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0 + ) + if shorten_method == "truncate" and truncate_split: dataset = TruncateDataset(dataset, tokens_per_sample) - elif shorten_method == 'random_crop' and truncate_split: + elif shorten_method == "random_crop" and truncate_split: dataset = RandomCropDataset(dataset, tokens_per_sample, seed) return dataset diff --git a/fairseq/data/sort_dataset.py b/fairseq/data/sort_dataset.py index 9b510b93a0..b3890e7279 100644 --- a/fairseq/data/sort_dataset.py +++ b/fairseq/data/sort_dataset.py @@ -9,7 +9,6 @@ class SortDataset(BaseWrapperDataset): - def __init__(self, dataset, sort_order): super().__init__(dataset) if not isinstance(sort_order, (list, tuple)): diff --git a/fairseq/data/strip_token_dataset.py b/fairseq/data/strip_token_dataset.py index e388db0e5f..cae39ba4d2 100644 --- a/fairseq/data/strip_token_dataset.py +++ b/fairseq/data/strip_token_dataset.py @@ -7,7 +7,6 @@ class StripTokenDataset(BaseWrapperDataset): - def __init__(self, dataset, id_to_strip): super().__init__(dataset) self.id_to_strip = id_to_strip diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py index 7eca9d4cb3..48feaf883f 100644 --- a/fairseq/data/subsample_dataset.py +++ b/fairseq/data/subsample_dataset.py @@ -16,10 +16,10 @@ class SubsampleDataset(BaseWrapperDataset): """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples - Args: - dataset (~torch.utils.data.Dataset): dataset to subsample - size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) - """ + Args: + dataset (~torch.utils.data.Dataset): dataset to subsample + size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) + """ def __init__(self, dataset, size_ratio, shuffle=False): super().__init__(dataset) diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index cae872c310..aa33f9d06f 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -5,7 +5,6 @@ import numpy as np import torch - from fairseq.data import FairseqDataset, plasma_utils @@ -31,6 +30,7 @@ class TokenBlockDataset(FairseqDataset): 'complete_doc' break mode). Typically 1 if the sentences have eos and 0 otherwise. """ + def __init__( self, dataset, @@ -49,8 +49,8 @@ def __init__( ) except ImportError: raise ImportError( - 'Please build Cython components with: `pip install --editable .` ' - 'or `python setup.py build_ext --inplace`' + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" ) super().__init__() @@ -69,13 +69,15 @@ def __init__( sizes = sizes.numpy() sizes = sizes.astype(np.int64) - break_mode = break_mode if break_mode is not None else 'none' + break_mode = break_mode if break_mode is not None else "none" # For "eos" break-mode, block_size is not required parameters. if break_mode == "eos" and block_size is None: block_size = 0 - slice_indices = _get_slice_indices_fast(sizes, str(break_mode), block_size, document_sep_len) + slice_indices = _get_slice_indices_fast( + sizes, str(break_mode), block_size, document_sep_len + ) self._sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices diff --git a/fairseq/data/transform_eos_dataset.py b/fairseq/data/transform_eos_dataset.py index 4ce5ad811b..fb14ff018e 100644 --- a/fairseq/data/transform_eos_dataset.py +++ b/fairseq/data/transform_eos_dataset.py @@ -33,11 +33,11 @@ def __init__( has_target=True, ): if not isinstance(dataset, FairseqDataset): - raise ValueError('dataset must be an instance of FairseqDataset') + raise ValueError("dataset must be an instance of FairseqDataset") if append_eos_to_src and remove_eos_from_src: - raise ValueError('cannot combine append_eos_to_src and remove_eos_from_src') + raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src") if append_eos_to_tgt and remove_eos_from_tgt: - raise ValueError('cannot combine append_eos_to_tgt and remove_eos_from_tgt') + raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt") self.dataset = dataset self.eos = torch.LongTensor([eos]) @@ -75,24 +75,23 @@ def __len__(self): return len(self.dataset) def collater(self, samples): - def transform(item): if self.append_eos_to_src: - self.eos = self.eos.to(device=item['source'].device) - self._check_src(item['source'], expect_eos=False) - item['source'] = torch.cat([item['source'], self.eos]) + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=False) + item["source"] = torch.cat([item["source"], self.eos]) if self.remove_eos_from_src: - self.eos = self.eos.to(device=item['source'].device) - self._check_src(item['source'], expect_eos=True) - item['source'] = item['source'][:-1] + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=True) + item["source"] = item["source"][:-1] if self.append_eos_to_tgt: - self.eos = self.eos.to(device=item['target'].device) - self._check_tgt(item['target'], expect_eos=False) - item['target'] = torch.cat([item['target'], self.eos]) + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=False) + item["target"] = torch.cat([item["target"], self.eos]) if self.remove_eos_from_tgt: - self.eos = self.eos.to(device=item['target'].device) - self._check_tgt(item['target'], expect_eos=True) - item['target'] = item['target'][:-1] + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=True) + item["target"] = item["target"][:-1] return item samples = list(map(transform, samples)) @@ -115,7 +114,7 @@ def ordered_indices(self): @property def supports_prefetch(self): - return getattr(self.dataset, 'supports_prefetch', False) + return getattr(self.dataset, "supports_prefetch", False) def prefetch(self, indices): return self.dataset.prefetch(indices) diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index 2783824838..1dd3d93d2b 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -4,10 +4,12 @@ # LICENSE file in the root directory of this source tree. -from . import FairseqDataset -import torch from typing import Optional +import torch + +from . import FairseqDataset + class TransformEosLangPairDataset(FairseqDataset): """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on @@ -50,25 +52,37 @@ def collater(self, samples, **extra_args): if self.new_src_eos is not None: if self.dataset.left_pad_source: - assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0 - samples['net_input']['src_tokens'][:, -1] = self.new_src_eos + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos else: - eos_idx = samples['net_input']['src_lengths'] - 1 - assert( - samples['net_input']['src_tokens'][torch.arange(eos_idx.size(0)), eos_idx] != self.src_eos + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos ).sum() == 0 - eos_idx = eos_idx.resize_(len(samples['net_input']['src_lengths']), 1) - samples['net_input']['src_tokens'].scatter_(1, eos_idx, self.new_src_eos) + eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1) + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx, self.new_src_eos + ) - if self.new_tgt_bos is not None and 'prev_output_tokens' in samples['net_input']: + if ( + self.new_tgt_bos is not None + and "prev_output_tokens" in samples["net_input"] + ): if self.dataset.left_pad_target: # TODO: support different padding direction on target side raise NotImplementedError( - 'TransformEosLangPairDataset does not implement --left-pad-target True option' + "TransformEosLangPairDataset does not implement --left-pad-target True option" ) else: - assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0 - samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos return samples @@ -88,7 +102,7 @@ def ordered_indices(self): @property def supports_prefetch(self): - return getattr(self.dataset, 'supports_prefetch', False) + return getattr(self.dataset, "supports_prefetch", False) def prefetch(self, indices): return self.dataset.prefetch(indices) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index 0685c968d5..ed1d12d865 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -113,12 +113,13 @@ class CommonParams(FairseqDataclass): default="", metadata={"help": "suffix to add to the checkpoint file name"} ) checkpoint_shard_count: int = field( - default=1, metadata={ + default=1, + metadata={ "help": "Number of shards containing the checkpoint - " - "if the checkpoint is over 300GB, it is preferable " - "to split it into shards to prevent OOM on CPU while loading " - "the checkpoint" - } + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, ) quantization_config_path: Optional[str] = field( default=None, metadata={"help": "path to quantization config file"} @@ -307,7 +308,10 @@ class DatasetParams(FairseqDataclass): default=8, metadata={"help": "batch size will be a multiplier of this value"} ) required_seq_len_multiple: int = field( - default=1, metadata={"help": "maximum sequence length in batch will be a multiplier of this value"} + default=1, + metadata={ + "help": "maximum sequence length in batch will be a multiplier of this value" + }, ) dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field( default=None, metadata={"help": "output dataset implementation"} @@ -351,8 +355,7 @@ class DatasetParams(FairseqDataclass): batch_size_valid: Optional[int] = field( default=None, metadata={ - "help": "batch size of the validation batch" - " (defaults to --batch-size)" + "help": "batch size of the validation batch" " (defaults to --batch-size)" }, ) curriculum: int = field( diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 9ab235d16d..599cc2b4c2 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -164,7 +164,9 @@ def get_kwargs_from_dc( raise NotImplementedError() if field_default is not MISSING: kwargs["default"] = ",".join(map(str, field_default)) - elif (isinstance(inter_type, type) and issubclass(inter_type, Enum)) or "Enum" in str(inter_type): + elif ( + isinstance(inter_type, type) and issubclass(inter_type, Enum) + ) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): @@ -184,7 +186,7 @@ def get_kwargs_from_dc( kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const - kwargs["nargs"] = '?' + kwargs["nargs"] = "?" return kwargs for k in dataclass_instance._get_all_attributes(): diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index ab5aad1425..bcb0595e6e 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -16,7 +16,6 @@ import torch import torch.distributed as dist - from fairseq import utils @@ -28,83 +27,103 @@ def is_master(args): def infer_init_method(args, force_distributed=False): - if args.distributed_init_method is not None or getattr(args, 'tpu', False): + if args.distributed_init_method is not None or getattr(args, "tpu", False): return if args.pipeline_model_parallel: - balance_exists = args.pipeline_balance is not None or \ - args.pipeline_encoder_balance is not None or \ - args.pipeline_decoder_balance is not None - devices_exist = args.pipeline_devices is not None or \ - args.pipeline_encoder_devices is not None or \ - args.pipeline_decoder_devices is not None + balance_exists = ( + args.pipeline_balance is not None + or args.pipeline_encoder_balance is not None + or args.pipeline_decoder_balance is not None + ) + devices_exist = ( + args.pipeline_devices is not None + or args.pipeline_encoder_devices is not None + or args.pipeline_decoder_devices is not None + ) if not balance_exists: - raise ValueError('--pipeline-balance is currently required for pipeline model parallelism') + raise ValueError( + "--pipeline-balance is currently required for pipeline model parallelism" + ) if not devices_exist: - raise ValueError('--pipeline-devices is currently required for pipeline model parallelism') + raise ValueError( + "--pipeline-devices is currently required for pipeline model parallelism" + ) args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) if args.pipeline_devices is not None: args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) num_pipeline_devices = len(set(args.pipeline_devices)) else: - args.pipeline_encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int) - args.pipeline_decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int) - num_pipeline_devices = len(set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)) + args.pipeline_encoder_devices = utils.eval_str_list( + args.pipeline_encoder_devices, type=int + ) + args.pipeline_decoder_devices = utils.eval_str_list( + args.pipeline_decoder_devices, type=int + ) + num_pipeline_devices = len( + set(args.pipeline_encoder_devices + args.pipeline_decoder_devices) + ) gpus_per_node = torch.cuda.device_count() - assert gpus_per_node >= num_pipeline_devices and gpus_per_node % num_pipeline_devices == 0, ( - 'the number of unique device IDs in --pipeline-devices must evenly divide ' - 'the number of GPUs per node (multi-node pipelining is not yet supported)' + assert ( + gpus_per_node >= num_pipeline_devices + and gpus_per_node % num_pipeline_devices == 0 + ), ( + "the number of unique device IDs in --pipeline-devices must evenly divide " + "the number of GPUs per node (multi-node pipelining is not yet supported)" ) num_pipelines_per_node = gpus_per_node // num_pipeline_devices # support torch.distributed.launch - if all(key in os.environ for key in [ - 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK' - ]): - args.distributed_init_method = 'env://' - args.distributed_world_size = int(os.environ['WORLD_SIZE']) - args.distributed_rank = int(os.environ['RANK']) + if all( + key in os.environ + for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] + ): + args.distributed_init_method = "env://" + args.distributed_world_size = int(os.environ["WORLD_SIZE"]) + args.distributed_rank = int(os.environ["RANK"]) # processes are created by torch.distributed.launch args.distributed_no_spawn = True # we can determine the init method automatically for Slurm elif args.distributed_port > 0: - node_list = os.environ.get('SLURM_STEP_NODELIST') + node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: - node_list = os.environ.get('SLURM_JOB_NODELIST') + node_list = os.environ.get("SLURM_JOB_NODELIST") if node_list is not None: try: - hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) - args.distributed_init_method = 'tcp://{host}:{port}'.format( - host=hostnames.split()[0].decode('utf-8'), + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", node_list] + ) + args.distributed_init_method = "tcp://{host}:{port}".format( + host=hostnames.split()[0].decode("utf-8"), port=args.distributed_port, ) - nnodes = int(os.environ.get('SLURM_NNODES')) - ntasks_per_node = os.environ.get('SLURM_NTASKS_PER_NODE') + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") if ntasks_per_node is not None: ntasks_per_node = int(ntasks_per_node) else: - ntasks = int(os.environ.get('SLURM_NTASKS')) - nnodes = int(os.environ.get('SLURM_NNODES')) + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) assert ntasks % nnodes == 0 ntasks_per_node = int(ntasks / nnodes) if ntasks_per_node == 1: gpus_per_node = torch.cuda.device_count() - node_id = int(os.environ.get('SLURM_NODEID')) + node_id = int(os.environ.get("SLURM_NODEID")) args.distributed_rank = node_id * gpus_per_node args.distributed_world_size = nnodes * gpus_per_node elif args.pipeline_model_parallel: assert ntasks_per_node == num_pipelines_per_node, ( - 'SLURM --ntasks-per-node must match number of pipelines per ' - 'node (={})'.format(num_pipelines_per_node) + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) ) args.distributed_no_spawn = True # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on # the first node, [1, 2] on the second node, etc. This # matches torch.distributed.launch. - node_id = int(os.environ.get('SLURM_NODEID')) - local_id = int(os.environ.get('SLURM_LOCALID')) + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) args.distributed_rank = node_id * num_pipelines_per_node + local_id # In the above example, device_id will always be in [0, 1], # which also matches torch.distributed.launch. @@ -115,8 +134,8 @@ def infer_init_method(args, force_distributed=False): else: assert ntasks_per_node == args.distributed_world_size // nnodes args.distributed_no_spawn = True - args.distributed_rank = int(os.environ.get('SLURM_PROCID')) - args.device_id = int(os.environ.get('SLURM_LOCALID')) + args.distributed_rank = int(os.environ.get("SLURM_PROCID")) + args.device_id = int(os.environ.get("SLURM_LOCALID")) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed @@ -126,7 +145,7 @@ def infer_init_method(args, force_distributed=False): # fallback for single node with multiple GPUs assert args.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_init_method = "tcp://localhost:{port}".format(port=port) if args.pipeline_model_parallel: if not args.distributed_no_spawn: @@ -134,7 +153,9 @@ def infer_init_method(args, force_distributed=False): # distributed_world_size to be based on the total number of GPUs, so # we need to correct them to be based on the number of pipelines. assert args.distributed_world_size % num_pipeline_devices == 0 - args.distributed_world_size = args.distributed_world_size // num_pipeline_devices + args.distributed_world_size = ( + args.distributed_world_size // num_pipeline_devices + ) # In the case of 4-way MP on nodes with 8 GPUs, we want # distributed_rank to be the starting GPU index for each pipeline # i.e., 0, 2, ... @@ -152,14 +173,16 @@ def infer_init_method(args, force_distributed=False): # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 # GPU node), we need to adjust pipeline_devices accordingly logger.debug( - "setting CUDA device={} on rank {}" - .format(args.device_id, args.distributed_rank) + "setting CUDA device={} on rank {}".format( + args.device_id, args.distributed_rank + ) ) torch.cuda.set_device(args.device_id) args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] logger.info( - "setting pipeline_devices={} on rank {}" - .format(args.pipeline_devices, args.distributed_rank), + "setting pipeline_devices={} on rank {}".format( + args.pipeline_devices, args.distributed_rank + ), ) elif not args.distributed_no_spawn: args.distributed_num_procs = min( @@ -169,22 +192,30 @@ def infer_init_method(args, force_distributed=False): def distributed_init(args): - if not getattr(args, 'tpu', False): + if not getattr(args, "tpu", False): if torch.distributed.is_initialized(): - warnings.warn('Distributed is already initialized, cannot initialize twice!') + warnings.warn( + "Distributed is already initialized, cannot initialize twice!" + ) else: - logger.info('distributed init (rank {}): {}'.format( - args.distributed_rank, args.distributed_init_method, - )) + logger.info( + "distributed init (rank {}): {}".format( + args.distributed_rank, + args.distributed_init_method, + ) + ) dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) - logger.info('initialized host {} as rank {}'.format( - socket.gethostname(), args.distributed_rank, - )) + logger.info( + "initialized host {} as rank {}".format( + socket.gethostname(), + args.distributed_rank, + ) + ) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): @@ -193,10 +224,11 @@ def distributed_init(args): args.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm + assert xm.xrt_world_size() == args.distributed_world_size args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() - xm.rendezvous('distributed_init') # wait for all workers + xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() if not is_master(args): @@ -211,14 +243,14 @@ def distributed_init(args): ) except ImportError: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) initialize_model_parallel(args.model_parallel_size) model_parallel_cuda_manual_seed(args.seed) model_part_number = get_model_parallel_rank() - args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number) + args.checkpoint_suffix += "-model_part-{0}".format(model_part_number) return args.distributed_rank @@ -227,11 +259,11 @@ def distributed_main(i, main, args, kwargs): if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): torch.cuda.set_device(args.device_id) if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = kwargs.pop('start_rank', 0) + i + args.distributed_rank = kwargs.pop("start_rank", 0) + i args.distributed_rank = distributed_init(args) - after_distributed_init_fn = kwargs.pop('after_distributed_init_fn', None) + after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) if after_distributed_init_fn: args = after_distributed_init_fn(args) @@ -247,7 +279,7 @@ def call_main(args, main, **kwargs): if not args.distributed_no_spawn: start_rank = args.distributed_rank args.distributed_rank = None # assign automatically - kwargs['start_rank'] = start_rank + kwargs["start_rank"] = start_rank torch.multiprocessing.spawn( fn=distributed_main, args=(main, args, kwargs), @@ -257,6 +289,7 @@ def call_main(args, main, **kwargs): distributed_main(args.device_id, main, args, kwargs) elif getattr(args, "tpu", False) and args.distributed_world_size > 1: import torch_xla.distributed.xla_multiprocessing as xmp + torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( fn=distributed_main, @@ -281,9 +314,10 @@ def get_default_group(): def all_reduce(tensor, group=None): - if isinstance(group, tuple) and group[0] == 'tpu': + if isinstance(group, tuple) and group[0] == "tpu": import torch_xla.core.xla_model as xm - return xm.all_reduce('sum', [tensor], groups=group[1]) + + return xm.all_reduce("sum", [tensor], groups=group[1]) else: if group is None: group = get_default_group() @@ -306,8 +340,10 @@ def all_gather_list(data, group=None, max_size=16384): world_size = get_world_size() buffer_size = max_size * world_size - if not hasattr(all_gather_list, '_buffer') or \ - all_gather_list._buffer.numel() < buffer_size: + if ( + not hasattr(all_gather_list, "_buffer") + or all_gather_list._buffer.numel() < buffer_size + ): all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() buffer = all_gather_list._buffer @@ -320,12 +356,14 @@ def all_gather_list(data, group=None, max_size=16384): header_size = 4 # size of header that contains the length of the encoded data size = header_size + enc_size if size > max_size: - raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size)) + raise ValueError( + "encoded data size ({}) exceeds max_size ({})".format(size, max_size) + ) header = struct.pack(">I", enc_size) cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) start = rank * max_size - buffer[start:start + size].copy_(cpu_buffer[:size]) + buffer[start : start + size].copy_(cpu_buffer[:size]) all_reduce(buffer, group=group) @@ -333,20 +371,24 @@ def all_gather_list(data, group=None, max_size=16384): try: result = [] for i in range(world_size): - out_buffer = buffer[i * max_size:(i + 1) * max_size] - enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) + out_buffer = buffer[i * max_size : (i + 1) * max_size] + (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) if enc_size > 0: - result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist()))) + result.append( + pickle.loads( + bytes(out_buffer[header_size : header_size + enc_size].tolist()) + ) + ) return result except pickle.UnpicklingError: raise Exception( - 'Unable to unpickle data from other workers. all_gather_list requires all ' - 'workers to enter the function together, so this error usually indicates ' - 'that the workers have fallen out of sync somehow. Workers can fall out of ' - 'sync if one of them runs out of memory, or if there are other conditions ' - 'in your training script that can cause one worker to finish an epoch ' - 'while other workers are still iterating over their portions of the data. ' - 'Try rerunning with --ddp-backend=no_c10d and see if that helps.' + "Unable to unpickle data from other workers. all_gather_list requires all " + "workers to enter the function together, so this error usually indicates " + "that the workers have fallen out of sync somehow. Workers can fall out of " + "sync if one of them runs out of memory, or if there are other conditions " + "in your training script that can cause one worker to finish an epoch " + "while other workers are still iterating over their portions of the data. " + "Try rerunning with --ddp-backend=no_c10d and see if that helps." ) diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py index 62278b367d..0a94ac7112 100644 --- a/fairseq/file_utils.py +++ b/fairseq/file_utils.py @@ -10,25 +10,28 @@ """ import fnmatch -from functools import wraps, partial -from hashlib import sha256 -from io import open import json import logging import os import shutil import tarfile import tempfile +from functools import partial, wraps +from hashlib import sha256 +from io import open try: from torch.hub import _get_torch_home + torch_cache_home = _get_torch_home() except ImportError: torch_cache_home = os.path.expanduser( - os.getenv('TORCH_HOME', os.path.join( - os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) -default_cache_path = os.path.join(torch_cache_home, 'pytorch_fairseq') + os.getenv( + "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") + ) + ) +default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq") try: from urllib.parse import urlparse @@ -37,11 +40,10 @@ try: from pathlib import Path - PYTORCH_FAIRSEQ_CACHE = Path( - os.getenv('PYTORCH_FAIRSEQ_CACHE', default_cache_path)) + + PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)) except (AttributeError, ImportError): - PYTORCH_FAIRSEQ_CACHE = os.getenv( - 'PYTORCH_FAIRSEQ_CACHE', default_cache_path) + PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path) CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin" @@ -67,17 +69,23 @@ def load_archive_file(archive_file): if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, resolved_archive_file)) + logger.info( + "loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) # Extract archive to temp dir and replace .tar.bz2 if necessary tempdir = None if not os.path.isdir(resolved_archive_file): tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, tempdir)) + logger.info( + "extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir + ) + ) ext = os.path.splitext(archive_file)[1][1:] - with tarfile.open(resolved_archive_file, 'r:' + ext) as archive: + with tarfile.open(resolved_archive_file, "r:" + ext) as archive: top_dir = os.path.commonprefix(archive.getnames()) archive.extractall(tempdir) os.remove(resolved_archive_file) @@ -93,14 +101,14 @@ def url_to_filename(url, etag=None): If `etag` is specified, append its hash to the URL's, delimited by a period. """ - url_bytes = url.encode('utf-8') + url_bytes = url.encode("utf-8") url_hash = sha256(url_bytes) filename = url_hash.hexdigest() if etag: - etag_bytes = etag.encode('utf-8') + etag_bytes = etag.encode("utf-8") etag_hash = sha256(etag_bytes) - filename += '.' + etag_hash.hexdigest() + filename += "." + etag_hash.hexdigest() return filename @@ -119,14 +127,14 @@ def filename_to_url(filename, cache_dir=None): if not os.path.exists(cache_path): raise EnvironmentError("file {} not found".format(cache_path)) - meta_path = cache_path + '.json' + meta_path = cache_path + ".json" if not os.path.exists(meta_path): raise EnvironmentError("file {} not found".format(meta_path)) with open(meta_path, encoding="utf-8") as meta_file: metadata = json.load(meta_file) - url = metadata['url'] - etag = metadata['etag'] + url = metadata["url"] + etag = metadata["etag"] return url, etag @@ -147,18 +155,20 @@ def cached_path(url_or_filename, cache_dir=None): parsed = urlparse(url_or_filename) - if parsed.scheme in ('http', 'https', 's3'): + if parsed.scheme in ("http", "https", "s3"): # URL, so get it from the cache (downloading if necessary) return get_from_cache(url_or_filename, cache_dir) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename - elif parsed.scheme == '': + elif parsed.scheme == "": # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) else: # Something unknown - raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) def split_s3_path(url): @@ -183,6 +193,7 @@ def s3_request(func): @wraps(func) def wrapper(url, *args, **kwargs): from botocore.exceptions import ClientError + try: return func(url, *args, **kwargs) except ClientError as exc: @@ -198,6 +209,7 @@ def wrapper(url, *args, **kwargs): def s3_etag(url): """Check ETag on S3 object.""" import boto3 + s3_resource = boto3.resource("s3") bucket_name, s3_path = split_s3_path(url) s3_object = s3_resource.Object(bucket_name, s3_path) @@ -208,6 +220,7 @@ def s3_etag(url): def s3_get(url, temp_file): """Pull a file directly from S3.""" import boto3 + s3_resource = boto3.resource("s3") bucket_name, s3_path = split_s3_path(url) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) @@ -215,12 +228,18 @@ def s3_get(url, temp_file): def request_wrap_timeout(func, url): import requests + for attempt, timeout in enumerate([10, 20, 40, 60, 60]): try: return func(timeout=timeout) except requests.exceptions.Timeout as e: - logger.warning("Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", - url, attempt, timeout, exc_info=e) + logger.warning( + "Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", + url, + attempt, + timeout, + exc_info=e, + ) continue raise RuntimeError(f"Unable to fetch file {url}") @@ -230,7 +249,7 @@ def http_get(url, temp_file): from tqdm import tqdm req = request_wrap_timeout(partial(requests.get, url, stream=True), url) - content_length = req.headers.get('Content-Length') + content_length = req.headers.get("Content-Length") total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) for chunk in req.iter_content(chunk_size=1024): @@ -259,7 +278,10 @@ def get_from_cache(url, cache_dir=None): else: try: import requests - response = request_wrap_timeout(partial(requests.head, url, allow_redirects=True), url) + + response = request_wrap_timeout( + partial(requests.head, url, allow_redirects=True), url + ) if response.status_code != 200: etag = None else: @@ -275,8 +297,8 @@ def get_from_cache(url, cache_dir=None): # If we don't have a connection (etag is None) and can't identify the file # try to get the last downloaded one if not os.path.exists(cache_path) and etag is None: - matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') - matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*") + matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files)) if matching_files: cache_path = os.path.join(cache_dir, matching_files[-1]) @@ -298,13 +320,13 @@ def get_from_cache(url, cache_dir=None): temp_file.seek(0) logger.info("copying %s to cache at %s", temp_file.name, cache_path) - with open(cache_path, 'wb') as cache_file: + with open(cache_path, "wb") as cache_file: shutil.copyfileobj(temp_file, cache_file) logger.info("creating metadata file for %s", cache_path) - meta = {'url': url, 'etag': etag} - meta_path = cache_path + '.json' - with open(meta_path, 'w') as meta_file: + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: output_string = json.dumps(meta) meta_file.write(output_string) @@ -314,12 +336,12 @@ def get_from_cache(url, cache_dir=None): def read_set_from_file(filename): - ''' + """ Extract a de-duped collection (set) of text from a file. Expected file format is one item per line. - ''' + """ collection = set() - with open(filename, 'r', encoding='utf-8') as file_: + with open(filename, "r", encoding="utf-8") as file_: for line in file_: collection.add(line.rstrip()) return collection diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index b56135abf3..b293e54e2a 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -8,13 +8,12 @@ import copy import logging import os -from typing import List, Dict, Iterator, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple import torch -from torch import nn - from fairseq import utils from fairseq.data import encoders +from torch import nn logger = logging.getLogger(__name__) @@ -22,8 +21,8 @@ def from_pretrained( model_name_or_path, - checkpoint_file='model.pt', - data_name_or_path='.', + checkpoint_file="model.pt", + data_name_or_path=".", archive_map=None, **kwargs ): @@ -39,34 +38,34 @@ def from_pretrained( # for each model if isinstance(model_name_or_path, dict): for k, v in model_name_or_path.items(): - if k == 'checkpoint_file': + if k == "checkpoint_file": checkpoint_file = v elif ( - k != 'path' + k != "path" # only set kwargs that don't already have overrides and k not in kwargs ): kwargs[k] = v - model_name_or_path = model_name_or_path['path'] + model_name_or_path = model_name_or_path["path"] model_path = file_utils.load_archive_file(model_name_or_path) # convenience hack for loading data and BPE codes from model archive - if data_name_or_path.startswith('.'): - kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path)) + if data_name_or_path.startswith("."): + kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path)) else: - kwargs['data'] = file_utils.load_archive_file(data_name_or_path) + kwargs["data"] = file_utils.load_archive_file(data_name_or_path) for file, arg in { - 'code': 'bpe_codes', - 'bpecodes': 'bpe_codes', - 'sentencepiece.bpe.model': 'sentencepiece_model', + "code": "bpe_codes", + "bpecodes": "bpe_codes", + "sentencepiece.bpe.model": "sentencepiece_model", }.items(): path = os.path.join(model_path, file) if os.path.exists(path): kwargs[arg] = path - if 'user_dir' in kwargs: - utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir'])) + if "user_dir" in kwargs: + utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"])) models, args, task = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)], @@ -74,9 +73,9 @@ def from_pretrained( ) return { - 'args': args, - 'task': task, - 'models': models, + "args": args, + "task": task, + "models": models, } @@ -100,7 +99,7 @@ def __init__(self, args, task, models): # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None)) + self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None)) self.tokenizer = encoders.build_tokenizer(args) self.bpe = encoders.build_bpe(args) @@ -110,28 +109,37 @@ def __init__(self, args, task, models): ) # this is useful for determining the device - self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) @property def device(self): return self._float_tensor.device - def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]: + def translate( + self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs + ) -> List[str]: return self.sample(sentences, beam, verbose, **kwargs) - def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]: + def sample( + self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs + ) -> List[str]: if isinstance(sentences, str): return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] tokenized_sentences = [self.encode(sentence) for sentence in sentences] batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) - return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos] + return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] def score(self, sentences: List[str], **kwargs): if isinstance(sentences, str): return self.score([sentences], **kwargs)[0] # NOTE: this doesn't support translation tasks currently tokenized_sentences = [self.encode(sentence) for sentence in sentences] - return [hypos[0] for hypos in self.generate(tokenized_sentences, score_reference=True, **kwargs)] + return [ + hypos[0] + for hypos in self.generate( + tokenized_sentences, score_reference=True, **kwargs + ) + ] def generate( self, @@ -174,17 +182,33 @@ def getarg(name, default): for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): src_str_with_unk = self.string(source_tokens) - logger.info('S\t{}'.format(src_str_with_unk)) + logger.info("S\t{}".format(src_str_with_unk)) for hypo in target_hypotheses: - hypo_str = self.decode(hypo['tokens']) - logger.info('H\t{}\t{}'.format(hypo['score'], hypo_str)) - logger.info('P\t{}'.format( - ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) - )) - if hypo['alignment'] is not None and getarg('print_alignment', False): - logger.info('A\t{}'.format( - ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in hypo['alignment']]) - )) + hypo_str = self.decode(hypo["tokens"]) + logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) + logger.info( + "P\t{}".format( + " ".join( + map( + lambda x: "{:.4f}".format(x), + hypo["positional_scores"].tolist(), + ) + ) + ) + ) + if hypo["alignment"] is not None and getarg( + "print_alignment", False + ): + logger.info( + "A\t{}".format( + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in hypo["alignment"] + ] + ) + ) + ) return outputs def encode(self, sentence: str) -> torch.LongTensor: diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py index 91128e8879..b26e6cd01c 100644 --- a/fairseq/incremental_decoding_utils.py +++ b/fairseq/incremental_decoding_utils.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional import uuid +from typing import Dict, Optional from torch import Tensor class FairseqIncrementalState(object): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_incremental_state() @@ -46,5 +45,7 @@ def set_incremental_state( def with_incremental_state(cls): - cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState) + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) return cls diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index 6ac805988a..4fb0946f49 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -5,20 +5,15 @@ from collections import namedtuple -import torch import numpy as np - +import torch from fairseq import utils -DecoderOut = namedtuple('IterativeRefinementDecoderOut', [ - 'output_tokens', - 'output_scores', - 'attn', - 'step', - 'max_step', - 'history' -]) +DecoderOut = namedtuple( + "IterativeRefinementDecoderOut", + ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], +) class IterativeRefinementGenerator(object): @@ -103,11 +98,12 @@ def generate_batched_itr( ref = utils.strip_pad(sample["target"][i, :], self.pad) yield id, src, ref, hypos[i] - @torch.no_grad() def generate(self, models, sample, prefix_tokens=None, constraints=None): if constraints is not None: - raise NotImplementedError("Constrained decoding with the IterativeRefinementGenerator is not supported") + raise NotImplementedError( + "Constrained decoding with the IterativeRefinementGenerator is not supported" + ) # TODO: iterative refinement generator does not support ensemble for now. if not self.retain_dropout: @@ -117,13 +113,17 @@ def generate(self, models, sample, prefix_tokens=None, constraints=None): model, reranker = models[0], None if self.reranking: assert len(models) > 1, "Assuming the last checkpoint is the reranker" - assert self.beam_size > 1, "Reranking requires multiple translation for each example" + assert ( + self.beam_size > 1 + ), "Reranking requires multiple translation for each example" reranker = models[-1] models = models[:-1] - if len(models) > 1 and hasattr(model, 'enable_ensemble'): - assert model.allow_ensemble, "{} does not support ensembling".format(model.__class__.__name__) + if len(models) > 1 and hasattr(model, "enable_ensemble"): + assert model.allow_ensemble, "{} does not support ensembling".format( + model.__class__.__name__ + ) model.enable_ensemble(models) # TODO: better encoder inputs? @@ -136,13 +136,22 @@ def generate(self, models, sample, prefix_tokens=None, constraints=None): prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) if self.beam_size > 1: - assert model.allow_length_beam, \ - "{} does not support decoding with length beam.".format(model.__class__.__name__) + assert ( + model.allow_length_beam + ), "{} does not support decoding with length beam.".format( + model.__class__.__name__ + ) # regenerate data based on length-beam - length_beam_order = utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) - encoder_out = model.encoder.reorder_encoder_out(encoder_out, length_beam_order) - prev_decoder_out = model.regenerate_length_beam(prev_decoder_out, self.beam_size) + length_beam_order = ( + utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, length_beam_order + ) + prev_decoder_out = model.regenerate_length_beam( + prev_decoder_out, self.beam_size + ) bsz = bsz * self.beam_size sent_idxs = torch.arange(bsz) @@ -206,7 +215,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( - prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn + prev_output_tokens, + decoder_out.output_tokens, + decoder_out.output_scores, + decoder_out.attn, ) decoder_out = decoder_out._replace( output_tokens=out_tokens, @@ -215,7 +227,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): ) else: - terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool() + terminated = decoder_out.output_tokens.new_zeros( + decoder_out.output_tokens.size(0) + ).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) @@ -225,7 +239,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): finalized_tokens = decoder_out.output_tokens[terminated] finalized_scores = decoder_out.output_scores[terminated] finalized_attn = ( - None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated] + None + if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) + else decoder_out.attn[terminated] ) if self.retain_history: @@ -242,13 +258,11 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): ] if self.retain_history: - finalized[finalized_idxs[i]][0]['history'] = [] + finalized[finalized_idxs[i]][0]["history"] = [] for j in range(len(finalized_history_tokens)): - finalized[finalized_idxs[i]][0]['history'].append( + finalized[finalized_idxs[i]][0]["history"].append( finalized_hypos( - step, - finalized_history_tokens[j][i], - None, None + step, finalized_history_tokens[j][i], None, None ) ) @@ -268,7 +282,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): if decoder_out.history is not None else None, ) - encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() + ) sent_idxs = sent_idxs[not_terminated] prev_output_tokens = prev_decoder_out.output_tokens.clone() @@ -280,38 +296,64 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): # aggregate information from length beam finalized = [ - finalized[np.argmax( - [finalized[self.beam_size * i + j][0]['score'] for j in range(self.beam_size)] - ) + self.beam_size * i] for i in range(len(finalized) // self.beam_size) + finalized[ + np.argmax( + [ + finalized[self.beam_size * i + j][0]["score"] + for j in range(self.beam_size) + ] + ) + + self.beam_size * i ] + for i in range(len(finalized) // self.beam_size) + ] return finalized def rerank(self, reranker, finalized, encoder_input, beam_size): - def rebuild_batch(finalized): - finalized_tokens = [f[0]['tokens'] for f in finalized] + finalized_tokens = [f[0]["tokens"] for f in finalized] finalized_maxlen = max(f.size(0) for f in finalized_tokens) - final_output_tokens = finalized_tokens[0].new_zeros(len(finalized_tokens), finalized_maxlen).fill_(self.pad) + final_output_tokens = ( + finalized_tokens[0] + .new_zeros(len(finalized_tokens), finalized_maxlen) + .fill_(self.pad) + ) for i, f in enumerate(finalized_tokens): - final_output_tokens[i, :f.size(0)] = f + final_output_tokens[i, : f.size(0)] = f return final_output_tokens final_output_tokens = rebuild_batch(finalized) - final_output_tokens[:, 0] = self.eos # autoregressive model assumes starting with EOS + final_output_tokens[ + :, 0 + ] = self.eos # autoregressive model assumes starting with EOS reranker_encoder_out = reranker.encoder(*encoder_input) - length_beam_order = utils.new_arange( - final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)).t().reshape(-1) - reranker_encoder_out = reranker.encoder.reorder_encoder_out(reranker_encoder_out, length_beam_order) + length_beam_order = ( + utils.new_arange( + final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) + ) + .t() + .reshape(-1) + ) + reranker_encoder_out = reranker.encoder.reorder_encoder_out( + reranker_encoder_out, length_beam_order + ) reranking_scores = reranker.get_normalized_probs( - reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), True, None) + reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), + True, + None, + ) reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) reranking_masks = final_output_tokens[:, 1:].ne(self.pad) - reranking_scores = reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) - reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(reranking_scores) + reranking_scores = ( + reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) + ) + reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( + reranking_scores + ) for i in range(len(finalized)): - finalized[i][0]['score'] = reranking_scores[i] + finalized[i][0]["score"] = reranking_scores[i] return finalized diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py index 9832f2c97a..44f87c7c42 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/legacy_distributed_data_parallel.py @@ -14,9 +14,9 @@ training with `--update-freq`. """ +import copy from collections import OrderedDict from contextlib import contextmanager -import copy import torch from torch import nn @@ -42,7 +42,7 @@ class LegacyDistributedDataParallel(nn.Module): performing all-reduce (default: 256M). """ - def __init__(self, module, world_size, process_group=None, buffer_size=2**28): + def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28): super().__init__() self.module = module @@ -66,7 +66,6 @@ def __init__(self, module, world_size, process_group=None, buffer_size=2**28): paramlists[device] += [param] self.per_device_params = list(paramlists.values()) - def __getstate__(self): attrs = copy.copy(self.__dict__) return attrs @@ -99,10 +98,10 @@ def all_reduce_params(params): for p in params: sz = p.numel() if p.grad is not None: - buffer[offset:offset+sz].copy_(p.grad.data.view(-1)) + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) nonzero_buffer = True else: - buffer[offset:offset+sz].zero_() + buffer[offset : offset + sz].zero_() offset += sz else: # we only have a single grad to all-reduce @@ -111,7 +110,7 @@ def all_reduce_params(params): buffer = p.grad.data nonzero_buffer = True elif p.numel() <= self.buffer.numel(): - buffer = buffer[:p.numel()] + buffer = buffer[: p.numel()] buffer.zero_() else: buffer = torch.zeros_like(p) @@ -126,9 +125,9 @@ def all_reduce_params(params): for p in params: sz = p.numel() if p.grad is not None: - p.grad.data.copy_(buffer[offset:offset+sz].view_as(p)) + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) else: - p.grad = buffer[offset:offset+sz].view_as(p).clone() + p.grad = buffer[offset : offset + sz].view_as(p).clone() offset += sz def reduction_fn(): @@ -149,9 +148,11 @@ def reduction_fn(): if param.grad is None: param.grad = torch.zeros_like(param) if param.grad.requires_grad: - raise RuntimeError("DistributedDataParallel only works " - "with gradients that don't require " - "grad") + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) sz = param.numel() if sz > self.buffer.numel(): # all-reduce big params directly diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py index 78e6d4d224..6793ef54e6 100644 --- a/fairseq/logging/meters.py +++ b/fairseq/logging/meters.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. import bisect -from collections import OrderedDict import time +from collections import OrderedDict from typing import Dict, Optional + try: import torch @@ -16,6 +17,8 @@ def type_as(a, b): return a.to(b) else: return a + + except ImportError: torch = None @@ -51,11 +54,11 @@ def smoothed_value(self) -> float: def safe_round(number, ndigits): - if hasattr(number, '__round__'): + if hasattr(number, "__round__"): return round(number, ndigits) elif torch is not None and torch.is_tensor(number) and number.numel() == 1: return safe_round(number.item(), ndigits) - elif np is not None and np.ndim(number) == 0 and hasattr(number, 'item'): + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): return safe_round(number.item(), ndigits) else: return number @@ -82,17 +85,17 @@ def update(self, val, n=1): def state_dict(self): return { - 'val': self.val, - 'sum': self.sum, - 'count': self.count, - 'round': self.round, + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, } def load_state_dict(self, state_dict): - self.val = state_dict['val'] - self.sum = state_dict['sum'] - self.count = state_dict['count'] - self.round = state_dict.get('round', None) + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) @property def avg(self): @@ -130,18 +133,18 @@ def update(self, val=1): def state_dict(self): return { - 'init': self.elapsed_time, - 'n': self.n, - 'round': self.round, + "init": self.elapsed_time, + "n": self.n, + "round": self.round, } def load_state_dict(self, state_dict): - if 'start' in state_dict: + if "start" in state_dict: # backwards compatibility for old state_dicts - self.reset(init=state_dict['init']) + self.reset(init=state_dict["init"]) else: - self.reset(init=state_dict['init'], n=state_dict['n']) - self.round = state_dict.get('round', None) + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) @property def avg(self): @@ -186,16 +189,16 @@ def reset(self): def state_dict(self): return { - 'sum': self.sum, - 'n': self.n, - 'round': self.round, + "sum": self.sum, + "n": self.n, + "round": self.round, } def load_state_dict(self, state_dict): - self.sum = state_dict['sum'] - self.n = state_dict['n'] + self.sum = state_dict["sum"] + self.n = state_dict["n"] self.start_time = None - self.round = state_dict.get('round', None) + self.round = state_dict.get("round", None) @property def avg(self): @@ -204,7 +207,7 @@ def avg(self): @property def elapsed_time(self): if self.start_time is None: - return 0. + return 0.0 return time.perf_counter() - self.start_time @property @@ -263,11 +266,13 @@ def get_smoothed_value(self, key: str) -> float: def get_smoothed_values(self) -> Dict[str, float]: """Get all smoothed values.""" - return OrderedDict([ - (key, self.get_smoothed_value(key)) - for key in self.keys() - if not key.startswith("_") - ]) + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) def reset(self): """Reset Meter instances.""" diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 6ca1d201e0..7b56e31592 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -11,11 +11,11 @@ :func:`aggregate` context manager for more details. """ -from collections import defaultdict, OrderedDict import contextlib import time -from typing import Callable, Dict, List, Optional import uuid +from collections import OrderedDict, defaultdict +from typing import Callable, Dict, List, Optional from .meters import * @@ -184,7 +184,7 @@ def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): agg[key].start() -def log_stop_time(key: str, weight: float = 0., prehook=None): +def log_stop_time(key: str, weight: float = 0.0, prehook=None): """Log the duration of some event in seconds. The duration will be computed since :func:`log_start_time` was called. @@ -279,10 +279,7 @@ def get_smoothed_values(name: str) -> Dict[str, float]: def state_dict(): - return OrderedDict([ - (name, agg.state_dict()) - for name, agg in _aggregators.items() - ]) + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) def load_state_dict(state_dict): diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 97e4162ea0..63e5394815 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -32,29 +32,30 @@ def progress_bar( epoch: Optional[int] = None, prefix: Optional[str] = None, tensorboard_logdir: Optional[str] = None, - default_log_format: str = 'tqdm', + default_log_format: str = "tqdm", ): if log_format is None: log_format = default_log_format - if log_format == 'tqdm' and not sys.stderr.isatty(): - log_format = 'simple' + if log_format == "tqdm" and not sys.stderr.isatty(): + log_format = "simple" - if log_format == 'json': + if log_format == "json": bar = JsonProgressBar(iterator, epoch, prefix, log_interval) - elif log_format == 'none': + elif log_format == "none": bar = NoopProgressBar(iterator, epoch, prefix) - elif log_format == 'simple': + elif log_format == "simple": bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) - elif log_format == 'tqdm': + elif log_format == "tqdm": bar = TqdmProgressBar(iterator, epoch, prefix) else: - raise ValueError('Unknown log format: {}'.format(log_format)) + raise ValueError("Unknown log format: {}".format(log_format)) if tensorboard_logdir: try: # [FB only] custom wrapper for TensorBoard import palaas # noqa from .fb_tbmf_wrapper import FbTbmfWrapper + bar = FbTbmfWrapper(bar, log_interval) except ImportError: bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) @@ -67,14 +68,14 @@ def build_progress_bar( iterator, epoch: Optional[int] = None, prefix: Optional[str] = None, - default: str = 'tqdm', - no_progress_bar: str = 'none', + default: str = "tqdm", + no_progress_bar: str = "none", ): """Legacy wrapper that takes an argparse.Namespace.""" - if getattr(args, 'no_progress_bar', False): + if getattr(args, "no_progress_bar", False): default = no_progress_bar - if getattr(args, 'distributed_rank', 0) == 0: - tensorboard_logdir = getattr(args, 'tensorboard_logdir', None) + if getattr(args, "distributed_rank", 0) == 0: + tensorboard_logdir = getattr(args, "tensorboard_logdir", None) else: tensorboard_logdir = None return progress_bar( @@ -90,13 +91,13 @@ def build_progress_bar( def format_stat(stat): if isinstance(stat, Number): - stat = '{:g}'.format(stat) + stat = "{:g}".format(stat) elif isinstance(stat, AverageMeter): - stat = '{:.3f}'.format(stat.avg) + stat = "{:.3f}".format(stat.avg) elif isinstance(stat, TimeMeter): - stat = '{:g}'.format(round(stat.avg)) + stat = "{:g}".format(round(stat.avg)) elif isinstance(stat, StopwatchMeter): - stat = '{:g}'.format(round(stat.sum)) + stat = "{:g}".format(round(stat.sum)) elif torch.is_tensor(stat): stat = stat.tolist() return stat @@ -104,15 +105,16 @@ def format_stat(stat): class BaseProgressBar(object): """Abstract class for progress bars.""" + def __init__(self, iterable, epoch=None, prefix=None): self.iterable = iterable - self.n = getattr(iterable, 'n', 0) + self.n = getattr(iterable, "n", 0) self.epoch = epoch - self.prefix = '' + self.prefix = "" if epoch is not None: - self.prefix += 'epoch {:03d}'.format(epoch) + self.prefix += "epoch {:03d}".format(epoch) if prefix is not None: - self.prefix += ' | {}'.format(prefix) + self.prefix += " | {}".format(prefix) def __len__(self): return len(self.iterable) @@ -135,12 +137,10 @@ def print(self, stats, tag=None, step=None): raise NotImplementedError def _str_commas(self, stats): - return ', '.join(key + '=' + stats[key].strip() - for key in stats.keys()) + return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) def _str_pipes(self, stats): - return ' | '.join(key + ' ' + stats[key].strip() - for key in stats.keys()) + return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) def _format_stats(self, stats): postfix = OrderedDict(stats) @@ -177,11 +177,7 @@ def __iter__(self): def log(self, stats, tag=None, step=None): """Log intermediate stats according to log_interval.""" step = step or self.i or 0 - if ( - step > 0 - and self.log_interval is not None - and step % self.log_interval == 0 - ): + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: update = ( self.epoch - 1 + (self.i + 1) / float(self.size) if self.epoch is not None @@ -195,7 +191,9 @@ def print(self, stats, tag=None, step=None): """Print end-of-epoch stats.""" self.stats = stats if tag is not None: - self.stats = OrderedDict([(tag + '_' + k, v) for k, v in self.stats.items()]) + self.stats = OrderedDict( + [(tag + "_" + k, v) for k, v in self.stats.items()] + ) stats = self._format_stats(self.stats, epoch=self.epoch) with rename_logger(logger, tag): logger.info(json.dumps(stats)) @@ -203,9 +201,9 @@ def print(self, stats, tag=None, step=None): def _format_stats(self, stats, epoch=None, update=None): postfix = OrderedDict() if epoch is not None: - postfix['epoch'] = epoch + postfix["epoch"] = epoch if update is not None: - postfix['update'] = round(update, 3) + postfix["update"] = round(update, 3) # Preprocess stats according to datatype for key in stats.keys(): postfix[key] = format_stat(stats[key]) @@ -249,24 +247,21 @@ def __iter__(self): def log(self, stats, tag=None, step=None): """Log intermediate stats according to log_interval.""" step = step or self.i or 0 - if ( - step > 0 - and self.log_interval is not None - and step % self.log_interval == 0 - ): + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: stats = self._format_stats(stats) postfix = self._str_commas(stats) with rename_logger(logger, tag): logger.info( - '{}: {:5d} / {:d} {}' - .format(self.prefix, self.i + 1, self.size, postfix) + "{}: {:5d} / {:d} {}".format( + self.prefix, self.i + 1, self.size, postfix + ) ) def print(self, stats, tag=None, step=None): """Print end-of-epoch stats.""" postfix = self._str_pipes(self._format_stats(stats)) with rename_logger(logger, tag): - logger.info('{} | {}'.format(self.prefix, postfix)) + logger.info("{} | {}".format(self.prefix, postfix)) class TqdmProgressBar(BaseProgressBar): @@ -275,6 +270,7 @@ class TqdmProgressBar(BaseProgressBar): def __init__(self, iterable, epoch=None, prefix=None): super().__init__(iterable, epoch, prefix) from tqdm import tqdm + self.tqdm = tqdm( iterable, self.prefix, @@ -293,7 +289,7 @@ def print(self, stats, tag=None, step=None): """Print end-of-epoch stats.""" postfix = self._str_pipes(self._format_stats(stats)) with rename_logger(logger, tag): - logger.info('{} | {}'.format(self.prefix, postfix)) + logger.info("{} | {}".format(self.prefix, postfix)) try: @@ -329,7 +325,7 @@ def _writer(self, key): _writers = _tensorboard_writers if key not in _writers: _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) - _writers[key].add_text('sys.argv', " ".join(sys.argv)) + _writers[key].add_text("sys.argv", " ".join(sys.argv)) return _writers[key] def __iter__(self): @@ -346,12 +342,12 @@ def print(self, stats, tag=None, step=None): self.wrapped_bar.print(stats, tag=tag, step=step) def _log_to_tensorboard(self, stats, tag=None, step=None): - writer = self._writer(tag or '') + writer = self._writer(tag or "") if writer is None: return if step is None: - step = stats['num_updates'] - for key in stats.keys() - {'num_updates'}: + step = stats["num_updates"] + for key in stats.keys() - {"num_updates"}: if isinstance(stats[key], AverageMeter): writer.add_scalar(key, stats[key].val, step) elif isinstance(stats[key], Number): diff --git a/fairseq/model_parallel/__init__.py b/fairseq/model_parallel/__init__.py index cc563db40b..69f2168487 100644 --- a/fairseq/model_parallel/__init__.py +++ b/fairseq/model_parallel/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import criterions, modules, models # noqa +from . import criterions, models, modules # noqa diff --git a/fairseq/model_parallel/criterions/__init__.py b/fairseq/model_parallel/criterions/__init__.py index b74de55982..6239b50362 100644 --- a/fairseq/model_parallel/criterions/__init__.py +++ b/fairseq/model_parallel/criterions/__init__.py @@ -9,6 +9,6 @@ # automatically import any Python files in the criterions/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - module = file[:file.find('.py')] - importlib.import_module('fairseq.model_parallel.criterions.' + module) + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.model_parallel.criterions." + module) diff --git a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py index eab8f9af4e..35c50ee152 100644 --- a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py +++ b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py @@ -8,24 +8,27 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion + try: - from fairseq.model_parallel.megatron.mpu.cross_entropy import vocab_parallel_cross_entropy + from fairseq.model_parallel.megatron.mpu.cross_entropy import ( + vocab_parallel_cross_entropy, + ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False -@register_criterion('vocab_parallel_cross_entropy') +@register_criterion("vocab_parallel_cross_entropy") class VocabParallelCrossEntropyCriterion(FairseqCriterion): - def __init__(self, task, sentence_avg): super().__init__(task) self.sentence_avg = sentence_avg if not has_megatron_submodule: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) def forward(self, model, sample, reduce=True): @@ -36,33 +39,43 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) - target = sample['target'] + net_output = model(**sample["net_input"]) + target = sample["target"] loss = vocab_parallel_cross_entropy(net_output[0].float(), target) loss = (loss * (target != self.padding_idx)).sum() - sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens'] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - loss_sum = sum(log.get('loss', 0) for log in logging_outputs) - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg)) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) else: - metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg)) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index d1142a993c..761ffc8e61 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -10,6 +10,7 @@ from fairseq import distributed_utils from fairseq.trainer import Trainer + try: from fairseq.model_parallel.megatron.mpu import ( get_data_parallel_group, @@ -18,20 +19,21 @@ get_model_parallel_group, get_model_parallel_src_rank, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False class MegatronTrainer(Trainer): - """Main class for model parallel with data parallel training. - """ + """Main class for model parallel with data parallel training.""" + def __init__(self, args, task, model, criterion): if not has_megatron_submodule: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) super().__init__(args, task, model, criterion) @@ -57,6 +59,7 @@ def _aggregate_model_parallel_grad_norm(total_norm): distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) total_norm = total_norm ** 0.5 return total_norm + return self.optimizer.clip_grad_norm( clip_norm, aggregate_norm_fn=_aggregate_model_parallel_grad_norm, diff --git a/fairseq/model_parallel/models/__init__.py b/fairseq/model_parallel/models/__init__.py index a3207981ad..3532479e52 100644 --- a/fairseq/model_parallel/models/__init__.py +++ b/fairseq/model_parallel/models/__init__.py @@ -11,6 +11,10 @@ models_dir = os.path.dirname(__file__) for file in os.listdir(models_dir): path = os.path.join(models_dir, file) - if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): - model_name = file[:file.find('.py')] if file.endswith('.py') else file - module = importlib.import_module('fairseq.model_parallel.models.' + model_name) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.model_parallel.models." + model_name) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py index e11f491486..eb81ded341 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -3,31 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import namedtuple import math +from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import options, utils from fairseq.modules import ( AdaptiveSoftmax, LayerNorm, - PositionalEmbedding, MultiheadAttention, + PositionalEmbedding, ) -EncoderOut = namedtuple('TransformerEncoderOut', [ - 'encoder_out', # T x B x C - 'encoder_padding_mask', # B x T - 'encoder_embedding', # B x T x C - 'encoder_states', # List[T x B x C] -]) + +EncoderOut = namedtuple( + "TransformerEncoderOut", + [ + "encoder_out", # T x B x C + "encoder_padding_mask", # B x T + "encoder_embedding", # B x T x C + "encoder_states", # List[T x B x C] + ], +) class TransformerEncoderEmbedding(nn.Module): """ Encoder Embedding + Positional Embedding """ + def __init__(self, args, embed_tokens): super().__init__() self.dropout = args.dropout @@ -40,11 +44,17 @@ def __init__(self, args, embed_tokens): self.padding_idx = embed_tokens.padding_idx embed_dim = embed_tokens.embedding_dim self.embed_scale = math.sqrt(embed_dim) - self.embed_positions = PositionalEmbedding( - args.max_source_positions, embed_dim, self.padding_idx, - learned=args.encoder_learned_pos, - ) if not args.no_token_positional_embeddings else None - if getattr(args, 'layernorm_embedding', False): + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None @@ -77,9 +87,10 @@ def forward(self, input): class TransformerEncoderLayerNorm(nn.Module): """ - Layer norm at the the end of all encoder layers if - args.encoder_enormalize_before = True + Layer norm at the the end of all encoder layers if + args.encoder_enormalize_before = True """ + def __init__(self, args, embed_dim): super().__init__() if args.encoder_normalize_before: @@ -99,30 +110,45 @@ def forward(self, input): class TransformerDecoderEmbedding(nn.Module): """ Decoder Embedding + Positional Embedding """ + def __init__(self, args, embed_tokens): super().__init__() self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed - input_embed_dim = sum(e.embedding_dim for e in embed_tokens) \ - if isinstance(embed_tokens, nn.ModuleList) \ + input_embed_dim = ( + sum(e.embedding_dim for e in embed_tokens) + if isinstance(embed_tokens, nn.ModuleList) else embed_tokens.embedding_dim + ) embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim - padding_idx = embed_tokens[0].padding_idx \ - if isinstance(embed_tokens, nn.ModuleList) \ + padding_idx = ( + embed_tokens[0].padding_idx + if isinstance(embed_tokens, nn.ModuleList) else embed_tokens.padding_idx + ) self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim - self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) - self.embed_positions = PositionalEmbedding( - args.max_target_positions, embed_dim, padding_idx, - learned=args.decoder_learned_pos, - ) if not args.no_token_positional_embeddings else None + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) def forward(self, input): mt_task = False @@ -147,10 +173,14 @@ def forward(self, input): encoder_padding_mask = None incremental_state = None - positions = self.embed_positions( - prev_output_tokens, - incremental_state=incremental_state, - ) if self.embed_positions is not None else None + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] @@ -190,8 +220,11 @@ def __init__(self, args, embed_tokens, dictionary): self.output_embed_dim = args.decoder_output_dim embed_dim = args.decoder_embed_dim - self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ - if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) self.adaptive_softmax = None if args.adaptive_softmax_cutoff is not None: assert not isinstance(embed_tokens, nn.ModuleList) @@ -205,10 +238,16 @@ def __init__(self, args, embed_tokens, dictionary): tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: - self.embed_tokens = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim)) - nn.init.normal_(self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5) + self.embed_tokens = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_( + self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5 + ) - if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False): + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None @@ -245,7 +284,7 @@ def output_layer(self, features, **kwargs): output = F.linear(features[:, :, sidx:eidx], emb.weight) else: output += F.linear(features[:, :, sidx:eidx], emb.weight) - + return output else: return F.linear(features, self.embed_tokens.weight) @@ -273,18 +312,20 @@ def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim self.self_attn = MultiheadAttention( - self.embed_dim, args.encoder_attention_heads, - dropout=args.attention_dropout, self_attention=True + self.embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( - activation=getattr(args, 'activation_fn', 'relu') + activation=getattr(args, "activation_fn", "relu") ) - self.activation_dropout = getattr(args, 'activation_dropout', 0) + self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout - self.activation_dropout = getattr(args, 'relu_dropout', 0) + self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) @@ -296,17 +337,12 @@ def upgrade_state_dict_named(self, state_dict, name): `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ - layer_norm_map = { - '0': 'self_attn_layer_norm', - '1': 'final_layer_norm' - } + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): - for m in ('weight', 'bias'): - k = '{}.layer_norms.{}.{}'.format(name, old, m) + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: - state_dict[ - '{}.{}.{}'.format(name, new, m) - ] = state_dict[k] + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] def forward(self, input): @@ -330,7 +366,9 @@ def forward(self, input): prev_output_tokens = input[2] residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) - x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask + ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) @@ -370,7 +408,9 @@ class TransformerDecoderLayer(nn.Module): (default: False). """ - def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( @@ -379,22 +419,22 @@ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, - self_attention=True + self_attention=True, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( - activation=getattr(args, 'activation_fn', 'relu') + activation=getattr(args, "activation_fn", "relu") ) - self.activation_dropout = getattr(args, 'activation_dropout', 0) + self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout - self.activation_dropout = getattr(args, 'relu_dropout', 0) + self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix - export = getattr(args, 'char_inputs', False) + export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: @@ -404,8 +444,8 @@ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, - kdim=getattr(args, 'encoder_embed_dim', None), - vdim=getattr(args, 'encoder_embed_dim', None), + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) @@ -520,10 +560,18 @@ def forward(self, input): def buffered_future_mask(self, tensor): dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) if self._future_mask.size(0) < dim: - self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) return self._future_mask[:dim, :dim] def maybe_layer_norm(self, layer_norm, x, before=False, after=False): @@ -548,5 +596,5 @@ def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: - nn.init.constant_(m.bias, 0.) + nn.init.constant_(m.bias, 0.0) return m diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 65a087a3fb..cbfc6ae4a0 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -5,7 +5,19 @@ import logging +import torch +import torch.nn as nn +import torch.nn.functional as F from fairseq import utils +from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( + Embedding, + TransformerDecoderEmbedding, + TransformerDecoderLayer, + TransformerDecoderOutputLayer, + TransformerEncoderEmbedding, + TransformerEncoderLayer, + TransformerEncoderLayerNorm, +) from fairseq.models import ( BaseFairseqModel, FairseqDecoder, @@ -13,25 +25,14 @@ register_model, register_model_architecture, ) +from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( base_architecture, transformer_iwslt_de_en, transformer_wmt_en_de_big, ) from fairseq.modules import SinusoidalPositionalEmbedding -from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( - Embedding, - TransformerEncoderLayer, - TransformerDecoderLayer, - TransformerEncoderEmbedding, - TransformerEncoderLayerNorm, - TransformerDecoderEmbedding, - TransformerDecoderOutputLayer, -) -import torch -import torch.nn as nn -import torch.nn.functional as F + logger = logging.getLogger(__name__) @@ -40,24 +41,27 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 -@register_model('pipeline_parallel_transformer') +@register_model("pipeline_parallel_transformer") class PipelineParallelTransformerModel(BaseFairseqModel): def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): try: from fairscale.nn import Pipe except ImportError: - raise ImportError('Please install fairscale with: pip install fairscale') + raise ImportError("Please install fairscale with: pip install fairscale") super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) - encoder_module_list = \ - [encoder.embedding_layer] + \ - list(encoder.encoder_layers) + \ - [encoder.final_layer_norm] + encoder_module_list = ( + [encoder.embedding_layer] + + list(encoder.encoder_layers) + + [encoder.final_layer_norm] + ) self.num_encoder_modules = len(encoder_module_list) - decoder_module_list = [decoder.embedding_layer] + \ - list(decoder.decoder_layers) + \ - [decoder.decoder_output_layer] + decoder_module_list = ( + [decoder.embedding_layer] + + list(decoder.decoder_layers) + + [decoder.decoder_output_layer] + ) self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices @@ -69,14 +73,12 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): checkpoint=checkpoint, ) self.encoder_max_positions = self.max_positions_helper( - encoder.embedding_layer, - 'max_source_positions' + encoder.embedding_layer, "max_source_positions" ) self.decoder_max_positions = self.max_positions_helper( - decoder.embedding_layer, - 'max_target_positions' + decoder.embedding_layer, "max_target_positions" ) - self.adaptive_softmax = getattr(decoder, 'adaptive_softmax', None) + self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) # Note: To be populated during inference self.encoder = None self.decoder = None @@ -87,9 +89,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) return self.model(input) else: - assert self.encoder is not None and self.decoder is not None, \ - "encoder and decoder need to be initialized by " + \ - "calling the `prepare_for_inference_()` method" + assert self.encoder is not None and self.decoder is not None, ( + "encoder and decoder need to be initialized by " + + "calling the `prepare_for_inference_()` method" + ) encoder_output_tuple = self.encoder(input) return self.decoder(encoder_output_tuple) @@ -109,7 +112,9 @@ def prepare_for_inference_(self, args): module_count += 1 self.model = None self.encoder = TransformerEncoder(args, None, None, encoder_module_list) - self.decoder = TransformerDecoder(args, None, None, decoder_module_list=decoder_module_list) + self.decoder = TransformerDecoder( + args, None, None, decoder_module_list=decoder_module_list + ) @staticmethod def add_args(parser): @@ -178,20 +183,22 @@ def build_model_base(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if not hasattr(args, 'max_source_positions'): + if not hasattr(args, "max_source_positions"): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS - if not hasattr(args, 'max_target_positions'): + if not hasattr(args, "max_target_positions"): args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): - assert embed_dim % num_embed_chunks == 0, \ - f"Number of embedding chunks = {num_embed_chunks} should be " + \ - f"divisible by the embedding dimension = {embed_dim}" - assert path is None or num_embed_chunks == 1, \ - "Loading embedding from a path with number of embedding chunks > 1" + \ - " is not yet supported" + assert embed_dim % num_embed_chunks == 0, ( + f"Number of embedding chunks = {num_embed_chunks} should be " + + f"divisible by the embedding dimension = {embed_dim}" + ) + assert path is None or num_embed_chunks == 1, ( + "Loading embedding from a path with number of embedding chunks > 1" + + " is not yet supported" + ) num_embeddings = len(dictionary) padding_idx = dictionary.pad() # if provided, load from preloaded dictionaries @@ -205,30 +212,45 @@ def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): for i in range(num_embed_chunks): emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) return emb + num_embed_chunks = args.num_embedding_chunks if args.share_all_embeddings: if src_dict != tgt_dict: - raise ValueError('--share-all-embeddings requires a joined dictionary') + raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( - '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path): - raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: - assert args.share_decoder_input_output_embed or num_embed_chunks == 1, \ - "Not sharing decoder I/O embeddings is not yet supported with number of " + \ - "embedding chunks > 1" + assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( + "Not sharing decoder I/O embeddings is not yet supported with number of " + + "embedding chunks > 1" + ) encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, ) decoder_embed_tokens = build_embedding( - tgt_dict, args.decoder_embed_dim, args.decoder_embed_path, num_embed_chunks, + tgt_dict, + args.decoder_embed_dim, + args.decoder_embed_path, + num_embed_chunks, ) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) @@ -263,21 +285,24 @@ def max_positions(self): """Maximum length supported by the model.""" return (self.encoder_max_positions, self.decoder_max_positions) - def max_positions_helper(self, embedding_layer, - max_positions_field='max_source_positions'): + def max_positions_helper( + self, embedding_layer, max_positions_field="max_source_positions" + ): """Maximum input length supported by the encoder or decoder.""" if embedding_layer.embed_positions is None: return getattr(embedding_layer, max_positions_field) - return min(getattr(embedding_layer, max_positions_field), - embedding_layer.embed_positions.max_positions) + return min( + getattr(embedding_layer, max_positions_field), + embedding_layer.embed_positions.max_positions, + ) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" - if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: - assert 'target' in sample - target = sample['target'] + assert "target" in sample + target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output, target=target) @@ -303,7 +328,7 @@ def load_state_dict(self, state_dict, strict=True, args=None): this additionally "upgrades" *state_dicts* from old checkpoints. """ self.upgrade_state_dict(state_dict) - is_regular_transformer = not any('model.partitions' in k for k in state_dict) + is_regular_transformer = not any("model.partitions" in k for k in state_dict) if is_regular_transformer: state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) return super().load_state_dict(state_dict, strict) @@ -313,27 +338,50 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict): encoder_layer_idx = 0 decoder_layer_idx = 0 encoder_key_suffixes = [ - 'self_attn.k_proj.weight', 'self_attn.k_proj.bias', - 'self_attn.v_proj.weight', 'self_attn.v_proj.bias', - 'self_attn.q_proj.weight', 'self_attn.q_proj.bias', - 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', - 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias', 'fc1.weight', - 'fc1.bias', 'fc2.weight', 'fc2.bias', 'final_layer_norm.weight', - 'final_layer_norm.bias', + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", ] decoder_key_suffixes = [ - 'self_attn.k_proj.weight', 'self_attn.k_proj.bias', - 'self_attn.v_proj.weight', 'self_attn.v_proj.bias', - 'self_attn.q_proj.weight', 'self_attn.q_proj.bias', - 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', - 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias', - 'encoder_attn.k_proj.weight', 'encoder_attn.k_proj.bias', - 'encoder_attn.v_proj.weight', 'encoder_attn.v_proj.bias', - 'encoder_attn.q_proj.weight', 'encoder_attn.q_proj.bias', - 'encoder_attn.out_proj.weight', 'encoder_attn.out_proj.bias', - 'encoder_attn_layer_norm.weight', 'encoder_attn_layer_norm.bias', - 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', - 'final_layer_norm.weight', 'final_layer_norm.bias' + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "encoder_attn.k_proj.weight", + "encoder_attn.k_proj.bias", + "encoder_attn.v_proj.weight", + "encoder_attn.v_proj.bias", + "encoder_attn.q_proj.weight", + "encoder_attn.q_proj.bias", + "encoder_attn.out_proj.weight", + "encoder_attn.out_proj.bias", + "encoder_attn_layer_norm.weight", + "encoder_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", ] for pid, partition in enumerate(self.model.partitions): logger.info(f"Begin Partition {pid}") @@ -376,29 +424,32 @@ class TransformerEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([3])) + self.register_buffer("version", torch.Tensor([3])) try: from fairscale.nn import Pipe except ImportError: - raise ImportError('Please install fairscale with: pip install fairscale') + raise ImportError("Please install fairscale with: pip install fairscale") if encoder_module_list is None: embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) - layers = [ - TransformerEncoderLayer(args) for i in range(args.encoder_layers) - ] + layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] if isinstance(embed_tokens, nn.ModuleList): emb_dim = sum(e.embedding_dim for e in embed_tokens) else: emb_dim = embed_tokens.embedding_dim final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) encoder_module_list = [embedding_layer] + layers + [final_layer_norm] - self.use_pipeline = (getattr(args, "pipeline_encoder_balance", None) is not None) + self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None if self.use_pipeline: - encoder_balance = utils.eval_str_list(args.pipeline_encoder_balance, type=int) - encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int) - assert sum(encoder_balance) == len(encoder_module_list), \ - f"Sum of encoder_balance={encoder_balance} is not equal " + \ - f"to num_encoder_modules={len(encoder_module_list)}" + encoder_balance = utils.eval_str_list( + args.pipeline_encoder_balance, type=int + ) + encoder_devices = utils.eval_str_list( + args.pipeline_encoder_devices, type=int + ) + assert sum(encoder_balance) == len(encoder_module_list), ( + f"Sum of encoder_balance={encoder_balance} is not equal " + + f"to num_encoder_modules={len(encoder_module_list)}" + ) self.model = Pipe( module=nn.Sequential(*encoder_module_list), balance=encoder_balance, @@ -433,7 +484,9 @@ def forward(self, src_tokens, src_lengths): Only populated if *return_all_hiddens* is True. ) """ - dummy_prev_output_tokens = torch.zeros(1, dtype=src_tokens.dtype, device=src_tokens.device) + dummy_prev_output_tokens = torch.zeros( + 1, dtype=src_tokens.dtype, device=src_tokens.device + ) input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) @@ -465,11 +518,15 @@ def reorder_encoder_out(self, encoder_out, new_order): ) if encoder_out.encoder_padding_mask is not None: encoder_out = encoder_out._replace( - encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order) + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( + 0, new_order + ) ) if encoder_out.encoder_embedding is not None: encoder_out = encoder_out._replace( - encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order) + encoder_embedding=encoder_out.encoder_embedding.index_select( + 0, new_order + ) ) if encoder_out.encoder_states is not None: for idx, state in enumerate(encoder_out.encoder_states): @@ -480,8 +537,10 @@ def max_positions(self): """Maximum input length supported by the encoder.""" if self.embedding_layer.embed_positions is None: return self.embedding_layer.max_source_positions - return min(self.embedding_layer.max_source_positions, - self.embedding_layer.embed_positions.max_positions) + return min( + self.embedding_layer.max_source_positions, + self.embedding_layer.embed_positions.max_positions, + ) class TransformerDecoder(FairseqDecoder): @@ -497,28 +556,42 @@ class TransformerDecoder(FairseqDecoder): (default: False). """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None): + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + decoder_module_list=None, + ): super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([3])) + self.register_buffer("version", torch.Tensor([3])) try: from fairscale.nn import Pipe except ImportError: - raise ImportError('Please install fairscale with: pip install fairscale') + raise ImportError("Please install fairscale with: pip install fairscale") if decoder_module_list is None: embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) layers = [ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] - decoder_output_layer = TransformerDecoderOutputLayer(args, embed_tokens, dictionary) + decoder_output_layer = TransformerDecoderOutputLayer( + args, embed_tokens, dictionary + ) decoder_module_list = [embedding_layer] + layers + [decoder_output_layer] - self.use_pipeline = (getattr(args, "pipeline_decoder_balance", None) is not None) + self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None if self.use_pipeline: - decoder_balance = utils.eval_str_list(args.pipeline_decoder_balance, type=int) - decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int) - assert sum(decoder_balance) == len(decoder_module_list), \ - f"Sum of decoder_balance={decoder_balance} is not equal " + \ - f"to num_decoder_modules={len(decoder_module_list)}" + decoder_balance = utils.eval_str_list( + args.pipeline_decoder_balance, type=int + ) + decoder_devices = utils.eval_str_list( + args.pipeline_decoder_devices, type=int + ) + assert sum(decoder_balance) == len(decoder_module_list), ( + f"Sum of decoder_balance={decoder_balance} is not equal " + + f"to num_decoder_modules={len(decoder_module_list)}" + ) self.model = Pipe( module=nn.Sequential(*decoder_module_list), balance=decoder_balance, @@ -531,7 +604,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, decode self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1]) self.decoder_output_layer = decoder_module_list[-1] - def forward(self, prev_output_tokens, encoder_out=None,): + def forward( + self, + prev_output_tokens, + encoder_out=None, + ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -548,14 +625,18 @@ def forward(self, prev_output_tokens, encoder_out=None,): - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ - input_tuple = (encoder_out.encoder_out, encoder_out.encoder_padding_mask, prev_output_tokens) + input_tuple = ( + encoder_out.encoder_out, + encoder_out.encoder_padding_mask, + prev_output_tokens, + ) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - return (self.model(input_tuple), ) + return (self.model(input_tuple),) else: embed_layer_output = self.embedding_layer(input_tuple) state = self.decoder_layers(embed_layer_output) - return (self.decoder_output_layer(state), ) + return (self.decoder_output_layer(state),) def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" @@ -572,43 +653,51 @@ def max_positions(self): """Maximum output length supported by the decoder.""" if self.embedding_layer.embed_positions is None: return self.embedding_layer.max_target_positions - return min(self.embedding_layer.max_target_positions, - self.embedding_layer.embed_positions.max_positions) + return min( + self.embedding_layer.max_target_positions, + self.embedding_layer.embed_positions.max_positions, + ) def buffered_future_mask(self, tensor): dim = tensor.size(0) if ( - not hasattr(self, '_future_mask') + not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim ): - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = '{}.embed_positions.weights'.format(name) + weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] - state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { - '0': 'self_attn_layer_norm', - '1': 'encoder_attn_layer_norm', - '2': 'final_layer_norm' + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", } for old, new in layer_norm_map.items(): - for m in ('weight', 'bias'): - k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m) + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) if k in state_dict: - state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k] + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] del state_dict[k] - version_key = '{}.version'.format(name) + version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None @@ -618,13 +707,15 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict -@register_model_architecture('pipeline_parallel_transformer', - 'transformer_iwslt_de_en_pipeline_parallel') +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel" +) def transformer_iwslt_de_en_dist(args): transformer_iwslt_de_en(args) -@register_model_architecture('pipeline_parallel_transformer', - 'transformer_wmt_en_de_big_pipeline_parallel') +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel" +) def transformer_wmt_en_de_big_dist(args): transformer_wmt_en_de_big(args) diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index ed49fbb338..68ad88d2a5 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -11,27 +11,19 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils -from fairseq.models import ( - FairseqEncoder, - register_model, - register_model_architecture, -) +from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder +from fairseq.models import FairseqEncoder, register_model, register_model_architecture from fairseq.models.roberta import ( - RobertaModel, + RobertaClassificationHead, RobertaEncoder, RobertaLMHead, - RobertaClassificationHead, -) -from fairseq.modules import ( - LayerNorm, - TransformerSentenceEncoder, -) -from fairseq.model_parallel.modules import ( - ModelParallelTransformerSentenceEncoder, + RobertaModel, ) +from fairseq.modules import LayerNorm, TransformerSentenceEncoder from fairseq.modules.transformer_sentence_encoder import init_bert_params + + try: from fairseq.model_parallel.megatron.mpu import ( copy_to_model_parallel_region, @@ -39,6 +31,7 @@ ColumnParallelLinear, RowParallelLinear, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -46,10 +39,8 @@ logger = logging.getLogger(__name__) -@register_model('model_parallel_roberta') +@register_model("model_parallel_roberta") class ModelParallelRobertaModel(RobertaModel): - - def __init__(self, args, encoder): super().__init__(args, encoder) @@ -69,18 +60,25 @@ def build_model(cls, args, task): task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) - if not hasattr(args, 'max_positions'): + if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample - if getattr(args, 'untie_weights_roberta', False): + if getattr(args, "untie_weights_roberta", False): raise NotImplementedError( - '--untie-weights-roberta is not supported in model parallel mode' + "--untie-weights-roberta is not supported in model parallel mode" ) encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) return cls(args, encoder) - def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs): + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + **kwargs + ): if classification_head_name is not None: features_only = True @@ -90,7 +88,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla x = self.classification_heads[classification_head_name](x) return x, extra - def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): """Register a classification head.""" if name in self.classification_heads: prev_num_classes = self.classification_heads[name].out_proj.out_features @@ -98,7 +98,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' - 'and inner_dim {} (prev: {})'.format( + "and inner_dim {} (prev: {})".format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) @@ -146,7 +146,9 @@ def forward(self, features, masked_tokens=None, **kwargs): class ModelParallelRobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout): + def __init__( + self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout + ): super().__init__() self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) @@ -206,7 +208,14 @@ def __init__(self, args, dictionary): weight=self.sentence_encoder.embed_tokens.weight, ) - def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + masked_tokens=None, + **unused + ): """ Args: src_tokens (LongTensor): input tokens of shape `(batch, src_len)` @@ -223,7 +232,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas is a list of hidden states. Note that the hidden states have shape `(src_len, batch, vocab)`. """ - x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) + x, extra = self.extract_features( + src_tokens, return_all_hiddens=return_all_hiddens + ) if not features_only: x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra @@ -234,7 +245,7 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused): last_state_only=not return_all_hiddens, ) features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {'inner_states': inner_states if return_all_hiddens else None} + return features, {"inner_states": inner_states if return_all_hiddens else None} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) @@ -244,33 +255,33 @@ def max_positions(self): return self.args.max_positions -@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta') +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 12) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_dropout = getattr(args, 'activation_dropout', 0.0) - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) - args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) - args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) -@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_base') +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") def roberta_base_architecture(args): base_architecture(args) -@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_large') +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") def roberta_large_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 24) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) base_architecture(args) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index 3ba539319f..4f34645226 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -7,21 +7,17 @@ import torch.nn as nn import torch.nn.functional as F - -from fairseq.models import ( - register_model, +from fairseq.model_parallel.modules import ( + ModelParallelTransformerDecoderLayer, + ModelParallelTransformerEncoderLayer, ) - +from fairseq.models import register_model from fairseq.models.transformer import ( TransformerDecoder, TransformerEncoder, TransformerModel, ) -from fairseq.model_parallel.modules import ( - ModelParallelTransformerDecoderLayer, - ModelParallelTransformerEncoderLayer, -) try: from fairseq.model_parallel.megatron.mpu import ( @@ -29,6 +25,7 @@ gather_from_model_parallel_region, VocabParallelEmbedding, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -37,18 +34,19 @@ logger = logging.getLogger(__name__) -@register_model('model_parallel_transformer') +@register_model("model_parallel_transformer") class ModelParallelTransformerModel(TransformerModel): """ Model parallel Transformer model. """ + @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): if not has_megatron_submodule: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) dictionary.pad_to_multiple_(args.model_parallel_size * 8) num_embeddings = len(dictionary) @@ -57,10 +55,15 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) nn.init.constant_(tensor[1], 0) - emb = VocabParallelEmbedding(num_embeddings, embed_dim, padding_idx, init_method=_vocab_init) + + emb = VocabParallelEmbedding( + num_embeddings, embed_dim, padding_idx, init_method=_vocab_init + ) # if provided, load from preloaded dictionaries if path: - raise NotImplementedError("Loading of embedding from path is not supported for model parallel") + raise NotImplementedError( + "Loading of embedding from path is not supported for model parallel" + ) return emb @classmethod @@ -73,7 +76,7 @@ def build_decoder(cls, args, tgt_dict, embed_tokens): args, tgt_dict, embed_tokens, - no_encoder_attn=getattr(args, 'no_cross_attention', False), + no_encoder_attn=getattr(args, "no_cross_attention", False), ) @@ -100,7 +103,7 @@ def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if not self.share_input_output_embed: raise NotImplementedError( - 'Model parallel training currently requires --share-decoder-input-output-embed' + "Model parallel training currently requires --share-decoder-input-output-embed" ) features = copy_to_model_parallel_region(features) @@ -108,6 +111,6 @@ def output_layer(self, features, **kwargs): # project back to size of vocabulary x = self.output_projection(features) - if getattr(self.args, 'criterion') != 'vocab_parallel_cross_entropy': + if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": x = gather_from_model_parallel_region(x).contiguous() return x diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 492dad653c..ed378c4320 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -4,14 +4,14 @@ # LICENSE file in the root directory of this source tree. import torch.nn as nn - -from fairseq.models import register_model, register_model_architecture -from fairseq.models.transformer_lm import ( - TransformerLanguageModel, -) from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer_lm import TransformerLanguageModel + + try: from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -20,17 +20,16 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 -@register_model('model_parallel_transformer_lm') +@register_model("model_parallel_transformer_lm") class ModelParallelTransformerLanguageModel(TransformerLanguageModel): - @classmethod def build_model(cls, args, task): """Build a new model instance.""" if not has_megatron_submodule: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) # make sure all arguments are present in older models @@ -42,18 +41,29 @@ def build_model(cls, args, task): if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) - if getattr(args, 'max_target_positions', None) is None: - args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) if args.character_embeddings: - raise NotImplementedError("Character embeddings is not supported for model parallel") + raise NotImplementedError( + "Character embeddings is not supported for model parallel" + ) elif args.adaptive_input: - raise NotImplementedError("Adaptive input is not supported for model parallel") + raise NotImplementedError( + "Adaptive input is not supported for model parallel" + ) else: - embed_tokens = cls.build_embedding(args, task.source_dictionary, args.decoder_input_dim) + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) decoder = ModelParallelTransformerDecoder( - args, task.target_dictionary, embed_tokens, no_encoder_attn=True, + args, + task.target_dictionary, + embed_tokens, + no_encoder_attn=True, ) return cls(decoder) @@ -62,78 +72,94 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): nn.init.normal_(tensor, mean=0, std=embed_dim ** -0.5) nn.init.constant_(tensor[1], 0) - embed_tokens = VocabParallelEmbedding(len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init) + + embed_tokens = VocabParallelEmbedding( + len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init + ) return embed_tokens def base_lm_architecture(args): # backward compatibility for older model checkpoints - if hasattr(args, 'no_tie_adaptive_proj'): + if hasattr(args, "no_tie_adaptive_proj"): # previous models defined --no-tie-adaptive-proj, so use the existence of # that option to determine if this is an "old" model checkpoint args.no_decoder_final_norm = True # old models always set this to True if args.no_tie_adaptive_proj is False: args.tie_adaptive_proj = True - if hasattr(args, 'decoder_final_norm'): + if hasattr(args, "decoder_final_norm"): args.no_decoder_final_norm = not args.decoder_final_norm - args.activation_fn = getattr(args, 'activation_fn', 'relu') - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.0) - args.activation_dropout = getattr(args, 'activation_dropout', 0.0) - args.relu_dropout = getattr(args, 'relu_dropout', 0.0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) # Model training is not stable without this args.decoder_normalize_before = True - args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) - args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.character_embeddings = getattr(args, 'character_embeddings', False) - args.character_filters = getattr(args, 'character_filters', '[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]') - args.character_embedding_dim = getattr(args, 'character_embedding_dim', 4) - args.char_embedder_highway_layers = getattr(args, 'char_embedder_highway_layers', 2) - args.adaptive_input = getattr(args, 'adaptive_input', False) - args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) - args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None) - args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) - args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) - args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0.0) - args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None) - args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) - args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) - args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0.0) - args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size', 8) - args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0.0) - args.add_bos_token = getattr(args, 'add_bos_token', False) - -@register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron') + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + args.character_filters = getattr( + args, + "character_filters", + "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + ) + args.character_embedding_dim = getattr(args, "character_embedding_dim", 4) + args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0) + args.add_bos_token = getattr(args, "add_bos_token", False) + + +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") def transformer_lm_megatron(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 3072) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072 * 4) - args.decoder_layers = getattr(args, 'decoder_layers', 72) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 32) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) -@register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron_11b') +@register_model_architecture( + "model_parallel_transformer_lm", "transformer_lm_megatron_11b" +) def transformer_lm_megatron_11b(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 3072) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072 * 6) - args.decoder_layers = getattr(args, 'decoder_layers', 72) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 32) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index 26401dcc7c..fb45b3c9e0 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -5,14 +5,19 @@ """isort:skip_file""" from .multihead_attention import ModelParallelMultiheadAttention -from .transformer_layer import ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer -from .transformer_sentence_encoder_layer import ModelParallelTransformerSentenceEncoderLayer +from .transformer_layer import ( + ModelParallelTransformerEncoderLayer, + ModelParallelTransformerDecoderLayer, +) +from .transformer_sentence_encoder_layer import ( + ModelParallelTransformerSentenceEncoderLayer, +) from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder __all__ = [ - 'ModelParallelMultiheadAttention', - 'ModelParallelTransformerEncoderLayer', - 'ModelParallelTransformerDecoderLayer', - 'ModelParallelTransformerSentenceEncoder', - 'ModelParallelTransformerSentenceEncoderLayer', + "ModelParallelMultiheadAttention", + "ModelParallelTransformerEncoderLayer", + "ModelParallelTransformerDecoderLayer", + "ModelParallelTransformerSentenceEncoder", + "ModelParallelTransformerSentenceEncoderLayer", ] diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index f55a712b01..4164bf9131 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -8,9 +8,10 @@ import torch import torch.nn.functional as F from fairseq import utils -from torch import Tensor, nn from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import Tensor, nn + try: from fairseq.model_parallel.megatron.mpu import ( @@ -19,6 +20,7 @@ ColumnParallelLinear, RowParallelLinear, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -46,9 +48,9 @@ def __init__( super().__init__() if not has_megatron_submodule: raise ImportError( - '\n\nPlease install the megatron submodule:' - '\n\n git submodule update --init ' - 'fairseq/model_parallel/megatron' + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" ) self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -74,14 +76,22 @@ def __init__( self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention - assert not self.self_attention or self.qkv_same_dim, ( - "Self-attention requires query, key and value to be of the same size" - ) + assert ( + not self.self_attention or self.qkv_same_dim + ), "Self-attention requires query, key and value to be of the same size" - self.k_proj = ColumnParallelLinear(self.kdim, embed_dim, bias=bias, gather_output=False) - self.v_proj = ColumnParallelLinear(self.vdim, embed_dim, bias=bias, gather_output=False) - self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False) - self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True) + self.k_proj = ColumnParallelLinear( + self.kdim, embed_dim, bias=bias, gather_output=False + ) + self.v_proj = ColumnParallelLinear( + self.vdim, embed_dim, bias=bias, gather_output=False + ) + self.q_proj = ColumnParallelLinear( + embed_dim, embed_dim, bias=bias, gather_output=False + ) + self.out_proj = RowParallelLinear( + embed_dim, embed_dim, bias=bias, input_is_parallel=True + ) self.tpu = False @@ -145,7 +155,6 @@ def forward( v = self.v_proj(value) q *= self.scaling - q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads_partition, self.head_dim) @@ -169,7 +178,9 @@ def forward( if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None - prev_key = _prev_key.view(bsz * self.num_heads_partition, -1, self.head_dim) + prev_key = _prev_key.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) if static_kv: k = prev_key else: @@ -178,7 +189,9 @@ def forward( if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None - prev_value = _prev_value.view(bsz * self.num_heads_partition, -1, self.head_dim) + prev_value = _prev_value.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) if static_kv: v = prev_value else: @@ -188,16 +201,22 @@ def forward( if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None - key_padding_mask = ModelParallelMultiheadAttention._append_prev_key_padding_mask( - key_padding_mask=key_padding_mask, - prev_key_padding_mask=prev_key_padding_mask, - batch_size=bsz, - src_len=k.size(1), - static_kv=static_kv, + key_padding_mask = ( + ModelParallelMultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) ) - saved_state["prev_key"] = k.view(bsz, self.num_heads_partition, -1, self.head_dim) - saved_state["prev_value"] = v.view(bsz, self.num_heads_partition, -1, self.head_dim) + saved_state["prev_key"] = k.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) + saved_state["prev_value"] = v.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None @@ -216,7 +235,11 @@ def forward( attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert list(attn_weights.size()) == [bsz * self.num_heads_partition, tgt_len, src_len] + assert list(attn_weights.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + src_len, + ] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) @@ -224,20 +247,23 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols - attn_weights = attn_weights.view(bsz, self.num_heads_partition, tgt_len, src_len) + attn_weights = attn_weights.view( + bsz, self.num_heads_partition, tgt_len, src_len + ) if not self.tpu: attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf')) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.view(bsz * self.num_heads_partition, tgt_len, src_len) + attn_weights = attn_weights.view( + bsz * self.num_heads_partition, tgt_len, src_len + ) - attn_weights_float = utils.softmax( - attn_weights, dim=-1 - ) + attn_weights_float = utils.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) with get_cuda_rng_tracker().fork(): @@ -245,7 +271,11 @@ def forward( assert v is not None attn = torch.bmm(attn_probs, v) - assert list(attn.size()) == [bsz * self.num_heads_partition, tgt_len, self.head_dim] + assert list(attn.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + self.head_dim, + ] embed_dim_partition = embed_dim // self.model_parallel_size attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition) attn = self.out_proj(attn) diff --git a/fairseq/model_parallel/modules/transformer_layer.py b/fairseq/model_parallel/modules/transformer_layer.py index 30b23d518c..7ab53c6e5f 100644 --- a/fairseq/model_parallel/modules/transformer_layer.py +++ b/fairseq/model_parallel/modules/transformer_layer.py @@ -3,18 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.modules import ( - TransformerEncoderLayer, - TransformerDecoderLayer, -) - from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer + try: from fairseq.model_parallel.megatron.mpu import ( ColumnParallelLinear, RowParallelLinear, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -23,7 +21,7 @@ class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer): """Encoder layer block over multiple gpus. - See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. """ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): @@ -48,8 +46,9 @@ def build_self_attention(self, embed_dim, args, **unused_kwargs): class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): """Decoder layer block. - See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. """ + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): if q_noise > 0: raise NotImplementedError diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py index a2a6eb81fa..a5d50a33c6 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder.py @@ -3,11 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import random from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer from fairseq.modules import ( LayerNorm, MultiheadAttention, @@ -15,24 +17,21 @@ TransformerSentenceEncoder, ) -from fairseq.model_parallel.modules import ( - ModelParallelTransformerSentenceEncoderLayer, -) try: from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False -import random - class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): """ Implementation for a Model Parallel Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models. """ + def build_embedding(self, vocab_size, embedding_dim, padding_idx): return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py index d09158b7f1..e10bf52332 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -5,17 +5,17 @@ import torch import torch.nn.functional as F - from fairseq import utils -from fairseq.modules import ( - TransformerSentenceEncoderLayer -) from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +from fairseq.modules import TransformerSentenceEncoderLayer + + try: from fairseq.model_parallel.megatron.mpu import ( ColumnParallelLinear, RowParallelLinear, ) + has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False @@ -26,6 +26,7 @@ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLay Implements a Model Parallel Transformer Encoder Layer used in BERT/XLM style pre-trained models. """ + def build_fc1(self, input_dim, output_dim, **unused): return ColumnParallelLinear(input_dim, output_dim, gather_output=False) @@ -40,10 +41,7 @@ def build_self_attention( **kwargs, ): return ModelParallelMultiheadAttention( - embed_dim, - num_attention_heads, - dropout=dropout, - self_attention=True + embed_dim, num_attention_heads, dropout=dropout, self_attention=True ) def forward( diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 48c59cb91d..cdabe36010 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -5,14 +5,12 @@ import copy import logging +from typing import List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from typing import List - from fairseq import utils from fairseq.data import encoders @@ -34,19 +32,23 @@ def __init__(self, args, task, model): self.bpe = encoders.build_bpe(args) - self.max_positions = min(utils.resolve_max_positions( - self.task.max_positions(), - self.model.max_positions(), - )) + self.max_positions = min( + utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + ) + ) # this is useful for determining the device - self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) @property def device(self): return self._float_tensor.device - def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.LongTensor: + def encode( + self, sentence: str, *addl_sentences, no_separator=True + ) -> torch.LongTensor: """ BPE-encode a sentence (or multiple sentences). @@ -67,12 +69,12 @@ def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.Lon [0, 8331, 2] """ tokens = self.bpe.encode(sentence) - if len(tokens.split(' ')) > self.max_positions - 2: - tokens = ' '.join(tokens.split(' ')[:self.max_positions - 2]) - bpe_sentence = ' ' + tokens + ' ' + if len(tokens.split(" ")) > self.max_positions - 2: + tokens = " ".join(tokens.split(" ")[: self.max_positions - 2]) + bpe_sentence = " " + tokens + " " for s in addl_sentences: - bpe_sentence += (' ' if not no_separator else '') - bpe_sentence += ' ' + self.bpe.encode(s) + ' ' + bpe_sentence += " " if not no_separator else "" + bpe_sentence += " " + self.bpe.encode(s) + " " tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) return tokens.long() @@ -81,10 +83,12 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens.cpu().numpy() if tokens[0] == self.task.source_dictionary.bos(): tokens = tokens[1:] # remove - eos_mask = (tokens == self.task.source_dictionary.eos()) + eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) - sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences] + sentences = [ + self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences + ] if len(sentences) == 1: return sentences[0] return sentences @@ -96,18 +100,23 @@ def _build_sample(self, src_tokens: List[torch.LongTensor]): [x.numel() for x in src_tokens], ) sample = dataset.collater(dataset) - sample = utils.apply_to_sample( - lambda tensor: tensor.to(self.device), - sample - ) + sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample) return sample - def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> str: + def sample( + self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs + ) -> str: input = [self.encode(sentence) for sentence in sentences] hypos = self.generate(input, beam, verbose, **kwargs) - return [self.decode(x['tokens']) for x in hypos] - - def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor: + return [self.decode(x["tokens"]) for x in hypos] + + def generate( + self, + tokens: List[torch.LongTensor], + beam: int = 5, + verbose: bool = False, + **kwargs + ) -> torch.LongTensor: sample = self._build_sample(tokens) # build generator using current args as well as any kwargs @@ -120,34 +129,40 @@ def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool generator, [self.model], sample, - prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(self.task.source_dictionary.bos()), + prefix_tokens=sample["net_input"]["src_tokens"] + .new_zeros((len(tokens), 1)) + .fill_(self.task.source_dictionary.bos()), ) if verbose: src_str_with_unk = self.string(tokens) - logger.info('S\t{}'.format(src_str_with_unk)) + logger.info("S\t{}".format(src_str_with_unk)) def getarg(name, default): return getattr(gen_args, name, getattr(self.args, name, default)) # Process top predictions hypos = [x[0] for x in translations] - hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))] + hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))] return hypos - def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor: + def extract_features( + self, tokens: torch.LongTensor, return_all_hiddens: bool = False + ) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > min(self.model.max_positions()): - raise ValueError('tokens exceeds maximum length: {} > {}'.format( - tokens.size(-1), self.model.max_positions() - )) + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) tokens.to(device=self.device), prev_output_tokens = tokens.clone() prev_output_tokens[:, 0] = tokens.gather( 1, - (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1)- 1).unsqueeze(-1), + (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1), ).squeeze() prev_output_tokens[:, 1:] = tokens[:, :-1] @@ -160,7 +175,7 @@ def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = ) if return_all_hiddens: # convert from T x B x C -> B x T x C - inner_states = extra['inner_states'] + inner_states = extra["inner_states"] return [inner_state.transpose(0, 1) for inner_state in inner_states] else: return features # just the last layer's features diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 90e79e4651..0f22352b68 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -11,12 +11,8 @@ import torch import torch.nn as nn - from fairseq import utils -from fairseq.models import ( - register_model, - register_model_architecture, -) +from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import TransformerModel from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -26,17 +22,16 @@ logger = logging.getLogger(__name__) -@register_model('bart') +@register_model("bart") class BARTModel(TransformerModel): - @classmethod def hub_models(cls): return { - 'bart.base': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz', - 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', - 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', - 'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz', - 'bart.large.xsum': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz', + "bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz", + "bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz", + "bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz", + "bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz", + "bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz", } def __init__(self, args, encoder, decoder): @@ -51,28 +46,35 @@ def __init__(self, args, encoder, decoder): def add_args(parser): super(BARTModel, BARTModel).add_args(parser) parser.add_argument( - '--pooler-dropout', type=float, metavar='D', - help='dropout probability in the masked_lm pooler layers' + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", ) parser.add_argument( - '--pooler-activation-fn', + "--pooler-activation-fn", choices=utils.get_available_activation_fns(), - help='activation function to use for pooler layer' + help="activation function to use for pooler layer", ) parser.add_argument( - '--spectral-norm-classification-head', - action='store_true', - help='Apply spectral normalization on the classification head' + "--spectral-norm-classification-head", + action="store_true", + help="Apply spectral normalization on the classification head", ) @property def supported_targets(self): - return {'self'} + return {"self"} def forward( - self, src_tokens, src_lengths, prev_output_tokens, - features_only=False, classification_head_name=None, - token_embeddings=None, **kwargs + self, + src_tokens, + src_lengths, + prev_output_tokens, + features_only=False, + classification_head_name=None, + token_embeddings=None, + **kwargs, ): if classification_head_name is not None: features_only = True @@ -103,12 +105,13 @@ def forward( def from_pretrained( cls, model_name_or_path, - checkpoint_file='model.pt', - data_name_or_path='.', - bpe='gpt2', + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="gpt2", **kwargs, ): from fairseq import hub_utils + x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -118,9 +121,11 @@ def from_pretrained( load_checkpoint_heads=True, **kwargs, ) - return BARTHubInterface(x['args'], x['task'], x['models'][0]) + return BARTHubInterface(x["args"], x["task"], x["models"][0]) - def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): """Register a classification head.""" logger.info("Registering classification head: {0}".format(name)) if name in self.classification_heads: @@ -129,7 +134,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' - 'and inner_dim {} (prev: {})'.format( + "and inner_dim {} (prev: {})".format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) @@ -139,43 +144,54 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * num_classes=num_classes, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, - do_spectral_norm=self.args.spectral_norm_classification_head + do_spectral_norm=self.args.spectral_norm_classification_head, ) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) - prefix = name + '.' if name != '' else '' - current_head_names = [] if not hasattr(self, 'classification_heads') else \ - self.classification_heads.keys() + prefix = name + "." if name != "" else "" + current_head_names = ( + [] + if not hasattr(self, "classification_heads") + else self.classification_heads.keys() + ) # Handle new classification heads present in the state dict. keys_to_delete = [] for k in state_dict.keys(): - if not k.startswith(prefix + 'classification_heads.'): + if not k.startswith(prefix + "classification_heads."): continue - head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] - num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) - inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) + head_name = k[len(prefix + "classification_heads.") :].split(".")[0] + num_classes = state_dict[ + prefix + "classification_heads." + head_name + ".out_proj.weight" + ].size(0) + inner_dim = state_dict[ + prefix + "classification_heads." + head_name + ".dense.weight" + ].size(0) - if getattr(self.args, 'load_checkpoint_heads', False): + if getattr(self.args, "load_checkpoint_heads", False): if head_name not in current_head_names: self.register_classification_head(head_name, num_classes, inner_dim) else: if head_name not in current_head_names: logger.warning( - 'deleting classification head ({}) from checkpoint ' - 'not present in current model: {}'.format(head_name, k) + "deleting classification head ({}) from checkpoint " + "not present in current model: {}".format(head_name, k) ) keys_to_delete.append(k) elif ( - num_classes != self.classification_heads[head_name].out_proj.out_features - or inner_dim != self.classification_heads[head_name].dense.out_features + num_classes + != self.classification_heads[head_name].out_proj.out_features + or inner_dim + != self.classification_heads[head_name].dense.out_features ): logger.warning( - 'deleting classification head ({}) from checkpoint ' - 'with different dimensions than current model: {}'.format(head_name, k) + "deleting classification head ({}) from checkpoint " + "with different dimensions than current model: {}".format( + head_name, k + ) ) keys_to_delete.append(k) for k in keys_to_delete: @@ -187,55 +203,66 @@ def truncate_emb(key): # When finetuning on translation task, remove last row of # embedding matrix that corresponds to mask_idx token. - loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0) - if loaded_dict_size == len(self.encoder.dictionary) + 1 and '' not in self.encoder.dictionary: - truncate_emb('encoder.embed_tokens.weight') - truncate_emb('decoder.embed_tokens.weight') - truncate_emb('encoder.output_projection.weight') - truncate_emb('decoder.output_projection.weight') + loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0) + if ( + loaded_dict_size == len(self.encoder.dictionary) + 1 + and "" not in self.encoder.dictionary + ): + truncate_emb("encoder.embed_tokens.weight") + truncate_emb("decoder.embed_tokens.weight") + truncate_emb("encoder.output_projection.weight") + truncate_emb("decoder.output_projection.weight") # When continued pretraining on new set of languages for mbart, # add extra lang embeddings at the end of embed_tokens. # Note: newly added languages are assumed to have been added at the end. - if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary): + if self.args.task == "multilingual_denoising" and loaded_dict_size < len( + self.encoder.dictionary + ): logger.info( - "Adding extra language embeddings not found in pretrained model for "\ + "Adding extra language embeddings not found in pretrained model for " "continued pretraining of MBART on new set of languages." ) - loaded_mask_token_embedding = state_dict['encoder.embed_tokens.weight'][-1, :] + loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][ + -1, : + ] num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size - embed_dim = state_dict['encoder.embed_tokens.weight'].size(1) + embed_dim = state_dict["encoder.embed_tokens.weight"].size(1) new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) - nn.init.normal_( - new_lang_embed_to_add, - mean=0, - std=embed_dim ** -0.5 - ) + nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5) new_lang_embed_to_add = new_lang_embed_to_add.to( - dtype=state_dict['encoder.embed_tokens.weight'].dtype, + dtype=state_dict["encoder.embed_tokens.weight"].dtype, ) - state_dict['encoder.embed_tokens.weight'] = torch.cat([ - state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :], - new_lang_embed_to_add, - loaded_mask_token_embedding.unsqueeze(0)] + state_dict["encoder.embed_tokens.weight"] = torch.cat( + [ + state_dict["encoder.embed_tokens.weight"][ + : loaded_dict_size - 1, : + ], + new_lang_embed_to_add, + loaded_mask_token_embedding.unsqueeze(0), + ] ) - state_dict['decoder.embed_tokens.weight'] = torch.cat([ - state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :], - new_lang_embed_to_add, - loaded_mask_token_embedding.unsqueeze(0)] + state_dict["decoder.embed_tokens.weight"] = torch.cat( + [ + state_dict["decoder.embed_tokens.weight"][ + : loaded_dict_size - 1, : + ], + new_lang_embed_to_add, + loaded_mask_token_embedding.unsqueeze(0), + ] ) # Copy any newly-added classification heads into the state dict # with their current weights. - if hasattr(self, 'classification_heads'): + if hasattr(self, "classification_heads"): cur_state = self.classification_heads.state_dict() for k, v in cur_state.items(): - if prefix + 'classification_heads.' + k not in state_dict: - logger.info('Overwriting', prefix + 'classification_heads.' + k) - state_dict[prefix + 'classification_heads.' + k] = v + if prefix + "classification_heads." + k not in state_dict: + logger.info("Overwriting", prefix + "classification_heads." + k) + state_dict[prefix + "classification_heads." + k] = v class BARTClassificationHead(nn.Module): @@ -248,7 +275,7 @@ def __init__( num_classes, activation_fn, pooler_dropout, - do_spectral_norm=False + do_spectral_norm=False, ): super().__init__() self.dense = nn.Linear(input_dim, inner_dim) @@ -269,67 +296,73 @@ def forward(self, features, **kwargs): return x -@register_model_architecture('bart', 'bart_large') +@register_model_architecture("bart", "bart_large") def bart_large_architecture(args): - args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024) - args.encoder_layers = getattr(args, 'encoder_layers', 12) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 12) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) - args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) - args.attention_dropout = getattr(args, 'attention_dropout', 0.) - args.relu_dropout = getattr(args, 'relu_dropout', 0.) - args.dropout = getattr(args, 'dropout', 0.1) - args.max_target_positions = getattr(args, 'max_target_positions', 1024) - args.max_source_positions = getattr(args, 'max_source_positions', 1024) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True) - args.share_all_embeddings = getattr(args, 'share_all_embeddings', True) - - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) - - args.no_scale_embedding = getattr(args, 'no_scale_embedding', True) - args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) - - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) - - -@register_model_architecture('bart', 'bart_base') + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.max_source_positions = getattr(args, "max_source_positions", 1024) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", True + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", True) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + + +@register_model_architecture("bart", "bart_base") def bart_base_architecture(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*768) - args.encoder_layers = getattr(args, 'encoder_layers', 6) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) bart_large_architecture(args) -@register_model_architecture('bart', 'mbart_large') +@register_model_architecture("bart", "mbart_large") def mbart_large_architecture(args): - args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) bart_large_architecture(args) -@register_model_architecture('bart', 'mbart_base') +@register_model_architecture("bart", "mbart_base") def mbart_base_architecture(args): - args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) bart_base_architecture(args) -@register_model_architecture('bart', 'mbart_base_wmt20') +@register_model_architecture("bart", "mbart_base_wmt20") def mbart_base_wmt20_architecture(args): - args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) mbart_base_architecture(args) diff --git a/fairseq/models/composite_encoder.py b/fairseq/models/composite_encoder.py index 60d1473f5f..4e20fe3a83 100644 --- a/fairseq/models/composite_encoder.py +++ b/fairseq/models/composite_encoder.py @@ -43,7 +43,9 @@ def forward(self, src_tokens, src_lengths): def reorder_encoder_out(self, encoder_out, new_order): """Reorder encoder output according to new_order.""" for key in self.encoders: - encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) + encoder_out[key] = self.encoders[key].reorder_encoder_out( + encoder_out[key], new_order + ) return encoder_out def max_positions(self): diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 4fe02b20dd..ece10c6333 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -6,7 +6,6 @@ import inspect import torch.nn as nn - from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel @@ -32,7 +31,7 @@ def DistributedFairseqModel(args, model, process_group=None): """ # determine which DDP class to extend assert isinstance(model, nn.Module) - if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d': + if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d": ddp_class = nn.parallel.DistributedDataParallel init_kwargs = dict( module=model, @@ -43,23 +42,23 @@ def DistributedFairseqModel(args, model, process_group=None): process_group=process_group, ) # Maintain backward compatibility - if 'check_reduction' in inspect.getargspec(ddp_class)[0]: - init_kwargs['check_reduction'] = True - if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: - init_kwargs['find_unused_parameters'] = args.find_unused_parameters - elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d': + if "check_reduction" in inspect.getargspec(ddp_class)[0]: + init_kwargs["check_reduction"] = True + if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: + init_kwargs["find_unused_parameters"] = args.find_unused_parameters + elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d": ddp_class = LegacyDistributedDataParallel init_kwargs = dict( module=model, world_size=args.distributed_world_size, - buffer_size=2**28, + buffer_size=2 ** 28, process_group=process_group, ) - elif args.distributed_wrapper == 'SlowMo': + elif args.distributed_wrapper == "SlowMo": if _GOSSIP_DISABLED: raise ImportError( - 'Cannot find gossip library. Please install from: ' - 'github.com/facebookresearch/stochastic_gradient_push' + "Cannot find gossip library. Please install from: " + "github.com/facebookresearch/stochastic_gradient_push" ) ddp_class = gossip.GossipDataParallel @@ -82,11 +81,11 @@ def DistributedFairseqModel(args, model, process_group=None): broadcast_buffers=args.broadcast_buffers, nprocs_per_node=args.nprocs_per_node, slowmo_momentum=args.slowmo_momentum, - localsgd=(args.slowmo_algorithm == 'LocalSGD'), - localsgd_frequency=args.localsgd_frequency + localsgd=(args.slowmo_algorithm == "LocalSGD"), + localsgd_frequency=args.localsgd_frequency, ) else: - raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) + raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) class _DistributedFairseqModel(ddp_class): """Extend DistributedDataParallel to check for missing @@ -96,7 +95,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __getattr__(self, name): - wrapped_module = super().__getattr__('module') + wrapped_module = super().__getattr__("module") if hasattr(wrapped_module, name): return getattr(wrapped_module, name) return super().__getattr__(name) diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py index 7ddc0fba01..c8873daa28 100644 --- a/fairseq/models/fairseq_encoder.py +++ b/fairseq/models/fairseq_encoder.py @@ -3,11 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, List, NamedTuple, Optional + import torch import torch.nn as nn -from typing import Dict, List, NamedTuple, Optional from torch import Tensor + EncoderOut = NamedTuple( "EncoderOut", [ @@ -55,9 +57,7 @@ def forward_torchscript(self, net_input: Dict[str, Tensor]): @torch.jit.unused def forward_non_torchscript(self, net_input: Dict[str, Tensor]): encoder_input = { - k: v - for k, v in net_input.items() - if k != "prev_output_tokens" + k: v for k, v in net_input.items() if k != "prev_output_tokens" } return self.forward(**encoder_input) @@ -86,6 +86,7 @@ def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" def _apply(m): - if hasattr(m, 'set_num_updates') and m != self: + if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) + self.apply(_apply) diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py index 68e583fea8..cc72a0f8f3 100644 --- a/fairseq/models/fairseq_incremental_decoder.py +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -6,10 +6,9 @@ import logging from typing import Dict, Optional -from torch import Tensor - -from fairseq.models import FairseqDecoder from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import FairseqDecoder +from torch import Tensor logger = logging.getLogger(__name__) @@ -41,7 +40,9 @@ class FairseqIncrementalDecoder(FairseqDecoder): def __init__(self, dictionary): super().__init__(dictionary) - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): """ Args: prev_output_tokens (LongTensor): shifted output tokens of shape @@ -58,7 +59,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, """ raise NotImplementedError - def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): """ Returns: tuple: @@ -92,19 +95,22 @@ def reorder_incremental_state_scripting( calling :func:`reorder_incremental_state` directly. """ for module in self.modules(): - if hasattr(module, 'reorder_incremental_state'): + if hasattr(module, "reorder_incremental_state"): result = module.reorder_incremental_state(incremental_state, new_order) if result is not None: incremental_state = result def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" - if getattr(self, '_beam_size', -1) != beam_size: + if getattr(self, "_beam_size", -1) != beam_size: seen = set() def apply_set_beam_size(module): - if module != self and hasattr(module, 'set_beam_size') \ - and module not in seen: + if ( + module != self + and hasattr(module, "set_beam_size") + and module not in seen + ): seen.add(module) module.set_beam_size(beam_size) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index facb7d011b..bfd41777b2 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -223,7 +223,7 @@ def apply_prepare_for_tpu_(module): @classmethod def upgrade_args(cls, args): - if hasattr(args, 'max_sentences') and not hasattr(args, 'batch_size'): + if hasattr(args, "max_sentences") and not hasattr(args, "batch_size"): args.batch_size = args.max_sentences @classmethod diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py index c60a2f4e5f..c99a215101 100644 --- a/fairseq/models/fconv.py +++ b/fairseq/models/fconv.py @@ -8,22 +8,25 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.models import ( FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.modules import ( - AdaptiveSoftmax, BeamableMM, FairseqDropout, GradMultiply, LearnedPositionalEmbedding, + AdaptiveSoftmax, + BeamableMM, + FairseqDropout, + GradMultiply, + LearnedPositionalEmbedding, LinearizedConvolution, ) -@register_model('fconv') +@register_model("fconv") class FConvModel(FairseqEncoderDecoderModel): """ A fully convolutional model, i.e. a convolutional encoder and a @@ -44,23 +47,30 @@ class FConvModel(FairseqEncoderDecoderModel): @classmethod def hub_models(cls): - def moses_subword(path): return { - 'path': path, - 'tokenizer': 'moses', - 'bpe': 'subword_nmt', + "path": path, + "tokenizer": "moses", + "bpe": "subword_nmt", } return { - 'conv.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2'), - 'conv.wmt14.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2'), - 'conv.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2'), + "conv.wmt14.en-fr": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2" + ), + "conv.wmt14.en-de": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2" + ), + "conv.wmt17.en-de": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2" + ), } def __init__(self, encoder, decoder): super().__init__(encoder, decoder) - self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) + self.encoder.num_attention_layers = sum( + layer is not None for layer in decoder.attention + ) @staticmethod def add_args(parser): @@ -147,8 +157,13 @@ class FConvEncoder(FairseqEncoder): """ def __init__( - self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024, - convolutions=((512, 3),) * 20, dropout=0.1, + self, + dictionary, + embed_dim=512, + embed_dict=None, + max_positions=1024, + convolutions=((512, 3),) * 20, + dropout=0.1, ): super().__init__(dictionary) self.dropout_module = FairseqDropout( @@ -160,7 +175,9 @@ def __init__( self.padding_idx = dictionary.pad() self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) if embed_dict: - self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens) + self.embed_tokens = utils.load_embedding( + embed_dict, self.dictionary, self.embed_tokens + ) self.embed_positions = PositionalEmbedding( max_positions, @@ -181,15 +198,23 @@ def __init__( residual_dim = out_channels else: residual_dim = layer_in_channels[-residual] - self.projections.append(Linear(residual_dim, out_channels) - if residual_dim != out_channels else None) + self.projections.append( + Linear(residual_dim, out_channels) + if residual_dim != out_channels + else None + ) if kernel_size % 2 == 1: padding = kernel_size // 2 else: padding = 0 self.convolutions.append( - ConvTBC(in_channels, out_channels * 2, kernel_size, - dropout=dropout, padding=padding) + ConvTBC( + in_channels, + out_channels * 2, + kernel_size, + dropout=dropout, + padding=padding, + ) ) self.residuals.append(residual) in_channels = out_channels @@ -232,7 +257,9 @@ def forward(self, src_tokens, src_lengths): residuals = [x] # temporal convolutions - for proj, conv, res_layer in zip(self.projections, self.convolutions, self.residuals): + for proj, conv, res_layer in zip( + self.projections, self.convolutions, self.residuals + ): if res_layer > 0: residual = residuals[-res_layer] residual = residual if proj is None else proj(residual) @@ -274,19 +301,20 @@ def forward(self, src_tokens, src_lengths): y = (x + input_embedding) * math.sqrt(0.5) return { - 'encoder_out': (x, y), - 'encoder_padding_mask': encoder_padding_mask, # B x T + "encoder_out": (x, y), + "encoder_padding_mask": encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): - if encoder_out['encoder_out'] is not None: - encoder_out['encoder_out'] = ( - encoder_out['encoder_out'][0].index_select(0, new_order), - encoder_out['encoder_out'][1].index_select(0, new_order), + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = ( + encoder_out["encoder_out"][0].index_select(0, new_order), + encoder_out["encoder_out"][1].index_select(0, new_order), ) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) return encoder_out def max_positions(self): @@ -313,10 +341,11 @@ def forward(self, x, target_embedding, encoder_out, encoder_padding_mask): # don't attend over padding if encoder_padding_mask is not None: - x = x.float().masked_fill( - encoder_padding_mask.unsqueeze(1), - float('-inf') - ).type_as(x) # FP16 support: cast to float and back + x = ( + x.float() + .masked_fill(encoder_padding_mask.unsqueeze(1), float("-inf")) + .type_as(x) + ) # FP16 support: cast to float and back # softmax over last dim sz = x.size() @@ -331,7 +360,9 @@ def forward(self, x, target_embedding, encoder_out, encoder_padding_mask): if encoder_padding_mask is None: x = x * (s * math.sqrt(1.0 / s)) else: - s = s - encoder_padding_mask.type_as(x).sum(dim=1, keepdim=True) # exclude padding + s = s - encoder_padding_mask.type_as(x).sum( + dim=1, keepdim=True + ) # exclude padding s = s.unsqueeze(-1) x = x * (s * s.rsqrt()) @@ -343,20 +374,29 @@ def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs): """Replace torch.bmm with BeamableMM.""" if beamable_mm_beam_size is not None: del self.bmm - self.add_module('bmm', BeamableMM(beamable_mm_beam_size)) + self.add_module("bmm", BeamableMM(beamable_mm_beam_size)) class FConvDecoder(FairseqIncrementalDecoder): """Convolutional decoder""" def __init__( - self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, - max_positions=1024, convolutions=((512, 3),) * 20, attention=True, - dropout=0.1, share_embed=False, positional_embeddings=True, - adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0., + self, + dictionary, + embed_dim=512, + embed_dict=None, + out_embed_dim=256, + max_positions=1024, + convolutions=((512, 3),) * 20, + attention=True, + dropout=0.1, + share_embed=False, + positional_embeddings=True, + adaptive_softmax_cutoff=None, + adaptive_softmax_dropout=0.0, ): super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([2])) + self.register_buffer("version", torch.Tensor([2])) self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) @@ -368,20 +408,28 @@ def __init__( # expand True into [True, True, ...] and do the same with False attention = [attention] * len(convolutions) if not isinstance(attention, list) or len(attention) != len(convolutions): - raise ValueError('Attention is expected to be a list of booleans of ' - 'length equal to the number of layers.') + raise ValueError( + "Attention is expected to be a list of booleans of " + "length equal to the number of layers." + ) num_embeddings = len(dictionary) padding_idx = dictionary.pad() self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) if embed_dict: - self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens) + self.embed_tokens = utils.load_embedding( + embed_dict, self.dictionary, self.embed_tokens + ) - self.embed_positions = PositionalEmbedding( - max_positions, - embed_dim, - padding_idx, - ) if positional_embeddings else None + self.embed_positions = ( + PositionalEmbedding( + max_positions, + embed_dim, + padding_idx, + ) + if positional_embeddings + else None + ) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.projections = nn.ModuleList() @@ -395,14 +443,23 @@ def __init__( residual_dim = out_channels else: residual_dim = layer_in_channels[-residual] - self.projections.append(Linear(residual_dim, out_channels) - if residual_dim != out_channels else None) + self.projections.append( + Linear(residual_dim, out_channels) + if residual_dim != out_channels + else None + ) self.convolutions.append( - LinearizedConv1d(in_channels, out_channels * 2, kernel_size, - padding=(kernel_size - 1), dropout=dropout) + LinearizedConv1d( + in_channels, + out_channels * 2, + kernel_size, + padding=(kernel_size - 1), + dropout=dropout, + ) + ) + self.attention.append( + AttentionLayer(out_channels, embed_dim) if attention[i] else None ) - self.attention.append(AttentionLayer(out_channels, embed_dim) - if attention[i] else None) self.residuals.append(residual) in_channels = out_channels layer_in_channels.append(out_channels) @@ -412,26 +469,35 @@ def __init__( if adaptive_softmax_cutoff is not None: assert not share_embed - self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff, - dropout=adaptive_softmax_dropout) + self.adaptive_softmax = AdaptiveSoftmax( + num_embeddings, + in_channels, + adaptive_softmax_cutoff, + dropout=adaptive_softmax_dropout, + ) else: self.fc2 = Linear(in_channels, out_embed_dim) if share_embed: - assert out_embed_dim == embed_dim, \ - "Shared embed weights implies same dimensions " \ + assert out_embed_dim == embed_dim, ( + "Shared embed weights implies same dimensions " " out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim) + ) self.fc3 = nn.Linear(out_embed_dim, num_embeddings) self.fc3.weight = self.embed_tokens.weight else: self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): if encoder_out is not None: - encoder_padding_mask = encoder_out['encoder_padding_mask'] - encoder_out = encoder_out['encoder_out'] + encoder_padding_mask = encoder_out["encoder_padding_mask"] + encoder_out = encoder_out["encoder_out"] # split and transpose encoder outputs - encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state) + encoder_a, encoder_b = self._split_encoder_out( + encoder_out, incremental_state + ) if self.embed_positions is not None: pos_embed = self.embed_positions(prev_output_tokens, incremental_state) @@ -457,8 +523,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, avg_attn_scores = None num_attn_layers = len(self.attention) residuals = [x] - for proj, conv, attention, res_layer in zip(self.projections, self.convolutions, self.attention, - self.residuals): + for proj, conv, attention, res_layer in zip( + self.projections, self.convolutions, self.attention, self.residuals + ): if res_layer > 0: residual = residuals[-res_layer] residual = residual if proj is None else proj(residual) @@ -473,7 +540,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, if attention is not None: x = self._transpose_if_training(x, incremental_state) - x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask) + x, attn_scores = attention( + x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask + ) if not self.training and self.need_attn: attn_scores = attn_scores / num_attn_layers @@ -502,23 +571,31 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out') + encoder_out = utils.get_incremental_state( + self, incremental_state, "encoder_out" + ) if encoder_out is not None: encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out) - utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out) + utils.set_incremental_state( + self, incremental_state, "encoder_out", encoder_out + ) def max_positions(self): """Maximum output length supported by the decoder.""" - return self.embed_positions.max_positions if self.embed_positions is not None else float('inf') + return ( + self.embed_positions.max_positions + if self.embed_positions is not None + else float("inf") + ) def upgrade_state_dict(self, state_dict): - if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2: + if utils.item(state_dict.get("decoder.version", torch.Tensor([1]))[0]) < 2: # old models use incorrect weight norm dimension for i, conv in enumerate(self.convolutions): # reconfigure weight norm nn.utils.remove_weight_norm(conv) self.convolutions[i] = nn.utils.weight_norm(conv, dim=0) - state_dict['decoder.version'] = torch.Tensor([1]) + state_dict["decoder.version"] = torch.Tensor([1]) return state_dict def make_generation_fast_(self, need_attn=False, **kwargs): @@ -535,7 +612,9 @@ def _split_encoder_out(self, encoder_out, incremental_state): This is cached when doing incremental inference. """ - cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out') + cached_result = utils.get_incremental_state( + self, incremental_state, "encoder_out" + ) if cached_result is not None: return cached_result @@ -545,7 +624,7 @@ def _split_encoder_out(self, encoder_out, incremental_state): result = (encoder_a, encoder_b) if incremental_state is not None: - utils.set_incremental_state(self, incremental_state, 'encoder_out', result) + utils.set_incremental_state(self, incremental_state, "encoder_out", result) return result def _transpose_if_training(self, x, incremental_state): @@ -567,7 +646,11 @@ def extend_conv_spec(convolutions): elif len(spec) == 2: extended.append(spec + (1,)) else: - raise Exception('invalid number of parameters in convolution spec ' + str(spec) + '. expected 2 or 3') + raise Exception( + "invalid number of parameters in convolution spec " + + str(spec) + + ". expected 2 or 3" + ) return tuple(extended) @@ -585,7 +668,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): return m -def Linear(in_features, out_features, dropout=0.): +def Linear(in_features, out_features, dropout=0.0): """Weight-normalized Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features) nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features)) @@ -593,7 +676,7 @@ def Linear(in_features, out_features, dropout=0.): return nn.utils.weight_norm(m) -def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwargs): +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer optimized for decoding""" m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) @@ -602,9 +685,10 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwarg return nn.utils.weight_norm(m, dim=2) -def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer""" from fairseq.modules import ConvTBC + m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) nn.init.normal_(m.weight, mean=0, std=std) @@ -612,61 +696,61 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): return nn.utils.weight_norm(m, dim=2) -@register_model_architecture('fconv', 'fconv') +@register_model_architecture("fconv", "fconv") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) - args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20') - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20') - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) - args.decoder_attention = getattr(args, 'decoder_attention', 'True') - args.share_input_output_embed = getattr(args, 'share_input_output_embed', False) - - -@register_model_architecture('fconv', 'fconv_iwslt_de_en') + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 20") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 20") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_attention = getattr(args, "decoder_attention", "True") + args.share_input_output_embed = getattr(args, "share_input_output_embed", False) + + +@register_model_architecture("fconv", "fconv_iwslt_de_en") def fconv_iwslt_de_en(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) - args.encoder_layers = getattr(args, 'encoder_layers', '[(256, 3)] * 4') - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) - args.decoder_layers = getattr(args, 'decoder_layers', '[(256, 3)] * 3') - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", "[(256, 3)] * 4") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_layers = getattr(args, "decoder_layers", "[(256, 3)] * 3") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) base_architecture(args) -@register_model_architecture('fconv', 'fconv_wmt_en_ro') +@register_model_architecture("fconv", "fconv_wmt_en_ro") def fconv_wmt_en_ro(args): - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) base_architecture(args) -@register_model_architecture('fconv', 'fconv_wmt_en_de') +@register_model_architecture("fconv", "fconv_wmt_en_de") def fconv_wmt_en_de(args): - convs = '[(512, 3)] * 9' # first 9 layers have 512 units - convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units - convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions - - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_layers = getattr(args, 'encoder_layers', convs) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768) - args.decoder_layers = getattr(args, 'decoder_layers', convs) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) + convs = "[(512, 3)] * 9" # first 9 layers have 512 units + convs += " + [(1024, 3)] * 4" # next 4 layers have 1024 units + convs += " + [(2048, 1)] * 2" # final 2 layers use 1x1 convolutions + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", convs) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_layers = getattr(args, "decoder_layers", convs) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) base_architecture(args) -@register_model_architecture('fconv', 'fconv_wmt_en_fr') +@register_model_architecture("fconv", "fconv_wmt_en_fr") def fconv_wmt_en_fr(args): - convs = '[(512, 3)] * 6' # first 6 layers have 512 units - convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units - convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units - convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions - convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions - - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_layers = getattr(args, 'encoder_layers', convs) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768) - args.decoder_layers = getattr(args, 'decoder_layers', convs) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) + convs = "[(512, 3)] * 6" # first 6 layers have 512 units + convs += " + [(768, 3)] * 4" # next 4 layers have 768 units + convs += " + [(1024, 3)] * 3" # next 3 layers have 1024 units + convs += " + [(2048, 1)] * 1" # next 1 layer uses 1x1 convolutions + convs += " + [(4096, 1)] * 1" # final 1 layer uses 1x1 convolutions + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", convs) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_layers = getattr(args, "decoder_layers", convs) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) base_architecture(args) diff --git a/fairseq/models/fconv_lm.py b/fairseq/models/fconv_lm.py index 4c3c5c66dd..07391eaa29 100644 --- a/fairseq/models/fconv_lm.py +++ b/fairseq/models/fconv_lm.py @@ -12,7 +12,7 @@ from fairseq.models.fconv import FConvDecoder -@register_model('fconv_lm') +@register_model("fconv_lm") class FConvLanguageModel(FairseqLanguageModel): def __init__(self, decoder): super().__init__(decoder) @@ -20,21 +20,45 @@ def __init__(self, decoder): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-layers', type=str, metavar='EXPR', - help='decoder layers [(dim, kernel_size), ...]') - parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', - help='decoder output embedding dimension') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - parser.add_argument('--decoder-attention', type=str, metavar='EXPR', - help='decoder attention [True, ...]') + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-layers", + type=str, + metavar="EXPR", + help="decoder layers [(dim, kernel_size), ...]", + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--decoder-attention", + type=str, + metavar="EXPR", + help="decoder attention [True, ...]", + ) @classmethod def build_model(cls, args, task): @@ -42,7 +66,9 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_lm_architecture(args) - if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'): + if hasattr(args, "max_target_positions") and not hasattr( + args, "tokens_per_sample" + ): args.tokens_per_sample = args.max_target_positions decoder = FConvDecoder( @@ -57,48 +83,53 @@ def build_model(cls, args, task): positional_embeddings=False, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None + if args.criterion == "adaptive_loss" + else None ), adaptive_softmax_dropout=args.adaptive_softmax_dropout, ) return FConvLanguageModel(decoder) -@register_model_architecture('fconv_lm', 'fconv_lm') +@register_model_architecture("fconv_lm", "fconv_lm") def base_lm_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128) - args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13') - args.decoder_attention = getattr(args, 'decoder_attention', 'False') - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", "[(1268, 4)] * 13") + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) -@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103') +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_wikitext103") def fconv_lm_dauphin_wikitext103(args): - layers = '[(850, 6)] * 3' - layers += ' + [(850, 1)] * 1' - layers += ' + [(850, 5)] * 4' - layers += ' + [(850, 1)] * 1' - layers += ' + [(850, 4)] * 3' - layers += ' + [(1024, 4)] * 1' - layers += ' + [(2048, 4)] * 1' - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 280) - args.decoder_layers = getattr(args, 'decoder_layers', layers) - args.decoder_attention = getattr(args, 'decoder_attention', 'False') - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000') + layers = "[(850, 6)] * 3" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 5)] * 4" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 4)] * 3" + layers += " + [(1024, 4)] * 1" + layers += " + [(2048, 4)] * 1" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 280) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,20000,200000" + ) base_lm_architecture(args) -@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw') +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_gbw") def fconv_lm_dauphin_gbw(args): - layers = '[(512, 5)]' - layers += ' + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3' - layers += ' + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3' - layers += ' + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6' - layers += ' + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]' - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128) - args.decoder_layers = getattr(args, 'decoder_layers', layers) - args.decoder_attention = getattr(args, 'decoder_attention', 'False') - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + layers = "[(512, 5)]" + layers += " + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3" + layers += " + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3" + layers += " + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6" + layers += " + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) base_lm_architecture(args) diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py index c3582da96f..8357ef7847 100644 --- a/fairseq/models/fconv_self_att.py +++ b/fairseq/models/fconv_self_att.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import checkpoint_utils +from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.models import ( CompositeEncoder, FairseqDecoder, @@ -21,48 +21,49 @@ register_model_architecture, ) from fairseq.modules import ( - FairseqDropout, DownsampledMultiHeadAttention, + FairseqDropout, GradMultiply, LayerNorm, LearnedPositionalEmbedding, LinearizedConvolution, ) -from fairseq.incremental_decoding_utils import with_incremental_state + logger = logging.getLogger(__name__) -@register_model('fconv_self_att') +@register_model("fconv_self_att") class FConvModelSelfAtt(FairseqEncoderDecoderModel): - @classmethod def hub_models(cls): return { - 'conv.stories.pretrained': { - 'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz', - 'checkpoint_file': 'pretrained_checkpoint.pt', - 'tokenizer': 'nltk', + "conv.stories.pretrained": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "pretrained_checkpoint.pt", + "tokenizer": "nltk", }, - 'conv.stories': { - 'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz', - 'checkpoint_file': 'fusion_checkpoint.pt', - 'tokenizer': 'nltk', - 'pretrained': 'True', - 'pretrained_checkpoint': './pretrained_checkpoint.pt', + "conv.stories": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "fusion_checkpoint.pt", + "tokenizer": "nltk", + "pretrained": "True", + "pretrained_checkpoint": "./pretrained_checkpoint.pt", }, # Test set containing dictionaries - 'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2', + "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", } def __init__(self, encoder, decoder, pretrained_encoder=None): super().__init__(encoder, decoder) - self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) + self.encoder.num_attention_layers = sum( + layer is not None for layer in decoder.attention + ) self.pretrained_encoder = pretrained_encoder if self.pretrained_encoder is None: - encoders = {'encoder': encoder} + encoders = {"encoder": encoder} else: - encoders = {'encoder': encoder, 'pretrained': self.pretrained_encoder} + encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} # for fusion model, CompositeEncoder contains both pretrained and training encoders # these are forwarded and then combined in the decoder self.encoder = CompositeEncoder(encoders) @@ -113,9 +114,11 @@ def build_model(cls, args, task): trained_encoder, trained_decoder = None, None pretrained = eval(args.pretrained) if pretrained: - logger.info('loading pretrained model') + logger.info("loading pretrained model") if not os.path.exists(args.pretrained_checkpoint): - new_pretrained_checkpoint = os.path.join(args.data, args.pretrained_checkpoint) + new_pretrained_checkpoint = os.path.join( + args.data, args.pretrained_checkpoint + ) if os.path.exists(new_pretrained_checkpoint): args.pretrained_checkpoint = new_pretrained_checkpoint trained_model = checkpoint_utils.load_model_ensemble( @@ -169,9 +172,15 @@ def pretrained(self): class FConvEncoder(FairseqEncoder): """Convolutional encoder""" + def __init__( - self, dictionary, embed_dim=512, max_positions=1024, - convolutions=((512, 3),) * 20, dropout=0.1, attention=False, + self, + dictionary, + embed_dim=512, + max_positions=1024, + convolutions=((512, 3),) * 20, + dropout=0.1, + attention=False, attention_nheads=1, ): super().__init__(dictionary) @@ -205,14 +214,18 @@ def expand_bool_array(val): self.attproj = nn.ModuleList() for i, (out_channels, kernel_size) in enumerate(convolutions): self.projections.append( - Linear(in_channels, out_channels) if in_channels != out_channels else None + Linear(in_channels, out_channels) + if in_channels != out_channels + else None ) self.convolutions.append( ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) ) self.attention.append( - SelfAttention(out_channels, embed_dim, attention_nheads) if attention[i] else None + SelfAttention(out_channels, embed_dim, attention_nheads) + if attention[i] + else None ) in_channels = out_channels @@ -235,7 +248,9 @@ def forward(self, src_tokens, src_lengths): x = x.transpose(0, 1) # temporal convolutions - for proj, conv, attention in zip(self.projections, self.convolutions, self.attention): + for proj, conv, attention in zip( + self.projections, self.convolutions, self.attention + ): residual = x if proj is None else proj(x) if encoder_padding_mask is not None: @@ -268,23 +283,24 @@ def forward(self, src_tokens, src_lengths): y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) return { - 'encoder_out': (x, y), - 'encoder_padding_mask': encoder_padding_mask, # B x T + "encoder_out": (x, y), + "encoder_padding_mask": encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): - encoder_out['encoder_out'] = tuple( - eo.index_select(0, new_order) for eo in encoder_out['encoder_out'] + encoder_out["encoder_out"] = tuple( + eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] ) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) - if 'pretrained' in encoder_out: - encoder_out['pretrained']['encoder_out'] = tuple( + if "pretrained" in encoder_out: + encoder_out["pretrained"]["encoder_out"] = tuple( eo.index_select(0, new_order) - for eo in encoder_out['pretrained']['encoder_out'] + for eo in encoder_out["pretrained"]["encoder_out"] ) return encoder_out @@ -297,15 +313,27 @@ def max_positions(self): @with_incremental_state class FConvDecoder(FairseqDecoder): """Convolutional decoder""" + def __init__( - self, dictionary, embed_dim=512, out_embed_dim=256, max_positions=1024, - convolutions=((512, 3),) * 8, attention=True, dropout=0.1, - selfattention=False, attention_nheads=1, selfattention_nheads=1, - project_input=False, gated_attention=False, downsample=False, - pretrained=False, trained_decoder=None, + self, + dictionary, + embed_dim=512, + out_embed_dim=256, + max_positions=1024, + convolutions=((512, 3),) * 8, + attention=True, + dropout=0.1, + selfattention=False, + attention_nheads=1, + selfattention_nheads=1, + project_input=False, + gated_attention=False, + downsample=False, + pretrained=False, + trained_decoder=None, ): super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([2])) + self.register_buffer("version", torch.Tensor([2])) self.pretrained = pretrained self.pretrained_decoder = trained_decoder self.dropout_module = FairseqDropout( @@ -324,8 +352,10 @@ def expand_bool_array(val): selfattention = expand_bool_array(selfattention) if not isinstance(attention, list) or len(attention) != len(convolutions): - raise ValueError('Attention is expected to be a list of booleans of ' - 'length equal to the number of layers.') + raise ValueError( + "Attention is expected to be a list of booleans of " + "length equal to the number of layers." + ) num_embeddings = len(dictionary) padding_idx = dictionary.pad() @@ -345,31 +375,49 @@ def expand_bool_array(val): self.attproj = nn.ModuleList() for i, (out_channels, kernel_size) in enumerate(convolutions): self.projections.append( - Linear(in_channels, out_channels) if in_channels != out_channels else None + Linear(in_channels, out_channels) + if in_channels != out_channels + else None ) self.convolutions.append( LinearizedConv1d( - in_channels, out_channels * 2, kernel_size, - padding=(kernel_size - 1), dropout=dropout, + in_channels, + out_channels * 2, + kernel_size, + padding=(kernel_size - 1), + dropout=dropout, ) ) self.attention.append( DownsampledMultiHeadAttention( - out_channels, embed_dim, attention_nheads, - project_input=project_input, gated=False, downsample=False, - ) if attention[i] else None + out_channels, + embed_dim, + attention_nheads, + project_input=project_input, + gated=False, + downsample=False, + ) + if attention[i] + else None ) self.attproj.append( - Linear(out_channels, embed_dim, dropout=dropout) if attention[i] else None + Linear(out_channels, embed_dim, dropout=dropout) + if attention[i] + else None ) self.selfattention.append( SelfAttention( - out_channels, embed_dim, selfattention_nheads, - project_input=project_input, gated=gated_attention, + out_channels, + embed_dim, + selfattention_nheads, + project_input=project_input, + gated=gated_attention, downsample=downsample, - ) if selfattention[i] else None + ) + if selfattention[i] + else None ) in_channels = out_channels @@ -379,18 +427,22 @@ def expand_bool_array(val): # model fusion if self.pretrained: # independent gates are learned from the concatenated input - self.gate1 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid()) - self.gate2 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid()) + self.gate1 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) + self.gate2 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) # pretrained and trained models are joined self.joining = nn.Sequential( - Linear(out_embed_dim*2, out_embed_dim*2), - LayerNorm(out_embed_dim*2), + Linear(out_embed_dim * 2, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), nn.GLU(), - Linear(out_embed_dim, out_embed_dim*2), - LayerNorm(out_embed_dim*2), + Linear(out_embed_dim, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), nn.GLU(), Linear(out_embed_dim, out_embed_dim), - LayerNorm(out_embed_dim) + LayerNorm(out_embed_dim), ) # pretrained model contains an output layer that is nhid -> vocab size # but the models are combined in their hidden state @@ -400,13 +452,14 @@ def expand_bool_array(val): def save_output(): def hook(a, b, output): self.pretrained_outputs["out"] = output + return hook self.pretrained_decoder.fc2.register_forward_hook(save_output()) def forward(self, prev_output_tokens, encoder_out): - trained_encoder_out = encoder_out['pretrained'] if self.pretrained else None - encoder_out = encoder_out['encoder']['encoder_out'] + trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None + encoder_out = encoder_out["encoder"]["encoder_out"] encoder_a, encoder_b = self._split_encoder_out(encoder_out) @@ -427,7 +480,11 @@ def forward(self, prev_output_tokens, encoder_out): # temporal convolutions avg_attn_scores = None for proj, conv, attention, selfattention, attproj in zip( - self.projections, self.convolutions, self.attention, self.selfattention, self.attproj + self.projections, + self.convolutions, + self.attention, + self.selfattention, + self.attproj, ): residual = x if proj is None else proj(x) @@ -438,7 +495,9 @@ def forward(self, prev_output_tokens, encoder_out): # attention if attention is not None: r = x - x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b) + x, attn_scores = attention( + attproj(x) + target_embedding, encoder_a, encoder_b + ) x = x + r if not self.training and self.need_attn: if avg_attn_scores is None: @@ -462,7 +521,9 @@ def forward(self, prev_output_tokens, encoder_out): # fusion gating if self.pretrained: - trained_x, _ = self.pretrained_decoder.forward(prev_output_tokens, trained_encoder_out) + trained_x, _ = self.pretrained_decoder.forward( + prev_output_tokens, trained_encoder_out + ) y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) gate1 = self.gate1(y) gate2 = self.gate2(y) @@ -493,12 +554,25 @@ def _split_encoder_out(self, encoder_out): class SelfAttention(nn.Module): - - def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False): + def __init__( + self, + out_channels, + embed_dim, + num_heads, + project_input=False, + gated=False, + downsample=False, + ): super().__init__() self.attention = DownsampledMultiHeadAttention( - out_channels, embed_dim, num_heads, dropout=0, bias=True, - project_input=project_input, gated=gated, downsample=downsample, + out_channels, + embed_dim, + num_heads, + dropout=0, + bias=True, + project_input=project_input, + gated=gated, + downsample=downsample, ) self.in_proj_q = Linear(out_channels, embed_dim) self.in_proj_k = Linear(out_channels, embed_dim) @@ -510,7 +584,9 @@ def forward(self, x): query = self.in_proj_q(x) key = self.in_proj_k(x) value = self.in_proj_v(x) - x, _ = self.attention(query, key, value, mask_future_timesteps=True, use_scalar_bias=True) + x, _ = self.attention( + query, key, value, mask_future_timesteps=True, use_scalar_bias=True + ) return self.ln(x + residual) @@ -526,7 +602,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): return m -def Linear(in_features, out_features, dropout=0.): +def Linear(in_features, out_features, dropout=0.0): """Weight-normalized Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features) m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) @@ -534,7 +610,7 @@ def Linear(in_features, out_features, dropout=0.): return m -def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwargs): +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer optimized for decoding""" m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) @@ -543,9 +619,10 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwarg return m -def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer""" from fairseq.modules import ConvTBC + m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) m.weight.data.normal_(mean=0, std=std) @@ -553,37 +630,45 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs): return m -@register_model_architecture('fconv_self_att', 'fconv_self_att') +@register_model_architecture("fconv_self_att", "fconv_self_att") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 3') - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 8') - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) - args.decoder_attention = getattr(args, 'decoder_attention', 'True') - args.self_attention = getattr(args, 'self_attention', 'False') - args.encoder_attention = getattr(args, 'encoder_attention', 'False') - args.multihead_attention_nheads = getattr(args, 'multihead_attention_nheads', 1) - args.multihead_self_attention_nheads = getattr(args, 'multihead_self_attention_nheads', 1) - args.encoder_attention_nheads = getattr(args, 'encoder_attention_nheads', 1) - args.project_input = getattr(args, 'project_input', 'False') - args.gated_attention = getattr(args, 'gated_attention', 'False') - args.downsample = getattr(args, 'downsample', 'False') - args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '') - args.pretrained = getattr(args, 'pretrained', 'False') - - -@register_model_architecture('fconv_self_att', 'fconv_self_att_wp') + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_attention = getattr(args, "decoder_attention", "True") + args.self_attention = getattr(args, "self_attention", "False") + args.encoder_attention = getattr(args, "encoder_attention", "False") + args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 1 + ) + args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) + args.project_input = getattr(args, "project_input", "False") + args.gated_attention = getattr(args, "gated_attention", "False") + args.downsample = getattr(args, "downsample", "False") + args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") + args.pretrained = getattr(args, "pretrained", "False") + + +@register_model_architecture("fconv_self_att", "fconv_self_att_wp") def fconv_self_att_wp(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) - args.encoder_layers = getattr(args, 'encoder_layers', '[(128, 3)] * 2 + [(512,3)] * 1') - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) - args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1') - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) - args.self_attention = getattr(args, 'self_attention', 'True') - args.multihead_self_attention_nheads = getattr(args, 'multihead_self_attention_nheads', 4) - args.project_input = getattr(args, 'project_input', 'True') - args.gated_attention = getattr(args, 'gated_attention', 'True') - args.downsample = getattr(args, 'downsample', 'True') + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr( + args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" + ) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_layers = getattr( + args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" + ) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.self_attention = getattr(args, "self_attention", "True") + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 4 + ) + args.project_input = getattr(args, "project_input", "True") + args.gated_attention = getattr(args, "gated_attention", "True") + args.downsample = getattr(args, "downsample", "True") base_architecture(args) diff --git a/fairseq/models/huggingface/__init__.py b/fairseq/models/huggingface/__init__.py index 633315f54d..f7911c2c8e 100644 --- a/fairseq/models/huggingface/__init__.py +++ b/fairseq/models/huggingface/__init__.py @@ -12,9 +12,9 @@ for file in os.listdir(models_dir): path = os.path.join(models_dir, file) if ( - not file.startswith('_') - and not file.startswith('.') - and (file.endswith('.py') or os.path.isdir(path)) + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) ): - model_name = file[:file.find('.py')] if file.endswith('.py') else file - module = importlib.import_module('fairseq.models.huggingface.' + model_name) + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.models.huggingface." + model_name) diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py index e81954ff65..a823453794 100644 --- a/fairseq/models/huggingface/hf_gpt2.py +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -16,13 +16,15 @@ register_model_architecture, ) + try: # Prepend the transformers submodule to the path, so that # it's prioritized over other installations. This allows # making local changes in the submodule. - hf_path = os.path.join(os.path.dirname(__file__), 'transformers', 'src') + hf_path = os.path.join(os.path.dirname(__file__), "transformers", "src") sys.path.insert(0, hf_path) from transformers import GPT2Config, GPT2LMHeadModel + sys.path.remove(hf_path) has_hf = True except ImportError: @@ -35,18 +37,17 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 -@register_model('hf_gpt2') +@register_model("hf_gpt2") class HuggingFaceGPT2LanguageModel(FairseqLanguageModel): - def __init__(self, decoder): super().__init__(decoder) if not has_hf: raise ImportError( - '\n\nPlease install huggingface/transformers with:' - '\n\n pip install transformers' - '\n\nOr to make local edits, install the submodule:' - '\n\n git submodule update --init ' - 'fairseq/models/huggingface/transformers' + "\n\nPlease install huggingface/transformers with:" + "\n\n pip install transformers" + "\n\nOr to make local edits, install the submodule:" + "\n\n git submodule update --init " + "fairseq/models/huggingface/transformers" ) @staticmethod @@ -74,17 +75,16 @@ def build_model(cls, args, task): class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder): - def __init__(self, args, task): super().__init__(task.target_dictionary) if not has_hf: raise ImportError( - '\n\nPlease install huggingface/transformers with:' - '\n\n pip install transformers' - '\n\nOr to make local edits, install the submodule:' - '\n\n git submodule update --init ' - 'fairseq/models/huggingface/transformers' + "\n\nPlease install huggingface/transformers with:" + "\n\n pip install transformers" + "\n\nOr to make local edits, install the submodule:" + "\n\n git submodule update --init " + "fairseq/models/huggingface/transformers" ) config = GPT2Config( @@ -115,7 +115,7 @@ def forward( ): features = self.extract_features(prev_output_tokens, incremental_state) lm_logits = self.model.lm_head(features) - return (lm_logits, ) + return (lm_logits,) def extract_features( self, @@ -154,38 +154,38 @@ def max_positions(self): return self.model.config.n_positions - 1 -@register_model_architecture('hf_gpt2', 'hf_gpt2') +@register_model_architecture("hf_gpt2", "hf_gpt2") def default_architecture(args): - if getattr(args, 'max_target_positions', None) is None: + if getattr(args, "max_target_positions", None) is None: args.max_target_positions = getattr( - args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS ) - args.embed_dim = getattr(args, 'embed_dim', 768) - args.num_attention_heads = getattr(args, 'num_attention_heads', 12) - args.num_layers = getattr(args, 'num_layers', 12) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) + args.embed_dim = getattr(args, "embed_dim", 768) + args.num_attention_heads = getattr(args, "num_attention_heads", 12) + args.num_layers = getattr(args, "num_layers", 12) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) -@register_model_architecture('hf_gpt2', 'hf_gpt2_medium') +@register_model_architecture("hf_gpt2", "hf_gpt2_medium") def hf_gpt2_medium(args): - args.embed_dim = getattr(args, 'embed_dim', 1024) - args.num_attention_heads = getattr(args, 'num_attention_heads', 16) - args.num_layers = getattr(args, 'num_layers', 24) + args.embed_dim = getattr(args, "embed_dim", 1024) + args.num_attention_heads = getattr(args, "num_attention_heads", 16) + args.num_layers = getattr(args, "num_layers", 24) default_architecture(args) -@register_model_architecture('hf_gpt2', 'hf_gpt2_large') +@register_model_architecture("hf_gpt2", "hf_gpt2_large") def hf_gpt2_large(args): - args.embed_dim = getattr(args, 'embed_dim', 1280) - args.num_attention_heads = getattr(args, 'num_attention_heads', 20) - args.num_layers = getattr(args, 'num_layers', 36) + args.embed_dim = getattr(args, "embed_dim", 1280) + args.num_attention_heads = getattr(args, "num_attention_heads", 20) + args.num_layers = getattr(args, "num_layers", 36) default_architecture(args) -@register_model_architecture('hf_gpt2', 'hf_gpt2_xl') +@register_model_architecture("hf_gpt2", "hf_gpt2_xl") def hf_gpt2_xl(args): - args.embed_dim = getattr(args, 'embed_dim', 1600) - args.num_attention_heads = getattr(args, 'num_attention_heads', 25) - args.num_layers = getattr(args, 'num_layers', 48) + args.embed_dim = getattr(args, "embed_dim", 1600) + args.num_attention_heads = getattr(args, "num_attention_heads", 25) + args.num_layers = getattr(args, "num_layers", 48) default_architecture(args) diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py index 09d4d0be2e..b614da3665 100644 --- a/fairseq/models/lightconv.py +++ b/fairseq/models/lightconv.py @@ -8,12 +8,11 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.models import ( FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) @@ -22,13 +21,13 @@ DynamicConv, FairseqDropout, LayerNorm, - PositionalEmbedding, LightweightConv, MultiheadAttention, + PositionalEmbedding, ) -@register_model('lightconv') +@register_model("lightconv") class LightConvModel(FairseqEncoderDecoderModel): """ LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019) @@ -81,75 +80,175 @@ def __init__(self, encoder, decoder): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--relu-dropout', type=float, metavar='D', - help='dropout probability after ReLU in FFN') - parser.add_argument('--input-dropout', type=float, metavar='D', - help='dropout probability of the inputs') - parser.add_argument('--encoder-embed-path', type=str, metavar='STR', - help='path to pre-trained encoder embedding') - parser.add_argument('--encoder-embed-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--encoder-conv-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-layers', type=int, metavar='N', - help='num encoder layers') - parser.add_argument('--encoder-attention-heads', type=int, metavar='N', - help='num encoder attention heads or LightConv/DynamicConv heads') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') - parser.add_argument('--encoder-learned-pos', action='store_true', - help='use learned positional embeddings in the encoder') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-conv-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads or LightConv/DynamicConv heads') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') - parser.add_argument('--decoder-normalize-before', action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--share-all-embeddings', action='store_true', - help='share encoder, decoder and output embeddings' - ' (requires shared dictionary and embed dim)') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion'), - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after ReLU in FFN", + ) + parser.add_argument( + "--input-dropout", + type=float, + metavar="D", + help="dropout probability of the inputs", + ) + parser.add_argument( + "--encoder-embed-path", + type=str, + metavar="STR", + help="path to pre-trained encoder embedding", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-conv-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the encoder", + ) + parser.add_argument( + "--decoder-embed-path", + type=str, + metavar="STR", + help="path to pre-trained decoder embedding", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-conv-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--share-all-embeddings", + action="store_true", + help="share encoder, decoder and output embeddings" + " (requires shared dictionary and embed dim)", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ), + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) """LightConv and DynamicConv arguments""" - parser.add_argument('--encoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), - help='list of kernel size (default: "[3,7,15,31,31,31,31]")') - parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), - help='list of kernel size (default: "[3,7,15,31,31,31]")') - parser.add_argument('--encoder-glu', type=utils.eval_bool, - help='glu after in proj') - parser.add_argument('--decoder-glu', type=utils.eval_bool, - help='glu after in proj') - parser.add_argument('--encoder-conv-type', default='dynamic', type=str, - choices=['dynamic', 'lightweight'], - help='type of convolution') - parser.add_argument('--decoder-conv-type', default='dynamic', type=str, - choices=['dynamic', 'lightweight'], - help='type of convolution') - parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool) - parser.add_argument('--weight-dropout', type=float, metavar='D', - help='dropout probability for conv weights') + parser.add_argument( + "--encoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31,31]")', + ) + parser.add_argument( + "--decoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")', + ) + parser.add_argument( + "--encoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--decoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--encoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument( + "--decoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool) + parser.add_argument( + "--weight-dropout", + type=float, + metavar="D", + help="dropout probability for conv weights", + ) @classmethod def build_model(cls, args, task): @@ -158,9 +257,9 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if not hasattr(args, 'max_source_positions'): + if not hasattr(args, "max_source_positions"): args.max_source_positions = 1024 - if not hasattr(args, 'max_target_positions'): + if not hasattr(args, "max_target_positions"): args.max_target_positions = 1024 src_dict, tgt_dict = task.source_dictionary, task.target_dictionary @@ -177,13 +276,19 @@ def build_embedding(dictionary, embed_dim, path=None): if args.share_all_embeddings: if src_dict != tgt_dict: - raise RuntimeError('--share-all-embeddings requires a joined dictionary') + raise RuntimeError( + "--share-all-embeddings requires a joined dictionary" + ) if args.encoder_embed_dim != args.decoder_embed_dim: raise RuntimeError( - '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path): - raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path') + args.decoder_embed_path != args.encoder_embed_path + ): + raise RuntimeError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) @@ -215,7 +320,9 @@ class LightConvEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx @@ -223,17 +330,27 @@ def __init__(self, args, dictionary, embed_tokens): self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) - self.embed_positions = PositionalEmbedding( - args.max_source_positions, embed_dim, self.padding_idx, - learned=args.encoder_learned_pos, - ) if not args.no_token_positional_embeddings else None + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) self.layers = nn.ModuleList([]) - self.layers.extend([ - LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) - for i in range(args.encoder_layers) - ]) - self.register_buffer('version', torch.Tensor([2])) + self.layers.extend( + [ + LightConvEncoderLayer( + args, kernel_size=args.encoder_kernel_size_list[i] + ) + for i in range(args.encoder_layers) + ] + ) + self.register_buffer("version", torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) @@ -273,8 +390,8 @@ def forward(self, src_tokens, **unused): x = self.layer_norm(x) return { - 'encoder_out': x, # T x B x C - 'encoder_padding_mask': encoder_padding_mask, # B x T + "encoder_out": x, # T x B x C + "encoder_padding_mask": encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): @@ -288,12 +405,14 @@ def reorder_encoder_out(self, encoder_out, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - if encoder_out['encoder_out'] is not None: - encoder_out['encoder_out'] = \ - encoder_out['encoder_out'].index_select(1, new_order) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) return encoder_out def max_positions(self): @@ -316,9 +435,13 @@ class LightConvDecoder(FairseqIncrementalDecoder): Default: ``False`` """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): + def __init__( + self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True + ): super().__init__(dictionary) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim @@ -331,23 +454,40 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_ self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim - self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None - - self.embed_positions = PositionalEmbedding( - args.max_target_positions, embed_dim, padding_idx, - learned=args.decoder_learned_pos, - ) if not args.no_token_positional_embeddings else None + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) self.layers = nn.ModuleList([]) - self.layers.extend([ - LightConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i]) - for i in range(args.decoder_layers) - ]) + self.layers.extend( + [ + LightConvDecoderLayer( + args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i] + ) + for i in range(args.decoder_layers) + ] + ) self.adaptive_softmax = None - self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ - if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None + self.project_out_dim = ( + Linear(embed_dim, output_embed_dim, bias=False) + if embed_dim != output_embed_dim and not args.tie_adaptive_weights + else None + ) if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( @@ -360,14 +500,18 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_ tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: - self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), output_embed_dim) + ) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5) - self.register_buffer('version', torch.Tensor([2])) + self.register_buffer("version", torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -385,10 +529,14 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, tgt_len, src_len)` """ # embed positions - positions = self.embed_positions( - prev_output_tokens, - incremental_state=incremental_state, - ) if self.embed_positions is not None else None + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] @@ -415,8 +563,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, for layer in self.layers: x, attn = layer( x, - encoder_out['encoder_out'] if encoder_out is not None else None, - encoder_out['encoder_padding_mask'] if encoder_out is not None else None, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] + if encoder_out is not None + else None, incremental_state, ) inner_states.append(x) @@ -437,7 +587,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, else: x = F.linear(x, self.embed_out) - return x, {'attn': attn, 'inner_states': inner_states} + return x, {"attn": attn, "inner_states": inner_states} def max_positions(self): """Maximum output length supported by the decoder.""" @@ -447,10 +597,18 @@ def max_positions(self): def buffered_future_mask(self, tensor): dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) if self._future_mask.size(0) < dim: - self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) return self._future_mask[:dim, :dim] @@ -466,31 +624,49 @@ def __init__(self, args, kernel_size=0): super().__init__() self.embed_dim = args.encoder_embed_dim self.conv_dim = args.encoder_conv_dim - padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2) + padding_l = ( + kernel_size // 2 + if kernel_size % 2 == 1 + else ((kernel_size - 1) // 2, kernel_size // 2) + ) if args.encoder_glu: - self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) + self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim) self.act = nn.GLU() else: self.linear1 = Linear(self.embed_dim, self.conv_dim) self.act = None - if args.encoder_conv_type == 'lightweight': - self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l, - weight_softmax=args.weight_softmax, - num_heads=args.encoder_attention_heads, - weight_dropout=args.weight_dropout) - elif args.encoder_conv_type == 'dynamic': - self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l, - weight_softmax=args.weight_softmax, - num_heads=args.encoder_attention_heads, - weight_dropout=args.weight_dropout) + if args.encoder_conv_type == "lightweight": + self.conv = LightweightConv( + self.conv_dim, + kernel_size, + padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + elif args.encoder_conv_type == "dynamic": + self.conv = DynamicConv( + self.conv_dim, + kernel_size, + padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout, + ) else: raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) - self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) - self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.relu_dropout_module = FairseqDropout( + args.relu_dropout, module_name=self.__class__.__name__ + ) + self.input_dropout_module = FairseqDropout( + args.input_dropout, module_name=self.__class__.__name__ + ) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) @@ -538,8 +714,14 @@ def maybe_layer_norm(self, i, x, before=False, after=False): return x def extra_repr(self): - return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( - self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) + return ( + "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format( + self.dropout_module.p, + self.relu_dropout_module.p, + self.input_dropout_module.p, + self.normalize_before, + ) + ) class LightConvDecoderLayer(nn.Module): @@ -557,28 +739,42 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0): self.embed_dim = args.decoder_embed_dim self.conv_dim = args.decoder_conv_dim if args.decoder_glu: - self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) + self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim) self.act = nn.GLU() else: self.linear1 = Linear(self.embed_dim, self.conv_dim) self.act = None - if args.decoder_conv_type == 'lightweight': - self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, - weight_softmax=args.weight_softmax, - num_heads=args.decoder_attention_heads, - weight_dropout=args.weight_dropout) - elif args.decoder_conv_type == 'dynamic': - self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, - weight_softmax=args.weight_softmax, - num_heads=args.decoder_attention_heads, - weight_dropout=args.weight_dropout) + if args.decoder_conv_type == "lightweight": + self.conv = LightweightConv( + self.conv_dim, + kernel_size, + padding_l=kernel_size - 1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + elif args.decoder_conv_type == "dynamic": + self.conv = DynamicConv( + self.conv_dim, + kernel_size, + padding_l=kernel_size - 1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout, + ) else: raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) - self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) - self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.relu_dropout_module = FairseqDropout( + args.relu_dropout, module_name=self.__class__.__name__ + ) + self.input_dropout_module = FairseqDropout( + args.input_dropout, module_name=self.__class__.__name__ + ) self.normalize_before = args.decoder_normalize_before self.conv_layer_norm = LayerNorm(self.embed_dim) @@ -588,8 +784,10 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0): self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( - self.embed_dim, args.decoder_attention_heads, - dropout=args.attention_dropout, encoder_decoder_attention=True, + self.embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) @@ -599,9 +797,17 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0): self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True - def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, - prev_conv_state=None, prev_attn_state=None, conv_mask=None, - conv_padding_mask=None): + def forward( + self, + x, + encoder_out, + encoder_padding_mask, + incremental_state, + prev_conv_state=None, + prev_attn_state=None, + conv_mask=None, + conv_padding_mask=None, + ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` @@ -671,8 +877,14 @@ def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn def extra_repr(self): - return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( - self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) + return ( + "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format( + self.dropout_module.p, + self.relu_dropout_module.p, + self.input_dropout_module.p, + self.normalize_before, + ) + ) def Embedding(num_embeddings, embedding_dim, padding_idx): @@ -686,101 +898,121 @@ def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: - nn.init.constant_(m.bias, 0.) + nn.init.constant_(m.bias, 0.0) return m -@register_model_architecture('lightconv', 'lightconv') +@register_model_architecture("lightconv", "lightconv") def base_architecture(args): - args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048) - args.encoder_layers = getattr(args, 'encoder_layers', 7) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) - args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) - args.attention_dropout = getattr(args, 'attention_dropout', 0.) - args.relu_dropout = getattr(args, 'relu_dropout', 0.) - args.dropout = getattr(args, 'dropout', 0.1) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) - args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) - - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) - - args.encoder_conv_dim = getattr(args, 'encoder_conv_dim', args.encoder_embed_dim) - args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim) - - args.encoder_kernel_size_list = getattr(args, 'encoder_kernel_size_list', [3, 7, 15, 31, 31, 31, 31]) - args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 7) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.encoder_conv_dim = getattr(args, "encoder_conv_dim", args.encoder_embed_dim) + args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim) + + args.encoder_kernel_size_list = getattr( + args, "encoder_kernel_size_list", [3, 7, 15, 31, 31, 31, 31] + ) + args.decoder_kernel_size_list = getattr( + args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31] + ) if len(args.encoder_kernel_size_list) == 1: - args.encoder_kernel_size_list = args.encoder_kernel_size_list * args.encoder_layers + args.encoder_kernel_size_list = ( + args.encoder_kernel_size_list * args.encoder_layers + ) if len(args.decoder_kernel_size_list) == 1: - args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers - assert len(args.encoder_kernel_size_list) == args.encoder_layers, "encoder_kernel_size_list doesn't match encoder_layers" - assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers" - args.encoder_glu = getattr(args, 'encoder_glu', True) - args.decoder_glu = getattr(args, 'decoder_glu', True) - args.input_dropout = getattr(args, 'input_dropout', 0.1) - args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout) - - -@register_model_architecture('lightconv', 'lightconv_iwslt_de_en') + args.decoder_kernel_size_list = ( + args.decoder_kernel_size_list * args.decoder_layers + ) + assert ( + len(args.encoder_kernel_size_list) == args.encoder_layers + ), "encoder_kernel_size_list doesn't match encoder_layers" + assert ( + len(args.decoder_kernel_size_list) == args.decoder_layers + ), "decoder_kernel_size_list doesn't match decoder_layers" + args.encoder_glu = getattr(args, "encoder_glu", True) + args.decoder_glu = getattr(args, "decoder_glu", True) + args.input_dropout = getattr(args, "input_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout) + + +@register_model_architecture("lightconv", "lightconv_iwslt_de_en") def lightconv_iwslt_de_en(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) - args.encoder_layers = getattr(args, 'encoder_layers', 7) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.weight_dropout = getattr(args, 'weight_dropout', 0.1) - args.encoder_glu = getattr(args, 'encoder_glu', False) - args.decoder_glu = getattr(args, 'decoder_glu', False) - args.input_dropout = getattr(args, 'input_dropout', 0.0) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 7) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", 0.1) + args.encoder_glu = getattr(args, "encoder_glu", False) + args.decoder_glu = getattr(args, "decoder_glu", False) + args.input_dropout = getattr(args, "input_dropout", 0.0) base_architecture(args) -@register_model_architecture('lightconv', 'lightconv_wmt_en_de') +@register_model_architecture("lightconv", "lightconv_wmt_en_de") def lightconv_wmt_en_de(args): base_architecture(args) -@register_model_architecture('lightconv', 'lightconv_wmt_en_de_big') +@register_model_architecture("lightconv", "lightconv_wmt_en_de_big") def lightconv_wmt_en_de_big(args): - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) - args.dropout = getattr(args, 'dropout', 0.3) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) base_architecture(args) -@register_model_architecture('lightconv', 'lightconv_wmt_en_fr_big') +@register_model_architecture("lightconv", "lightconv_wmt_en_fr_big") def lightconv_wmt_en_fr_big(args): - args.dropout = getattr(args, 'dropout', 0.1) + args.dropout = getattr(args, "dropout", 0.1) lightconv_wmt_en_de_big(args) -@register_model_architecture('lightconv', 'lightconv_wmt_zh_en_big') +@register_model_architecture("lightconv", "lightconv_wmt_zh_en_big") def lightconv_wmt_zh_en_big(args): - args.dropout = getattr(args, 'dropout', 0.2) - args.attention_dropout = getattr(args, 'attention_dropout', 0.2) - args.weight_dropout = getattr(args, 'weight_dropout', 0.2) + args.dropout = getattr(args, "dropout", 0.2) + args.attention_dropout = getattr(args, "attention_dropout", 0.2) + args.weight_dropout = getattr(args, "weight_dropout", 0.2) lightconv_wmt_en_de_big(args) diff --git a/fairseq/models/lightconv_lm.py b/fairseq/models/lightconv_lm.py index 861f6430e9..1d9efc4e42 100644 --- a/fairseq/models/lightconv_lm.py +++ b/fairseq/models/lightconv_lm.py @@ -9,17 +9,11 @@ register_model, register_model_architecture, ) -from fairseq.models.lightconv import ( - Embedding, - LightConvDecoder, -) -from fairseq.modules import ( - AdaptiveInput, - CharacterTokenEmbedder, -) +from fairseq.models.lightconv import Embedding, LightConvDecoder +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder -@register_model('lightconv_lm') +@register_model("lightconv_lm") class LightConvLanguageModel(FairseqLanguageModel): def __init__(self, decoder): super().__init__(decoder) @@ -27,72 +21,182 @@ def __init__(self, decoder): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" - parser.add_argument('--dropout', default=0.1, type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--relu-dropout', default=0., type=float, metavar='D', - help='dropout probability after ReLU in FFN') - parser.add_argument('--input-dropout', type=float, metavar='D', - help='dropout probability of the inputs') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-output-dim', type=int, metavar='N', - help='decoder output dimension') - parser.add_argument('--decoder-input-dim', type=int, metavar='N', - help='decoder input dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads or LightConv/DynamicConv heads') - parser.add_argument('--decoder-normalize-before', default=False, action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N', - help='adaptive input factor') - parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', - help='if set, disables positional embeddings (outside self attention)') - parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--character-embeddings', default=False, action='store_true', - help='if set, uses character embedding convolutions to produce token embeddings') - parser.add_argument('--character-filters', type=str, metavar='LIST', - default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]', - help='size of character embeddings') - parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4, - help='size of character embeddings') - parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2, - help='number of highway layers for character token embeddder') - parser.add_argument('--adaptive-input', default=False, action='store_true', - help='if set, uses adaptive input') - parser.add_argument('--adaptive-input-factor', type=float, metavar='N', - help='adaptive input factor') - parser.add_argument('--adaptive-input-cutoff', metavar='EXPR', - help='comma separated list of adaptive input cutoff points.') - parser.add_argument('--tie-adaptive-weights', action='store_true', - help='if set, ties the weights of adaptive softmax and adaptive input') - parser.add_argument('--tie-adaptive-proj', action='store_true', - help='if set, ties the projection weights of adaptive softmax and adaptive input') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--attention-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--relu-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability after ReLU in FFN", + ) + parser.add_argument( + "--input-dropout", + type=float, + metavar="D", + help="dropout probability of the inputs", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension", + ) + parser.add_argument( + "--decoder-input-dim", type=int, metavar="N", help="decoder input dimension" + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--decoder-normalize-before", + default=False, + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--adaptive-softmax-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--no-token-positional-embeddings", + default=False, + action="store_true", + help="if set, disables positional embeddings (outside self attention)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + default=False, + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--character-embeddings", + default=False, + action="store_true", + help="if set, uses character embedding convolutions to produce token embeddings", + ) + parser.add_argument( + "--character-filters", + type=str, + metavar="LIST", + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + help="size of character embeddings", + ) + parser.add_argument( + "--character-embedding-dim", + type=int, + metavar="N", + default=4, + help="size of character embeddings", + ) + parser.add_argument( + "--char-embedder-highway-layers", + type=int, + metavar="N", + default=2, + help="number of highway layers for character token embeddder", + ) + parser.add_argument( + "--adaptive-input", + default=False, + action="store_true", + help="if set, uses adaptive input", + ) + parser.add_argument( + "--adaptive-input-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--adaptive-input-cutoff", + metavar="EXPR", + help="comma separated list of adaptive input cutoff points.", + ) + parser.add_argument( + "--tie-adaptive-weights", + action="store_true", + help="if set, ties the weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--tie-adaptive-proj", + action="store_true", + help="if set, ties the projection weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) """LightConv and DynamicConv arguments""" - parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int), - help='list of kernel size (default: "[3,7,15,31,31,31]")') - parser.add_argument('--decoder-glu', type=utils.eval_bool, - help='glu after in proj') - parser.add_argument('--decoder-conv-type', default='dynamic', type=str, - choices=['dynamic', 'lightweight'], - help='type of convolution') - parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool) - parser.add_argument('--weight-dropout', type=float, metavar='D', - help='dropout probability for conv weights') + parser.add_argument( + "--decoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")', + ) + parser.add_argument( + "--decoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--decoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool) + parser.add_argument( + "--weight-dropout", + type=float, + metavar="D", + help="dropout probability for conv weights", + ) @classmethod def build_model(cls, args, task): @@ -101,76 +205,102 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_lm_architecture(args) - if getattr(args, 'max_source_positions', None) is None: + if getattr(args, "max_source_positions", None) is None: args.max_source_positions = args.tokens_per_sample - if getattr(args, 'max_target_positions', None) is None: + if getattr(args, "max_target_positions", None) is None: args.max_target_positions = args.tokens_per_sample if args.character_embeddings: - embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters), - args.character_embedding_dim, - args.decoder_embed_dim, - args.char_embedder_highway_layers, - ) + embed_tokens = CharacterTokenEmbedder( + task.dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) elif args.adaptive_input: - embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, - args.adaptive_input_factor, args.decoder_embed_dim, - utils.eval_str_list(args.adaptive_input_cutoff, type=int)) + embed_tokens = AdaptiveInput( + len(task.dictionary), + task.dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, + utils.eval_str_list(args.adaptive_input_cutoff, type=int), + ) else: - embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) + embed_tokens = Embedding( + len(task.dictionary), args.decoder_input_dim, task.dictionary.pad() + ) if args.tie_adaptive_weights: assert args.adaptive_input assert args.adaptive_input_factor == args.adaptive_softmax_factor - assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format( - args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) assert args.decoder_input_dim == args.decoder_output_dim - decoder = LightConvDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) + decoder = LightConvDecoder( + args, + task.output_dictionary, + embed_tokens, + no_encoder_attn=True, + final_norm=False, + ) return LightConvLanguageModel(decoder) -@register_model_architecture('lightconv_lm', 'lightconv_lm') +@register_model_architecture("lightconv_lm", "lightconv_lm") def base_lm_architecture(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.character_embeddings = getattr(args, 'character_embeddings', False) + args.character_embeddings = getattr(args, "character_embeddings", False) - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) - args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim) # The model training is not stable without this args.decoder_normalize_before = True - args.adaptive_input = getattr(args, 'adaptive_input', False) - args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4) - args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) - args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) - args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) - args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) + args.decoder_kernel_size_list = getattr( + args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31] + ) if len(args.decoder_kernel_size_list) == 1: - args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers - assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers" - args.decoder_glu = getattr(args, 'decoder_glu', True) - args.input_dropout = getattr(args, 'input_dropout', 0.1) - args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout) + args.decoder_kernel_size_list = ( + args.decoder_kernel_size_list * args.decoder_layers + ) + assert ( + len(args.decoder_kernel_size_list) == args.decoder_layers + ), "decoder_kernel_size_list doesn't match decoder_layers" + args.decoder_glu = getattr(args, "decoder_glu", True) + args.input_dropout = getattr(args, "input_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout) -@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw') +@register_model_architecture("lightconv_lm", "lightconv_lm_gbw") def lightconv_lm_gbw(args): - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) base_lm_architecture(args) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 8404cafe1d..1a9dca3c75 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -3,28 +3,28 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, List, Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.models import ( FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.modules import AdaptiveSoftmax, FairseqDropout from torch import Tensor -from typing import Dict, List, Optional, Tuple DEFAULT_MAX_SOURCE_POSITIONS = 1e5 DEFAULT_MAX_TARGET_POSITIONS = 1e5 -@register_model('lstm') +@register_model("lstm") class LSTMModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @@ -89,10 +89,14 @@ def build_model(cls, args, task): base_architecture(args) if args.encoder_layers != args.decoder_layers: - raise ValueError('--encoder-layers must match --decoder-layers') + raise ValueError("--encoder-layers must match --decoder-layers") - max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS) - max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS) + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) + max_target_positions = getattr( + args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -104,7 +108,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.encoder_embed_path: pretrained_encoder_embed = load_pretrained_embedding_from_file( - args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim) + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) else: num_embeddings = len(task.source_dictionary) pretrained_encoder_embed = Embedding( @@ -114,16 +119,17 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.share_all_embeddings: # double check all parameters combinations are valid if task.source_dictionary != task.target_dictionary: - raise ValueError('--share-all-embeddings requires a joint dictionary') + raise ValueError("--share-all-embeddings requires a joint dictionary") if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path): + args.decoder_embed_path != args.encoder_embed_path + ): raise ValueError( - '--share-all-embed not compatible with --decoder-embed-path' + "--share-all-embed not compatible with --decoder-embed-path" ) if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( - '--share-all-embeddings requires --encoder-embed-dim to ' - 'match --decoder-embed-dim' + "--share-all-embeddings requires --encoder-embed-dim to " + "match --decoder-embed-dim" ) pretrained_decoder_embed = pretrained_encoder_embed args.share_decoder_input_output_embed = True @@ -134,14 +140,15 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed = load_pretrained_embedding_from_file( args.decoder_embed_path, task.target_dictionary, - args.decoder_embed_dim + args.decoder_embed_dim, ) # one last double check of parameter combinations if args.share_decoder_input_output_embed and ( - args.decoder_embed_dim != args.decoder_out_embed_dim): + args.decoder_embed_dim != args.decoder_out_embed_dim + ): raise ValueError( - '--share-decoder-input-output-embeddings requires ' - '--decoder-embed-dim to match --decoder-out-embed-dim' + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" ) if args.encoder_freeze_embed: @@ -174,7 +181,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, residuals=False, @@ -190,23 +198,38 @@ def forward( ): encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) decoder_out = self.decoder( - prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, ) return decoder_out class LSTMEncoder(FairseqEncoder): """LSTM encoder.""" + def __init__( - self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, - dropout_in=0.1, dropout_out=0.1, bidirectional=False, - left_pad=True, pretrained_embed=None, padding_idx=None, + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_idx=None, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(dictionary) self.num_layers = num_layers - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.bidirectional = bidirectional self.hidden_size = hidden_size self.max_source_positions = max_source_positions @@ -222,7 +245,7 @@ def __init__( input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, - dropout=self.dropout_out_module.p if num_layers > 1 else 0., + dropout=self.dropout_out_module.p if num_layers > 1 else 0.0, bidirectional=bidirectional, ) self.left_pad = left_pad @@ -281,7 +304,9 @@ def forward( packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_idx*1.0) + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_idx * 1.0 + ) x = self.dropout_out_module(x) assert list(x.size()) == [seqlen, bsz, self.output_units] @@ -291,24 +316,28 @@ def forward( encoder_padding_mask = src_tokens.eq(self.padding_idx).t() - return tuple(( - x, # seq_len x batch x hidden - final_hiddens, # num_layers x batch x num_directions*hidden - final_cells, # num_layers x batch x num_directions*hidden - encoder_padding_mask, # seq_len x batch - )) + return tuple( + ( + x, # seq_len x batch x hidden + final_hiddens, # num_layers x batch x num_directions*hidden + final_cells, # num_layers x batch x num_directions*hidden + encoder_padding_mask, # seq_len x batch + ) + ) def combine_bidir(self, outs, bsz: int): out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() return out.view(self.num_layers, bsz, -1) def reorder_encoder_out(self, encoder_out, new_order): - return tuple(( - encoder_out[0].index_select(1, new_order), - encoder_out[1].index_select(1, new_order), - encoder_out[2].index_select(1, new_order), - encoder_out[3].index_select(1, new_order), - )) + return tuple( + ( + encoder_out[0].index_select(1, new_order), + encoder_out[1].index_select(1, new_order), + encoder_out[2].index_select(1, new_order), + encoder_out[3].index_select(1, new_order), + ) + ) def max_positions(self): """Maximum input length supported by the encoder.""" @@ -320,7 +349,9 @@ def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=Fal super().__init__() self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias) - self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias) + self.output_proj = Linear( + input_embed_dim + source_embed_dim, output_embed_dim, bias=bias + ) def forward(self, input, source_hids, encoder_padding_mask): # input: bsz x input_embed_dim @@ -334,10 +365,11 @@ def forward(self, input, source_hids, encoder_padding_mask): # don't attend over padding if encoder_padding_mask is not None: - attn_scores = attn_scores.float().masked_fill_( - encoder_padding_mask, - float('-inf') - ).type_as(attn_scores) # FP16 support: cast to float and back + attn_scores = ( + attn_scores.float() + .masked_fill_(encoder_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz @@ -350,17 +382,31 @@ def forward(self, input, source_hids, encoder_padding_mask): class LSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" + def __init__( - self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, - encoder_output_units=512, pretrained_embed=None, - share_input_output_embed=False, adaptive_softmax_cutoff=None, + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + attention=True, + encoder_output_units=512, + pretrained_embed=None, + share_input_output_embed=False, + adaptive_softmax_cutoff=None, max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, residuals=False, ): super().__init__(dictionary) - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed self.need_attn = True @@ -386,17 +432,23 @@ def __init__( # disable input feeding if there is no encoder # input feeding is described in arxiv.org/abs/1508.04025 input_feed_size = 0 if encoder_output_units == 0 else hidden_size - self.layers = nn.ModuleList([ - LSTMCell( - input_size=input_feed_size + embed_dim if layer == 0 else hidden_size, - hidden_size=hidden_size, - ) - for layer in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=input_feed_size + embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) if attention: # TODO make bias configurable - self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False) + self.attention = AttentionLayer( + hidden_size, encoder_output_units, hidden_size, bias=False + ) else: self.attention = None @@ -406,7 +458,10 @@ def __init__( if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax( - num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out, + num_embeddings, + hidden_size, + adaptive_softmax_cutoff, + dropout=dropout_out, ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -459,7 +514,9 @@ def extract_features( # initialize previous states (or get from cache during incremental generation) if incremental_state is not None and len(incremental_state) > 0: - prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) + prev_hiddens, prev_cells, input_feed = self.get_cached_state( + incremental_state + ) elif encoder_out is not None: # setup recurrent cells prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)] @@ -475,9 +532,12 @@ def extract_features( prev_cells = [zero_state for i in range(self.num_layers)] input_feed = None - assert srclen > 0 or self.attention is None, \ - "attention is not supported if there are no encoder outputs" - attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None + assert ( + srclen > 0 or self.attention is None + ), "attention is not supported if there are no encoder outputs" + attn_scores = ( + x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None + ) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step @@ -502,7 +562,9 @@ def extract_features( # apply attention using the last layer's hidden state if self.attention is not None: assert attn_scores is not None - out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask) + out, attn_scores[:, j, :] = self.attention( + hidden, encoder_outs, encoder_padding_mask + ) else: out = hidden out = self.dropout_out_module(out) @@ -523,9 +585,9 @@ def extract_features( "prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed, - } + }, ) - self.set_incremental_state(incremental_state, 'cached_state', cache_state) + self.set_incremental_state(incremental_state, "cached_state", cache_state) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) @@ -533,7 +595,7 @@ def extract_features( # T x B x C -> B x T x C x = x.transpose(1, 0) - if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: + if hasattr(self, "additional_fc") and self.adaptive_softmax is None: x = self.additional_fc(x) x = self.dropout_out_module(x) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen @@ -557,7 +619,7 @@ def get_cached_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: - cached_state = self.get_incremental_state(incremental_state, 'cached_state') + cached_state = self.get_incremental_state(incremental_state, "cached_state") assert cached_state is not None prev_hiddens_ = cached_state["prev_hiddens"] assert prev_hiddens_ is not None @@ -565,7 +627,9 @@ def get_cached_state( assert prev_cells_ is not None prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models + input_feed = cached_state[ + "input_feed" + ] # can be None for decoder-only language models return prev_hiddens, prev_cells, input_feed def reorder_incremental_state( @@ -586,9 +650,9 @@ def reorder_incremental_state( "prev_hiddens": torch.stack(prev_hiddens), "prev_cells": torch.stack(prev_cells), "input_feed": input_feed, - } + }, ) - self.set_incremental_state(incremental_state, 'cached_state', cached_state_new), + self.set_incremental_state(incremental_state, "cached_state", cached_state_new), return def max_positions(self): @@ -609,7 +673,7 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): def LSTM(input_size, hidden_size, **kwargs): m = nn.LSTM(input_size, hidden_size, **kwargs) for name, param in m.named_parameters(): - if 'weight' in name or 'bias' in name: + if "weight" in name or "bias" in name: param.data.uniform_(-0.1, 0.1) return m @@ -617,12 +681,12 @@ def LSTM(input_size, hidden_size, **kwargs): def LSTMCell(input_size, hidden_size, **kwargs): m = nn.LSTMCell(input_size, hidden_size, **kwargs) for name, param in m.named_parameters(): - if 'weight' in name or 'bias' in name: + if "weight" in name or "bias" in name: param.data.uniform_(-0.1, 0.1) return m -def Linear(in_features, out_features, bias=True, dropout=0.): +def Linear(in_features, out_features, bias=True, dropout=0.0): """Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features, bias=bias) m.weight.data.uniform_(-0.1, 0.1) @@ -631,51 +695,59 @@ def Linear(in_features, out_features, bias=True, dropout=0.): return m -@register_model_architecture('lstm', 'lstm') +@register_model_architecture("lstm", "lstm") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) - args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False) - args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim) - args.encoder_layers = getattr(args, 'encoder_layers', 1) - args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) - args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) - args.decoder_attention = getattr(args, 'decoder_attention', '1') - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') - - -@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_freeze_embed = getattr(args, "encoder_freeze_embed", False) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "1") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + + +@register_model_architecture("lstm", "lstm_wiseman_iwslt_de_en") def lstm_wiseman_iwslt_de_en(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) - args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", 0) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", 0) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) base_architecture(args) -@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') +@register_model_architecture("lstm", "lstm_luong_wmt_en_de") def lstm_luong_wmt_en_de(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) - args.encoder_layers = getattr(args, 'encoder_layers', 4) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000) - args.decoder_layers = getattr(args, 'decoder_layers', 4) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000) + args.encoder_layers = getattr(args, "encoder_layers", 4) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1000) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1000) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", 0) base_architecture(args) diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py index 1a39b95289..454f0ac36f 100644 --- a/fairseq/models/lstm_lm.py +++ b/fairseq/models/lstm_lm.py @@ -5,15 +5,17 @@ from fairseq import utils from fairseq.models import ( - FairseqLanguageModel, register_model, register_model_architecture -) -from fairseq.models.lstm import ( - LSTMDecoder, Embedding + FairseqLanguageModel, + register_model, + register_model_architecture, ) +from fairseq.models.lstm import Embedding, LSTMDecoder + DEFAULT_MAX_TARGET_POSITIONS = 1e5 -@register_model('lstm_lm') + +@register_model("lstm_lm") class LSTMLanguageModel(FairseqLanguageModel): def __init__(self, decoder): super().__init__(decoder) @@ -60,10 +62,12 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if getattr(args, 'max_target_positions', None) is not None: + if getattr(args, "max_target_positions", None) is not None: max_target_positions = args.max_target_positions else: - max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) + max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -76,21 +80,21 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed = None if args.decoder_embed_path: pretrained_decoder_embed = load_pretrained_embedding_from_file( - args.decoder_embed_path, - task.target_dictionary, - args.decoder_embed_dim + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim ) if args.share_decoder_input_output_embed: # double check all parameters combinations are valid if task.source_dictionary != task.target_dictionary: - raise ValueError('--share-decoder-input-output-embeddings requires a joint dictionary') + raise ValueError( + "--share-decoder-input-output-embeddings requires a joint dictionary" + ) if args.decoder_embed_dim != args.decoder_out_embed_dim: raise ValueError( - '--share-decoder-input-output-embeddings requires ' - '--decoder-embed-dim to match --decoder-out-embed-dim' - ) + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" + ) decoder = LSTMDecoder( dictionary=task.dictionary, @@ -106,26 +110,33 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, - residuals=args.residuals + residuals=args.residuals, ) return cls(decoder) -@register_model_architecture('lstm_lm', 'lstm_lm') +@register_model_architecture("lstm_lm", "lstm_lm") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) - args.decoder_attention = getattr(args, 'decoder_attention', '0') - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') - args.residuals = getattr(args, 'residuals', False) + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "0") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + args.residuals = getattr(args, "residuals", False) diff --git a/fairseq/models/masked_lm.py b/fairseq/models/masked_lm.py index 35a6323ef2..c786de9125 100644 --- a/fairseq/models/masked_lm.py +++ b/fairseq/models/masked_lm.py @@ -8,11 +8,10 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.models import ( - FairseqEncoderModel, FairseqEncoder, + FairseqEncoderModel, register_model, register_model_architecture, ) @@ -27,12 +26,13 @@ logger = logging.getLogger(__name__) -@register_model('masked_lm') +@register_model("masked_lm") class MaskedLMModel(FairseqEncoderModel): """ Class for training a Masked Language Model. It also supports an additional sentence level prediction if the sent-loss argument is set. """ + def __init__(self, args, encoder): super().__init__(encoder) self.args = args @@ -40,66 +40,111 @@ def __init__(self, args, encoder): # if specified then apply bert initialization on the model. We need # to explictly call this to make sure that the output embeddings # and projection layers are also correctly initialized - if getattr(args, 'apply_bert_init', False): + if getattr(args, "apply_bert_init", False): self.apply(init_bert_params) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # Arguments related to dropout - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, - metavar='D', help='dropout probability for' - ' attention weights') - parser.add_argument('--act-dropout', type=float, - metavar='D', help='dropout probability after' - ' activation in FFN') + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for" " attention weights", + ) + parser.add_argument( + "--act-dropout", + type=float, + metavar="D", + help="dropout probability after" " activation in FFN", + ) # Arguments related to hidden states and self-attention - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-layers', type=int, metavar='N', - help='num encoder layers') - parser.add_argument('--encoder-attention-heads', type=int, metavar='N', - help='num encoder attention heads') + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) # Arguments related to input and output embeddings - parser.add_argument('--encoder-embed-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--share-encoder-input-output-embed', - action='store_true', help='share encoder input' - ' and output embeddings') - parser.add_argument('--encoder-learned-pos', action='store_true', - help='use learned positional embeddings in the encoder') - parser.add_argument('--no-token-positional-embeddings', - action='store_true', - help='if set, disables positional embeddings' - ' (outside self attention)') - parser.add_argument('--num-segment', type=int, metavar='N', - help='num segment in the input') - parser.add_argument('--max-positions', type=int, - help='number of positional embeddings to learn') + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--share-encoder-input-output-embed", + action="store_true", + help="share encoder input" " and output embeddings", + ) + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the encoder", + ) + parser.add_argument( + "--no-token-positional-embeddings", + action="store_true", + help="if set, disables positional embeddings" " (outside self attention)", + ) + parser.add_argument( + "--num-segment", type=int, metavar="N", help="num segment in the input" + ) + parser.add_argument( + "--max-positions", type=int, help="number of positional embeddings to learn" + ) # Arguments related to sentence level prediction - parser.add_argument('--sentence-class-num', type=int, metavar='N', - help='number of classes for sentence task') - parser.add_argument('--sent-loss', action='store_true', help='if set,' - ' calculate sentence level predictions') + parser.add_argument( + "--sentence-class-num", + type=int, + metavar="N", + help="number of classes for sentence task", + ) + parser.add_argument( + "--sent-loss", + action="store_true", + help="if set," " calculate sentence level predictions", + ) # Arguments related to parameter initialization - parser.add_argument('--apply-bert-init', action='store_true', - help='use custom param initialization for BERT') + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) # misc params - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--pooler-activation-fn', - choices=utils.get_available_activation_fns(), - help='Which activation function to use for pooler layer.') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="Which activation function to use for pooler layer.", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) def forward(self, src_tokens, segment_labels=None, **kwargs): return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs) @@ -113,7 +158,7 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if not hasattr(args, 'max_positions'): + if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample logger.info(args) @@ -160,14 +205,16 @@ def __init__(self, args, dictionary): self.lm_output_learned_bias = None # Remove head is set to true during fine-tuning - self.load_softmax = not getattr(args, 'remove_head', False) + self.load_softmax = not getattr(args, "remove_head", False) self.masked_lm_pooler = nn.Linear( args.encoder_embed_dim, args.encoder_embed_dim ) self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn) - self.lm_head_transform_weight = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim) + self.lm_head_transform_weight = nn.Linear( + args.encoder_embed_dim, args.encoder_embed_dim + ) self.activation_fn = utils.get_activation_fn(args.activation_fn) self.layer_norm = LayerNorm(args.encoder_embed_dim) @@ -177,16 +224,12 @@ def __init__(self, args, dictionary): if not self.share_input_output_embed: self.embed_out = nn.Linear( - args.encoder_embed_dim, - self.vocab_size, - bias=False + args.encoder_embed_dim, self.vocab_size, bias=False ) if args.sent_loss: self.sentence_projection_layer = nn.Linear( - args.encoder_embed_dim, - self.sentence_out_dim, - bias=False + args.encoder_embed_dim, self.sentence_out_dim, bias=False ) def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused): @@ -227,8 +270,9 @@ def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused) pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep)) # project back to size of vocabulary - if self.share_input_output_embed \ - and hasattr(self.sentence_encoder.embed_tokens, 'weight'): + if self.share_input_output_embed and hasattr( + self.sentence_encoder.embed_tokens, "weight" + ): x = F.linear(x, self.sentence_encoder.embed_tokens.weight) elif self.embed_out is not None: x = self.embed_out(x) @@ -239,9 +283,9 @@ def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused) sentence_logits = self.sentence_projection_layer(pooled_output) return x, { - 'inner_states': inner_states, - 'pooled_output': pooled_output, - 'sentence_logits': sentence_logits + "inner_states": inner_states, + "pooled_output": pooled_output, + "sentence_logits": sentence_logits, } def max_positions(self): @@ -250,103 +294,110 @@ def max_positions(self): def upgrade_state_dict_named(self, state_dict, name): if isinstance( - self.sentence_encoder.embed_positions, - SinusoidalPositionalEmbedding + self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding ): state_dict[ - name + '.sentence_encoder.embed_positions._float_tensor' + name + ".sentence_encoder.embed_positions._float_tensor" ] = torch.FloatTensor(1) if not self.load_softmax: for k in list(state_dict.keys()): if ( - "embed_out.weight" in k or - "sentence_projection_layer.weight" in k or - "lm_output_learned_bias" in k + "embed_out.weight" in k + or "sentence_projection_layer.weight" in k + or "lm_output_learned_bias" in k ): del state_dict[k] return state_dict -@register_model_architecture('masked_lm', 'masked_lm') +@register_model_architecture("masked_lm", "masked_lm") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.act_dropout = getattr(args, 'act_dropout', 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.0) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) - args.encoder_layers = getattr(args, 'encoder_layers', 6) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False) - args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False) - args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) - args.num_segment = getattr(args, 'num_segment', 2) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.num_segment = getattr(args, "num_segment", 2) - args.sentence_class_num = getattr(args, 'sentence_class_num', 2) - args.sent_loss = getattr(args, 'sent_loss', False) + args.sentence_class_num = getattr(args, "sentence_class_num", 2) + args.sent_loss = getattr(args, "sent_loss", False) - args.apply_bert_init = getattr(args, 'apply_bert_init', False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) - args.activation_fn = getattr(args, 'activation_fn', 'relu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) -@register_model_architecture('masked_lm', 'bert_base') +@register_model_architecture("masked_lm", "bert_base") def bert_base_architecture(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) args.share_encoder_input_output_embed = getattr( - args, 'share_encoder_input_output_embed', True) + args, "share_encoder_input_output_embed", True + ) args.no_token_positional_embeddings = getattr( - args, 'no_token_positional_embeddings', False) - args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) - args.num_segment = getattr(args, 'num_segment', 2) + args, "no_token_positional_embeddings", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.num_segment = getattr(args, "num_segment", 2) - args.encoder_layers = getattr(args, 'encoder_layers', 12) + args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.sentence_class_num = getattr(args, 'sentence_class_num', 2) - args.sent_loss = getattr(args, 'sent_loss', True) + args.sentence_class_num = getattr(args, "sentence_class_num", 2) + args.sent_loss = getattr(args, "sent_loss", True) - args.apply_bert_init = getattr(args, 'apply_bert_init', True) + args.apply_bert_init = getattr(args, "apply_bert_init", True) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) base_architecture(args) -@register_model_architecture('masked_lm', 'bert_large') +@register_model_architecture("masked_lm", "bert_large") def bert_large_architecture(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_layers = getattr(args, 'encoder_layers', 24) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) bert_base_architecture(args) -@register_model_architecture('masked_lm', 'xlm_base') +@register_model_architecture("masked_lm", "xlm_base") def xlm_architecture(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.share_encoder_input_output_embed = getattr( - args, 'share_encoder_input_output_embed', True) + args, "share_encoder_input_output_embed", True + ) args.no_token_positional_embeddings = getattr( - args, 'no_token_positional_embeddings', False) - args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) - args.num_segment = getattr(args, 'num_segment', 1) + args, "no_token_positional_embeddings", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.num_segment = getattr(args, "num_segment", 1) - args.encoder_layers = getattr(args, 'encoder_layers', 6) + args.encoder_layers = getattr(args, "encoder_layers", 6) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.sent_loss = getattr(args, 'sent_loss', False) + args.sent_loss = getattr(args, "sent_loss", False) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - args.apply_bert_init = getattr(args, 'apply_bert_init', True) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.apply_bert_init = getattr(args, "apply_bert_init", True) base_architecture(args) diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py index 46ec62f772..732d66b1d5 100644 --- a/fairseq/models/model_utils.py +++ b/fairseq/models/model_utils.py @@ -60,7 +60,9 @@ def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: @torch.jit.script -def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]: +def fill_tensors( + x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int +) -> Optional[Tensor]: """ Filling tensor x with y at masked positions (dim=0). """ @@ -82,9 +84,9 @@ def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: in elif x.size(1) > y.size(1): x[mask] = torch.tensor(padding_idx).type_as(x) if x.dim() == 2: - x[mask, :y.size(1)] = y + x[mask, : y.size(1)] = y else: - x[mask, :y.size(1), :] = y + x[mask, : y.size(1), :] = y else: x[mask] = y return x diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index 91a413753c..e3fbbd5710 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -12,15 +12,15 @@ register_model_architecture, ) from fairseq.models.transformer import ( - base_architecture, Embedding, - TransformerModel, - TransformerEncoder, TransformerDecoder, + TransformerEncoder, + TransformerModel, + base_architecture, ) -@register_model('multilingual_transformer') +@register_model("multilingual_transformer") class MultilingualTransformerModel(FairseqMultiModel): """Train Transformer models for multiple language pairs simultaneously. @@ -44,31 +44,44 @@ def __init__(self, encoders, decoders): def add_args(parser): """Add model-specific arguments to the parser.""" TransformerModel.add_args(parser) - parser.add_argument('--share-encoder-embeddings', action='store_true', - help='share encoder embeddings across languages') - parser.add_argument('--share-decoder-embeddings', action='store_true', - help='share decoder embeddings across languages') - parser.add_argument('--share-encoders', action='store_true', - help='share encoders across languages') - parser.add_argument('--share-decoders', action='store_true', - help='share decoders across languages') + parser.add_argument( + "--share-encoder-embeddings", + action="store_true", + help="share encoder embeddings across languages", + ) + parser.add_argument( + "--share-decoder-embeddings", + action="store_true", + help="share decoder embeddings across languages", + ) + parser.add_argument( + "--share-encoders", + action="store_true", + help="share encoders across languages", + ) + parser.add_argument( + "--share-decoders", + action="store_true", + help="share decoders across languages", + ) @classmethod def build_model(cls, args, task): """Build a new model instance.""" from fairseq.tasks.multilingual_translation import MultilingualTranslationTask + assert isinstance(task, MultilingualTranslationTask) # make sure all arguments are present in older models base_multilingual_architecture(args) - if not hasattr(args, 'max_source_positions'): + if not hasattr(args, "max_source_positions"): args.max_source_positions = 1024 - if not hasattr(args, 'max_target_positions'): + if not hasattr(args, "max_target_positions"): args.max_target_positions = 1024 - src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs] - tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs] + src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] + tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] if args.share_encoders: args.share_encoder_embeddings = True @@ -90,10 +103,14 @@ def build_embedding(dictionary, embed_dim, path=None): if args.share_all_embeddings: if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( - '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path): - raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( dicts=task.dicts, langs=task.langs, @@ -105,24 +122,20 @@ def build_embedding(dictionary, embed_dim, path=None): args.share_decoder_input_output_embed = True else: if args.share_encoder_embeddings: - shared_encoder_embed_tokens = ( - FairseqMultiModel.build_shared_embeddings( - dicts=task.dicts, - langs=src_langs, - embed_dim=args.encoder_embed_dim, - build_embedding=build_embedding, - pretrained_embed_path=args.encoder_embed_path, - ) + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=src_langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, ) if args.share_decoder_embeddings: - shared_decoder_embed_tokens = ( - FairseqMultiModel.build_shared_embeddings( - dicts=task.dicts, - langs=tgt_langs, - embed_dim=args.decoder_embed_dim, - build_embedding=build_embedding, - pretrained_embed_path=args.decoder_embed_path, - ) + shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=tgt_langs, + embed_dim=args.decoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.decoder_embed_path, ) # encoders/decoders for each language @@ -134,10 +147,13 @@ def get_encoder(lang): encoder_embed_tokens = shared_encoder_embed_tokens else: encoder_embed_tokens = build_embedding( - task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, ) lang_encoders[lang] = cls._get_module_class( - True, args, task.dicts[lang], encoder_embed_tokens, src_langs) + True, args, task.dicts[lang], encoder_embed_tokens, src_langs + ) return lang_encoders[lang] def get_decoder(lang): @@ -146,10 +162,13 @@ def get_decoder(lang): decoder_embed_tokens = shared_decoder_embed_tokens else: decoder_embed_tokens = build_embedding( - task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path + task.dicts[lang], + args.decoder_embed_dim, + args.decoder_embed_path, ) lang_decoders[lang] = cls._get_module_class( - False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs) + False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs + ) return lang_decoders[lang] # shared encoders/decoders (if applicable) @@ -161,8 +180,12 @@ def get_decoder(lang): encoders, decoders = OrderedDict(), OrderedDict() for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): - encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src) - decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt) + encoders[lang_pair] = ( + shared_encoder if shared_encoder is not None else get_encoder(src) + ) + decoders[lang_pair] = ( + shared_decoder if shared_decoder is not None else get_decoder(tgt) + ) return MultilingualTransformerModel(encoders, decoders) @@ -174,30 +197,32 @@ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): def load_state_dict(self, state_dict, strict=True, args=None): state_dict_subset = state_dict.copy() for k, _ in state_dict.items(): - assert k.startswith('models.') - lang_pair = k.split('.')[1] + assert k.startswith("models.") + lang_pair = k.split(".")[1] if lang_pair not in self.models: del state_dict_subset[k] super().load_state_dict(state_dict_subset, strict=strict, args=args) -@register_model_architecture('multilingual_transformer', 'multilingual_transformer') +@register_model_architecture("multilingual_transformer", "multilingual_transformer") def base_multilingual_architecture(args): base_architecture(args) - args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', False) - args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', False) - args.share_encoders = getattr(args, 'share_encoders', False) - args.share_decoders = getattr(args, 'share_decoders', False) + args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) + args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) + args.share_encoders = getattr(args, "share_encoders", False) + args.share_decoders = getattr(args, "share_decoders", False) -@register_model_architecture('multilingual_transformer', 'multilingual_transformer_iwslt_de_en') +@register_model_architecture( + "multilingual_transformer", "multilingual_transformer_iwslt_de_en" +) def multilingual_transformer_iwslt_de_en(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) - args.encoder_layers = getattr(args, 'encoder_layers', 6) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) - args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) base_multilingual_architecture(args) diff --git a/fairseq/models/nat/cmlm_transformer.py b/fairseq/models/nat/cmlm_transformer.py index 86c770569d..c876e9453c 100644 --- a/fairseq/models/nat/cmlm_transformer.py +++ b/fairseq/models/nat/cmlm_transformer.py @@ -38,26 +38,34 @@ def forward( # encoding encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction - length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out) - length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens) + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) # decoding word_ins_out = self.decoder( normalize=False, prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out) + encoder_out=encoder_out, + ) word_ins_mask = prev_output_tokens.eq(self.unk) return { "word_ins": { - "out": word_ins_out, "tgt": tgt_tokens, - "mask": word_ins_mask, "ls": self.args.label_smoothing, - "nll_loss": True + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, }, "length": { - "out": length_out, "tgt": length_tgt, - "factor": self.decoder.length_loss_factor - } + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, } def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): @@ -98,7 +106,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar output_tokens=output_tokens, output_scores=output_scores, attn=None, - history=history + history=history, ) diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py index d37a234ba9..1dbc29d0f4 100644 --- a/fairseq/models/nat/fairseq_nat_model.py +++ b/fairseq/models/nat/fairseq_nat_model.py @@ -4,9 +4,13 @@ # LICENSE file in the root directory of this source tree. import math -import torch -from fairseq.models.transformer import TransformerModel, TransformerEncoder, TransformerDecoder +import torch +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -22,22 +26,31 @@ def stack(key): return torch.stack(outs, -1) if outs[0] is not None else None return _encoder_out._replace( - encoder_out=stack('encoder_out'), - encoder_embedding=stack('encoder_embedding'), - encoder_states=stack('encoder_states') + encoder_out=stack("encoder_out"), + encoder_embedding=stack("encoder_embedding"), + encoder_states=stack("encoder_states"), ) + return wrapper def ensemble_decoder(func): def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): if self.ensemble_models is None or len(self.ensemble_models) == 1: - return func(self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs) + return func( + self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs + ) action_outs = [ - func(model, normalize=normalize, encoder_out=encoder_out._replace( - encoder_out=encoder_out.encoder_out[:, :, :, i] - ), *args, **kwargs) + func( + model, + normalize=normalize, + encoder_out=encoder_out._replace( + encoder_out=encoder_out.encoder_out[:, :, :, i] + ), + *args, + **kwargs + ) for i, model in enumerate(self.ensemble_models) ] @@ -51,19 +64,19 @@ def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): if i == 0 and normalize: ensembled_outs += [ torch.logsumexp( - torch.stack([a[i] for a in action_outs], -1), - dim=-1) - math.log(len(self.ensemble_models)) + torch.stack([a[i] for a in action_outs], -1), dim=-1 + ) + - math.log(len(self.ensemble_models)) ] elif action_outs[0][i] is not None: - ensembled_outs += [ - torch.stack([a[i] for a in action_outs], -1) - ] + ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)] else: ensembled_outs += [None] if len(ensembled_outs) == 1: return ensembled_outs[0] return tuple(ensembled_outs) + return wrapper @@ -71,6 +84,7 @@ class FairseqNATModel(TransformerModel): """ Abstract class for all nonautoregressive-based models """ + def __init__(self, args, encoder, decoder): super().__init__(args, encoder, decoder) self.tgt_dict = decoder.dictionary diff --git a/fairseq/models/nat/insertion_transformer.py b/fairseq/models/nat/insertion_transformer.py index a5f3c1abc5..bc28000f59 100644 --- a/fairseq/models/nat/insertion_transformer.py +++ b/fairseq/models/nat/insertion_transformer.py @@ -6,17 +6,16 @@ import numpy as np import torch import torch.nn.functional as F - from fairseq.models import register_model, register_model_architecture from fairseq.models.nat import ( + FairseqNATModel, LevenshteinTransformerDecoder, LevenshteinTransformerModel, - FairseqNATModel, - ensemble_decoder + ensemble_decoder, ) from fairseq.models.transformer import Linear -from fairseq.utils import new_arange from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import new_arange class NegativeDistanceScore(object): @@ -58,7 +57,8 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, ta from fairseq import libnat except ImportError as e: import sys - sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') + + sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") raise e B = in_tokens.size(0) @@ -147,7 +147,7 @@ def forward( word_ins_out = self.decoder.forward_word_ins( normalize=False, prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out + encoder_out=encoder_out, ) word_ins_tgt = _get_ins_targets( @@ -162,9 +162,11 @@ def forward( return { "word_ins": { - "out": word_ins_out, "tgt": word_ins_tgt, - "mask": word_ins_masks, "ls": self.args.label_smoothing, - "nll_loss": True + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_masks, + "ls": self.args.label_smoothing, + "nll_loss": True, } } @@ -178,9 +180,7 @@ def forward_decoder( # TODO: decoding for InsertionTransformer word_ins_score = self.decoder.forward_word_ins( - normalize=True, - prev_output_tokens=output_tokens, - encoder_out=encoder_out + normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out ) if eos_penalty > 0.0: @@ -202,7 +202,7 @@ def forward_decoder( output_tokens=output_tokens, output_scores=output_scores, attn=None, - history=history + history=history, ) diff --git a/fairseq/models/nat/iterative_nonautoregressive_transformer.py b/fairseq/models/nat/iterative_nonautoregressive_transformer.py index dc340c387d..bc39509980 100644 --- a/fairseq/models/nat/iterative_nonautoregressive_transformer.py +++ b/fairseq/models/nat/iterative_nonautoregressive_transformer.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import torch - from fairseq.models import register_model, register_model_architecture from fairseq.models.nat import NATransformerModel @@ -44,8 +43,16 @@ def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1): def gumbel_noise(input, TINY=1e-8): - return input.new_zeros(*input.size()).uniform_().add_( - TINY).log_().neg_().add_(TINY).log_().neg_() + return ( + input.new_zeros(*input.size()) + .uniform_() + .add_(TINY) + .log_() + .neg_() + .add_(TINY) + .log_() + .neg_() + ) @register_model("iterative_nonautoregressive_transformer") @@ -53,12 +60,21 @@ class IterNATransformerModel(NATransformerModel): @staticmethod def add_args(parser): NATransformerModel.add_args(parser) - parser.add_argument("--train-step", type=int, - help="number of refinement iterations during training") - parser.add_argument("--dae-ratio", type=float, - help="the probability of switching to the denoising auto-encoder loss") - parser.add_argument("--stochastic-approx", action="store_true", - help="sampling from the decoder as the inputs for next iteration") + parser.add_argument( + "--train-step", + type=int, + help="number of refinement iterations during training", + ) + parser.add_argument( + "--dae-ratio", + type=float, + help="the probability of switching to the denoising auto-encoder loss", + ) + parser.add_argument( + "--stochastic-approx", + action="store_true", + help="sampling from the decoder as the inputs for next iteration", + ) @classmethod def build_model(cls, args, task): @@ -78,14 +94,18 @@ def forward( encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction - length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out) - length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens) + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) # decoding word_ins_outs, word_ins_tgts, word_ins_masks = [], [], [] for t in range(self.train_step): word_ins_out = self.decoder( - normalize=False, + normalize=False, prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, step=t, @@ -133,14 +153,17 @@ def forward( return { "word_ins": { - "out": word_ins_out, "tgt": word_ins_tgt, - "mask": word_ins_mask, "ls": self.args.label_smoothing, - "nll_loss": True + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, }, "length": { - "out": length_out, "tgt": length_tgt, - "factor": self.decoder.length_loss_factor - } + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, } diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index e1748145c3..f7a3f003ca 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -6,33 +6,26 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture -from fairseq.models.transformer import ( - Embedding, - TransformerDecoderLayer -) - -from fairseq.models.nat import ( - FairseqNATModel, - FairseqNATDecoder, - ensemble_decoder -) - +from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder +from fairseq.models.transformer import Embedding, TransformerDecoderLayer from fairseq.modules.transformer_sentence_encoder import init_bert_params - from .levenshtein_utils import ( - _skip, _skip_encoder_out, _fill, - _get_ins_targets, _get_del_targets, - _apply_ins_masks, _apply_ins_words, _apply_del_words + _apply_del_words, + _apply_ins_masks, + _apply_ins_words, + _fill, + _get_del_targets, + _get_ins_targets, + _skip, + _skip_encoder_out, ) @register_model("levenshtein_transformer") class LevenshteinTransformerModel(FairseqNATModel): - @property def allow_length_beam(self): return False @@ -63,8 +56,8 @@ def add_args(parser): ) parser.add_argument( "--sampling-for-deletion", - action='store_true', - help='instead of argmax, use sampling to predict the tokens' + action="store_true", + help="instead of argmax, use sampling to predict the tokens", ) @classmethod @@ -93,19 +86,19 @@ def forward( mask_ins_out, _ = self.decoder.forward_mask_ins( normalize=False, prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out + encoder_out=encoder_out, ) word_ins_out, _ = self.decoder.forward_word_ins( normalize=False, prev_output_tokens=masked_tgt_tokens, - encoder_out=encoder_out + encoder_out=encoder_out, ) # make online prediction if self.decoder.sampling_for_deletion: word_predictions = torch.multinomial( - F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view( - word_ins_out.size(0), -1) + F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1 + ).view(word_ins_out.size(0), -1) else: word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] @@ -118,23 +111,29 @@ def forward( word_del_out, _ = self.decoder.forward_word_del( normalize=False, prev_output_tokens=word_predictions, - encoder_out=encoder_out) + encoder_out=encoder_out, + ) word_del_masks = word_predictions.ne(self.pad) return { "mask_ins": { - "out": mask_ins_out, "tgt": mask_ins_targets, - "mask": mask_ins_masks, "ls": 0.01, + "out": mask_ins_out, + "tgt": mask_ins_targets, + "mask": mask_ins_masks, + "ls": 0.01, }, "word_ins": { - "out": word_ins_out, "tgt": tgt_tokens, - "mask": masked_tgt_masks, "ls": self.args.label_smoothing, - "nll_loss": True + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": masked_tgt_masks, + "ls": self.args.label_smoothing, + "nll_loss": True, }, "word_del": { - "out": word_del_out, "tgt": word_del_targets, - "mask": word_del_masks - } + "out": word_del_out, + "tgt": word_del_targets, + "mask": word_del_masks, + }, } def forward_decoder( @@ -164,7 +163,7 @@ def forward_decoder( word_del_score, word_del_attn = self.decoder.forward_word_del( normalize=True, prev_output_tokens=_skip(output_tokens, can_del_word), - encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_del_word) + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_del_word), ) word_del_pred = word_del_score.max(-1)[1].bool() @@ -179,7 +178,7 @@ def forward_decoder( ) output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) output_scores = _fill(output_scores, can_del_word, _scores, 0) - attn = _fill(attn, can_del_word, _attn, 0.) + attn = _fill(attn, can_del_word, _attn, 0.0) if history is not None: history.append(output_tokens.clone()) @@ -190,7 +189,7 @@ def forward_decoder( mask_ins_score, _ = self.decoder.forward_mask_ins( normalize=True, prev_output_tokens=_skip(output_tokens, can_ins_mask), - encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_mask) + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_mask), ) if eos_penalty > 0.0: mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty @@ -219,7 +218,7 @@ def forward_decoder( word_ins_score, word_ins_attn = self.decoder.forward_word_ins( normalize=True, prev_output_tokens=_skip(output_tokens, can_ins_word), - encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_word) + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_word), ) word_ins_score, word_ins_pred = word_ins_score.max(-1) _tokens, _scores = _apply_ins_words( @@ -232,7 +231,7 @@ def forward_decoder( output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) - attn = _fill(attn, can_ins_word, word_ins_attn, 0.) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) if history is not None: history.append(output_tokens.clone()) @@ -247,7 +246,7 @@ def forward_decoder( output_tokens=output_tokens, output_scores=output_scores, attn=attn, - history=history + history=history, ) def initialize_output_tokens(self, encoder_out, src_tokens): @@ -265,7 +264,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens): attn=None, step=0, max_step=0, - history=None + history=None, ) @@ -283,29 +282,40 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.embed_word_del = Embedding(2, self.output_embed_dim, None) # del_word, ins_mask, ins_word - self.early_exit = [int(i) for i in args.early_exit.split(',')] + self.early_exit = [int(i) for i in args.early_exit.split(",")] assert len(self.early_exit) == 3 # copy layers for mask-predict/deletion self.layers_msk = None if getattr(args, "no_share_maskpredictor", False): - self.layers_msk = nn.ModuleList([ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(self.early_exit[1]) - ]) + self.layers_msk = nn.ModuleList( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[1]) + ] + ) self.layers_del = None if getattr(args, "no_share_discriminator", False): - self.layers_del = nn.ModuleList([ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(self.early_exit[0]) - ]) + self.layers_del = nn.ModuleList( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[0]) + ] + ) if getattr(args, "share_discriminator_maskpredictor", False): - assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator" + assert getattr( + args, "no_share_discriminator", False + ), "must set saperate discriminator" self.layers_msk = self.layers_del def extract_features( - self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused + self, + prev_output_tokens, + encoder_out=None, + early_exit=None, + layers=None, + **unused ): """ Similar to *forward* but only return features. @@ -344,7 +354,7 @@ def extract_features( decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) layers = self.layers if layers is None else layers early_exit = len(layers) if early_exit is None else early_exit - for _, layer in enumerate(layers[: early_exit]): + for _, layer in enumerate(layers[:early_exit]): x, attn, _ = layer( x, encoder_out.encoder_out if encoder_out is not None else None, @@ -368,33 +378,45 @@ def extract_features( @ensemble_decoder def forward_mask_ins(self, normalize, encoder_out, prev_output_tokens, **unused): features, extra = self.extract_features( - prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[1], + layers=self.layers_msk, + **unused ) features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) decoder_out = F.linear(features_cat, self.embed_mask_ins.weight) if normalize: - return F.log_softmax(decoder_out, -1), extra['attn'] - return decoder_out, extra['attn'] + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] @ensemble_decoder def forward_word_ins(self, normalize, encoder_out, prev_output_tokens, **unused): features, extra = self.extract_features( - prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[2], + layers=self.layers, + **unused ) decoder_out = self.output_layer(features) if normalize: - return F.log_softmax(decoder_out, -1), extra['attn'] - return decoder_out, extra['attn'] + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] @ensemble_decoder def forward_word_del(self, normalize, encoder_out, prev_output_tokens, **unused): features, extra = self.extract_features( - prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[0], + layers=self.layers_del, + **unused ) decoder_out = F.linear(features, self.embed_word_del.weight) if normalize: - return F.log_softmax(decoder_out, -1), extra['attn'] - return decoder_out, extra['attn'] + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] @register_model_architecture("levenshtein_transformer", "levenshtein_transformer") @@ -439,7 +461,9 @@ def levenshtein_base_architecture(args): args.early_exit = getattr(args, "early_exit", "6,6,6") args.no_share_discriminator = getattr(args, "no_share_discriminator", False) args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False) - args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False) + args.share_discriminator_maskpredictor = getattr( + args, "share_discriminator_maskpredictor", False + ) args.no_share_last_layer = getattr(args, "no_share_last_layer", False) diff --git a/fairseq/models/nat/levenshtein_utils.py b/fairseq/models/nat/levenshtein_utils.py index 11fb29578b..375a98c2e1 100644 --- a/fairseq/models/nat/levenshtein_utils.py +++ b/fairseq/models/nat/levenshtein_utils.py @@ -9,21 +9,27 @@ # -------------- Helper Functions --------------------------------------------------- # + def load_libnat(): try: from fairseq import libnat_cuda + return libnat_cuda, True except ImportError as e: - print(str(e) + '... fall back to CPU version') + print(str(e) + "... fall back to CPU version") try: from fairseq import libnat + return libnat, False except ImportError as e: import sys - sys.stderr.write("ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n") + + sys.stderr.write( + "ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n" + ) raise e @@ -34,14 +40,18 @@ def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx): in_masks = in_tokens.ne(padding_idx) out_masks = out_tokens.ne(padding_idx) mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels( - out_tokens.int(), libnat.levenshtein_distance( - in_tokens.int(), out_tokens.int(), - in_masks.sum(1).int(), out_masks.sum(1).int() - ) + out_tokens.int(), + libnat.levenshtein_distance( + in_tokens.int(), + out_tokens.int(), + in_masks.sum(1).int(), + out_masks.sum(1).int(), + ), ) masked_tgt_masks = masked_tgt_masks.bool() & out_masks - mask_ins_targets = mask_ins_targets.type_as( - in_tokens)[:, 1:in_masks.size(1)].masked_fill_(~in_masks[:, 1:], 0) + mask_ins_targets = mask_ins_targets.type_as(in_tokens)[ + :, 1 : in_masks.size(1) + ].masked_fill_(~in_masks[:, 1:], 0) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets @@ -73,7 +83,8 @@ def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx): mask_label + [0 for _ in range(out_seq_len - len(mask_label))] ) mask_ins_targets = [ - mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] + mask_input[1:-1] + + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] for mask_input in mask_inputs ] @@ -100,18 +111,23 @@ def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx): word_del_targets = libnat.generate_deletion_labels( in_tokens.int(), libnat.levenshtein_distance( - in_tokens.int(), out_tokens.int(), - in_masks.sum(1).int(), out_masks.sum(1).int() - ) + in_tokens.int(), + out_tokens.int(), + in_masks.sum(1).int(), + out_masks.sum(1).int(), + ), + ) + word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_( + ~in_masks, 0 ) - word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(~in_masks, 0) return word_del_targets def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx): out_seq_len = out_tokens.size(1) with torch.cuda.device_of(in_tokens): in_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + [t for t in s if t != padding_idx] + for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] @@ -149,10 +165,7 @@ def _apply_ins_masks( out_lengths = in_lengths + mask_ins_pred.sum(1) out_max_len = out_lengths.max() - out_masks = ( - new_arange(out_lengths, out_max_len)[None, :] - < out_lengths[:, None] - ) + out_masks = new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None] reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) out_tokens = ( @@ -173,9 +186,7 @@ def _apply_ins_masks( return out_tokens, out_scores -def _apply_ins_words( - in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx -): +def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx): word_ins_masks = in_tokens.eq(unk_idx) out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks]) @@ -200,11 +211,7 @@ def _apply_del_words( word_del_pred.masked_fill_(~in_masks, 1) word_del_pred.masked_fill_(bos_eos_masks, 0) - reordering = ( - new_arange(in_tokens) - .masked_fill_(word_del_pred, max_len) - .sort(1)[1] - ) + reordering = new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1] out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) @@ -216,7 +223,7 @@ def _apply_del_words( if in_attn is not None: _mask = word_del_pred[:, :, None].expand_as(in_attn) _reordering = reordering[:, :, None].expand_as(in_attn) - out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering) + out_attn = in_attn.masked_fill(_mask, 0.0).gather(1, _reordering) return out_tokens, out_scores, out_attn @@ -250,7 +257,9 @@ def _skip_encoder_out(encoder, encoder_out, mask): if not mask.any(): return encoder_out else: - return encoder.reorder_encoder_out(encoder_out, mask.nonzero(as_tuple=False).squeeze()) + return encoder.reorder_encoder_out( + encoder_out, mask.nonzero(as_tuple=False).squeeze() + ) def _fill(x, mask, y, padding_idx): @@ -276,9 +285,9 @@ def _fill(x, mask, y, padding_idx): elif x.size(1) > y.size(1): x[mask] = padding_idx if x.dim() == 2: - x[mask, :y.size(1)] = y + x[mask, : y.size(1)] = y else: - x[mask, :y.size(1), :] = y + x[mask, : y.size(1), :] = y else: x[mask] = y return x diff --git a/fairseq/models/nat/nat_crf_transformer.py b/fairseq/models/nat/nat_crf_transformer.py index 8dd3a08f72..d4b3cd931c 100644 --- a/fairseq/models/nat/nat_crf_transformer.py +++ b/fairseq/models/nat/nat_crf_transformer.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. -from fairseq.models.nat import NATransformerModel, base_architecture from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import NATransformerModel, base_architecture from fairseq.modules import DynamicCRF @@ -16,7 +16,7 @@ def __init__(self, args, encoder, decoder): self.crf_layer = DynamicCRF( num_embedding=len(self.tgt_dict), low_rank=args.crf_lowrank_approx, - beam_size=args.crf_beam_approx + beam_size=args.crf_beam_approx, ) @property @@ -26,12 +26,21 @@ def allow_ensemble(self): @staticmethod def add_args(parser): NATransformerModel.add_args(parser) - parser.add_argument("--crf-lowrank-approx", type=int, - help="the dimension of low-rank approximation of transition") - parser.add_argument("--crf-beam-approx", type=int, - help="the beam size for apporixmating the normalizing factor") - parser.add_argument("--word-ins-loss-factor", type=float, - help="weights on NAT loss used to co-training with CRF loss.") + parser.add_argument( + "--crf-lowrank-approx", + type=int, + help="the dimension of low-rank approximation of transition", + ) + parser.add_argument( + "--crf-beam-approx", + type=int, + help="the beam size for apporixmating the normalizing factor", + ) + parser.add_argument( + "--word-ins-loss-factor", + type=float, + help="weights on NAT loss used to co-training with CRF loss.", + ) def forward( self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs @@ -40,14 +49,19 @@ def forward( encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction - length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out) - length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens) + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) # decoding word_ins_out = self.decoder( normalize=False, prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out) + encoder_out=encoder_out, + ) word_ins_tgt, word_ins_mask = tgt_tokens, tgt_tokens.ne(self.pad) # compute the log-likelihood of CRF @@ -56,17 +70,19 @@ def forward( return { "word_ins": { - "out": word_ins_out, "tgt": word_ins_tgt, - "mask": word_ins_mask, "ls": self.args.label_smoothing, - "nll_loss": True, "factor": self.args.word_ins_loss_factor - }, - "word_crf": { - "loss": crf_nll + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, + "factor": self.args.word_ins_loss_factor, }, + "word_crf": {"loss": crf_nll}, "length": { - "out": length_out, "tgt": length_tgt, - "factor": self.decoder.length_loss_factor - } + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, } def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): @@ -77,9 +93,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar # execute the decoder and get emission scores output_masks = output_tokens.ne(self.pad) word_ins_out = self.decoder( - normalize=False, - prev_output_tokens=output_tokens, - encoder_out=encoder_out + normalize=False, prev_output_tokens=output_tokens, encoder_out=encoder_out ) # run viterbi decoding through CRF @@ -93,7 +107,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar output_tokens=output_tokens, output_scores=output_scores, attn=None, - history=history + history=history, ) diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py index 2ed4d956e0..46bb8aac43 100644 --- a/fairseq/models/nat/nonautoregressive_ensembles.py +++ b/fairseq/models/nat/nonautoregressive_ensembles.py @@ -7,14 +7,13 @@ import torch import torch.nn.functional as F - from fairseq.models.nat import ( + _apply_del_words, + _apply_ins_masks, + _apply_ins_words, _fill, _skip, _skip_encoder_out, - _apply_ins_masks, - _apply_ins_words, - _apply_del_words, ) @@ -43,7 +42,7 @@ def __init__(self, models): self.encoder = _EnsembleModelEncoder(self.models) def has_encoder(self): - return hasattr(self.models[0], 'encoder') + return hasattr(self.models[0], "encoder") def max_decoder_positions(self): return min(m.max_decoder_positions() for m in self.models) @@ -69,7 +68,9 @@ def __init__(self, models): super().__init__(models) @torch.no_grad() - def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs): + def forward_decoder( + self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs + ): # LevT ensembling # A pipeline of three steps: deletion, placeholder, and word insertion. # We need to average scores in each step in a pipeline way because of dependence. @@ -83,7 +84,11 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio= max_lens = output_tokens.new().fill_(255) else: if encoder_outs[0].encoder_padding_mask is None: - src_lens = encoder_outs[0].encoder_out.new(bsz).fill_(encoder_outs[0].encoder_out.size(1)) + src_lens = ( + encoder_outs[0] + .encoder_out.new(bsz) + .fill_(encoder_outs[0].encoder_out.size(1)) + ) else: src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() @@ -104,13 +109,13 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio= can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens if can_ins_mask.sum() != 0: output_tokens, output_scores = self.forward_mask_ins( - encoder_outs, - output_tokens, - output_scores, - can_ins_mask, - eos_penalty, - max_lens, - ) + encoder_outs, + output_tokens, + output_scores, + can_ins_mask, + eos_penalty, + max_lens, + ) # insert words can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 @@ -132,10 +137,12 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio= output_tokens=output_tokens, output_scores=output_scores, attn=attn, - history=None + history=None, ) - def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can_del_word): + def forward_word_del( + self, encoder_outs, output_tokens, output_scores, attn, can_del_word + ): word_del_score_avg = [] word_del_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): @@ -146,10 +153,12 @@ def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can word_del_score = F.log_softmax(word_del_out, 2) word_del_score_avg.append(word_del_score) word_del_attn_avg.append(word_del_attn) - word_del_score_avg = torch.logsumexp(torch.stack(word_del_score_avg, dim=0), dim=0) - math.log(len(self.models)) + word_del_score_avg = torch.logsumexp( + torch.stack(word_del_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) word_del_pred = word_del_score_avg.max(-1)[1].bool() if word_del_attn_avg[0] is not None: - word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0)/len(self.models) + word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models) else: word_del_attn_avg = None @@ -164,10 +173,18 @@ def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can ) output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) output_scores = _fill(output_scores, can_del_word, _scores, 0) - attn = _fill(attn, can_del_word, _attn, 0.) + attn = _fill(attn, can_del_word, _attn, 0.0) return output_tokens, output_scores, attn - def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens): + def forward_mask_ins( + self, + encoder_outs, + output_tokens, + output_scores, + can_ins_mask, + eos_penalty, + max_lens, + ): mask_ins_score_avg = [] for model, encoder_out in zip(self.models, encoder_outs): mask_ins_out, _ = model.decoder.forward_mask_ins( @@ -178,7 +195,9 @@ def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_m if eos_penalty > 0.0: mask_ins_score[:, :, 0] -= eos_penalty mask_ins_score_avg.append(mask_ins_score) - mask_ins_score_avg = torch.logsumexp(torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(len(self.models)) + mask_ins_score_avg = torch.logsumexp( + torch.stack(mask_ins_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) mask_ins_pred = mask_ins_score_avg.max(-1)[1] mask_ins_pred = torch.min( mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) @@ -195,7 +214,9 @@ def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_m output_scores = _fill(output_scores, can_ins_mask, _scores, 0) return output_tokens, output_scores - def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can_ins_word): + def forward_word_ins( + self, encoder_outs, output_tokens, output_scores, attn, can_ins_word + ): word_ins_score_avg = [] word_ins_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): @@ -206,9 +227,11 @@ def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can word_ins_score = F.log_softmax(word_ins_out, 2) word_ins_score_avg.append(word_ins_score) word_ins_attn_avg.append(word_ins_attn) - word_ins_score_avg = torch.logsumexp(torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(len(self.models)) + word_ins_score_avg = torch.logsumexp( + torch.stack(word_ins_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) if word_ins_attn_avg[0] is not None: - word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0)/len(self.models) + word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models) else: word_ins_attn_avg = None word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) @@ -223,7 +246,7 @@ def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) - attn = _fill(attn, can_ins_word, word_ins_attn, 0.) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) return output_tokens, output_scores, attn def initialize_output_tokens(self, encoder_outs, src_tokens): diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py index 050755c308..735297fc29 100644 --- a/fairseq/models/nat/nonautoregressive_transformer.py +++ b/fairseq/models/nat/nonautoregressive_transformer.py @@ -5,17 +5,11 @@ import torch import torch.nn.functional as F - from fairseq import utils from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder from fairseq.models.transformer import Embedding - -from fairseq.models.nat import ( - FairseqNATModel, - FairseqNATDecoder, - ensemble_decoder -) from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -48,7 +42,6 @@ def _uniform_assignment(src_lens, trg_lens): @register_model("nonautoregressive_transformer") class NATransformerModel(FairseqNATModel): - @property def allow_length_beam(self): return True @@ -58,14 +51,26 @@ def add_args(parser): FairseqNATModel.add_args(parser) # length prediction - parser.add_argument("--src-embedding-copy", action="store_true", - help="copy encoder word embeddings as the initial input of the decoder") - parser.add_argument("--pred-length-offset", action="store_true", - help="predicting the length difference between the target and source sentences") - parser.add_argument("--sg-length-pred", action="store_true", - help="stop the gradients back-propagated from the length predictor") - parser.add_argument("--length-loss-factor", type=float, - help="weights on the length prediction loss") + parser.add_argument( + "--src-embedding-copy", + action="store_true", + help="copy encoder word embeddings as the initial input of the decoder", + ) + parser.add_argument( + "--pred-length-offset", + action="store_true", + help="predicting the length difference between the target and source sentences", + ) + parser.add_argument( + "--sg-length-pred", + action="store_true", + help="stop the gradients back-propagated from the length predictor", + ) + parser.add_argument( + "--length-loss-factor", + type=float, + help="weights on the length prediction loss", + ) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): @@ -81,25 +86,33 @@ def forward( encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction - length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out) - length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens) + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) # decoding word_ins_out = self.decoder( normalize=False, prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out) + encoder_out=encoder_out, + ) return { "word_ins": { - "out": word_ins_out, "tgt": tgt_tokens, - "mask": tgt_tokens.ne(self.pad), "ls": self.args.label_smoothing, - "nll_loss": True + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": tgt_tokens.ne(self.pad), + "ls": self.args.label_smoothing, + "nll_loss": True, }, "length": { - "out": length_out, "tgt": length_tgt, - "factor": self.decoder.length_loss_factor - } + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, } def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): @@ -126,14 +139,14 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar output_tokens=output_tokens, output_scores=output_scores, attn=None, - history=history + history=history, ) def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction length_tgt = self.decoder.forward_length_prediction( self.decoder.forward_length(normalize=True, encoder_out=encoder_out), - encoder_out=encoder_out + encoder_out=encoder_out, ) max_length = length_tgt.clamp_(min=2).max() @@ -158,13 +171,17 @@ def initialize_output_tokens(self, encoder_out, src_tokens): attn=None, step=0, max_step=0, - history=None + history=None, ) def regenerate_length_beam(self, decoder_out, beam_size): output_tokens = decoder_out.output_tokens length_tgt = output_tokens.ne(self.pad).sum(1) - length_tgt = length_tgt[:, None] + utils.new_arange(length_tgt, 1, beam_size) - beam_size // 2 + length_tgt = ( + length_tgt[:, None] + + utils.new_arange(length_tgt, 1, beam_size) + - beam_size // 2 + ) length_tgt = length_tgt.view(-1).clamp_(min=2) max_length = length_tgt.max() idx_length = utils.new_arange(length_tgt, max_length) @@ -183,8 +200,7 @@ def regenerate_length_beam(self, decoder_out, beam_size): ).type_as(decoder_out.output_scores) return decoder_out._replace( - output_tokens=initial_output_tokens, - output_scores=initial_output_scores + output_tokens=initial_output_tokens, output_scores=initial_output_scores ) diff --git a/fairseq/models/roberta/alignment_utils.py b/fairseq/models/roberta/alignment_utils.py index 45d2e37194..ccc7f74cb9 100644 --- a/fairseq/models/roberta/alignment_utils.py +++ b/fairseq/models/roberta/alignment_utils.py @@ -29,23 +29,25 @@ def clean(text): # remove whitespaces to simplify alignment bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens] - bpe_tokens = [clean(roberta.bpe.decode(x) if x not in {'', ''} else x) for x in bpe_tokens] + bpe_tokens = [ + clean(roberta.bpe.decode(x) if x not in {"", ""} else x) for x in bpe_tokens + ] other_tokens = [clean(str(o)) for o in other_tokens] # strip leading bpe_tokens = bpe_tokens[1:] - assert ''.join(bpe_tokens) == ''.join(other_tokens) + assert "".join(bpe_tokens) == "".join(other_tokens) # create alignment from every word to a list of BPE tokens alignment = [] - bpe_toks = filter(lambda item: item[1] != '', enumerate(bpe_tokens, start=1)) + bpe_toks = filter(lambda item: item[1] != "", enumerate(bpe_tokens, start=1)) j, bpe_tok = next(bpe_toks) for other_tok in other_tokens: bpe_indices = [] while True: if other_tok.startswith(bpe_tok): bpe_indices.append(j) - other_tok = other_tok[len(bpe_tok):] + other_tok = other_tok[len(bpe_tok) :] try: j, bpe_tok = next(bpe_toks) except StopIteration: @@ -53,11 +55,11 @@ def clean(text): elif bpe_tok.startswith(other_tok): # other_tok spans multiple BPE tokens bpe_indices.append(j) - bpe_tok = bpe_tok[len(other_tok):] - other_tok = '' + bpe_tok = bpe_tok[len(other_tok) :] + other_tok = "" else: raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok)) - if other_tok == '': + if other_tok == "": break assert len(bpe_indices) > 0 alignment.append(bpe_indices) @@ -96,20 +98,21 @@ def align_features_to_words(roberta, features, alignment): def spacy_nlp(): - if getattr(spacy_nlp, '_nlp', None) is None: + if getattr(spacy_nlp, "_nlp", None) is None: try: from spacy.lang.en import English + spacy_nlp._nlp = English() except ImportError: - raise ImportError('Please install spacy with: pip install spacy') + raise ImportError("Please install spacy with: pip install spacy") return spacy_nlp._nlp def spacy_tokenizer(): - if getattr(spacy_tokenizer, '_tokenizer', None) is None: + if getattr(spacy_tokenizer, "_tokenizer", None) is None: try: nlp = spacy_nlp() spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp) except ImportError: - raise ImportError('Please install spacy with: pip install spacy') + raise ImportError("Please install spacy with: pip install spacy") return spacy_tokenizer._tokenizer diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 20456b3f5c..526823bd1f 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.data import encoders @@ -27,13 +26,15 @@ def __init__(self, args, task, model): self.bpe = encoders.build_bpe(args) # this is useful for determining the device - self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) @property def device(self): return self._float_tensor.device - def encode(self, sentence: str, *addl_sentences, no_separator=False) -> torch.LongTensor: + def encode( + self, sentence: str, *addl_sentences, no_separator=False + ) -> torch.LongTensor: """ BPE-encode a sentence (or multiple sentences). @@ -54,11 +55,13 @@ def encode(self, sentence: str, *addl_sentences, no_separator=False) -> torch.Lo >>> roberta.encode('world').tolist() [0, 8331, 2] """ - bpe_sentence = ' ' + self.bpe.encode(sentence) + ' ' + bpe_sentence = " " + self.bpe.encode(sentence) + " " for s in addl_sentences: - bpe_sentence += (' ' if not no_separator else '') - bpe_sentence += ' ' + self.bpe.encode(s) + ' ' - tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False, add_if_not_exist=False) + bpe_sentence += " " if not no_separator else "" + bpe_sentence += " " + self.bpe.encode(s) + " " + tokens = self.task.source_dictionary.encode_line( + bpe_sentence, append_eos=False, add_if_not_exist=False + ) return tokens.long() def decode(self, tokens: torch.LongTensor): @@ -66,21 +69,27 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens.numpy() if tokens[0] == self.task.source_dictionary.bos(): tokens = tokens[1:] # remove - eos_mask = (tokens == self.task.source_dictionary.eos()) + eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) - sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences] + sentences = [ + self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences + ] if len(sentences) == 1: return sentences[0] return sentences - def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor: + def extract_features( + self, tokens: torch.LongTensor, return_all_hiddens: bool = False + ) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > self.model.max_positions(): - raise ValueError('tokens exceeds maximum length: {} > {}'.format( - tokens.size(-1), self.model.max_positions() - )) + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) features, extra = self.model( tokens.to(device=self.device), features_only=True, @@ -88,7 +97,7 @@ def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = ) if return_all_hiddens: # convert from T x B x C -> B x T x C - inner_states = extra['inner_states'] + inner_states = extra["inner_states"] return [inner_state.transpose(0, 1) for inner_state in inner_states] else: return features # just the last layer's features @@ -107,7 +116,9 @@ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = Fal return logits return F.log_softmax(logits, dim=-1) - def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: bool = False) -> torch.Tensor: + def extract_features_aligned_to_words( + self, sentence: str, return_all_hiddens: bool = False + ) -> torch.Tensor: """Extract RoBERTa features, aligned to spaCy's word-level tokenizer.""" from fairseq.models.roberta import alignment_utils from spacy.tokens import Doc @@ -122,31 +133,42 @@ def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: b alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws) # extract features and align them - features = self.extract_features(bpe_toks, return_all_hiddens=return_all_hiddens) + features = self.extract_features( + bpe_toks, return_all_hiddens=return_all_hiddens + ) features = features.squeeze(0) - aligned_feats = alignment_utils.align_features_to_words(self, features, alignment) + aligned_feats = alignment_utils.align_features_to_words( + self, features, alignment + ) # wrap in spaCy Doc doc = Doc( nlp.vocab, - words=[''] + [x.text for x in spacy_toks] + [''], - spaces=[True] + [x.endswith(' ') for x in spacy_toks_ws[:-1]] + [True, False], + words=[""] + [x.text for x in spacy_toks] + [""], + spaces=[True] + + [x.endswith(" ") for x in spacy_toks_ws[:-1]] + + [True, False], ) assert len(doc) == aligned_feats.size(0) - doc.user_token_hooks['vector'] = lambda token: aligned_feats[token.i] + doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i] return doc def fill_mask(self, masked_input: str, topk: int = 5): - masked_token = '' - assert masked_token in masked_input and masked_input.count(masked_token) == 1, \ - "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token) + masked_token = "" + assert ( + masked_token in masked_input and masked_input.count(masked_token) == 1 + ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format( + masked_token + ) text_spans = masked_input.split(masked_token) - text_spans_bpe = (' {0} '.format(masked_token)).join( - [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] - ).strip() + text_spans_bpe = ( + (" {0} ".format(masked_token)) + .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) + .strip() + ) tokens = self.task.source_dictionary.encode_line( - ' ' + text_spans_bpe + ' ', + " " + text_spans_bpe + " ", append_eos=False, add_if_not_exist=False, ) @@ -167,25 +189,31 @@ def fill_mask(self, masked_input: str, topk: int = 5): topk_predicted_token_bpe = self.task.source_dictionary.string(index) topk_filled_outputs = [] - for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): + for index, predicted_token_bpe in enumerate( + topk_predicted_token_bpe.split(" ") + ): predicted_token = self.bpe.decode(predicted_token_bpe) # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 - if predicted_token_bpe.startswith('\u2581'): - predicted_token = ' ' + predicted_token + if predicted_token_bpe.startswith("\u2581"): + predicted_token = " " + predicted_token if " {0}".format(masked_token) in masked_input: - topk_filled_outputs.append(( - masked_input.replace( - ' {0}'.format(masked_token), predicted_token - ), - values[index].item(), - predicted_token, - )) + topk_filled_outputs.append( + ( + masked_input.replace( + " {0}".format(masked_token), predicted_token + ), + values[index].item(), + predicted_token, + ) + ) else: - topk_filled_outputs.append(( - masked_input.replace(masked_token, predicted_token), - values[index].item(), - predicted_token, - )) + topk_filled_outputs.append( + ( + masked_input.replace(masked_token, predicted_token), + values[index].item(), + predicted_token, + ) + ) return topk_filled_outputs def disambiguate_pronoun(self, sentence: str) -> bool: @@ -198,7 +226,10 @@ def disambiguate_pronoun(self, sentence: str) -> bool: >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') 'The trophy' """ - assert hasattr(self.task, 'disambiguate_pronoun'), \ - 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.' + assert hasattr( + self.task, "disambiguate_pronoun" + ), "roberta.disambiguate_pronoun() requires a model trained with the WSC task." with utils.model_eval(self.model): - return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda') + return self.task.disambiguate_pronoun( + self.model, sentence, use_cuda=self.device.type == "cuda" + ) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 0917927e34..6ce216a6bf 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils from fairseq.models import ( FairseqEncoder, @@ -19,12 +18,9 @@ register_model, register_model_architecture, ) -from fairseq.modules import ( - LayerNorm, - TransformerSentenceEncoder, -) -from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.modules import LayerNorm, TransformerSentenceEncoder from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from fairseq.modules.transformer_sentence_encoder import init_bert_params from .hub_interface import RobertaHubInterface @@ -32,16 +28,15 @@ logger = logging.getLogger(__name__) -@register_model('roberta') +@register_model("roberta") class RobertaModel(FairseqEncoderModel): - @classmethod def hub_models(cls): return { - 'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz', - 'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz', - 'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz', - 'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz', + "roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz", + "roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz", + "roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz", + "roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz", } def __init__(self, args, encoder): @@ -56,50 +51,117 @@ def __init__(self, args, encoder): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" - parser.add_argument('--encoder-layers', type=int, metavar='L', - help='num encoder layers') - parser.add_argument('--encoder-embed-dim', type=int, metavar='H', - help='encoder embedding dimension') - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='F', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-attention-heads', type=int, metavar='A', - help='num encoder attention heads') - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--pooler-activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use for pooler layer') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--activation-dropout', type=float, metavar='D', - help='dropout probability after activation in FFN') - parser.add_argument('--pooler-dropout', type=float, metavar='D', - help='dropout probability in the masked_lm pooler layers') - parser.add_argument('--max-positions', type=int, - help='number of positional embeddings to learn') - parser.add_argument('--load-checkpoint-heads', action='store_true', - help='(re-)register and load heads when loading checkpoints') + parser.add_argument( + "--encoder-layers", type=int, metavar="L", help="num encoder layers" + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use for pooler layer", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + parser.add_argument( + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--max-positions", type=int, help="number of positional embeddings to learn" + ) + parser.add_argument( + "--load-checkpoint-heads", + action="store_true", + help="(re-)register and load heads when loading checkpoints", + ) # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) - parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, - help='LayerDrop probability for encoder') - parser.add_argument('--encoder-layers-to-keep', default=None, - help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument( + "--encoder-layerdrop", + type=float, + metavar="D", + default=0, + help="LayerDrop probability for encoder", + ) + parser.add_argument( + "--encoder-layers-to-keep", + default=None, + help="which layers to *keep* when pruning as a comma-separated list", + ) # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) - parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, - help='iterative PQ quantization noise at training time') - parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, - help='block size of quantization noise at training time') - parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, - help='scalar quantization noise and scalar quantization at training time') - parser.add_argument('--untie-weights-roberta', action='store_true', - help='Untie weights between embeddings and classifiers in RoBERTa') - parser.add_argument('--spectral-norm-classification-head', action='store_true', default=False, - help='Apply spectral normalization on the classification head') + parser.add_argument( + "--quant-noise-pq", + type=float, + metavar="D", + default=0, + help="iterative PQ quantization noise at training time", + ) + parser.add_argument( + "--quant-noise-pq-block-size", + type=int, + metavar="D", + default=8, + help="block size of quantization noise at training time", + ) + parser.add_argument( + "--quant-noise-scalar", + type=float, + metavar="D", + default=0, + help="scalar quantization noise and scalar quantization at training time", + ) + parser.add_argument( + "--untie-weights-roberta", + action="store_true", + help="Untie weights between embeddings and classifiers in RoBERTa", + ) + parser.add_argument( + "--spectral-norm-classification-head", + action="store_true", + default=False, + help="Apply spectral normalization on the classification head", + ) @classmethod def build_model(cls, args, task): @@ -108,13 +170,20 @@ def build_model(cls, args, task): # make sure all arguments are present base_architecture(args) - if not hasattr(args, 'max_positions'): + if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample encoder = RobertaEncoder(args, task.source_dictionary) return cls(args, encoder) - def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs): + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + **kwargs + ): if classification_head_name is not None: features_only = True @@ -132,7 +201,9 @@ def get_normalized_probs(self, net_output, log_probs, sample=None): else: return F.softmax(logits, dim=-1) - def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): """Register a classification head.""" if name in self.classification_heads: prev_num_classes = self.classification_heads[name].out_proj.out_features @@ -140,7 +211,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' - 'and inner_dim {} (prev: {})'.format( + "and inner_dim {} (prev: {})".format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) @@ -157,11 +228,19 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, * @property def supported_targets(self): - return {'self'} + return {"self"} @classmethod - def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='gpt2', **kwargs): + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="gpt2", + **kwargs + ): from fairseq import hub_utils + x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -171,15 +250,15 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na load_checkpoint_heads=True, **kwargs, ) - return RobertaHubInterface(x['args'], x['task'], x['models'][0]) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) def upgrade_state_dict_named(self, state_dict, name): - prefix = name + '.' if name != '' else '' + prefix = name + "." if name != "" else "" # rename decoder -> encoder before upgrading children modules for k in list(state_dict.keys()): - if k.startswith(prefix + 'decoder'): - new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):] + if k.startswith(prefix + "decoder"): + new_k = prefix + "encoder" + k[len(prefix + "decoder") :] state_dict[new_k] = state_dict[k] del state_dict[k] @@ -188,35 +267,44 @@ def upgrade_state_dict_named(self, state_dict, name): # Handle new classification heads present in the state dict. current_head_names = ( - [] if not hasattr(self, 'classification_heads') + [] + if not hasattr(self, "classification_heads") else self.classification_heads.keys() ) keys_to_delete = [] for k in state_dict.keys(): - if not k.startswith(prefix + 'classification_heads.'): + if not k.startswith(prefix + "classification_heads."): continue - head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] - num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) - inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) + head_name = k[len(prefix + "classification_heads.") :].split(".")[0] + num_classes = state_dict[ + prefix + "classification_heads." + head_name + ".out_proj.weight" + ].size(0) + inner_dim = state_dict[ + prefix + "classification_heads." + head_name + ".dense.weight" + ].size(0) - if getattr(self.args, 'load_checkpoint_heads', False): + if getattr(self.args, "load_checkpoint_heads", False): if head_name not in current_head_names: self.register_classification_head(head_name, num_classes, inner_dim) else: if head_name not in current_head_names: logger.warning( - 'deleting classification head ({}) from checkpoint ' - 'not present in current model: {}'.format(head_name, k) + "deleting classification head ({}) from checkpoint " + "not present in current model: {}".format(head_name, k) ) keys_to_delete.append(k) elif ( - num_classes != self.classification_heads[head_name].out_proj.out_features - or inner_dim != self.classification_heads[head_name].dense.out_features + num_classes + != self.classification_heads[head_name].out_proj.out_features + or inner_dim + != self.classification_heads[head_name].dense.out_features ): logger.warning( - 'deleting classification head ({}) from checkpoint ' - 'with different dimensions than current model: {}'.format(head_name, k) + "deleting classification head ({}) from checkpoint " + "with different dimensions than current model: {}".format( + head_name, k + ) ) keys_to_delete.append(k) for k in keys_to_delete: @@ -224,12 +312,12 @@ def upgrade_state_dict_named(self, state_dict, name): # Copy any newly-added classification heads into the state dict # with their current weights. - if hasattr(self, 'classification_heads'): + if hasattr(self, "classification_heads"): cur_state = self.classification_heads.state_dict() for k, v in cur_state.items(): - if prefix + 'classification_heads.' + k not in state_dict: - logger.info('Overwriting ' + prefix + 'classification_heads.' + k) - state_dict[prefix + 'classification_heads.' + k] = v + if prefix + "classification_heads." + k not in state_dict: + logger.info("Overwriting " + prefix + "classification_heads." + k) + state_dict[prefix + "classification_heads." + k] = v class RobertaLMHead(nn.Module): @@ -284,7 +372,8 @@ def __init__( if do_spectral_norm: if q_noise != 0: raise NotImplementedError( - "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported") + "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported" + ) self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) def forward(self, features, **kwargs): @@ -326,7 +415,7 @@ def __init__(self, args, dictionary): q_noise=args.quant_noise_pq, qn_block_size=args.quant_noise_pq_block_size, ) - args.untie_weights_roberta = getattr(args, 'untie_weights_roberta', False) + args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) self.lm_head = RobertaLMHead( embed_dim=args.encoder_embed_dim, @@ -339,7 +428,14 @@ def __init__(self, args, dictionary): ), ) - def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + masked_tokens=None, + **unused + ): """ Args: src_tokens (LongTensor): input tokens of shape `(batch, src_len)` @@ -356,7 +452,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas is a list of hidden states. Note that the hidden states have shape `(src_len, batch, vocab)`. """ - x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) + x, extra = self.extract_features( + src_tokens, return_all_hiddens=return_all_hiddens + ) if not features_only: x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra @@ -365,10 +463,10 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): inner_states, _ = self.sentence_encoder( src_tokens, last_state_only=not return_all_hiddens, - token_embeddings=kwargs.get('token_embeddings', None), + token_embeddings=kwargs.get("token_embeddings", None), ) features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {'inner_states': inner_states if return_all_hiddens else None} + return features, {"inner_states": inner_states if return_all_hiddens else None} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) @@ -378,44 +476,46 @@ def max_positions(self): return self.args.max_positions -@register_model_architecture('roberta', 'roberta') +@register_model_architecture("roberta", "roberta") def base_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 12) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) - - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - - args.dropout = getattr(args, 'dropout', 0.1) - args.attention_dropout = getattr(args, 'attention_dropout', 0.1) - args.activation_dropout = getattr(args, 'activation_dropout', 0.0) - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) - args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) - args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) - args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) - args.spectral_norm_classification_head = getattr(args, 'spectral_nrom_classification_head', False) - - -@register_model_architecture('roberta', 'roberta_base') + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.spectral_norm_classification_head = getattr( + args, "spectral_nrom_classification_head", False + ) + + +@register_model_architecture("roberta", "roberta_base") def roberta_base_architecture(args): base_architecture(args) -@register_model_architecture('roberta', 'roberta_large') +@register_model_architecture("roberta", "roberta_large") def roberta_large_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 24) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) base_architecture(args) -@register_model_architecture('roberta', 'xlm') +@register_model_architecture("roberta", "xlm") def xlm_architecture(args): - args.encoder_layers = getattr(args, 'encoder_layers', 16) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1280) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1280*4) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) base_architecture(args) diff --git a/fairseq/models/roberta/model_camembert.py b/fairseq/models/roberta/model_camembert.py index eb57d81d8d..46447546fa 100644 --- a/fairseq/models/roberta/model_camembert.py +++ b/fairseq/models/roberta/model_camembert.py @@ -12,25 +12,32 @@ from .model import RobertaModel -@register_model('camembert') +@register_model("camembert") class CamembertModel(RobertaModel): - @classmethod def hub_models(cls): return { - 'camembert': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz', - 'camembert.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz', - 'camembert-base': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz', - 'camembert-large': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz', - 'camembert-base-ccnet': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz', - 'camembert-base-ccnet-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz', - 'camembert-base-wikipedia-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz', - 'camembert-base-oscar-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz', + "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz", + "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz", + "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz", + "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz", + "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz", } @classmethod - def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs): + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="sentencepiece", + **kwargs + ): from fairseq import hub_utils + x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -40,4 +47,4 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na load_checkpoint_heads=True, **kwargs, ) - return RobertaHubInterface(x['args'], x['task'], x['models'][0]) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/roberta/model_xlmr.py b/fairseq/models/roberta/model_xlmr.py index fa71a27d12..5886880f73 100644 --- a/fairseq/models/roberta/model_xlmr.py +++ b/fairseq/models/roberta/model_xlmr.py @@ -12,19 +12,26 @@ from .model import RobertaModel -@register_model('xlmr') +@register_model("xlmr") class XLMRModel(RobertaModel): - @classmethod def hub_models(cls): return { - 'xlmr.base': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz', - 'xlmr.large': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz', + "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz", + "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz", } @classmethod - def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs): + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="sentencepiece", + **kwargs + ): from fairseq import hub_utils + x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -34,4 +41,4 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na load_checkpoint_heads=True, **kwargs, ) - return RobertaHubInterface(x['args'], x['task'], x['models'][0]) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 351c16fee5..5d7f59b3a6 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .berard import * # noqa -from .s2t_transformer import * # noqa +from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/berard.py b/fairseq/models/speech_to_text/berard.py index f5ae46eeb2..c505e3acaa 100644 --- a/fairseq/models/speech_to_text/berard.py +++ b/fairseq/models/speech_to_text/berard.py @@ -6,16 +6,15 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.models import ( FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) -from fairseq.data.data_utils import lengths_to_padding_mask @register_model("s2t_berard") @@ -40,48 +39,85 @@ def __init__(self, encoder, decoder): @staticmethod def add_args(parser): - parser.add_argument("--input-layers", type=str, metavar="EXPR", - help="List of linear layer dimensions. These " - "layers are applied to the input features and " - "are followed by tanh and possibly dropout.") parser.add_argument( - "--dropout", type=float, metavar="D", + "--input-layers", + type=str, + metavar="EXPR", + help="List of linear layer dimensions. These " + "layers are applied to the input features and " + "are followed by tanh and possibly dropout.", + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", help="Dropout probability to use in the encoder/decoder. " - "Note that this parameters control dropout in various places, " - "there is no fine-grained control for dropout for embeddings " - "vs LSTM layers for example." + "Note that this parameters control dropout in various places, " + "there is no fine-grained control for dropout for embeddings " + "vs LSTM layers for example.", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="Number of encoder input channels. " "Typically value is 1.", + ) + parser.add_argument( + "--conv-layers", + type=str, + metavar="EXPR", + help="List of conv layers " "(format: (channels, kernel, stride)).", + ) + parser.add_argument( + "--num-blstm-layers", + type=int, + metavar="N", + help="Number of encoder bi-LSTM layers.", ) - parser.add_argument("--in-channels", type=int, metavar="N", - help="Number of encoder input channels. " - "Typically value is 1.") - parser.add_argument("--conv-layers", type=str, metavar="EXPR", - help="List of conv layers " - "(format: (channels, kernel, stride)).") - parser.add_argument("--num-blstm-layers", type=int, metavar="N", - help="Number of encoder bi-LSTM layers.") - parser.add_argument("--lstm-size", type=int, metavar="N", - help="LSTM hidden size.") parser.add_argument( - "--decoder-embed-dim", type=int, metavar="N", - help="Embedding dimension of the decoder target tokens." + "--lstm-size", type=int, metavar="N", help="LSTM hidden size." ) - parser.add_argument("--decoder-hidden-dim", type=int, metavar="N", - help="Decoder LSTM hidden dimension.") - parser.add_argument("--decoder-num-layers", type=int, metavar="N", - help="Number of decoder LSTM layers.") - parser.add_argument("--attention-dim", type=int, metavar="N", - help="Hidden layer dimension in MLP attention.") parser.add_argument( - "--output-layer-dim", type=int, metavar="N", - help="Hidden layer dim for linear layer prior to output projection." + "--decoder-embed-dim", + type=int, + metavar="N", + help="Embedding dimension of the decoder target tokens.", ) parser.add_argument( - "--load-pretrained-encoder-from", type=str, metavar="STR", - help="model to take encoder weights from (for initialization)" + "--decoder-hidden-dim", + type=int, + metavar="N", + help="Decoder LSTM hidden dimension.", ) parser.add_argument( - "--load-pretrained-decoder-from", type=str, metavar="STR", - help="model to take decoder weights from (for initialization)" + "--decoder-num-layers", + type=int, + metavar="N", + help="Number of decoder LSTM layers.", + ) + parser.add_argument( + "--attention-dim", + type=int, + metavar="N", + help="Hidden layer dimension in MLP attention.", + ) + parser.add_argument( + "--output-layer-dim", + type=int, + metavar="N", + help="Hidden layer dim for linear layer prior to output projection.", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", ) @classmethod @@ -170,8 +206,7 @@ def __init__( if dropout > 0: self.input_layers.append( nn.Sequential( - nn.Linear(in_features, out_features), - nn.Dropout(p=dropout) + nn.Linear(in_features, out_features), nn.Dropout(p=dropout) ) ) else: @@ -194,9 +229,7 @@ def __init__( padding=conv_kernel_size // 2, ) ) - self.conv_kernel_sizes_and_strides.append( - (conv_kernel_size, conv_stride) - ) + self.conv_kernel_sizes_and_strides.append((conv_kernel_size, conv_stride)) in_channels = out_channels lstm_input_dim //= conv_stride @@ -241,8 +274,7 @@ def forward(self, src_tokens, src_lengths=None, **kwargs): # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> # (T, B, C * feat) - x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, - bsz, -1) + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) input_lengths = src_lengths.clone() for k, s in self.conv_kernel_sizes_and_strides: @@ -261,8 +293,9 @@ def forward(self, src_tokens, src_lengths=None, **kwargs): if self.dropout is not None: x = self.dropout(x) - encoder_padding_mask = lengths_to_padding_mask(output_lengths).to( - src_tokens.device).t() + encoder_padding_mask = ( + lengths_to_padding_mask(output_lengths).to(src_tokens.device).t() + ) return { "encoder_out": x, # (T, B, C) @@ -293,8 +326,7 @@ def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim): self.context_dim = context_dim self.attention_dim = attention_dim # W_ae and b_a - self.encoder_proj = nn.Linear(context_dim, self.attention_dim, - bias=True) + self.encoder_proj = nn.Linear(context_dim, self.attention_dim, bias=True) # W_ad self.decoder_proj = nn.Linear( decoder_hidden_state_dim, self.attention_dim, bias=False @@ -314,8 +346,7 @@ def forward(self, decoder_state, source_hids, encoder_padding_mask): # (src_len*bsz) x attention_dim encoder_component = self.encoder_proj(flat_source_hids) # src_len x bsz x attention_dim - encoder_component = encoder_component.view(src_len, bsz, - self.attention_dim) + encoder_component = encoder_component.view(src_len, bsz, self.attention_dim) # 1 x bsz x attention_dim decoder_component = self.decoder_proj(decoder_state).unsqueeze(0) # Sum with broadcasting and apply the non linearity @@ -400,8 +431,9 @@ def __init__( ) self.output_projection = nn.Linear(output_layer_dim, num_embeddings) - def forward(self, prev_output_tokens, encoder_out=None, - incremental_state=None, **kwargs): + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): encoder_padding_mask = encoder_out["encoder_padding_mask"] encoder_outs = encoder_out["encoder_out"] @@ -428,9 +460,7 @@ def forward(self, prev_output_tokens, encoder_out=None, if cached_state is not None: prev_hiddens, prev_cells = cached_state else: - prev_hiddens = [ - encoder_out["encoder_out"].mean(dim=0) - ] * self.num_layers + prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)] * self.num_layers prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers attn_scores = x.new_zeros(bsz, srclen) @@ -510,9 +540,7 @@ def reorder_state(state): return state.index_select(0, new_order) new_state = tuple(map(reorder_state, cached_state)) - utils.set_incremental_state( - self, incremental_state, "cached_state", new_state - ) + utils.set_incremental_state(self, incremental_state, "cached_state", new_state) @register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard") @@ -538,8 +566,7 @@ def berard(args): ) -@register_model_architecture(model_name="s2t_berard", - arch_name="s2t_berard_256_3_3") +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_256_3_3") def berard_256_3_3(args): """Used in * "Harnessing Indirect Training Data for End-to-End Automatic Speech @@ -553,8 +580,7 @@ def berard_256_3_3(args): berard(args) -@register_model_architecture(model_name="s2t_berard", - arch_name="s2t_berard_512_3_2") +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_3_2") def berard_512_3_2(args): args.num_blstm_layers = getattr(args, "num_blstm_layers", 3) args.lstm_size = getattr(args, "lstm_size", 512) @@ -567,8 +593,7 @@ def berard_512_3_2(args): berard(args) -@register_model_architecture(model_name="s2t_berard", - arch_name="s2t_berard_512_5_3") +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_5_3") def berard_512_5_3(args): args.num_blstm_layers = getattr(args, "num_blstm_layers", 5) args.lstm_size = getattr(args, "lstm_size", 512) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 3492f691f7..8e48964f79 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -6,14 +6,22 @@ import torch import torch.nn as nn -from fairseq import utils, checkpoint_utils -from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, - register_model, register_model_architecture) -from fairseq.models.fairseq_encoder import EncoderOut +from fairseq import checkpoint_utils, utils from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import Embedding, TransformerDecoder -from fairseq.modules import (PositionalEmbedding, TransformerEncoderLayer, - FairseqDropout, LayerNorm) +from fairseq.modules import ( + FairseqDropout, + LayerNorm, + PositionalEmbedding, + TransformerEncoderLayer, +) from torch import Tensor @@ -31,15 +39,23 @@ class Conv1dSubsampler(nn.Module): out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ - def __init__(self, in_channels: int, mid_channels: int, out_channels: int, - kernel_sizes: List[int] = (3, 3)): + + def __init__( + self, + in_channels: int, + mid_channels: int, + out_channels: int, + kernel_sizes: List[int] = (3, 3), + ): super(Conv1dSubsampler, self).__init__() self.n_layers = len(kernel_sizes) self.conv_layers = nn.ModuleList( nn.Conv1d( in_channels if i == 0 else mid_channels // 2, mid_channels if i < self.n_layers - 1 else out_channels * 2, - k, stride=2, padding=k // 2 + k, + stride=2, + padding=k // 2, ) for i, k in enumerate(kernel_sizes) ) @@ -76,48 +92,109 @@ def __init__(self, encoder, decoder): def add_args(parser): """Add model-specific arguments to the parser.""" # input - parser.add_argument("--conv-kernel-sizes", type=str, metavar="N", - help="kernel sizes of Conv1d subsampling layers") - parser.add_argument("--conv-channels", type=int, metavar="N", - help="# of channels in Conv1d subsampling layers") + parser.add_argument( + "--conv-kernel-sizes", + type=str, + metavar="N", + help="kernel sizes of Conv1d subsampling layers", + ) + parser.add_argument( + "--conv-channels", + type=int, + metavar="N", + help="# of channels in Conv1d subsampling layers", + ) # Transformer - parser.add_argument("--activation-fn", type=str, default='relu', - choices=utils.get_available_activation_fns(), - help="activation function to use") - parser.add_argument("--dropout", type=float, metavar="D", - help="dropout probability") - parser.add_argument("--attention-dropout", type=float, metavar="D", - help="dropout probability for attention weights") - parser.add_argument("--activation-dropout", "--relu-dropout", - type=float, metavar="D", - help="dropout probability after activation in FFN.") - parser.add_argument("--encoder-embed-dim", type=int, metavar="N", - help="encoder embedding dimension") - parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N", - help="encoder embedding dimension for FFN") - parser.add_argument("--encoder-layers", type=int, metavar="N", - help="num encoder layers") - parser.add_argument("--encoder-attention-heads", type=int, metavar="N", - help="num encoder attention heads") - parser.add_argument("--encoder-normalize-before", action="store_true", - help="apply layernorm before each encoder block") - parser.add_argument("--decoder-embed-dim", type=int, metavar="N", - help="decoder embedding dimension") - parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N", - help="decoder embedding dimension for FFN") - parser.add_argument("--decoder-layers", type=int, metavar="N", - help="num decoder layers") - parser.add_argument("--decoder-attention-heads", type=int, metavar="N", - help="num decoder attention heads") - parser.add_argument("--decoder-normalize-before", action="store_true", - help="apply layernorm before each decoder block") - parser.add_argument("--layernorm-embedding", action="store_true", - help="add layernorm to embedding") - parser.add_argument("--no-scale-embedding", action="store_true", - help="if True, dont scale embeddings") parser.add_argument( - "--load-pretrained-encoder-from", type=str, metavar="STR", - help="model to take encoder weights from (for initialization)" + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", ) @classmethod @@ -127,14 +204,15 @@ def build_encoder(cls, args): encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from ) - logger.info(f'loaded pretrained encoder from: ' - f'{args.load_pretrained_encoder_from}') + logger.info( + f"loaded pretrained encoder from: " + f"{args.load_pretrained_encoder_from}" + ) return encoder @classmethod def build_decoder(cls, args, task, embed_tokens): - return TransformerDecoderScriptable(args, task.target_dictionary, - embed_tokens) + return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) @classmethod def build_model(cls, args, task): @@ -148,8 +226,9 @@ def build_embedding(dictionary, embed_dim): padding_idx = dictionary.pad() return Embedding(num_embeddings, embed_dim, padding_idx) - decoder_embed_tokens = build_embedding(task.target_dictionary, - args.decoder_embed_dim) + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) encoder = cls.build_encoder(args) decoder = cls.build_decoder(args, task, decoder_embed_tokens) return cls(encoder, decoder) @@ -161,8 +240,7 @@ def get_normalized_probs( sample: Optional[Dict[str, Tensor]] = None, ): # net_output['encoder_out'] is a (B, T, D) tensor - lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, - sample) + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) lprobs.batch_first = True return lprobs @@ -172,10 +250,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): argument in its input, which is not supported in torchscript. This method overrites the forward method definition without **kwargs. """ - encoder_out = self.encoder(src_tokens=src_tokens, - src_lengths=src_lengths) - decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out) + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) return decoder_out @@ -196,13 +274,13 @@ def __init__(self, args): self.subsample = Conv1dSubsampler( args.input_feat_per_channel * args.input_channels, - args.conv_channels, args.encoder_embed_dim, - [int(k) for k in args.conv_kernel_sizes.split(',')] + args.conv_channels, + args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(",")], ) self.embed_positions = PositionalEmbedding( - args.max_source_positions, args.encoder_embed_dim, - self.padding_idx + args.max_source_positions, args.encoder_embed_dim, self.padding_idx ) self.transformer_layers = nn.ModuleList( @@ -232,9 +310,12 @@ def forward(self, src_tokens, src_lengths): x = self.layer_norm(x) return EncoderOut( - encoder_out=x, encoder_padding_mask=encoder_padding_mask, - encoder_embedding=None, encoder_states=None, src_tokens=None, - src_lengths=None + encoder_out=x, + encoder_padding_mask=encoder_padding_mask, + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=None, ) @torch.jit.export @@ -245,8 +326,7 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): variables for Torchscript Optional refinement """ - encoder_padding_mask: Optional[Tensor] = \ - encoder_out.encoder_padding_mask + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( @@ -294,40 +374,40 @@ def extract_features( ): # call scriptable method from parent class x, _ = self.extract_features_scriptable( - prev_output_tokens, encoder_out, incremental_state, - full_context_alignment, alignment_layer, alignment_heads, + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, ) return x, None -@register_model_architecture(model_name="s2t_transformer", - arch_name="s2t_transformer") +@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") def base_architecture(args): # Convolutional subsampler - args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", '5,5') + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_channels = getattr(args, "conv_channels", 1024) # Transformer args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", - True) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", - args.encoder_embed_dim) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", - args.encoder_ffn_embed_dim) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", - True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", args.dropout) args.activation_dropout = getattr(args, "activation_dropout", args.dropout) args.activation_fn = getattr(args, "activation_fn", "relu") - args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", - None) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False @@ -337,10 +417,10 @@ def base_architecture(args): ) args.adaptive_input = getattr(args, "adaptive_input", False) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) - args.decoder_output_dim = getattr(args, "decoder_output_dim", - args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, "decoder_input_dim", - args.decoder_embed_dim) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) @@ -380,8 +460,7 @@ def s2t_transformer_mp(args): @register_model_architecture("s2t_transformer", "s2t_transformer_l") def s2t_transformer_l(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", - 1024 * 4) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.dropout = getattr(args, "dropout", 0.2) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index ca1c6aaf5c..fbb7ce2338 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -30,6 +30,7 @@ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor + DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -308,7 +309,9 @@ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim @@ -543,7 +546,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed diff --git a/fairseq/models/transformer_align.py b/fairseq/models/transformer_align.py index c80cc4341c..eaf585bd10 100644 --- a/fairseq/models/transformer_align.py +++ b/fairseq/models/transformer_align.py @@ -5,9 +5,9 @@ from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import ( + TransformerModel, base_architecture, transformer_wmt_en_de_big, - TransformerModel, ) diff --git a/fairseq/models/transformer_from_pretrained_xlm.py b/fairseq/models/transformer_from_pretrained_xlm.py index bd03c8450f..236d9942e1 100644 --- a/fairseq/models/transformer_from_pretrained_xlm.py +++ b/fairseq/models/transformer_from_pretrained_xlm.py @@ -19,7 +19,6 @@ @register_model("transformer_from_pretrained_xlm") class TransformerFromPretrainedXLMModel(TransformerModel): - @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" @@ -96,25 +95,24 @@ def upgrade_state_dict_with_xlm_weights( for search_key in ["embed_tokens", "embed_positions", "layers"]: if search_key in key: - subkey = key[key.find(search_key):] + subkey = key[key.find(search_key) :] assert subkey in state_dict, ( "{} Transformer encoder / decoder " "state_dict does not contain {}. Cannot " "load {} from pretrained XLM checkpoint " "{} into Transformer.".format( - str(state_dict.keys()), - subkey, key, pretrained_xlm_checkpoint) + str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint ) + ) state_dict[subkey] = xlm_state_dict[key] return state_dict class TransformerEncoderFromPretrainedXLM(TransformerEncoder): - def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) - if getattr(args, 'init_decoder_only', False): + if getattr(args, "init_decoder_only", False): # Don't load XLM weights for encoder if --init-decoder-only return @@ -130,10 +128,9 @@ def __init__(self, args, dictionary, embed_tokens): class TransformerDecoderFromPretrainedXLM(TransformerDecoder): - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(args, dictionary, embed_tokens, no_encoder_attn) - if getattr(args, 'init_encoder_only', False): + if getattr(args, "init_encoder_only", False): # Don't load XLM weights for decoder if --init-encoder-only return assert hasattr(args, "pretrained_xlm_checkpoint"), ( diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py index 905df824f3..772995b526 100644 --- a/fairseq/models/wav2vec/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.modules import ( Fp32GroupNorm, @@ -21,6 +20,7 @@ ) from fairseq.utils import buffered_arange + logger = logging.getLogger(__name__) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 4f1ab2277f..6a0f787601 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -5,14 +5,12 @@ import logging import math -import numpy as np +from typing import List, Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from typing import List, Tuple - from fairseq import utils from fairseq.data.data_utils import compute_mask_indices from fairseq.models import BaseFairseqModel, register_model, register_model_architecture diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index e47e1f7009..52ca9a8007 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -6,19 +6,17 @@ import contextlib import copy import math -import numpy as np +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import checkpoint_utils, tasks, utils - from fairseq.models import ( + BaseFairseqModel, FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, - BaseFairseqModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) @@ -335,7 +333,9 @@ def __init__(self, args, tgt_dict=None): state = None w2v_args = args.w2v_args - assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same' + assert ( + args.normalize == w2v_args.normalize + ), "Fine-tuning works best when data normalization is the same" w2v_args.data = args.data task = tasks.setup_task(w2v_args) @@ -358,7 +358,7 @@ def __init__(self, args, tgt_dict=None): if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) - elif getattr(args, 'decoder_embed_dim', d) != d: + elif getattr(args, "decoder_embed_dim", d) != d: self.proj = Linear(d, args.decoder_embed_dim) else: self.proj = None @@ -668,6 +668,8 @@ def seq2seq_architecture(args): args.decoder_dropout = getattr(args, "decoder_dropout", 0) args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) - args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) base_architecture(args) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 52432e0de4..e2326ac6e3 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -37,40 +37,40 @@ from .vggblock import VGGBlock __all__ = [ - 'AdaptiveInput', - 'AdaptiveSoftmax', - 'BeamableMM', - 'CharacterTokenEmbedder', - 'ConvTBC', - 'cross_entropy', - 'DownsampledMultiHeadAttention', - 'DynamicConv1dTBC', - 'DynamicConv', - 'DynamicCRF', - 'FairseqDropout', - 'Fp32GroupNorm', - 'Fp32LayerNorm', - 'gelu', - 'gelu_accurate', - 'GradMultiply', - 'GumbelVectorQuantizer', - 'KmeansVectorQuantizer', - 'LayerDropModuleList', - 'LayerNorm', - 'LearnedPositionalEmbedding', - 'LightweightConv1dTBC', - 'LightweightConv', - 'LinearizedConvolution', - 'MultiheadAttention', - 'PositionalEmbedding', - 'SamePad', - 'ScalarBias', - 'SinusoidalPositionalEmbedding', - 'TransformerSentenceEncoderLayer', - 'TransformerSentenceEncoder', - 'TransformerDecoderLayer', - 'TransformerEncoderLayer', - 'TransposeLast', - 'VGGBlock', - 'unfold1d', + "AdaptiveInput", + "AdaptiveSoftmax", + "BeamableMM", + "CharacterTokenEmbedder", + "ConvTBC", + "cross_entropy", + "DownsampledMultiHeadAttention", + "DynamicConv1dTBC", + "DynamicConv", + "DynamicCRF", + "FairseqDropout", + "Fp32GroupNorm", + "Fp32LayerNorm", + "gelu", + "gelu_accurate", + "GradMultiply", + "GumbelVectorQuantizer", + "KmeansVectorQuantizer", + "LayerDropModuleList", + "LayerNorm", + "LearnedPositionalEmbedding", + "LightweightConv1dTBC", + "LightweightConv", + "LinearizedConvolution", + "MultiheadAttention", + "PositionalEmbedding", + "SamePad", + "ScalarBias", + "SinusoidalPositionalEmbedding", + "TransformerSentenceEncoderLayer", + "TransformerSentenceEncoder", + "TransformerDecoderLayer", + "TransformerEncoderLayer", + "TransposeLast", + "VGGBlock", + "unfold1d", ] diff --git a/fairseq/modules/adaptive_input.py b/fairseq/modules/adaptive_input.py index 4cfe8fca66..446534a9f8 100644 --- a/fairseq/modules/adaptive_input.py +++ b/fairseq/modules/adaptive_input.py @@ -4,15 +4,14 @@ # LICENSE file in the root directory of this source tree. +from typing import List + import torch -from torch import nn from fairseq.modules.quant_noise import quant_noise - -from typing import List +from torch import nn class AdaptiveInput(nn.Module): - def __init__( self, vocab_size: int, @@ -29,8 +28,9 @@ def __init__( if vocab_size > cutoff[-1]: cutoff = cutoff + [vocab_size] else: - assert vocab_size == cutoff[ - -1], 'cannot specify cutoff larger than vocab size' + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" self.cutoff = cutoff self.embedding_dim = output_dim @@ -43,7 +43,9 @@ def __init__( dim = int(initial_dim // (factor ** i)) seq = nn.Sequential( nn.Embedding(size, dim, self.padding_idx), - quant_noise(nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size), + quant_noise( + nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size + ), ) self.embeddings.append(seq) @@ -54,12 +56,12 @@ def init_weights(m): if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) nn.init.constant_(m.weight[padding_idx], 0) - elif hasattr(m, 'weight'): + elif hasattr(m, "weight"): nn.init.xavier_uniform_(m.weight) self.apply(init_weights) - self.register_buffer('_float_tensor', torch.FloatTensor(1)) + self.register_buffer("_float_tensor", torch.FloatTensor(1)) def weights_for_band(self, band: int): return self.embeddings[band][0].weight, self.embeddings[band][1].weight diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py index 8e47134a70..ae0c77ba0f 100644 --- a/fairseq/modules/adaptive_softmax.py +++ b/fairseq/modules/adaptive_softmax.py @@ -3,13 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import operator import functools +import operator import torch import torch.nn.functional as F -from fairseq.modules.quant_noise import quant_noise from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise from torch import nn @@ -29,23 +29,29 @@ def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size): tied_emb, _ = weights self.num_words, emb_dim = tied_emb.size() - self.word_proj = quant_noise(TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size) + self.word_proj = quant_noise( + TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size + ) if input_dim != emb_dim: self.word_proj = nn.Sequential( - quant_noise(nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size), + quant_noise( + nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size + ), self.word_proj, ) - self.class_proj = quant_noise(nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size) + self.class_proj = quant_noise( + nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size + ) self.out_dim = self.num_words + num_classes - self.register_buffer('_float_tensor', torch.FloatTensor(1)) + self.register_buffer("_float_tensor", torch.FloatTensor(1)) def forward(self, input): inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1) out = self._float_tensor.new(inp_sz, self.out_dim) - out[:, :self.num_words] = self.word_proj(input.view(inp_sz, -1)) - out[:, self.num_words:] = self.class_proj(input.view(inp_sz, -1)) + out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1)) + out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1)) return out @@ -56,21 +62,34 @@ class AdaptiveSoftmax(nn.Module): approximation for GPUs" (http://arxiv.org/abs/1609.04309). """ - def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_inputs=None, tie_proj=False, - q_noise=0, qn_block_size=8): + def __init__( + self, + vocab_size, + input_dim, + cutoff, + dropout, + factor=4.0, + adaptive_inputs=None, + tie_proj=False, + q_noise=0, + qn_block_size=8, + ): super().__init__() if vocab_size > cutoff[-1]: cutoff = cutoff + [vocab_size] else: - assert vocab_size == cutoff[ - -1], 'cannot specify cutoff larger than vocab size' + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" output_dim = cutoff[0] + len(cutoff) - 1 self.vocab_size = vocab_size self.cutoff = cutoff - self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.input_dim = input_dim self.factor = factor self.q_noise = q_noise @@ -79,38 +98,69 @@ def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_i self.lsm = nn.LogSoftmax(dim=1) if adaptive_inputs is not None: - self.head = TiedHeadModule(adaptive_inputs.weights_for_band(0), input_dim, len(cutoff) - 1, self.q_noise, self.qn_block_size) + self.head = TiedHeadModule( + adaptive_inputs.weights_for_band(0), + input_dim, + len(cutoff) - 1, + self.q_noise, + self.qn_block_size, + ) else: - self.head = quant_noise(nn.Linear(input_dim, output_dim, bias=False), self.q_noise, self.qn_block_size) + self.head = quant_noise( + nn.Linear(input_dim, output_dim, bias=False), + self.q_noise, + self.qn_block_size, + ) self._make_tail(adaptive_inputs, tie_proj) def init_weights(m): - if hasattr(m, 'weight') and not isinstance(m, TiedLinear) and not isinstance(m, TiedHeadModule): + if ( + hasattr(m, "weight") + and not isinstance(m, TiedLinear) + and not isinstance(m, TiedHeadModule) + ): nn.init.xavier_uniform_(m.weight) self.apply(init_weights) - self.register_buffer('version', torch.LongTensor([1])) + self.register_buffer("version", torch.LongTensor([1])) def _make_tail(self, adaptive_inputs=None, tie_proj=False): self.tail = nn.ModuleList() for i in range(len(self.cutoff) - 1): dim = int(self.input_dim // self.factor ** (i + 1)) - tied_emb, tied_proj = adaptive_inputs.weights_for_band(i + 1) \ - if adaptive_inputs is not None else (None, None) + tied_emb, tied_proj = ( + adaptive_inputs.weights_for_band(i + 1) + if adaptive_inputs is not None + else (None, None) + ) if tied_proj is not None: if tie_proj: - proj = quant_noise(TiedLinear(tied_proj, transpose=True), self.q_noise, self.qn_block_size) + proj = quant_noise( + TiedLinear(tied_proj, transpose=True), + self.q_noise, + self.qn_block_size, + ) else: - proj = quant_noise(nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), self.q_noise, self.qn_block_size) + proj = quant_noise( + nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), + self.q_noise, + self.qn_block_size, + ) else: - proj = quant_noise(nn.Linear(self.input_dim, dim, bias=False), self.q_noise, self.qn_block_size) + proj = quant_noise( + nn.Linear(self.input_dim, dim, bias=False), + self.q_noise, + self.qn_block_size, + ) if tied_emb is None: - out_proj = nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) + out_proj = nn.Linear( + dim, self.cutoff[i + 1] - self.cutoff[i], bias=False + ) else: out_proj = TiedLinear(tied_emb, transpose=False) @@ -123,9 +173,9 @@ def _make_tail(self, adaptive_inputs=None, tie_proj=False): self.tail.append(m) def upgrade_state_dict_named(self, state_dict, name): - version_name = name + '.version' + version_name = name + ".version" if version_name not in state_dict: - raise Exception('This version of the model is no longer supported') + raise Exception("This version of the model is no longer supported") def adapt_target(self, target): """ @@ -194,7 +244,7 @@ def get_log_prob(self, input, target): head_sz = self.cutoff[0] + len(self.tail) log_probs[:, :head_sz] = self.lsm(head_y) - tail_priors = log_probs[:, self.cutoff[0]: head_sz].clone() + tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone() for i in range(len(self.tail)): start = self.cutoff[i] @@ -203,12 +253,16 @@ def get_log_prob(self, input, target): if target_idxs is None: tail_out = log_probs[:, start:end] tail_out.copy_(self.tail[i](input)) - log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None]) + log_probs[:, start:end] = self.lsm(tail_out).add_( + tail_priors[:, i, None] + ) elif target_idxs[i] is not None: idxs = target_idxs[i] tail_out = log_probs[idxs, start:end] tail_out.copy_(self.tail[i](input[idxs])) - log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None]) + log_probs[idxs, start:end] = self.lsm(tail_out).add_( + tail_priors[idxs, i, None] + ) log_probs = log_probs.view(bsz, length, -1) return log_probs diff --git a/fairseq/modules/beamable_mm.py b/fairseq/modules/beamable_mm.py index df77105a94..eff1a4607f 100644 --- a/fairseq/modules/beamable_mm.py +++ b/fairseq/modules/beamable_mm.py @@ -15,16 +15,18 @@ class BeamableMM(nn.Module): inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. """ + def __init__(self, beam_size=None): super(BeamableMM, self).__init__() self.beam_size = beam_size def forward(self, input1, input2): if ( - not self.training and # test mode - self.beam_size is not None and # beam size is set - input1.dim() == 3 and # only support batched input - input1.size(1) == 1 # single time step update + not self.training + and self.beam_size is not None # test mode + and input1.dim() == 3 # beam size is set + and input1.size(1) # only support batched input + == 1 # single time step update ): bsz, beam = input1.size(0), self.beam_size diff --git a/fairseq/modules/character_token_embedder.py b/fairseq/modules/character_token_embedder.py index 3abdaf4f28..181221b61b 100644 --- a/fairseq/modules/character_token_embedder.py +++ b/fairseq/modules/character_token_embedder.py @@ -7,10 +7,10 @@ from typing import List, Tuple import torch -from torch import nn import torch.nn.functional as F - from fairseq.data import Dictionary +from torch import nn + CHAR_PAD_IDX = 0 CHAR_EOS_IDX = 257 @@ -21,14 +21,14 @@ class CharacterTokenEmbedder(torch.nn.Module): def __init__( - self, - vocab: Dictionary, - filters: List[Tuple[int, int]], - char_embed_dim: int, - word_embed_dim: int, - highway_layers: int, - max_char_len: int = 50, - char_inputs: bool = False + self, + vocab: Dictionary, + filters: List[Tuple[int, int]], + char_embed_dim: int, + word_embed_dim: int, + highway_layers: int, + max_char_len: int = 50, + char_inputs: bool = False, ): super(CharacterTokenEmbedder, self).__init__() @@ -52,7 +52,9 @@ def __init__( self.projection = nn.Linear(last_dim, word_embed_dim) - assert vocab is not None or char_inputs, "vocab must be set if not using char inputs" + assert ( + vocab is not None or char_inputs + ), "vocab must be set if not using char inputs" self.vocab = None if vocab is not None: self.set_vocab(vocab, max_char_len) @@ -79,7 +81,11 @@ def set_vocab(self, vocab, max_char_len): word_to_char[i] = torch.LongTensor(char_idxs) if truncated > 0: - logger.info('truncated {} words longer than {} characters'.format(truncated, max_char_len)) + logger.info( + "truncated {} words longer than {} characters".format( + truncated, max_char_len + ) + ) self.vocab = vocab self.word_to_char = word_to_char @@ -93,12 +99,14 @@ def reset_parameters(self): nn.init.xavier_normal_(self.symbol_embeddings) nn.init.xavier_uniform_(self.projection.weight) - nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.) - nn.init.constant_(self.projection.bias, 0.) + nn.init.constant_( + self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0 + ) + nn.init.constant_(self.projection.bias, 0.0) def forward( - self, - input: torch.Tensor, + self, + input: torch.Tensor, ): if self.char_inputs: chars = input.view(-1, self.max_char_len) @@ -113,7 +121,9 @@ def forward( unk = None else: flat_words = input.view(-1) - chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(input) + chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as( + input + ) pads = flat_words.eq(self.vocab.pad()) eos = flat_words.eq(self.vocab.eos()) unk = flat_words.eq(self.vocab.unk()) @@ -121,11 +131,17 @@ def forward( word_embs = self._convolve(chars) if self.onnx_trace: if pads.any(): - word_embs = torch.where(pads.unsqueeze(1), word_embs.new_zeros(1), word_embs) + word_embs = torch.where( + pads.unsqueeze(1), word_embs.new_zeros(1), word_embs + ) if eos.any(): - word_embs = torch.where(eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs) + word_embs = torch.where( + eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs + ) if unk is not None and unk.any(): - word_embs = torch.where(unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs) + word_embs = torch.where( + unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs + ) else: if pads.any(): word_embs[pads] = 0 @@ -137,8 +153,8 @@ def forward( return word_embs.view(input.size()[:2] + (-1,)) def _convolve( - self, - char_idxs: torch.Tensor, + self, + char_idxs: torch.Tensor, ): char_embs = self.char_embeddings(char_idxs) char_embs = char_embs.transpose(1, 2) # BTC -> BCT @@ -166,15 +182,12 @@ class Highway(torch.nn.Module): Adopted from the AllenNLP implementation. """ - def __init__( - self, - input_dim: int, - num_layers: int = 1 - ): + def __init__(self, input_dim: int, num_layers: int = 1): super(Highway, self).__init__() self.input_dim = input_dim - self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) - for _ in range(num_layers)]) + self.layers = nn.ModuleList( + [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)] + ) self.activation = nn.ReLU() self.reset_parameters() @@ -186,15 +199,12 @@ def reset_parameters(self): # setting the bias on `B(x)` to be positive, because that means `g` will be biased to # be high, so we will carry the input forward. The bias on `B(x)` is the second half # of the bias vector in each Linear layer. - nn.init.constant_(layer.bias[self.input_dim:], 1) + nn.init.constant_(layer.bias[self.input_dim :], 1) - nn.init.constant_(layer.bias[:self.input_dim], 0) + nn.init.constant_(layer.bias[: self.input_dim], 0) nn.init.xavier_normal_(layer.weight) - def forward( - self, - x: torch.Tensor - ): + def forward(self, x: torch.Tensor): for layer in self.layers: projection = layer(x) proj_x, gate = projection.chunk(2, dim=-1) diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py index 1aa3eff9dc..2dc46c4b9b 100644 --- a/fairseq/modules/conv_tbc.py +++ b/fairseq/modules/conv_tbc.py @@ -13,6 +13,7 @@ class ConvTBC(torch.nn.Module): The implementation uses gemm to perform the convolution. This implementation is faster than cuDNN for small kernel sizes. """ + def __init__(self, in_channels, out_channels, kernel_size, padding=0): super(ConvTBC, self).__init__() self.in_channels = in_channels @@ -20,17 +21,22 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0): self.kernel_size = _single(kernel_size) self.padding = _single(padding) - self.weight = torch.nn.Parameter(torch.Tensor( - self.kernel_size[0], in_channels, out_channels)) + self.weight = torch.nn.Parameter( + torch.Tensor(self.kernel_size[0], in_channels, out_channels) + ) self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) def forward(self, input): - return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) + return torch.conv_tbc( + input.contiguous(), self.weight, self.bias, self.padding[0] + ) def __repr__(self): - s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' - ', padding={padding}') + s = ( + "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" + ", padding={padding}" + ) if self.bias is None: - s += ', bias=False' - s += ')' + s += ", bias=False" + s += ")" return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py index b46143f3af..0d2beb44bb 100644 --- a/fairseq/modules/cross_entropy.py +++ b/fairseq/modules/cross_entropy.py @@ -12,10 +12,13 @@ logger = logging.getLogger(__name__) -def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'): +def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) return F.nll_loss( - lprobs, target, ignore_index=ignore_index, reduction=reduction, + lprobs, + target, + ignore_index=ignore_index, + reduction=reduction, ) @@ -23,29 +26,34 @@ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'): import xentropy_cuda from apex.contrib import xentropy - logger.info('using fused cross entropy') + logger.info("using fused cross entropy") - def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): - if logits.device == torch.device('cpu'): + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): + if logits.device == torch.device("cpu"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction) else: - half_to_float = (logits.dtype == torch.half) + half_to_float = logits.dtype == torch.half losses = xentropy.SoftmaxCrossEntropyLoss.apply( - logits, target, 0.0, ignore_index, half_to_float, + logits, + target, + 0.0, + ignore_index, + half_to_float, ) - if reduction == 'sum': + if reduction == "sum": return losses.sum() - elif reduction == 'mean': + elif reduction == "mean": if ignore_index >= 0: return losses.sum() / target.ne(ignore_index).sum() else: return losses.mean() - elif reduction == 'none': + elif reduction == "none": return losses else: raise NotImplementedError + except ImportError: - def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction) diff --git a/fairseq/modules/downsampled_multihead_attention.py b/fairseq/modules/downsampled_multihead_attention.py index eeaf9bbdd3..2cdece3f7f 100644 --- a/fairseq/modules/downsampled_multihead_attention.py +++ b/fairseq/modules/downsampled_multihead_attention.py @@ -9,22 +9,33 @@ import torch import torch.nn as nn import torch.nn.functional as F -from fairseq.modules.scalar_bias import scalar_bias from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.scalar_bias import scalar_bias class SingleHeadAttention(nn.Module): """ Single-head attention that supports Gating and Downsampling """ + def __init__( - self, out_channels, embed_dim, head_dim, head_index, dropout=0., - bias=True, project_input=True, gated=False, downsample=False, + self, + out_channels, + embed_dim, + head_dim, + head_index, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, num_heads=1, ): super().__init__() self.embed_dim = embed_dim - self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.head_index = head_index self.head_dim = head_dim self.project_input = project_input @@ -58,11 +69,16 @@ def __init__( else: self.out_proj = Linear(out_proj_size, out_channels, bias=bias) - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 def forward( - self, query, key, value, mask_future_timesteps=False, - key_padding_mask=None, use_scalar_bias=False, + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, ): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for @@ -106,16 +122,17 @@ def forward( attn_weights = torch.bmm(q, k.transpose(1, 2)) if mask_future_timesteps: - assert query.size() == key.size(), \ - 'mask_future_timesteps only applies to self-attention' + assert ( + query.size() == key.size() + ), "mask_future_timesteps only applies to self-attention" attn_weights *= torch.tril( attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), diagonal=-1, - )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) attn_weights += torch.triu( attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), - diagonal=0 - )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) + diagonal=0, + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) tgt_size = tgt_len if use_scalar_bias: attn_weights = scalar_bias(attn_weights, 2) @@ -128,7 +145,9 @@ def forward( if self.downsample: attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) else: - attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view( + size, self.num_heads, tgt_len, src_len + ) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -math.inf, @@ -152,9 +171,17 @@ class DownsampledMultiHeadAttention(nn.ModuleList): """ Multi-headed attention with Gating and Downsampling """ + def __init__( - self, out_channels, embed_dim, num_heads, dropout=0., bias=True, - project_input=True, gated=False, downsample=False, + self, + out_channels, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -169,9 +196,16 @@ def __init__( for index in range(self.num_heads): attention_heads.append( SingleHeadAttention( - out_channels, self.embed_dim, self.head_dim, index, - dropout, bias, self.project_input, self.gated, - self.downsample, self.num_heads, + out_channels, + self.embed_dim, + self.head_dim, + index, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, ) ) super().__init__(modules=attention_heads) @@ -181,13 +215,26 @@ def __init__( # if not being downsampled, we can do the heads with one linear layer instead of separate ones super().__init__() self.attention_module = SingleHeadAttention( - out_channels, self.embed_dim, self.head_dim, 1, dropout, - bias, self.project_input, self.gated, self.downsample, self.num_heads, + out_channels, + self.embed_dim, + self.head_dim, + 1, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, ) def forward( - self, query, key, value, mask_future_timesteps=False, - key_padding_mask=None, use_scalar_bias=False, + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, ): src_len, bsz, embed_dim = key.size() tgt_len = query.size(0) @@ -205,7 +252,12 @@ def forward( for attention_head_number in range(self.num_heads): # call the forward of each attention head _attn, _attn_weight = self[attention_head_number]( - query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias, + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, ) attn.append(_attn) attn_weights.append(_attn_weight) @@ -214,13 +266,20 @@ def forward( return full_attn, attn_weights[0].clone() else: _attn, _attn_weight = self.attention_module( - query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias, + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, ) attn.append(_attn) attn_weights.append(_attn_weight) full_attn = torch.cat(attn, dim=2) full_attn_weights = torch.cat(attn_weights) - full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len) + full_attn_weights = full_attn_weights.view( + bsz, self.num_heads, tgt_size, src_len + ) full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads return full_attn, full_attn_weights @@ -229,15 +288,16 @@ class Downsample(nn.Module): """ Selects every nth element, where n is the index """ + def __init__(self, index): super().__init__() self.index = index def forward(self, x): - return x[::self.index+1] + return x[:: self.index + 1] -def Linear(in_features, out_features, dropout=0., bias=True): +def Linear(in_features, out_features, dropout=0.0, bias=True): """Weight-normalized Linear layer (input: B x T x C)""" m = nn.Linear(in_features, out_features, bias=bias) m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) @@ -245,12 +305,12 @@ def Linear(in_features, out_features, dropout=0., bias=True): return nn.utils.weight_norm(m) -def GatedLinear(in_features, out_features, dropout=0., bias=True): +def GatedLinear(in_features, out_features, dropout=0.0, bias=True): """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" return nn.Sequential( - Linear(in_features, out_features*4, dropout, bias), + Linear(in_features, out_features * 4, dropout, bias), nn.GLU(), - Linear(out_features*2, out_features*2, dropout, bias), + Linear(out_features * 2, out_features * 2, dropout, bias), nn.GLU(), - Linear(out_features, out_features, dropout, bias) + Linear(out_features, out_features, dropout, bias), ) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py index 5a8ecb99a8..5999a04539 100644 --- a/fairseq/modules/dynamic_convolution.py +++ b/fairseq/modules/dynamic_convolution.py @@ -6,43 +6,63 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils -from .unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout +from .unfold import unfold1d + -def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, - weight_dropout=0., weight_softmax=False, - renorm_padding=False, bias=False, conv_bias=False, - query_size=None, in_proj=False): +def DynamicConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, +): if torch.cuda.is_available(): try: from fairseq.modules.dynamicconv_layer import DynamicconvLayer - return DynamicconvLayer(input_size, kernel_size=kernel_size, - padding_l=padding_l, num_heads=num_heads, - weight_dropout=weight_dropout, - weight_softmax=weight_softmax, bias=bias) + + return DynamicconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) except ImportError as e: print(e) - return DynamicConv1dTBC(input_size, kernel_size=kernel_size, - padding_l=padding_l, num_heads=num_heads, - weight_dropout=weight_dropout, - weight_softmax=weight_softmax, bias=bias) + return DynamicConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: - nn.init.constant_(m.bias, 0.) + nn.init.constant_(m.bias, 0.0) return m @with_incremental_state class DynamicConv1dTBC(nn.Module): - '''Dynamic lightweight convolution taking T x B x C inputs + """Dynamic lightweight convolution taking T x B x C inputs Args: input_size: # of channels of the input kernel_size: convolution channels @@ -64,25 +84,42 @@ class DynamicConv1dTBC(nn.Module): weight: the learnable weights of the module of shape `(num_heads, 1, kernel_size)` bias: the learnable bias of the module of shape `(input_size)` - ''' - def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, - weight_dropout=0., weight_softmax=False, - renorm_padding=False, bias=False, conv_bias=False, - query_size=None, in_proj=False): + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, + ): super().__init__() self.input_size = input_size self.query_size = input_size if query_size is None else query_size self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads - self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) self.weight_softmax = weight_softmax self.renorm_padding = renorm_padding if in_proj: - self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1) + self.weight_linear = Linear( + self.input_size, self.input_size + num_heads * kernel_size * 1 + ) else: - self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias) + self.weight_linear = Linear( + self.query_size, num_heads * kernel_size * 1, bias=bias + ) if conv_bias: self.conv_bias = nn.Parameter(torch.Tensor(input_size)) else: @@ -91,22 +128,27 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, @property def in_proj(self): - return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size + return ( + self.weight_linear.out_features + == self.input_size + self.num_heads * self.kernel_size + ) def reset_parameters(self): self.weight_linear.reset_parameters() if self.conv_bias is not None: - nn.init.constant_(self.conv_bias, 0.) + nn.init.constant_(self.conv_bias, 0.0) def forward(self, x, incremental_state=None, query=None, unfold=None): - '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C args: x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) incremental_state: A dict to keep the state unfold: unfold the input or not. If not, we use the matrix trick instead query: use the specified query to predict the conv filters - ''' - unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory + """ + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory unfold = unfold or (incremental_state is not None) assert query is None or not self.in_proj @@ -122,8 +164,8 @@ def forward(self, x, incremental_state=None, query=None, unfold=None): return output def _forward_unfolded(self, x, incremental_state, query): - '''The conventional implementation of convolutions. - Unfolding the input by having a window shifting to the right.''' + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H @@ -132,9 +174,11 @@ def _forward_unfolded(self, x, incremental_state, query): if self.in_proj: proj = self.weight_linear(x) x = proj.narrow(2, 0, self.input_size).contiguous() - weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) else: - weight = self.weight_linear(query).view(T*B*H, -1) + weight = self.weight_linear(query).view(T * B * H, -1) # renorm_padding is only implemented in _forward_expanded assert not self.renorm_padding or incremental_state is not None @@ -145,23 +189,25 @@ def _forward_unfolded(self, x, incremental_state, query): input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: - self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) - x_unfold = x_unfold.view(T*B*H, R, -1) + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) else: padding_l = self.padding_l - if K > T and padding_l == K-1: - weight = weight.narrow(1, K-T, T) - K, padding_l = T, T-1 + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, K, padding_l, 0) - x_unfold = x_unfold.view(T*B*H, R, K) + x_unfold = x_unfold.view(T * B * H, R, K) if self.weight_softmax and not self.renorm_padding: weight = F.softmax(weight, dim=1) weight = weight.narrow(1, 0, K) if incremental_state is not None: - weight = weight[:, -x_unfold.size(2):] + weight = weight[:, -x_unfold.size(2) :] K = weight.size(1) if self.weight_softmax and self.renorm_padding: @@ -174,10 +220,10 @@ def _forward_unfolded(self, x, incremental_state, query): return output def _forward_expanded(self, x, incremental_stat, query): - '''Turn the convolution filters into band matrices and do matrix multiplication. + """Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. - ''' + """ T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H @@ -185,22 +231,26 @@ def _forward_expanded(self, x, incremental_stat, query): if self.in_proj: proj = self.weight_linear(x) x = proj.narrow(2, 0, self.input_size).contiguous() - weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) else: - weight = self.weight_linear(query).view(T*B*H, -1) + weight = self.weight_linear(query).view(T * B * H, -1) if not self.renorm_padding: if self.weight_softmax: weight = F.softmax(weight, dim=1) weight = self.weight_dropout_module(weight, inplace=False) weight = weight.narrow(1, 0, K).contiguous() - weight = weight.view(T, B*H, K).transpose(0, 1) + weight = weight.view(T, B * H, K).transpose(0, 1) - x = x.view(T, B*H, R).transpose(0, 1) + x = x.view(T, B * H, R).transpose(0, 1) if self.weight_softmax and self.renorm_padding: # turn the convolution filters into band matrices - weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) - weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) weight_expanded = weight_expanded.narrow(2, self.padding_l, T) # normalize the weight over valid positions like self-attention weight_expanded = F.softmax(weight_expanded, dim=2) @@ -208,12 +258,14 @@ def _forward_expanded(self, x, incremental_stat, query): else: P = self.padding_l # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length - if K > T and P == K-1: - weight = weight.narrow(2, K-T, T) - K, P = T, T-1 + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 # turn the convolution filters into band matrices - weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) - weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) @@ -226,20 +278,27 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): - return utils.get_incremental_state(self, incremental_state, 'input_buffer') + return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): - return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) def extra_repr(self): - s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format( - self.input_size, self.kernel_size, self.padding_l, - self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding, + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.conv_bias is not None, + self.renorm_padding, self.in_proj, ) if self.query_size != self.input_size: - s += ', query_size={}'.format(self.query_size) - if self.weight_dropout_module.p > 0.: - s += ', weight_dropout={}'.format(self.weight_dropout_module.p) + s += ", query_size={}".format(self.query_size) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) return s diff --git a/fairseq/modules/dynamic_crf_layer.py b/fairseq/modules/dynamic_crf_layer.py index 6f5acf3772..8fcc6b8d26 100644 --- a/fairseq/modules/dynamic_crf_layer.py +++ b/fairseq/modules/dynamic_crf_layer.py @@ -27,16 +27,16 @@ def logsumexp(x, dim=1): class DynamicCRF(nn.Module): """Dynamic CRF layer is used to approximate the traditional - Conditional Random Fields (CRF) - $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$ + Conditional Random Fields (CRF) + $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$ - where in this function, we assume the emition scores (s) are given, - and the transition score is a |V| x |V| matrix $M$ + where in this function, we assume the emition scores (s) are given, + and the transition score is a |V| x |V| matrix $M$ - in the following two aspects: - (1) it used a low-rank approximation for the transition matrix: - $M = E_1 E_2^T$ - (2) it used a beam to estimate the normalizing factor Z(x) + in the following two aspects: + (1) it used a low-rank approximation for the transition matrix: + $M = E_1 E_2^T$ + (2) it used a beam to estimate the normalizing factor Z(x) """ def __init__(self, num_embedding, low_rank=32, beam_size=64): @@ -51,7 +51,8 @@ def __init__(self, num_embedding, low_rank=32, beam_size=64): def extra_repr(self): return "vocab_size={}, low_rank={}, beam_size={}".format( - self.vocb, self.rank, self.beam) + self.vocb, self.rank, self.beam + ) def forward(self, emissions, targets, masks, beam=None): """ @@ -104,26 +105,27 @@ def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None): beam = beam if beam is not None else self.beam batch_size, seq_len = emissions.size()[:2] if targets is not None: - _emissions = emissions.scatter(2, targets[:, :, None], np.float('inf')) + _emissions = emissions.scatter(2, targets[:, :, None], np.float("inf")) beam_targets = _emissions.topk(beam, 2)[1] beam_emission_scores = emissions.gather(2, beam_targets) else: beam_emission_scores, beam_targets = emissions.topk(beam, 2) beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D - beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D beam_transition_matrix = torch.bmm( beam_transition_score1.view(-1, beam, self.rank), - beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2)) + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # compute the normalizer in the log-space score = beam_emission_scores[:, 0] # B x K for i in range(1, seq_len): - next_score = score[:, :, None] + beam_transition_matrix[:, i-1] + next_score = score[:, :, None] + beam_transition_matrix[:, i - 1] next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i] if masks is not None: - score = torch.where(masks[:, i:i+1], next_score, score) + score = torch.where(masks[:, i : i + 1], next_score, score) else: score = next_score @@ -137,10 +139,11 @@ def _viterbi_decode(self, emissions, masks=None, beam=None): batch_size, seq_len = emissions.size()[:2] beam_emission_scores, beam_targets = emissions.topk(beam, 2) beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D - beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D beam_transition_matrix = torch.bmm( beam_transition_score1.view(-1, beam, self.rank), - beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2)) + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) traj_tokens, traj_scores = [], [] @@ -148,17 +151,19 @@ def _viterbi_decode(self, emissions, masks=None, beam=None): # compute the normalizer in the log-space score = beam_emission_scores[:, 0] # B x K - dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous() + dummy = ( + torch.arange(beam, device=score.device).expand(*score.size()).contiguous() + ) for i in range(1, seq_len): traj_scores.append(score) - _score = score[:, :, None] + beam_transition_matrix[:, i-1] + _score = score[:, :, None] + beam_transition_matrix[:, i - 1] _score, _index = _score.max(dim=1) _score = _score + beam_emission_scores[:, i] if masks is not None: - score = torch.where(masks[:, i: i+1], _score, score) - index = torch.where(masks[:, i: i+1], _index, dummy) + score = torch.where(masks[:, i : i + 1], _score, score) + index = torch.where(masks[:, i : i + 1], _index, dummy) else: score, index = _score, _index traj_tokens.append(index) diff --git a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py index 926d6ca846..9304f99eb8 100644 --- a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py +++ b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py @@ -77,7 +77,7 @@ def gen_forward(): } """ - with open("dynamicconv_cuda_forward.cu", 'w') as forward: + with open("dynamicconv_cuda_forward.cu", "w") as forward: forward.write(head) forward.write(switch) for k in kernels: @@ -191,7 +191,7 @@ def gen_backward(): } """ - with open("dynamicconv_cuda_backward.cu", 'w') as backward: + with open("dynamicconv_cuda_backward.cu", "w") as backward: backward.write(head) for seq in seqs: backward.write(sequence_if.format(seq=seq)) diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py index 52cc1e8118..4a683d2690 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -3,20 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import dynamicconv_cuda import torch -from torch import nn -from torch.autograd import Function import torch.nn.functional as F - -import dynamicconv_cuda from fairseq import utils -from fairseq.modules.unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d +from torch import nn +from torch.autograd import Function class dynamicconvFunction(Function): - @staticmethod def forward(ctx, x, weights, padding_l): ctx.padding_l = padding_l @@ -28,9 +26,8 @@ def forward(ctx, x, weights, padding_l): @staticmethod def backward(ctx, grad_output): outputs = dynamicconv_cuda.backward( - grad_output.contiguous(), - ctx.padding_l, - *ctx.saved_tensors) + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) grad_input, grad_weights = outputs return grad_input, grad_weights, None @@ -38,17 +35,17 @@ def backward(ctx, grad_output): @with_incremental_state class DynamicconvLayer(nn.Module): def __init__( - self, - input_size, - kernel_size=1, - padding_l=None, - weight_softmax=False, - num_heads=1, - weight_dropout=0., - bias=False, - renorm_padding=False, - conv_bias=False, - query_size=None, + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, + renorm_padding=False, + conv_bias=False, + query_size=None, ): super(DynamicconvLayer, self).__init__() @@ -58,7 +55,9 @@ def __init__( self.padding_l = padding_l self.num_heads = num_heads self.weight_softmax = weight_softmax - self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) self.renorm_padding = renorm_padding self.bias = bias @@ -72,8 +71,8 @@ def __init__( def reset_parameters(self): nn.init.xavier_uniform_(self.weight_linear.weight) if self.conv_bias is not None: - nn.init.constant_(self.conv_bias, 0.) - nn.init.constant_(self.weight_linaer.bias, 0.) + nn.init.constant_(self.conv_bias, 0.0) + nn.init.constant_(self.weight_linaer.bias, 0.0) def forward(self, x, incremental_state=None, query=None, unfold=None): @@ -83,7 +82,9 @@ def forward(self, x, incremental_state=None, query=None, unfold=None): # during inference time, incremental BMM is faster if incremental_state is not None: - unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory unfold = unfold or (incremental_state is not None) assert query is None @@ -110,7 +111,9 @@ def forward(self, x, incremental_state=None, query=None, unfold=None): weight = weight.permute(1, 2, 3, 0).contiguous() self.filters = weight x = x.permute(1, 2, 0).contiguous() - output = dynamicconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) + output = dynamicconvFunction.apply(x, weight, self.padding_l).permute( + 2, 0, 1 + ) if self.conv_bias is not None: output = output + self.conv_bias.view(1, 1, -1) return output @@ -122,20 +125,22 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): - return utils.get_incremental_state(self, incremental_state, 'input_buffer') + return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): - return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) def _forward_unfolded(self, x, incremental_state, query): - '''The conventional implementation of convolutions. - Unfolding the input by having a window shifting to the right.''' + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size - weight = self.weight_linear(query).view(T*B*H, -1) + weight = self.weight_linear(query).view(T * B * H, -1) # renorm_padding is only implemented in _forward_expanded assert not self.renorm_padding or incremental_state is not None @@ -146,23 +151,25 @@ def _forward_unfolded(self, x, incremental_state, query): input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: - self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) - x_unfold = x_unfold.view(T*B*H, R, -1) + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) else: padding_l = self.padding_l - if K > T and padding_l == K-1: - weight = weight.narrow(1, K-T, T) - K, padding_l = T, T-1 + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, K, padding_l, 0) - x_unfold = x_unfold.view(T*B*H, R, K) + x_unfold = x_unfold.view(T * B * H, R, K) if self.weight_softmax and not self.renorm_padding: weight = F.softmax(weight, dim=1) weight = weight.narrow(1, 0, K) if incremental_state is not None: - weight = weight[:, -x_unfold.size(2):] + weight = weight[:, -x_unfold.size(2) :] K = weight.size(1) if self.weight_softmax and self.renorm_padding: @@ -175,28 +182,30 @@ def _forward_unfolded(self, x, incremental_state, query): return output def _forward_expanded(self, x, incremental_stat, query): - '''Turn the convolution filters into band matrices and do matrix multiplication. + """Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. - ''' + """ T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size - weight = self.weight_linear(query).view(T*B*H, -1) + weight = self.weight_linear(query).view(T * B * H, -1) if not self.renorm_padding: if self.weight_softmax: weight = F.softmax(weight, dim=1) weight = self.weight_dropout_module(weight, inplace=False) weight = weight.narrow(1, 0, K).contiguous() - weight = weight.view(T, B*H, K).transpose(0, 1) + weight = weight.view(T, B * H, K).transpose(0, 1) - x = x.view(T, B*H, R).transpose(0, 1) + x = x.view(T, B * H, R).transpose(0, 1) if self.weight_softmax and self.renorm_padding: # turn the convolution filters into band matrices - weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) - weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) weight_expanded = weight_expanded.narrow(2, self.padding_l, T) # normalize the weight over valid positions like self-attention weight_expanded = F.softmax(weight_expanded, dim=2) @@ -204,12 +213,14 @@ def _forward_expanded(self, x, incremental_stat, query): else: P = self.padding_l # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length - if K > T and P == K-1: - weight = weight.narrow(2, K-T, T) - K, P = T, T-1 + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 # turn the convolution filters into band matrices - weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) - weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) diff --git a/fairseq/modules/dynamicconv_layer/setup.py b/fairseq/modules/dynamicconv_layer/setup.py index 4d789c3283..6a21f7e2ee 100644 --- a/fairseq/modules/dynamicconv_layer/setup.py +++ b/fairseq/modules/dynamicconv_layer/setup.py @@ -5,19 +5,19 @@ # LICENSE file in the root directory of this source tree. from setuptools import setup -from torch.utils.cpp_extension import CUDAExtension, BuildExtension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + setup( - name='dynamicconv_layer', + name="dynamicconv_layer", ext_modules=[ CUDAExtension( - name='dynamicconv_cuda', + name="dynamicconv_cuda", sources=[ - 'dynamicconv_cuda.cpp', - 'dynamicconv_cuda_kernel.cu', + "dynamicconv_cuda.cpp", + "dynamicconv_cuda_kernel.cu", ], ), ], - cmdclass={ - 'build_ext': BuildExtension - }) + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py index cbfacf477f..f070a804e6 100644 --- a/fairseq/modules/fairseq_dropout.py +++ b/fairseq/modules/fairseq_dropout.py @@ -14,7 +14,6 @@ class FairseqDropout(nn.Module): - def __init__(self, p, module_name=None): super().__init__() self.p = p @@ -37,16 +36,16 @@ def make_generation_fast_( if retain_dropout: if retain_dropout_modules is not None and self.module_name is None: logger.warning( - 'Cannot enable dropout during inference for module {} ' - 'because module_name was not set'.format(name) + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) ) elif ( retain_dropout_modules is None # if None, apply to all modules or self.module_name in retain_dropout_modules ): logger.info( - 'Enabling dropout during inference for module: {}'.format(name) + "Enabling dropout during inference for module: {}".format(name) ) self.apply_during_inference = True else: - logger.info('Disabling dropout for module: {}'.format(name)) + logger.info("Disabling dropout for module: {}".format(name)) diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 01ddd2298b..47657bb0ab 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -83,6 +83,7 @@ def set_num_updates(self, num_updates): self.curr_temp = max( self.max_temp * self.temp_decay ** num_updates, self.min_temp ) + def get_codebook_indices(self): if self.codebook_indices is None: from itertools import product @@ -106,8 +107,8 @@ def codebook(self): indices = self.get_codebook_indices() return ( self.vars.squeeze(0) - .index_select(0, indices) - .view(self.num_vars ** self.groups, -1) + .index_select(0, indices) + .view(self.num_vars ** self.groups, -1) ) def sample_from_codebook(self, b, n): @@ -115,7 +116,7 @@ def sample_from_codebook(self, b, n): indices = indices.view(-1, self.groups) cb_size = indices.size(0) assert ( - n < cb_size + n < cb_size ), f"sample size {n} is greater than size of codebook {cb_size}" sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) indices = indices[sample_idx] diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py index be56e6081b..040db1e83e 100644 --- a/fairseq/modules/kmeans_vector_quantizer.py +++ b/fairseq/modules/kmeans_vector_quantizer.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn - from fairseq.modules import Fp32GroupNorm @@ -13,17 +12,17 @@ class KmeansVectorQuantizer(nn.Module): def __init__( self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 ): - '''Vector quantization using straight pass-through estimator (i.e. kmeans) - - Args: - dim: input dimension (channels) - num_vars: number of quantized vectors per group - groups: number of groups for vector quantization - combine_groups: whether to use the vectors for all groups - vq_dim: dimensionality of the resulting quantized vector - time_first: if true, expect input in BxTxC format, otherwise in BxCxT - gamma: commitment loss coefficient - ''' + """Vector quantization using straight pass-through estimator (i.e. kmeans) + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + gamma: commitment loss coefficient + """ super().__init__() self.groups = groups @@ -51,7 +50,7 @@ def __init__( self.mse_mean = nn.MSELoss(reduction="mean") def _pass_grad(self, x, y): - """ Manually set gradient for backward pass. + """Manually set gradient for backward pass. for y = f(x), ensure that during the backward pass, dL/dy = dL/dx regardless of f(x). Returns: @@ -102,9 +101,9 @@ def forward(self, x, produce_targets=False): x = self._pass_grad(ze, zq) hard_x = ( - idx.new_zeros(bsz*tsz*self.groups, self.num_vars) - .scatter_(-1, idx.view(-1, 1), 1.0) - .view(bsz * tsz, self.groups, -1) + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) ) hard_probs = torch.mean(hard_x.float(), dim=0) result["code_perplexity"] = torch.exp( diff --git a/fairseq/modules/layer_norm.py b/fairseq/modules/layer_norm.py index 7b1d241436..234609d9e2 100644 --- a/fairseq/modules/layer_norm.py +++ b/fairseq/modules/layer_norm.py @@ -22,6 +22,7 @@ def forward(self, x): with torch.cuda.device(x.device): return super().forward(x) + except ImportError: has_fused_layernorm = False diff --git a/fairseq/modules/lightconv_layer/cuda_function_gen.py b/fairseq/modules/lightconv_layer/cuda_function_gen.py index afec9e19e7..a25433dd8e 100644 --- a/fairseq/modules/lightconv_layer/cuda_function_gen.py +++ b/fairseq/modules/lightconv_layer/cuda_function_gen.py @@ -91,7 +91,7 @@ def gen_forward(): } """ - with open("lightconv_cuda_forward.cu", 'w') as forward: + with open("lightconv_cuda_forward.cu", "w") as forward: forward.write(head) for seq in seqs: forward.write(sequence_if.format(seq=seq)) @@ -261,7 +261,7 @@ def gen_backward(): thresh = [32, 32, 64, 128, 256, -1, -1, -1] max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] - with open("lightconv_cuda_backward.cu", 'w') as backward: + with open("lightconv_cuda_backward.cu", "w") as backward: backward.write(head) for (k, t, mem) in zip(kernels, thresh, max_mem): backward.write(case_k.format(k=k)) diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py index 9b4c9a951e..e7e597f474 100644 --- a/fairseq/modules/lightconv_layer/lightconv_layer.py +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -3,19 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import lightconv_cuda import torch -from torch import nn -from torch.autograd import Function import torch.nn.functional as F - -import lightconv_cuda from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import nn +from torch.autograd import Function class lightconvFunction(Function): - @staticmethod def forward(ctx, x, weights, padding_l): ctx.padding_l = padding_l @@ -27,9 +25,8 @@ def forward(ctx, x, weights, padding_l): @staticmethod def backward(ctx, grad_output): outputs = lightconv_cuda.backward( - grad_output.contiguous(), - ctx.padding_l, - *ctx.saved_tensors) + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) grad_input, grad_weights = outputs return grad_input, grad_weights, None @@ -37,14 +34,14 @@ def backward(ctx, grad_output): @with_incremental_state class LightconvLayer(nn.Module): def __init__( - self, - input_size, - kernel_size=1, - padding_l=None, - weight_softmax=False, - num_heads=1, - weight_dropout=0., - bias=False, + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, ): super(LightconvLayer, self).__init__() self.input_size = input_size @@ -52,7 +49,9 @@ def __init__( self.padding_l = padding_l self.num_heads = num_heads self.weight_softmax = weight_softmax - self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) if bias: @@ -62,16 +61,16 @@ def __init__( self.reset_parameters() def upgrade_state_dict_named(self, state_dict, name): - prefix = name + '.' if name != '' else '' + prefix = name + "." if name != "" else "" for k, v in state_dict.items(): - if k.endswith(prefix + 'weight'): + if k.endswith(prefix + "weight"): if v.dim() == 3 and v.size(1) == 1: state_dict[k] = v.squeeze(1) def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: - nn.init.constant_(self.bias, 0.) + nn.init.constant_(self.bias, 0.0) def forward(self, x, incremental_state=None): @@ -85,18 +84,25 @@ def forward(self, x, incremental_state=None): input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: - self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) - x_unfold = x_unfold.view(T*B*H, R, -1) + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) weight = self.weight if self.weight_softmax: weight = F.softmax(weight.float(), dim=1).type_as(weight) - weight = weight[:, -x_unfold.size(2):] + weight = weight[:, -x_unfold.size(2) :] K = weight.size(1) - weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) + weight = ( + weight.view(1, H, K) + .expand(T * B, H, K) + .contiguous() + .view(T * B * H, K, 1) + ) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 @@ -120,10 +126,12 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): - return utils.get_incremental_state(self, incremental_state, 'input_buffer') + return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): - return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) def half(self): return self._apply(lambda t: t.half() if t.is_floating_point() else t) diff --git a/fairseq/modules/lightconv_layer/setup.py b/fairseq/modules/lightconv_layer/setup.py index 0eac1df03c..052635be79 100644 --- a/fairseq/modules/lightconv_layer/setup.py +++ b/fairseq/modules/lightconv_layer/setup.py @@ -5,16 +5,19 @@ # LICENSE file in the root directory of this source tree. from setuptools import setup -from torch.utils.cpp_extension import CUDAExtension, BuildExtension +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + setup( - name='lightconv_layer', + name="lightconv_layer", ext_modules=[ - CUDAExtension('lightconv_cuda', [ - 'lightconv_cuda.cpp', - 'lightconv_cuda_kernel.cu', - ]), + CUDAExtension( + "lightconv_cuda", + [ + "lightconv_cuda.cpp", + "lightconv_cuda_kernel.cu", + ], + ), ], - cmdclass={ - 'build_ext': BuildExtension - }) + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py index 3d4cddb134..ec11a95079 100644 --- a/fairseq/modules/lightweight_convolution.py +++ b/fairseq/modules/lightweight_convolution.py @@ -6,32 +6,49 @@ import torch import torch.nn as nn import torch.nn.functional as F - from fairseq import utils -from fairseq.modules.unfold import unfold1d from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d -def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, - weight_dropout=0., weight_softmax=False, bias=False): +def LightweightConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, +): if torch.cuda.is_available(): try: from fairseq.modules.lightconv_layer import LightconvLayer - return LightconvLayer(input_size, kernel_size=kernel_size, - padding_l=padding_l, num_heads=num_heads, - weight_dropout=weight_dropout, - weight_softmax=weight_softmax, bias=bias) + + return LightconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) except ImportError as e: print(e) - return LightweightConv1dTBC(input_size, kernel_size=kernel_size, - padding_l=padding_l, num_heads=num_heads, - weight_dropout=weight_dropout, - weight_softmax=weight_softmax, bias=bias) + return LightweightConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) class LightweightConv1d(nn.Module): - '''Lightweight Convolution assuming the input is BxCxT + """Lightweight Convolution assuming the input is BxCxT This is just an example that explains LightConv clearer than the TBC version. We don't use this module in the model. @@ -51,10 +68,18 @@ class LightweightConv1d(nn.Module): weight: the learnable weights of the module of shape `(num_heads, 1, kernel_size)` bias: the learnable bias of the module of shape `(input_size)` - ''' - - def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1, - weight_softmax=False, bias=False, weight_dropout=0.): + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding=0, + num_heads=1, + weight_softmax=False, + bias=False, + weight_dropout=0.0, + ): super().__init__() self.input_size = input_size self.kernel_size = kernel_size @@ -67,19 +92,21 @@ def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1, self.bias = nn.Parameter(torch.Tensor(input_size)) else: self.bias = None - self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: - nn.init.constant_(self.bias, 0.) + nn.init.constant_(self.bias, 0.0) def forward(self, input): - ''' + """ input size: B x C x T output size: B x C x T - ''' + """ B, C, T = input.size() H = self.num_heads @@ -103,7 +130,7 @@ def forward(self, input): @with_incremental_state class LightweightConv1dTBC(nn.Module): - '''Lightweight Convolution assuming the input is TxBxC + """Lightweight Convolution assuming the input is TxBxC Args: input_size: # of channels of the input kernel_size: convolution channels @@ -121,15 +148,26 @@ class LightweightConv1dTBC(nn.Module): weight: the learnable weights of the module of shape `(num_heads, 1, kernel_size)` bias: the learnable bias of the module of shape `(input_size)` - ''' - def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, - weight_dropout=0., weight_softmax=False, bias=False): + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, + ): super().__init__() self.input_size = input_size self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads - self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) self.weight_softmax = weight_softmax self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) @@ -144,15 +182,15 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: - nn.init.constant_(self.bias, 0.) + nn.init.constant_(self.bias, 0.0) def forward(self, x, incremental_state=None, unfold=False): - '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C args: x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) incremental_state: A dict to keep the state unfold: unfold the input or not. If not, we use the matrix trick instead - ''' + """ unfold = unfold or (incremental_state is not None) if unfold: @@ -168,8 +206,8 @@ def prepare_for_onnx_export_(self): self.onnx_trace = True def _forward_unfolded(self, x, incremental_state): - '''The conventional implementation of convolutions. - Unfolding the input by having a window shifting to the right.''' + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H @@ -182,21 +220,27 @@ def _forward_unfolded(self, x, incremental_state): input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: - self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) - x_unfold = x_unfold.view(T*B*H, R, -1) + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) else: # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) - x_unfold = x_unfold.view(T*B*H, R, K) + x_unfold = x_unfold.view(T * B * H, R, K) if self.weight_softmax: - weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) if incremental_state is not None: - weight = weight[:, -x_unfold.size(2):] + weight = weight[:, -x_unfold.size(2) :] K = weight.size(1) - weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) + weight = ( + weight.view(1, H, K).expand(T * B, H, K).contiguous().view(T * B * H, K, 1) + ) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 @@ -204,10 +248,10 @@ def _forward_unfolded(self, x, incremental_state): return output def _forward_expanded(self, x, incremental_state): - '''Turn the convolution filters into band matrices and do matrix multiplication. + """Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. - ''' + """ T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H @@ -215,18 +259,22 @@ def _forward_expanded(self, x, incremental_state): weight = self.weight.view(H, K) if self.weight_softmax: - weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) - weight = weight.view(1, H, K).expand(T*B, H, K).contiguous() - weight = weight.view(T, B*H, K).transpose(0, 1) + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) + weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) - x = x.view(T, B*H, R).transpose(0, 1) + x = x.view(T, B * H, R).transpose(0, 1) P = self.padding_l - if K > T and P == K-1: - weight = weight.narrow(2, K-T, T) - K, P = T, T-1 + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 # turn the convolution filters into band matrices - weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) - weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_( + weight + ) weight_expanded = weight_expanded.narrow(2, P, T) weight_expanded = self.weight_dropout_module(weight_expanded) @@ -241,16 +289,22 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): - return utils.get_incremental_state(self, incremental_state, 'input_buffer') + return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): - return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) def extra_repr(self): - s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}'.format( - self.input_size, self.kernel_size, self.padding_l, - self.num_heads, self.weight_softmax, self.bias is not None + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.bias is not None, ) - if self.weight_dropout_module.p > 0.: - s += ', weight_dropout={}'.format(self.weight_dropout_module.p) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) return s diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py index 3dd4b151c1..09a8f201c0 100644 --- a/fairseq/modules/linearized_convolution.py +++ b/fairseq/modules/linearized_convolution.py @@ -5,11 +5,11 @@ import torch import torch.nn.functional as F - from fairseq import utils -from .conv_tbc import ConvTBC from fairseq.incremental_decoding_utils import with_incremental_state +from .conv_tbc import ConvTBC + @with_incremental_state class LinearizedConvolution(ConvTBC): @@ -26,17 +26,17 @@ def __init__(self, in_channels, out_channels, kernel_size, **kwargs): self._linearized_weight = None self.register_backward_hook(self._clear_linearized_weight) - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars) # don't store redundant _linearized_weight in checkpoints - if prefix + '_linearized_weight' in state: - del state[prefix + '_linearized_weight'] + if prefix + "_linearized_weight" in state: + del state[prefix + "_linearized_weight"] return state def upgrade_state_dict_named(self, state_dict, name): - prefix = name + '.' if name != '' else '' - if prefix + '_linearized_weight' in state_dict: - del state_dict[prefix + '_linearized_weight'] + prefix = name + "." if name != "" else "" + if prefix + "_linearized_weight" in state_dict: + del state_dict[prefix + "_linearized_weight"] def forward(self, input, incremental_state=None): """ @@ -52,7 +52,7 @@ def forward(self, input, incremental_state=None): output = super().forward(input) if self.kernel_size[0] > 1 and self.padding[0] > 0: # remove future timesteps added by padding - output = output[:-self.padding[0], :, :] + output = output[: -self.padding[0], :, :] return output # reshape weight @@ -83,17 +83,21 @@ def reorder_incremental_state(self, incremental_state, new_order): self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): - return utils.get_incremental_state(self, incremental_state, 'input_buffer') + return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): - return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) def _get_linearized_weight(self): if self._linearized_weight is None: kw = self.kernel_size[0] weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() assert weight.size() == (self.out_channels, kw, self.in_channels) - self._linearized_weight = torch.nn.Parameter(weight.view(self.out_channels, -1)) + self._linearized_weight = torch.nn.Parameter( + weight.view(self.out_channels, -1) + ) return self._linearized_weight def _clear_linearized_weight(self, *args): diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 90b635af2b..99f95deb5f 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -8,13 +8,12 @@ import torch import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn import Parameter - from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter @with_incremental_state @@ -63,11 +62,19 @@ def __init__( "Self-attention requires query, key and " "value to be of the same size" ) - self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) - self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) - self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) - self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) @@ -102,7 +109,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: - nn.init.constant_(self.out_proj.bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: @@ -333,11 +340,11 @@ def forward( if not self.tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf") + float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf')) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -411,7 +418,9 @@ def _append_prev_key_padding_mask( @torch.jit.export def reorder_incremental_state( - self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, ): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) @@ -419,7 +428,9 @@ def reorder_incremental_state( for k in input_buffer.keys(): input_buffer_k = input_buffer[k] if input_buffer_k is not None: - if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0): + if self.encoder_decoder_attention and input_buffer_k.size( + 0 + ) == new_order.size(0): break input_buffer[k] = input_buffer_k.index_select(0, new_order) incremental_state = self._set_input_buffer(incremental_state, input_buffer) diff --git a/fairseq/modules/positional_embedding.py b/fairseq/modules/positional_embedding.py index 511460fcb7..8e94e35edb 100644 --- a/fairseq/modules/positional_embedding.py +++ b/fairseq/modules/positional_embedding.py @@ -4,15 +4,16 @@ # LICENSE file in the root directory of this source tree. import torch.nn as nn + from .learned_positional_embedding import LearnedPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding def PositionalEmbedding( - num_embeddings: int, - embedding_dim: int, - padding_idx: int, - learned: bool = False, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + learned: bool = False, ): if learned: # if padding_idx is specified then offset the embedding ids by @@ -27,6 +28,8 @@ def PositionalEmbedding( nn.init.constant_(m.weight[padding_idx], 0) else: m = SinusoidalPositionalEmbedding( - embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, + embedding_dim, + padding_idx, + init_size=num_embeddings + padding_idx + 1, ) return m diff --git a/fairseq/modules/quant_noise.py b/fairseq/modules/quant_noise.py index b38ea263d3..d777dfbb6c 100644 --- a/fairseq/modules/quant_noise.py +++ b/fairseq/modules/quant_noise.py @@ -39,13 +39,17 @@ def quant_noise(module, p, block_size): # 2D matrix if not is_conv: - assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" # 4D matrix else: # 1x1 convolutions if module.kernel_size == (1, 1): - assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" # regular convolutions else: k = module.kernel_size[0] * module.kernel_size[1] @@ -61,7 +65,9 @@ def _forward_pre_hook(mod, input): out_features = weight.size(0) # split weight matrix into blocks and randomly drop selected blocks - mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) @@ -73,16 +79,27 @@ def _forward_pre_hook(mod, input): # split weight matrix into blocks and randomly drop selected blocks if mod.kernel_size == (1, 1): - mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device) + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) else: - mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) mask.bernoulli_(p) - mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) # scale weights and apply mask - mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript s = 1 / (1 - p) mod.weight.data = s * weight.masked_fill(mask, 0) diff --git a/fairseq/modules/quantization/pq/em.py b/fairseq/modules/quantization/pq/em.py index 420d8afda2..6f15c3e46b 100644 --- a/fairseq/modules/quantization/pq/em.py +++ b/fairseq/modules/quantization/pq/em.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import os import random -import logging from collections import Counter import torch diff --git a/fairseq/modules/quantization/pq/modules/__init__.py b/fairseq/modules/quantization/pq/modules/__init__.py index f52f6f37a6..b67c8e8ad6 100644 --- a/fairseq/modules/quantization/pq/modules/__init__.py +++ b/fairseq/modules/quantization/pq/modules/__init__.py @@ -4,5 +4,5 @@ # LICENSE file in the root directory of this source tree. from .qconv import PQConv2d # NOQA -from .qlinear import PQLinear # NOQA from .qemb import PQEmbedding # NOQA +from .qlinear import PQLinear # NOQA diff --git a/fairseq/modules/quantization/pq/modules/qemb.py b/fairseq/modules/quantization/pq/modules/qemb.py index 98d856d04e..3a74ad3c4c 100644 --- a/fairseq/modules/quantization/pq/modules/qemb.py +++ b/fairseq/modules/quantization/pq/modules/qemb.py @@ -27,9 +27,19 @@ class PQEmbedding(nn.Module): the non-quantized nn.Embedding module for a standard training loop. """ - def __init__(self, centroids, assignments, num_embeddings, embedding_dim, - padding_idx=None, max_norm=None, norm_type=2., - scale_grad_by_freq=False, sparse=False, _weight=None): + def __init__( + self, + centroids, + assignments, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + ): super(PQEmbedding, self).__init__() self.block_size = centroids.size(1) self.n_centroids = centroids.size(0) @@ -37,9 +47,13 @@ def __init__(self, centroids, assignments, num_embeddings, embedding_dim, self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm @@ -67,21 +81,27 @@ def weight(self): def forward(self, input): return F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self): - s = '{num_embeddings}, {embedding_dim}' + s = "{num_embeddings}, {embedding_dim}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.max_norm is not None: - s += ', max_norm={max_norm}' + s += ", max_norm={max_norm}" if self.norm_type != 2: - s += ', norm_type={norm_type}' + s += ", norm_type={norm_type}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' - s += ', n_centroids={n_centroids}, block_size={block_size}' + s += ", sparse=True" + s += ", n_centroids={n_centroids}, block_size={block_size}" return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py index 57aaa1b7a3..03b15e4b1b 100644 --- a/fairseq/modules/quantization/pq/utils.py +++ b/fairseq/modules/quantization/pq/utils.py @@ -8,10 +8,10 @@ from operator import attrgetter, itemgetter import numpy as np -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn -from .modules import PQConv2d, PQLinear, PQEmbedding +from .modules import PQConv2d, PQEmbedding, PQLinear from .pq import PQ @@ -63,7 +63,9 @@ def quantize_model_( for layer in quantized_layers: # book-keeping - is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) verbose = verbose and is_master_process # get block size and centroids @@ -71,11 +73,13 @@ def quantize_model_( block_size = get_param(module, layer, block_sizes_config) n_centroids = get_param(module, layer, n_centroids_config) if verbose: - logging.info(f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids") + logging.info( + f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids" + ) # quantize layer weight = module.weight.data.clone() - is_bias = 'bias' in [x[0] for x in module.named_parameters()] + is_bias = "bias" in [x[0] for x in module.named_parameters()] bias = module.bias.data.clone() if is_bias else None quantizer = PQ( weight, @@ -238,9 +242,7 @@ def get_param(module, layer_name, param_config): if "*" in params: feature_value = "*" else: - raise KeyError( - f"name={layer_name} not in config for {module}" - ) + raise KeyError(f"name={layer_name} not in config for {module}") else: feature_value = feature_values[0] diff --git a/fairseq/modules/quantization/scalar/modules/__init__.py b/fairseq/modules/quantization/scalar/modules/__init__.py index ead4669611..8031d9cdb2 100644 --- a/fairseq/modules/quantization/scalar/modules/__init__.py +++ b/fairseq/modules/quantization/scalar/modules/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .qact import ActivationQuantizer # NOQA from .qconv import IntConv2d # NOQA -from .qlinear import IntLinear # NOQA from .qemb import IntEmbedding # NOQA -from .qact import ActivationQuantizer # NOQA +from .qlinear import IntLinear # NOQA diff --git a/fairseq/modules/quantization/scalar/modules/qact.py b/fairseq/modules/quantization/scalar/modules/qact.py index a9f79011c1..c5dd1d6336 100644 --- a/fairseq/modules/quantization/scalar/modules/qact.py +++ b/fairseq/modules/quantization/scalar/modules/qact.py @@ -32,8 +32,16 @@ class ActivationQuantizer: - The activations are hard-clamped in [-clamp_threshold, clamp_threshold] to prevent overflow during the backward pass """ - def __init__(self, module, p=1, update_step=1000, bits=8, - method="histogram", clamp_threshold=5): + + def __init__( + self, + module, + p=1, + update_step=1000, + bits=8, + method="histogram", + clamp_threshold=5, + ): self.module = module self.p = p self.update_step = update_step @@ -72,7 +80,7 @@ def quantize_hook(module, x, y): noise = (y_q - y).masked_fill(mask.bool(), 0) # using straight-through estimator (STE) - clamp_low = - self.scale * self.zero_point + clamp_low = -self.scale * self.zero_point clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) return torch.clamp(y, clamp_low.item(), clamp_high.item()) + noise.detach() diff --git a/fairseq/modules/quantization/scalar/modules/qconv.py b/fairseq/modules/quantization/scalar/modules/qconv.py index d718c9b90d..83788c6f71 100644 --- a/fairseq/modules/quantization/scalar/modules/qconv.py +++ b/fairseq/modules/quantization/scalar/modules/qconv.py @@ -118,9 +118,12 @@ def forward(self, input): noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) # using straight-through estimator (STE) - clamp_low = - self.scale * self.zero_point + clamp_low = -self.scale * self.zero_point clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) - weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach() + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) # return output output = self._conv_forward(input, weight) diff --git a/fairseq/modules/quantization/scalar/modules/qemb.py b/fairseq/modules/quantization/scalar/modules/qemb.py index 835b2782a7..d6cf06e587 100644 --- a/fairseq/modules/quantization/scalar/modules/qemb.py +++ b/fairseq/modules/quantization/scalar/modules/qemb.py @@ -37,7 +37,7 @@ def __init__( embedding_dim, padding_idx=None, max_norm=None, - norm_type=2., + norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, @@ -51,9 +51,13 @@ def __init__( self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm @@ -63,8 +67,10 @@ def __init__( self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" self.weight = nn.Parameter(_weight) self.sparse = sparse @@ -106,27 +112,36 @@ def forward(self, input): noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) # using straight-through estimator (STE) - clamp_low = - self.scale * self.zero_point + clamp_low = -self.scale * self.zero_point clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) - weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach() + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) # return output output = F.embedding( - input, weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return output def extra_repr(self): - s = '{num_embeddings}, {embedding_dim}' + s = "{num_embeddings}, {embedding_dim}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.max_norm is not None: - s += ', max_norm={max_norm}' + s += ", max_norm={max_norm}" if self.norm_type != 2: - s += ', norm_type={norm_type}' + s += ", norm_type={norm_type}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' - s += 'quant_noise={p}, bits={bits}, method={method}' + s += ", sparse=True" + s += "quant_noise={p}, bits={bits}, method={method}" return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/scalar/modules/qlinear.py b/fairseq/modules/quantization/scalar/modules/qlinear.py index 2d4b27dc6c..9db1559386 100644 --- a/fairseq/modules/quantization/scalar/modules/qlinear.py +++ b/fairseq/modules/quantization/scalar/modules/qlinear.py @@ -91,9 +91,12 @@ def forward(self, input): noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) # using straight-through estimator (STE) - clamp_low = - self.scale * self.zero_point + clamp_low = -self.scale * self.zero_point clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) - weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach() + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) # return output output = F.linear(input, weight, self.bias) diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py index 90bc737cc8..2a855159be 100644 --- a/fairseq/modules/quantization/scalar/ops.py +++ b/fairseq/modules/quantization/scalar/ops.py @@ -12,7 +12,9 @@ def emulate_int(w, bits, method, scale=None, zero_point=None): def quantize(w, scale, zero_point): - return (torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point) * scale + return ( + torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point + ) * scale def emulate_int8_histogram(w, scale=None, zero_point=None): diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py index 4071f7b80a..32cf616568 100644 --- a/fairseq/modules/quantization/scalar/utils.py +++ b/fairseq/modules/quantization/scalar/utils.py @@ -6,11 +6,11 @@ import logging from operator import attrgetter -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn -from ..pq.utils import get_layers, attrsetter -from .modules import IntConv2d, IntLinear, IntEmbedding, ActivationQuantizer +from ..pq.utils import attrsetter, get_layers +from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} @@ -34,15 +34,25 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000): for layer in quantized_layers: # book-keeping - is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) # recover module module = attrgetter(layer)(model) if is_master_process: - logging.info(f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}") + logging.info( + f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}" + ) # quantization params - q_params = {"p": p, "update_step": update_step, "bits": bits, "method": "histogram", "counter": 0} + q_params = { + "p": p, + "update_step": update_step, + "bits": bits, + "method": "histogram", + "counter": 0, + } # instantiate the quantized counterpart if isinstance(module, tuple(MAPPING.keys())): diff --git a/fairseq/modules/sparse_multihead_attention.py b/fairseq/modules/sparse_multihead_attention.py index 61430195c2..3cbd9d6785 100644 --- a/fairseq/modules/sparse_multihead_attention.py +++ b/fairseq/modules/sparse_multihead_attention.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. import math + import torch + from .multihead_attention import MultiheadAttention class SparseMultiheadAttention(MultiheadAttention): - """ Sparse Multi-Headed Attention. + """Sparse Multi-Headed Attention. "Generating Long Sequences with Sparse Transformers". Implements fixed factorized self attention, where l=stride and c=expressivity. @@ -19,19 +21,40 @@ class SparseMultiheadAttention(MultiheadAttention): as in the paper. """ - def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, - add_bias_kv=False, add_zero_attn=False, self_attention=False, - encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + stride=32, + expressivity=8, + is_bidirectional=True, + ): super().__init__( - embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, - add_zero_attn, self_attention, encoder_decoder_attention + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, ) self.is_bidirectional = is_bidirectional self.stride = stride self.expressivity = expressivity - assert(self.stride > 0 and self.stride >= self.expressivity) + assert self.stride > 0 and self.stride >= self.expressivity # Used for Ai(2) calculations - beginning of [l-c, l] range def compute_checkpoint(self, word_index): @@ -40,7 +63,8 @@ def compute_checkpoint(self, word_index): else: checkpoint_index = ( math.floor(word_index / self.stride) * self.stride - + self.stride - self.expressivity + + self.stride + - self.expressivity ) return checkpoint_index @@ -48,12 +72,15 @@ def compute_checkpoint(self, word_index): def compute_subset_summaries(self, absolute_max): checkpoint_index = self.compute_checkpoint(0) subset_two = set() - while checkpoint_index <= absolute_max-1: - summary = set(range(checkpoint_index, min( - checkpoint_index+self.expressivity+1, absolute_max) - )) + while checkpoint_index <= absolute_max - 1: + summary = set( + range( + checkpoint_index, + min(checkpoint_index + self.expressivity + 1, absolute_max), + ) + ) subset_two = subset_two.union(summary) - checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride) + checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) return subset_two # Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf @@ -65,12 +92,19 @@ def compute_fixed_attention_subset(self, word_index, tgt_len): absolute_max = tgt_len # Subset 1 - whole window - rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride + rounded_index = ( + math.floor((word_index + self.stride) / self.stride) * self.stride + ) if word_index % self.stride == 0 and word_index != 0: - subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1))) + subset_one = set( + range(word_index - self.stride, min(absolute_max, word_index + 1)) + ) else: - subset_one = set(range(max(0, rounded_index - self.stride), min( - absolute_max, rounded_index+1)) + subset_one = set( + range( + max(0, rounded_index - self.stride), + min(absolute_max, rounded_index + 1), + ) ) # Subset 2 - summary per window @@ -83,8 +117,8 @@ def compute_fixed_attention_subset(self, word_index, tgt_len): # Compute sparse mask - if bidirectional, can pre-compute and store def buffered_sparse_mask(self, tensor, tgt_len, src_len): - assert(tgt_len > self.stride) - sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf')) + assert tgt_len > self.stride + sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) # If bidirectional, subset 2 is the same for every index subset_summaries = set() @@ -100,5 +134,7 @@ def buffered_sparse_mask(self, tensor, tgt_len, src_len): def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) - sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len) + sparse_mask = sparse_mask.unsqueeze(0).expand( + bsz * self.num_heads, tgt_len, src_len + ) attn_weights += sparse_mask diff --git a/fairseq/modules/sparse_transformer_sentence_encoder.py b/fairseq/modules/sparse_transformer_sentence_encoder.py index 3d50d5a882..f41ec09327 100644 --- a/fairseq/modules/sparse_transformer_sentence_encoder.py +++ b/fairseq/modules/sparse_transformer_sentence_encoder.py @@ -5,7 +5,9 @@ import torch.nn as nn from fairseq.modules import TransformerSentenceEncoder -from fairseq.modules.sparse_transformer_sentence_encoder_layer import SparseTransformerSentenceEncoderLayer +from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( + SparseTransformerSentenceEncoderLayer, +) class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): @@ -43,12 +45,27 @@ def __init__( ) -> None: super().__init__( - padding_idx, vocab_size, num_encoder_layers, embedding_dim, - ffn_embedding_dim, num_attention_heads, dropout, attention_dropout, - activation_dropout, max_seq_len, num_segments, use_position_embeddings, - offset_positions_by_padding, encoder_normalize_before, apply_bert_init, - activation_fn, learned_pos_embedding, embed_scale, freeze_embeddings, - n_trans_layers_to_freeze, export + padding_idx, + vocab_size, + num_encoder_layers, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + max_seq_len, + num_segments, + use_position_embeddings, + offset_positions_by_padding, + encoder_normalize_before, + apply_bert_init, + activation_fn, + learned_pos_embedding, + embed_scale, + freeze_embeddings, + n_trans_layers_to_freeze, + export, ) self.layers = nn.ModuleList( diff --git a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py index 21c2fe4d5a..d95da59c24 100644 --- a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py +++ b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py @@ -20,7 +20,7 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, - activation_fn: str = 'relu', + activation_fn: str = "relu", export: bool = False, is_bidirectional: bool = True, stride: int = 32, @@ -28,8 +28,14 @@ def __init__( ) -> None: super().__init__( - embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, - attention_dropout, activation_dropout, activation_fn, export + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, ) self.self_attn = SparseMultiheadAttention( diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 9965f2f26c..48cd4c7314 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -9,10 +9,11 @@ import torch.nn as nn from fairseq import utils from fairseq.modules import LayerNorm, MultiheadAttention -from fairseq.modules.quant_noise import quant_noise from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise from torch import Tensor + class TransformerEncoderLayer(nn.Module): """Encoder layer block. @@ -35,7 +36,9 @@ def __init__(self, args): self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) @@ -48,19 +51,29 @@ def __init__(self, args): ) self.normalize_before = args.encoder_normalize_before self.fc1 = self.build_fc1( - self.embed_dim, args.encoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size + self.embed_dim, + args.encoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, ) self.fc2 = self.build_fc2( - args.encoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size + args.encoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, ) self.final_layer_norm = LayerNorm(self.embed_dim) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise(nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size) + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise(nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size) + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) def build_self_attention(self, embed_dim, args): return MultiheadAttention( @@ -164,7 +177,9 @@ def __init__( ): super().__init__() self.embed_dim = args.decoder_embed_dim - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.quant_noise = getattr(args, "quant_noise_pq", 0) self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) @@ -178,14 +193,17 @@ def __init__( ) self.activation_fn = utils.get_activation_fn( - activation=str(args.activation_fn) if getattr(args, "activation_fn", None) is not None else "relu" + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" ) activation_dropout_p = getattr(args, "activation_dropout", 0) if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout activation_dropout_p = getattr(args, "relu_dropout", 0) self.activation_dropout_module = FairseqDropout( - float(activation_dropout_p), module_name=self.__class__.__name__) + float(activation_dropout_p), module_name=self.__class__.__name__ + ) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. @@ -202,10 +220,16 @@ def __init__( self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = self.build_fc1( - self.embed_dim, args.decoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size + self.embed_dim, + args.decoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, ) self.fc2 = self.build_fc2( - args.decoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size + args.decoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, ) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) @@ -219,7 +243,9 @@ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) - def build_self_attention(self, embed_dim, args, add_bias_kv=False, add_zero_attn=False): + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): return MultiheadAttention( embed_dim, args.decoder_attention_heads, diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 74cd1d0664..208488f562 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -102,7 +102,9 @@ def __init__( super().__init__() self.padding_idx = padding_idx self.vocab_size = vocab_size - self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim @@ -148,21 +150,23 @@ def __init__( self.layers = LayerDropModuleList(p=self.layerdrop) else: self.layers = nn.ModuleList([]) - self.layers.extend([ - self.build_transformer_sentence_encoder_layer( - embedding_dim=self.embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=self.dropout_module.p, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) - for _ in range(num_encoder_layers) - ]) + self.layers.extend( + [ + self.build_transformer_sentence_encoder_layer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=self.dropout_module.p, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + for _ in range(num_encoder_layers) + ] + ) if encoder_normalize_before: self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index 383938f68f..3589c60fe6 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -7,15 +7,10 @@ import torch import torch.nn as nn - from fairseq import utils -from fairseq.modules import ( - LayerNorm, - MultiheadAttention, -) -from fairseq.modules.quant_noise import quant_noise +from fairseq.modules import LayerNorm, MultiheadAttention from fairseq.modules.fairseq_dropout import FairseqDropout - +from fairseq.modules.quant_noise import quant_noise class TransformerSentenceEncoderLayer(nn.Module): @@ -32,7 +27,7 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, - activation_fn: str = 'relu', + activation_fn: str = "relu", export: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, @@ -45,8 +40,12 @@ def __init__( # Initialize parameters self.embedding_dim = embedding_dim - self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) - self.activation_dropout_module = FairseqDropout(activation_dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.activation_dropout_module = FairseqDropout( + activation_dropout, module_name=self.__class__.__name__ + ) # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) @@ -79,14 +78,10 @@ def __init__( self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise( - nn.Linear(input_dim, output_dim), q_noise, qn_block_size - ) + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise( - nn.Linear(input_dim, output_dim), q_noise, qn_block_size - ) + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_self_attention( self, diff --git a/fairseq/modules/unfold.py b/fairseq/modules/unfold.py index 3a142db698..138272f1ef 100644 --- a/fairseq/modules/unfold.py +++ b/fairseq/modules/unfold.py @@ -7,11 +7,13 @@ def unfold1d(x, kernel_size, padding_l, pad_value=0): - '''unfold T x B x C to T x B x C x K''' + """unfold T x B x C to T x B x C x K""" if kernel_size > 1: T, B, C = x.size() - x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) - x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) + x = F.pad( + x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value + ) + x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) else: x = x.unsqueeze(3) return x diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index 0d7d8d7d79..faa8031d46 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -4,14 +4,16 @@ # LICENSE file in the root directory of this source tree. import logging + import torch + logger = logging.getLogger(__name__) class NanDetector: """ - Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name + Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name """ def __init__(self, model, forward=True, backward=True): @@ -83,7 +85,7 @@ def _apply(self, module, inp, x, backward): f" input max: {inp.max().item()}, input min: {inp.min().item()}" ) - has_printed_attr = 'has_printed_b' if backward else 'has_printed_f' + has_printed_attr = "has_printed_b" if backward else "has_printed_f" logger.warning(err) setattr(self, has_printed_attr, True) elif isinstance(x, dict): diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py index 9b311ae38a..f1a2154977 100644 --- a/fairseq/optim/adadelta.py +++ b/fairseq/optim/adadelta.py @@ -5,10 +5,10 @@ import torch.optim -from . import register_optimizer, LegacyFairseqOptimizer +from . import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('adadelta') +@register_optimizer("adadelta") class Adadelta(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -36,10 +36,10 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'rho': self.args.adadelta_rho, - 'eps': self.args.adadelta_eps, - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "rho": self.args.adadelta_rho, + "eps": self.args.adadelta_eps, + "weight_decay": self.args.weight_decay, } @property diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index ab69e0e58d..91745ce10e 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. import math + import torch import torch.optim -from . import register_optimizer, LegacyFairseqOptimizer +from . import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('adafactor') +@register_optimizer("adafactor") class FairseqAdafactor(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -50,15 +51,15 @@ def optimizer_config(self): Might require search for appropriate configuration. """ return { - 'lr': self.args.lr[0], - 'eps': eval(self.args.adafactor_eps), - 'clip_threshold': self.args.clip_threshold, - 'decay_rate': self.args.decay_rate, - 'beta1': self.args.beta1, - 'weight_decay': self.args.weight_decay, - 'scale_parameter': self.args.scale_parameter, # defaults to False - 'relative_step': self.args.relative_step, # defaults to False - 'warmup_init': self.args.warmup_init, + "lr": self.args.lr[0], + "eps": eval(self.args.adafactor_eps), + "clip_threshold": self.args.clip_threshold, + "decay_rate": self.args.decay_rate, + "beta1": self.args.beta1, + "weight_decay": self.args.weight_decay, + "scale_parameter": self.args.scale_parameter, # defaults to False + "relative_step": self.args.relative_step, # defaults to False + "warmup_init": self.args.warmup_init, } @@ -96,17 +97,35 @@ class Adafactor(torch.optim.Optimizer): whether warm-up initialization is being used (default: False) """ - def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, - decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, - relative_step=True, warmup_init=False): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): if lr is not None and relative_step: - raise ValueError('Cannot combine manual lr and relative_step options') + raise ValueError("Cannot combine manual lr and relative_step options") if warmup_init and not relative_step: - raise ValueError('warmup_init requires relative_step=True') - - defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate, - beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, - relative_step=relative_step, warmup_init=warmup_init) + raise ValueError("warmup_init requires relative_step=True") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) super(Adafactor, self).__init__(params, defaults) @property @@ -118,18 +137,20 @@ def supports_flat_params(self): return False def _get_lr(self, param_group, param_state): - rel_step_sz = param_group['lr'] - if param_group['relative_step']: - min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 - rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step'])) + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 - if param_group['scale_parameter']: - param_scale = max(param_group['eps'][1], param_state['RMS']) + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) return param_scale * rel_step_sz def _get_options(self, param_group, param_shape): factored = len(param_shape) >= 2 - use_first_moment = param_group['beta1'] is not None + use_first_moment = param_group["beta1"] is not None return factored, use_first_moment def _rms(self, tensor): @@ -137,8 +158,10 @@ def _rms(self, tensor): def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = ( - exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) - ).rsqrt_().unsqueeze(-1) + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) @@ -154,14 +177,14 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError('Adafactor does not support sparse gradients.') + raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape @@ -169,65 +192,73 @@ def step(self, closure=None): factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 if use_first_moment: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(grad) + state["exp_avg"] = torch.zeros_like(grad) if factored: - state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) - state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) else: - state['exp_avg_sq'] = torch.zeros_like(grad) + state["exp_avg_sq"] = torch.zeros_like(grad) - state['RMS'] = 0 + state["RMS"] = 0 else: if use_first_moment: - state['exp_avg'] = state['exp_avg'].to(grad) + state["exp_avg"] = state["exp_avg"].to(grad) if factored: - state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) - state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: - state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() - state['step'] += 1 - state['RMS'] = self._rms(p_data_fp32) - group['lr'] = self._get_lr(group, state) + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + group["lr"] = self._get_lr(group, state) - beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) - update = (grad**2) + group['eps'][0] + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] if factored: - exp_avg_sq_row = state['exp_avg_sq_row'] - exp_avg_sq_col = state['exp_avg_sq_col'] + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: - exp_avg_sq = state['exp_avg_sq'] + exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) update = exp_avg_sq.rsqrt().mul_(grad) update.div_( - (self._rms(update) / group['clip_threshold']).clamp_(min=1.0) + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) ) - update.mul_(group['lr']) + update.mul_(group["lr"]) if use_first_moment: - exp_avg = state['exp_avg'] - exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) update = exp_avg - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) p_data_fp32.add_(-update) diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index 5056752776..a79b6c39da 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -5,10 +5,10 @@ import torch.optim -from . import register_optimizer, LegacyFairseqOptimizer +from . import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('adagrad') +@register_optimizer("adagrad") class Adagrad(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -31,8 +31,8 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, } @property diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py index 195e7a90d8..577a688166 100644 --- a/fairseq/optim/adamax.py +++ b/fairseq/optim/adamax.py @@ -6,10 +6,10 @@ import torch import torch.optim -from . import register_optimizer, LegacyFairseqOptimizer +from . import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('adamax') +@register_optimizer("adamax") class FairseqAdamax(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -38,11 +38,11 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'betas': eval(self.args.adamax_betas), - 'eps': self.args.adamax_eps, - 'weight_decay': self.args.weight_decay, - 'bias_correction': not self.args.no_bias_correction, + "lr": self.args.lr[0], + "betas": eval(self.args.adamax_betas), + "eps": self.args.adamax_eps, + "weight_decay": self.args.weight_decay, + "bias_correction": not self.args.no_bias_correction, } @@ -67,8 +67,15 @@ class Adamax(torch.optim.Optimizer): __ https://arxiv.org/abs/1412.6980 """ - def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, bias_correction=True): + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + bias_correction=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -80,8 +87,13 @@ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - bias_correction=bias_correction) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + ) super(Adamax, self).__init__(params, defaults) @property @@ -104,12 +116,12 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: - raise RuntimeError('Adamax does not support sparse gradients') + raise RuntimeError("Adamax does not support sparse gradients") p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: @@ -119,18 +131,18 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_inf'] = torch.zeros_like(p_data_fp32) + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_inf"] = torch.zeros_like(p_data_fp32) else: - state['exp_avg'] = state['exp_avg'].to(p_data_fp32) - state['exp_inf'] = state['exp_inf'].to(p_data_fp32) + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_inf"] = state["exp_inf"].to(p_data_fp32) - exp_avg, exp_inf = state['exp_avg'], state['exp_inf'] - beta1, beta2 = group['betas'] - eps = group['eps'] + exp_avg, exp_inf = state["exp_avg"], state["exp_inf"] + beta1, beta2 = group["betas"] + eps = group["eps"] - state['step'] += 1 + state["step"] += 1 # Update biased first moment estimate. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) @@ -142,13 +154,15 @@ def step(self, closure=None): out=exp_inf, ) - step_size = group['lr'] - if group['bias_correction']: - bias_correction = 1 - beta1 ** state['step'] + step_size = group["lr"] + if group["bias_correction"]: + bias_correction = 1 - beta1 ** state["step"] step_size /= bias_correction - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size) diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py index 9d1f0b2c05..c5da604220 100644 --- a/fairseq/optim/dynamic_loss_scaler.py +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -3,11 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -class DynamicLossScaler(object): +class DynamicLossScaler(object): def __init__( - self, init_scale=2.**15, scale_factor=2., scale_window=2000, - tolerance=0.05, threshold=None, min_loss_scale=1e-4 + self, + init_scale=2.0 ** 15, + scale_factor=2.0, + scale_window=2000, + tolerance=0.05, + threshold=None, + min_loss_scale=1e-4, ): self.loss_scale = init_scale self.scale_factor = scale_factor @@ -36,7 +41,7 @@ def _decrease_loss_scale(self): def check_overflow(self, grad_norm): # detect inf and nan - if grad_norm == float('inf') or grad_norm != grad_norm: + if grad_norm == float("inf") or grad_norm != grad_norm: # overflow has occured prev_scale = self.loss_scale iter_since_rescale = self._iter - self._last_rescale_iter @@ -53,11 +58,13 @@ def check_overflow(self, grad_norm): # Use FloatingPointError as an uncommon error that parent # functions can safely catch to stop training. self.loss_scale = prev_scale - raise FloatingPointError(( - 'Minimum loss scale reached ({}). Your loss is probably exploding. ' - 'Try lowering the learning rate, using gradient clipping or ' - 'increasing the batch size.' - ).format(self.min_loss_scale)) + raise FloatingPointError( + ( + "Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try lowering the learning rate, using gradient clipping or " + "increasing the batch size." + ).format(self.min_loss_scale) + ) self._iter += 1 - raise OverflowError('setting loss scale to: ' + str(self.loss_scale)) + raise OverflowError("setting loss scale to: " + str(self.loss_scale)) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index b602e51818..8a10399a8b 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -109,8 +109,8 @@ def step(self, closure=None, scale=1.0): if self.supports_step_with_scale: self.optimizer.step(closure, scale=scale) else: - if scale != 1.: - self.multiply_grads(1. / scale) + if scale != 1.0: + self.multiply_grads(1.0 / scale) self.optimizer.step(closure) def zero_grad(self): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index edb4f536ea..b622fbde44 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -3,41 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from itertools import chain from collections import defaultdict +from itertools import chain import torch - from fairseq import optim, utils from .dynamic_loss_scaler import DynamicLossScaler class _FP16OptimizerMixin(object): - def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) - self._multiply_factor = 1. + self._multiply_factor = 1.0 @property def has_flat_params(self): - return ( - torch.is_tensor(self.fp32_params) or - ( - isinstance(self.fp32_params, dict) and - all(torch.is_tensor(t) for t in self.fp32_params.values()) - ) + return torch.is_tensor(self.fp32_params) or ( + isinstance(self.fp32_params, dict) + and all(torch.is_tensor(t) for t in self.fp32_params.values()) ) @classmethod def build_fp32_params(cls, args, params, flatten=True): # create FP32 copy of parameters and grads if flatten: - is_pipeline_parallel = ( - getattr(args, 'pipeline_model_parallel', False) - and getattr(args, 'distributed_no_spawn', False) - ) + is_pipeline_parallel = getattr( + args, "pipeline_model_parallel", False + ) and getattr(args, "distributed_no_spawn", False) total_param_size = sum(p.data.numel() for p in params) devices = [torch.cuda.current_device()] if is_pipeline_parallel: @@ -45,19 +39,25 @@ def build_fp32_params(cls, args, params, flatten=True): fp32_params = {} for device in devices: if is_pipeline_parallel: - device_param_size = sum(p.data.numel() for p in params if p.device.index == device) + device_param_size = sum( + p.data.numel() for p in params if p.device.index == device + ) device_params = [p for p in params if p.device.index == device] else: device_param_size = total_param_size device_params = params - fp32_params[device] = device_params[0].new(0).float().new(device_param_size) + fp32_params[device] = ( + device_params[0].new(0).float().new(device_param_size) + ) offset = 0 for p in device_params: numel = p.data.numel() - fp32_params[device][offset:offset+numel].copy_(p.data.view(-1)) + fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) offset += numel fp32_params[device] = torch.nn.Parameter(fp32_params[device]) - fp32_params[device].grad = fp32_params[device].data.new(device_param_size) + fp32_params[device].grad = fp32_params[device].data.new( + device_param_size + ) return fp32_params else: fp32_params = [] @@ -71,7 +71,7 @@ def state_dict(self): """Return the optimizer's state dict.""" state_dict = self.fp32_optimizer.state_dict() if self.scaler is not None: - state_dict['loss_scale'] = self.scaler.loss_scale + state_dict["loss_scale"] = self.scaler.loss_scale return state_dict def load_state_dict(self, state_dict, optimizer_overrides=None): @@ -82,8 +82,8 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): allows us to resume training from a checkpoint using a new set of optimizer args. """ - if 'loss_scale' in state_dict and self.scaler is not None: - self.scaler.loss_scale = state_dict['loss_scale'] + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) def backward(self, loss): @@ -111,9 +111,15 @@ def _sync_fp16_grads_to_fp32(self): device_params = device_params_dict[device] offset = 0 for p in device_params: - grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) numel = grad_data.numel() - self.fp32_params[device].grad.data[offset:offset+numel].copy_(grad_data.view(-1)) + self.fp32_params[device].grad.data[ + offset : offset + numel + ].copy_(grad_data.view(-1)) offset += numel else: for p, p32 in zip(self.fp16_params, self.fp32_params): @@ -138,7 +144,11 @@ def _sync_fp32_params_to_fp16(self): offset = 0 for p in device_params: numel = p.data.numel() - p.data.copy_(self.fp32_params[device].data[offset:offset+numel].view_as(p.data)) + p.data.copy_( + self.fp32_params[device] + .data[offset : offset + numel] + .view_as(p.data) + ) offset += numel else: for p, p32 in zip(self.fp16_params, self.fp32_params): @@ -148,9 +158,9 @@ def _sync_fp32_params_to_fp16(self): def _unscale_grads(self): self._sync_fp16_grads_to_fp32() - if self._multiply_factor != 1.: + if self._multiply_factor != 1.0: self.fp32_optimizer.multiply_grads(self._multiply_factor) - self._multiply_factor = 1. + self._multiply_factor = 1.0 def multiply_grads(self, c): """Multiplies grads by a constant ``c``.""" @@ -160,7 +170,9 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" self._sync_fp16_grads_to_fp32() - grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(0, aggregate_norm_fn) + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) if self.scaler is not None: if grad_norm > max_norm > 0.0: @@ -177,8 +189,8 @@ def step(self, closure=None): """Performs a single optimization step.""" self._sync_fp16_grads_to_fp32() - if getattr(self, 'supports_step_with_scale', False): - self.fp32_optimizer.step(closure, scale=(1. / self._multiply_factor)) + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) else: self._unscale_grads() self.fp32_optimizer.step(closure) @@ -199,14 +211,14 @@ def zero_grad(self): for fp32_params in self.fp32_params.values(): fp32_params.grad.zero_() else: - raise("self.fp32_params must be a tensor or dict") + raise ("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: p32.grad.zero_() self._needs_sync = False if self.scaler is not None: - self._multiply_factor = 1. / float(self.scaler.loss_scale) + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): @@ -220,24 +232,26 @@ def __init__(self, args, params, fp32_optimizer, fp32_params): self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params - if getattr(args, 'fp16_scale_window', None) is None: + if getattr(args, "fp16_scale_window", None) is None: if len(args.update_freq) > 1: raise ValueError( - '--fp16-scale-window must be given explicitly when using a ' - 'custom --update-freq schedule' + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" ) - data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) - scale_window = int(2**14 / data_parallel_size / args.update_freq[0]) + data_parallel_size = int( + args.distributed_world_size / args.model_parallel_size + ) + scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0]) else: scale_window = args.fp16_scale_window - if not getattr(args, 'bf16', False): + if not getattr(args, "bf16", False): self.scaler = DynamicLossScaler( init_scale=args.fp16_init_scale, scale_window=scale_window, tolerance=args.fp16_scale_tolerance, threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale + min_loss_scale=args.min_loss_scale, ) else: # disable loss scaling for bfloat16 @@ -250,8 +264,8 @@ def build_optimizer(cls, args, params): args (argparse.Namespace): fairseq args params (iterable): iterable of parameters to optimize """ - flatten = not getattr(args, 'fp16_no_flatten_grads', False) - if getattr(args, 'bf16', False): + flatten = not getattr(args, "fp16_no_flatten_grads", False) + if getattr(args, "bf16", False): flatten = False # mixed precision is faster on TPUs without flat grads fp32_params = cls.build_fp32_params(args, params, flatten=flatten) if flatten: @@ -260,8 +274,8 @@ def build_optimizer(cls, args, params): fp32_optimizer = optim.build_optimizer(args, fp32_params) if flatten and not fp32_optimizer.supports_flat_params: raise RuntimeError( - 'chosen optimizer does not support flat params, ' - 'please set --fp16-no-flatten-grads' + "chosen optimizer does not support flat params, " + "please set --fp16-no-flatten-grads" ) return cls(args, params, fp32_optimizer, fp32_params) @@ -285,11 +299,10 @@ def set_lr(self, lr): class _MemoryEfficientFP16OptimizerMixin(object): - def __init__(self, *args, **kwargs): # forward __init__ call to the next class in MRO (method resolution order) super().__init__(*args, **kwargs) - self._multiply_factor = 1. + self._multiply_factor = 1.0 @property def has_flat_params(self): @@ -299,7 +312,7 @@ def state_dict(self): """Return the optimizer's state dict.""" state_dict = self.wrapped_optimizer.state_dict() if self.scaler is not None: - state_dict['loss_scale'] = self.scaler.loss_scale + state_dict["loss_scale"] = self.scaler.loss_scale return state_dict def load_state_dict(self, state_dict, optimizer_overrides=None): @@ -310,8 +323,8 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): allows us to resume training from a checkpoint using a new set of optimizer args. """ - if 'loss_scale' in state_dict and self.scaler is not None: - self.scaler.loss_scale = state_dict['loss_scale'] + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) @@ -320,17 +333,17 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): # params are FP16 while the optimizer state is FP32 and we don't want # to cast. A workaround is to manually copy back the original state # after the optimizer has been loaded. - if not getattr(self.optimizer, 'disable_mem_eff_fp16_loading_hack', False): + if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): groups = self.optimizer.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] id_map = { old_id: p for old_id, p in zip( - chain(*(g['params'] for g in saved_groups)), - chain(*(g['params'] for g in groups)) + chain(*(g["params"] for g in saved_groups)), + chain(*(g["params"] for g in groups)), ) } - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] self.optimizer.state[param] = v @@ -347,9 +360,9 @@ def backward(self, loss): loss.backward() def _unscale_grads(self): - if self._multiply_factor != 1.: + if self._multiply_factor != 1.0: self.wrapped_optimizer.multiply_grads(self._multiply_factor) - self._multiply_factor = 1. + self._multiply_factor = 1.0 def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" @@ -358,11 +371,13 @@ def multiply_grads(self, c): def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm and updates dynamic loss scaler.""" max_norm = float(max_norm) - grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) if self.scaler is not None: grad_norm_cpu = float(grad_norm) - if grad_norm_cpu > max_norm > 0.: + if grad_norm_cpu > max_norm > 0.0: self._multiply_factor *= max_norm / grad_norm_cpu # detect overflow and adjust loss scale @@ -375,9 +390,9 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): def step(self, closure=None): """Performs a single optimization step.""" - if getattr(self, 'supports_step_with_scale', False): + if getattr(self, "supports_step_with_scale", False): # NOTE(msb) optimizer divides by scale factor - self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor)) + self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) else: self._unscale_grads() self.wrapped_optimizer.step(closure) @@ -389,12 +404,14 @@ def zero_grad(self): """Clears the gradients of all optimized parameters.""" self.wrapped_optimizer.zero_grad() if self.scaler is not None: - self._multiply_factor = 1. / float(self.scaler.loss_scale) + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) else: - self._multiply_factor = 1. + self._multiply_factor = 1.0 -class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): +class MemoryEfficientFP16Optimizer( + _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer +): """ Wrap an *optimizer* to support FP16 (mixed precision) training. @@ -413,30 +430,32 @@ class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.Fai def __init__(self, args, params, optimizer): if not optimizer.supports_memory_efficient_fp16: raise ValueError( - 'Unsupported optimizer: {}'.format(optimizer.__class__.__name__) + "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) super().__init__(args) self.wrapped_optimizer = optimizer - if getattr(args, 'fp16_scale_window', None) is None: + if getattr(args, "fp16_scale_window", None) is None: if len(args.update_freq) > 1: raise ValueError( - '--fp16-scale-window must be given explicitly when using a ' - 'custom --update-freq schedule' + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" ) - data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) - scale_window = 2**14 / data_parallel_size / args.update_freq[0] + data_parallel_size = int( + args.distributed_world_size / args.model_parallel_size + ) + scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0] else: scale_window = args.fp16_scale_window - if not getattr(args, 'bf16', False): + if not getattr(args, "bf16", False): self.scaler = DynamicLossScaler( init_scale=args.fp16_init_scale, scale_window=scale_window, tolerance=args.fp16_scale_tolerance, threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale + min_loss_scale=args.min_loss_scale, ) else: # disable loss scaling for bfloat16 diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index 9024451aff..1780f9c0bb 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -21,6 +21,7 @@ def get_fused_adam_class(): # `--deprecated_fused_adam` option when building apex. global fused_adam_cuda import importlib + fused_adam_cuda = importlib.import_module("fused_adam_cuda") return FusedAdamV1 except ImportError: @@ -28,6 +29,7 @@ def get_fused_adam_class(): # fallback to the newer interface from apex.optimizers import FusedAdam as _FusedAdam # noqa from apex.multi_tensor_apply import multi_tensor_applier + if multi_tensor_applier.available: return FusedAdamV2 except ImportError: @@ -67,23 +69,32 @@ class FusedAdamV1(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, - lr=1e-3, bias_correction=True, - betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False, - weight_decay=0., max_grad_norm=0., amsgrad=False): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0.0, + max_grad_norm=0.0, + amsgrad=False, + ): global fused_adam_cuda import importlib + fused_adam_cuda = importlib.import_module("fused_adam_cuda") if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = { - 'lr': lr, - 'bias_correction': bias_correction, - 'betas': betas, - 'eps': eps, - 'weight_decay': weight_decay, - 'max_grad_norm': max_grad_norm, + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "max_grad_norm": max_grad_norm, } super().__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 @@ -100,7 +111,7 @@ def supports_flat_params(self): def supports_step_with_scale(self): return True - def step(self, closure=None, grads=None, scale=1., grad_norms=None): + def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model @@ -130,23 +141,25 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None): grads_group = grads if grad_norms is None: - grad_norms = [None]*len(self.param_groups) + grad_norms = [None] * len(self.param_groups) - for group, grads_this_group, grad_norm in zip(self.param_groups, grads_group, grad_norms): + for group, grads_this_group, grad_norm in zip( + self.param_groups, grads_group, grad_norms + ): if grads_this_group is None: - grads_this_group = [None]*len(group['params']) + grads_this_group = [None] * len(group["params"]) # compute combined scale factor for this group combined_scale = scale - if group.get('max_grad_norm', 0) > 0: + if group.get("max_grad_norm", 0) > 0: # norm is in fact norm*scale - clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm'] + clip = ((grad_norm / scale) + 1e-6) / group["max_grad_norm"] if clip > 1: combined_scale = clip * scale - bias_correction = 1 if group.get('bias_correction', 1) else 0 + bias_correction = 1 if group.get("bias_correction", 1) else 0 - for p, grad in zip(group['params'], grads_this_group): + for p, grad in zip(group["params"], grads_this_group): # note: p.grad should not ever be set for correct # operation of mixed precision optimizer that sometimes # sends None gradients @@ -156,8 +169,8 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None): grad = p.grad.data if grad.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, ' - 'please consider SparseAdam instead' + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" ) p_data_fp32 = p.data.float() @@ -166,37 +179,39 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None): # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p_data_fp32) + state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: - state['exp_avg'] = state['exp_avg'].to(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].to(p_data_fp32) + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 out_p = p.data with torch.cuda.device(p.device): - fused_adam_cuda.adam(p_data_fp32, - out_p, - exp_avg, - exp_avg_sq, - grad, - group['lr'], - beta1, - beta2, - group['eps'], - combined_scale, - state['step'], - self.eps_mode, - bias_correction, - group['weight_decay']) + fused_adam_cuda.adam( + p_data_fp32, + out_p, + exp_avg, + exp_avg_sq, + grad, + group["lr"], + beta1, + beta2, + group["eps"], + combined_scale, + state["step"], + self.eps_mode, + bias_correction, + group["weight_decay"], + ) return loss @@ -213,8 +228,10 @@ class FusedAdamV2(FusedAdam): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not hasattr(self, 'multi_tensor_adam'): - raise Exception('Apex installation is outdated. Please install an updated version of apex.') + if not hasattr(self, "multi_tensor_adam"): + raise Exception( + "Apex installation is outdated. Please install an updated version of apex." + ) @property def supports_memory_efficient_fp16(self): @@ -224,89 +241,108 @@ def supports_memory_efficient_fp16(self): def supports_flat_params(self): return True - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + ): """Performs a single optimization step.""" loss = None if closure is not None: loss = closure() for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_16, p_16, orig_p_16, m_16, v_16 = [], [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, ' - 'please consider SparseAdam instead' + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float) + state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float) + state["exp_avg_sq"] = torch.zeros_like( + p.data, dtype=torch.float + ) else: - state['exp_avg'] = state['exp_avg'].to(device=p.data.device, dtype=torch.float) - state['exp_avg_sq'] = state['exp_avg_sq'].to(device=p.data.device, dtype=torch.float) + state["exp_avg"] = state["exp_avg"].to( + device=p.data.device, dtype=torch.float + ) + state["exp_avg_sq"] = state["exp_avg_sq"].to( + device=p.data.device, dtype=torch.float + ) if p.dtype == torch.float16: g_16.append(p.grad.data.float()) p_16.append(p.data.float()) orig_p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) else: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + raise RuntimeError("FusedAdam only support fp16 and fp32.") with torch.cuda.device(p.device): - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) for orig_p, p in zip(orig_p_16, p_16): orig_p.copy_(p.data) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) return loss + + except ImportError: pass diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py index d48ecbc8e0..f4f2bdb0c6 100644 --- a/fairseq/optim/fused_lamb.py +++ b/fairseq/optim/fused_lamb.py @@ -3,10 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.optim import register_optimizer, LegacyFairseqOptimizer +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('lamb') +@register_optimizer("lamb") class FairseqLAMB(LegacyFairseqOptimizer): """LAMB optimizer.""" @@ -14,9 +14,10 @@ def __init__(self, args, params): super().__init__(args) try: from apex.optimizers import FusedLAMB + self._optimizer = FusedLAMB(params, **self.optimizer_config) except ImportError: - raise ImportError('Please install apex to use LAMB optimizer') + raise ImportError("Please install apex to use LAMB optimizer") @staticmethod def add_args(parser): @@ -39,10 +40,10 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'betas': eval(self.args.lamb_betas), - 'eps': self.args.lamb_eps, - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "betas": eval(self.args.lamb_betas), + "eps": self.args.lamb_eps, + "weight_decay": self.args.weight_decay, } @property diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index 9a30195fab..7ca7826ed2 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -3,10 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import register_lr_scheduler, LegacyFairseqLRScheduler +from . import LegacyFairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('fixed') +@register_lr_scheduler("fixed") class FixedSchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" @@ -14,11 +14,11 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) # set defaults - args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 + args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 self.lr = args.lr[0] if args.warmup_updates > 0: - self.warmup_factor = 1. / args.warmup_updates + self.warmup_factor = 1.0 / args.warmup_updates else: self.warmup_factor = 1 @@ -35,11 +35,11 @@ def add_args(parser): # fmt: on def state_dict(self): - return {'lr': self.lr} + return {"lr": self.lr} def load_state_dict(self, state_dict): - if 'lr' in state_dict: - self.lr = state_dict['lr'] + if "lr" in state_dict: + self.lr = state_dict["lr"] def get_next_lr(self, epoch): lrs = self.args.lr @@ -48,7 +48,9 @@ def get_next_lr(self, epoch): next_lr = lrs[min(epoch, len(lrs) - 1)] else: # annneal based on lr_shrink - next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) + next_lr = lrs[-1] * self.args.lr_shrink ** ( + epoch + 1 - self.args.force_anneal + ) return next_lr def step(self, epoch, val_loss=None): diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index 73e8b170bc..ea8e647668 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -3,10 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import register_lr_scheduler, LegacyFairseqLRScheduler +from . import LegacyFairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('polynomial_decay') +@register_lr_scheduler("polynomial_decay") class PolynomialDecaySchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" @@ -14,11 +14,11 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) # set defaults - args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 + args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 self.lr = args.lr[0] if args.warmup_updates > 0: - self.warmup_factor = 1. / args.warmup_updates + self.warmup_factor = 1.0 / args.warmup_updates else: self.warmup_factor = 1 self.end_learning_rate = args.end_learning_rate @@ -29,13 +29,23 @@ def __init__(self, args, optimizer): @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" - parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', - help='force annealing at specified epoch') - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--end-learning-rate', default=0.0, type=float) - parser.add_argument('--power', default=1.0, type=float) - parser.add_argument('--total-num-update', default=1000000, type=int) + parser.add_argument( + "--force-anneal", + "--fa", + type=int, + metavar="N", + help="force annealing at specified epoch", + ) + parser.add_argument( + "--warmup-updates", + default=0, + type=int, + metavar="N", + help="warmup the learning rate linearly for the first N updates", + ) + parser.add_argument("--end-learning-rate", default=0.0, type=float) + parser.add_argument("--power", default=1.0, type=float) + parser.add_argument("--total-num-update", default=1000000, type=int) def get_next_lr(self, epoch): lrs = self.args.lr @@ -64,7 +74,9 @@ def step_update(self, num_updates): else: warmup = self.args.warmup_updates lr_range = self.lr - self.end_learning_rate - pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup) + pct_remaining = 1 - (num_updates - warmup) / ( + self.total_num_update - warmup + ) lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate self.optimizer.set_lr(lr) return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 5199b09a3e..82bb36efe9 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -5,10 +5,10 @@ import torch.optim.lr_scheduler -from . import register_lr_scheduler, LegacyFairseqLRScheduler +from . import LegacyFairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('reduce_lr_on_plateau') +@register_lr_scheduler("reduce_lr_on_plateau") class ReduceLROnPlateau(LegacyFairseqLRScheduler): """ Decay the LR by a factor every time the validation loss plateaus. @@ -30,13 +30,16 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) if len(args.lr) > 1: raise ValueError( - 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' - ' Consider --lr-scheduler=fixed instead.' + "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." + " Consider --lr-scheduler=fixed instead." ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, - mode='max' if args.maximize_best_checkpoint_metric else 'min', - threshold=args.lr_threshold) + self.optimizer.optimizer, + patience=args.lr_patience, + factor=args.lr_shrink, + mode="max" if args.maximize_best_checkpoint_metric else "min", + threshold=args.lr_threshold, + ) warmup_end_lr = args.lr[0] # if no warm up, sets initial lr to be args.lr[0] if args.warmup_init_lr < 0: @@ -76,15 +79,15 @@ def add_args(parser): def state_dict(self): """Return the LR scheduler state dict.""" return { - 'best': self.lr_scheduler.best, - 'last_epoch': self.lr_scheduler.last_epoch, + "best": self.lr_scheduler.best, + "last_epoch": self.lr_scheduler.last_epoch, } def load_state_dict(self, state_dict): """Load an LR scheduler state dict.""" - self.lr_scheduler.best = state_dict['best'] - if 'last_epoch' in state_dict: - self.lr_scheduler.last_epoch = state_dict['last_epoch'] + self.lr_scheduler.best = state_dict["best"] + if "last_epoch" in state_dict: + self.lr_scheduler.last_epoch = state_dict["last_epoch"] def step(self, epoch, val_loss=None): """ @@ -103,7 +106,7 @@ def step_update(self, num_updates): # if there is warmup if self.args.warmup_updates > 0: if num_updates <= self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates*self.lr_step + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step self.optimizer.set_lr(self.lr) else: if self.warmup_end is False: diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index 95c5576f20..c573237f11 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import register_lr_scheduler, LegacyFairseqLRScheduler import math +from . import LegacyFairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('tri_stage') + +@register_lr_scheduler("tri_stage") class TriStageLRSchedule(LegacyFairseqLRScheduler): """Tristage learning rate schedulr @@ -50,8 +51,8 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) if len(args.lr) > 1: raise ValueError( - 'Cannot use a fixed learning rate schedule with tri-stage lr.' - ' Consider --lr-scheduler=fixed instead.' + "Cannot use a fixed learning rate schedule with tri-stage lr." + " Consider --lr-scheduler=fixed instead." ) # calculate LR at each point @@ -65,7 +66,8 @@ def __init__(self, args, optimizer): self.decay_steps = args.decay_steps self.warmup_rate = ( - (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 + (self.peak_lr - self.init_lr) / self.warmup_steps + if self.warmup_steps != 0 else 0 ) self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index 67e1df65e1..0f3193f2b8 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -5,10 +5,10 @@ import math -from . import register_lr_scheduler, LegacyFairseqLRScheduler +from . import LegacyFairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler('triangular') +@register_lr_scheduler("triangular") class TriangularSchedule(LegacyFairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. @@ -19,13 +19,13 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) if len(args.lr) > 1: raise ValueError( - 'Cannot use a fixed learning rate schedule with triangular.' - ' Consider --lr-scheduler=fixed instead.' + "Cannot use a fixed learning rate schedule with triangular." + " Consider --lr-scheduler=fixed instead." ) lr = args.lr[0] - assert args.max_lr > lr, 'max_lr must be more than lr' + assert args.max_lr > lr, "max_lr must be more than lr" self.min_lr = lr self.max_lr = args.max_lr self.stepsize = args.lr_period_updates // 2 diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py index b558f41ab0..8e34fb99a1 100644 --- a/fairseq/optim/sgd.py +++ b/fairseq/optim/sgd.py @@ -5,10 +5,10 @@ import torch.optim -from . import register_optimizer, LegacyFairseqOptimizer +from . import LegacyFairseqOptimizer, register_optimizer -@register_optimizer('sgd') +@register_optimizer("sgd") class SGD(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) @@ -33,9 +33,9 @@ def optimizer_config(self): different learning rate. """ return { - 'lr': self.args.lr[0], - 'momentum': self.args.momentum, - 'weight_decay': self.args.weight_decay, + "lr": self.args.lr[0], + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, } @property diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 8c508f41f2..a035a1c1f9 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -6,6 +6,7 @@ try: from fairscale.optim import OSS + _has_fairscale = True except ImportError: _has_fairscale = False @@ -14,8 +15,7 @@ def shard_(args, optimizer, group): if not _has_fairscale: raise ImportError( - '\n\nPlease install the fairscale package:' - '\n\n pip install fairscale' + "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" ) class FairseqOSS(OSS): @@ -26,9 +26,16 @@ def disable_mem_eff_fp16_loading_hack(self): def __getattr__(self, name): if name.startswith("supports") and hasattr(self.optim, name): return getattr(self.optim, name) - raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name)) + raise AttributeError( + "'FairseqOSS' object has no attribute {0!r}".format(name) + ) torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) - - optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, group=group, **optimizer.optimizer_config) + + optimizer.optimizer = FairseqOSS( + torch_optimizer.param_groups, + optim_cls, + group=group, + **optimizer.optimizer_config + ) diff --git a/fairseq/options.py b/fairseq/options.py index 31ed28a80e..1a24fccaec 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -168,7 +168,9 @@ def parse_args_and_arch( args = parser.parse_args(input_args) extra = None # Post-process args. - if (hasattr(args, "batch_size_valid") and args.batch_size_valid is None) or not hasattr(args, "batch_size_valid"): + if ( + hasattr(args, "batch_size_valid") and args.batch_size_valid is None + ) or not hasattr(args, "batch_size_valid"): args.batch_size_valid = args.batch_size if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: args.max_tokens_valid = args.max_tokens diff --git a/fairseq/pdb.py b/fairseq/pdb.py index f1ce3c46bc..1ba6ef0d33 100644 --- a/fairseq/pdb.py +++ b/fairseq/pdb.py @@ -9,7 +9,7 @@ import sys -__all__ = ['set_trace'] +__all__ = ["set_trace"] _stdin = [None] diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py index a7f5ade9b3..69dd61d785 100644 --- a/fairseq/quantization_utils.py +++ b/fairseq/quantization_utils.py @@ -12,7 +12,7 @@ def quantize_model_scalar(model, args): - quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0) + quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) if quant_noise_scalar > 0: # quantize_model edits the model in place scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) @@ -20,12 +20,11 @@ def quantize_model_scalar(model, args): class Quantizer(object): - def __init__(self, config_path, max_epoch, max_update): try: import yaml except ImportError: - raise ImportError('Please install yaml with: pip install yaml') + raise ImportError("Please install yaml with: pip install yaml") # parse config if config_path: @@ -46,22 +45,23 @@ def __init__(self, config_path, max_epoch, max_update): num_iterations = len(self.layers_to_quantize) if max_epoch > 0: assert max_epoch % num_iterations == 0, ( - 'for iterative PQ, --max-epoch (={}) must be evenly divisible by ' - 'len(layers_to_quantize) (={})'.format(max_epoch, num_iterations) + "for iterative PQ, --max-epoch (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_epoch, num_iterations) ) self.epoch_schedule = max_epoch // num_iterations else: self.epoch_schedule = None if max_update > 0: assert max_update % num_iterations == 0, ( - 'for iterative PQ, --max-update (={}) must be evenly divisible by ' - 'len(layers_to_quantize) (={})'.format(max_update, num_iterations) + "for iterative PQ, --max-update (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_update, num_iterations) ) self.update_schedule = max_update // num_iterations else: self.update_schedule = None - assert (self.epoch_schedule is not None) ^ (self.update_schedule is not None), \ - 'for iterative PQ, cannot specify both --max-update and --max-epoch' + assert (self.epoch_schedule is not None) ^ ( + self.update_schedule is not None + ), "for iterative PQ, cannot specify both --max-update and --max-epoch" # 0 is a special value for quantization step, which will force # the first call to begin_epoch() to call step() @@ -80,7 +80,7 @@ def step(self): return logger.info( - 'quantizing model (step={}; layers_to_quantize[step]={})'.format( + "quantizing model (step={}; layers_to_quantize[step]={})".format( self.quantization_step, self.layers_to_quantize[self.quantization_step] ) ) @@ -92,7 +92,7 @@ def step(self): self.n_centroids_config, step=self.quantization_step, ) - logger.info('quantized layers: {}'.format(quantized_layers)) + logger.info("quantized layers: {}".format(quantized_layers)) logger.info(self.size_tracker) self.quantization_step += 1 @@ -125,18 +125,18 @@ def step_update(self, num_updates): def state_dict(self): return { - 'n_centroids_config': self.n_centroids_config, - 'block_sizes_config': self.block_sizes_config, - 'layers_to_quantize': self.layers_to_quantize, - 'epoch_schedule': self.epoch_schedule, - 'update_schedule': self.update_schedule, - 'quantization_step': self.quantization_step, + "n_centroids_config": self.n_centroids_config, + "block_sizes_config": self.block_sizes_config, + "layers_to_quantize": self.layers_to_quantize, + "epoch_schedule": self.epoch_schedule, + "update_schedule": self.update_schedule, + "quantization_step": self.quantization_step, } def load_state_dict(self, state_dict): - self.n_centroids_config = state_dict['n_centroids_config'] - self.block_sizes_config = state_dict['block_sizes_config'] - self.layers_to_quantize = state_dict['layers_to_quantize'] - self.epoch_schedule = state_dict['epoch_schedule'] - self.update_schedule = state_dict['update_schedule'] - self.quantization_step = state_dict['quantization_step'] + self.n_centroids_config = state_dict["n_centroids_config"] + self.block_sizes_config = state_dict["block_sizes_config"] + self.layers_to_quantize = state_dict["layers_to_quantize"] + self.epoch_schedule = state_dict["epoch_schedule"] + self.update_schedule = state_dict["update_schedule"] + self.quantization_step = state_dict["quantization_step"] diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 4468f2ad21..4be0cb5188 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -49,6 +49,7 @@ def build_scorer(args, tgt_dict): args.scoring = "sacrebleu" if args.scoring == "bleu": from fairseq.scoring import bleu + return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) return _build_scorer(args) diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index a45d44b003..7f8bd73bf5 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -8,7 +8,6 @@ import sys import torch - from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer @@ -33,11 +32,12 @@ class SacrebleuScorer(BaseScorer): def __init__(self, args): super(SacrebleuScorer, self).__init__(args) import sacrebleu + self.sacrebleu = sacrebleu self.tokenizer = EvaluationTokenizer( tokenizer_type=self.args.sacrebleu_tokenizer, lowercase=self.args.sacrebleu_lowercase, - character_tokenization=self.args.sacrebleu_char_level + character_tokenization=self.args.sacrebleu_char_level, ) @staticmethod @@ -63,8 +63,9 @@ def result_string(self, order=4): if order != 4: raise NotImplementedError # tokenization and lowercasing are performed by self.tokenizer instead. - return self.sacrebleu.corpus_bleu(self.pred, [self.ref], - tokenize='none').format() + return self.sacrebleu.corpus_bleu( + self.pred, [self.ref], tokenize="none" + ).format() @register_scorer("bleu") @@ -78,7 +79,9 @@ def __init__(self, pad, eos, unk): try: from fairseq import libbleu except ImportError as e: - sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") + sys.stderr.write( + "ERROR: missing libbleu.so. run `pip install --editable .`\n" + ) raise e self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) diff --git a/fairseq/scoring/chrf.py b/fairseq/scoring/chrf.py index b932a43604..0d6cb77383 100644 --- a/fairseq/scoring/chrf.py +++ b/fairseq/scoring/chrf.py @@ -6,11 +6,12 @@ from fairseq.scoring import BaseScorer, register_scorer -@register_scorer('chrf') +@register_scorer("chrf") class ChrFScorer(BaseScorer): def __init__(self, args): super(ChrFScorer, self).__init__(args) import sacrebleu + self.sacrebleu = sacrebleu def add_string(self, ref, pred): diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py index c9d5218e1e..dbcc6e4d10 100644 --- a/fairseq/scoring/tokenizer.py +++ b/fairseq/scoring/tokenizer.py @@ -19,13 +19,18 @@ class EvaluationTokenizer(object): category) from text. character_tokenization (bool): tokenize the text to characters. """ + SPACE = chr(32) SPACE_ESCAPE = chr(9601) - ALL_TOKENIZER_TYPES = ['none', '13a', 'intl', 'zh', 'ja-mecab'] + ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] - def __init__(self, tokenizer_type: str = '13a', lowercase: bool = False, - punctuation_removal: bool = False, - character_tokenization: bool = False): + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): from sacrebleu.tokenizers import TOKENIZERS assert tokenizer_type in self.ALL_TOKENIZER_TYPES @@ -38,8 +43,9 @@ def __init__(self, tokenizer_type: str = '13a', lowercase: bool = False, def remove_punctuation(cls, sent: str): """Remove punctuation based on Unicode category.""" return cls.SPACE.join( - t for t in sent.split(cls.SPACE) - if not all(unicodedata.category(c)[0] == 'P' for c in t) + t + for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == "P" for c in t) ) def tokenize(self, sent: str): diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 61c5fd950e..21efefd9b8 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.scoring import register_scorer, BaseScorer +from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer @@ -15,7 +15,7 @@ def __init__(self, args): try: import editdistance as ed except ImportError: - raise ImportError('Please install editdistance to use WER scorer') + raise ImportError("Please install editdistance to use WER scorer") self.ed = ed self.tokenizer = EvaluationTokenizer( tokenizer_type=self.args.wer_tokenizer, @@ -52,7 +52,4 @@ def result_string(self): return f"WER: {self.score():.2f}" def score(self): - return ( - 100.0 * self.distance / self.ref_length if self.ref_length > 0 - else 0 - ) + return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 diff --git a/fairseq/search.py b/fairseq/search.py index 2c21b66bbd..d5ea68b4ce 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -4,17 +4,16 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Optional, List +from typing import List, Optional import torch import torch.nn as nn -from torch import Tensor - from fairseq.token_generation_constraints import ( ConstraintState, - UnorderedConstraintState, OrderedConstraintState, + UnorderedConstraintState, ) +from torch import Tensor class Search(nn.Module): diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 7ce797746f..ddfb67853f 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -34,7 +34,7 @@ def __init__( eos=None, symbols_to_strip_from_output=None, lm_model=None, - lm_weight=1.0 + lm_weight=1.0, ): """Generates translations of a given source sentence. @@ -69,7 +69,9 @@ def __init__( self.eos = tgt_dict.eos() if eos is None else eos self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) - if symbols_to_strip_from_output is not None else {self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad @@ -92,7 +94,9 @@ def __init__( # We only need to set src_lengths in LengthConstrainedBeamSearch. # As a module attribute, setting it would break in multithread # settings when the model is shared. - self.should_set_src_lengths = hasattr(self.search, 'needs_src_lengths') and self.search.needs_src_lengths + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) self.model.eval() @@ -188,19 +192,21 @@ def _generate( ) net_input = sample["net_input"] - if 'src_tokens' in net_input: - src_tokens = net_input['src_tokens'] + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] # length of the source text being the character length except EndOfSentence and pad - src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) - elif 'source' in net_input: - src_tokens = net_input['source'] src_lengths = ( - net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) - if net_input['padding_mask'] is not None + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None else torch.tensor(src_tokens.size(-1)).to(src_tokens) ) else: - raise Exception('expected src_tokens or source in net input') + raise Exception("expected src_tokens or source in net input") # bsz: total number of sentences in beam # Note that src_tokens may have more than 2 dimenions (i.e. audio features) @@ -208,7 +214,9 @@ def _generate( beam_size = self.beam_size if constraints is not None and not self.search.supports_constraints: - raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them") + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) # Initialize constraints, when active self.search.init_constraints(constraints, beam_size) @@ -421,10 +429,14 @@ def _generate( new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass - batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device) + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) batch_mask[finalized_sents] = False # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it - batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask) + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) # Choose the subset of the hypothesized constraints that will continue self.search.prune_sentences(batch_idxs) @@ -519,10 +531,14 @@ def _generate( # sort by score descending for sent in range(len(finalized)): - scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]]) + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) _, sorted_scores_indices = torch.sort(scores, descending=True) finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] - finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent]) + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) return finalized def _prefix_tokens( @@ -787,10 +803,7 @@ def max_decoder_positions(self): def forward_encoder(self, net_input: Dict[str, Tensor]): if not self.has_encoder(): return None - return [ - model.encoder.forward_torchscript(net_input) - for model in self.models - ] + return [model.encoder.forward_torchscript(net_input) for model in self.models] @torch.jit.export def forward_decoder( @@ -915,9 +928,12 @@ def generate(self, models, sample, **kwargs): src_tokens = sample["net_input"]["src_tokens"] bsz = src_tokens.shape[0] beam_size = self.beam_size - src_tokens, src_lengths, prev_output_tokens, tgt_tokens = self._prepare_batch_for_alignment( - sample, finalized - ) + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) if any(getattr(m, "full_context_alignment", False) for m in self.model.models): attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) else: @@ -927,9 +943,9 @@ def generate(self, models, sample, **kwargs): ] if src_tokens.device != "cpu": - src_tokens = src_tokens.to('cpu') - tgt_tokens = tgt_tokens.to('cpu') - attn = [i.to('cpu') for i in attn] + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index c8ded1930c..411d4df444 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import sys +import torch from fairseq import utils @@ -13,7 +13,11 @@ class SequenceScorer(object): """Scores the target for a given source sentence.""" def __init__( - self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None, + self, + tgt_dict, + softmax_batch=None, + compute_alignment=False, + eos=None, symbols_to_strip_from_output=None, ): self.pad = tgt_dict.pad() @@ -23,12 +27,14 @@ def __init__( self.compute_alignment = compute_alignment self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) - if symbols_to_strip_from_output is not None else {self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) @torch.no_grad() def generate(self, models, sample, **kwargs): """Score a batch of translations.""" - net_input = sample['net_input'] + net_input = sample["net_input"] def batch_for_softmax(dec_out, target): # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) @@ -52,7 +58,7 @@ def gather_target_probs(probs, target): ) return probs - orig_target = sample['target'] + orig_target = sample["target"] # compute scores for each model in the ensemble avg_probs = None @@ -62,13 +68,15 @@ def gather_target_probs(probs, target): decoder_out = model(**net_input) attn = decoder_out[1] if len(decoder_out) > 1 else None if type(attn) is dict: - attn = attn.get('attn', None) + attn = attn.get("attn", None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 for bd, tgt, is_single in batched: - sample['target'] = tgt - curr_prob = model.get_normalized_probs(bd, log_probs=len(models) == 1, sample=sample).data + sample["target"] = tgt + curr_prob = model.get_normalized_probs( + bd, log_probs=len(models) == 1, sample=sample + ).data if is_single: probs = gather_target_probs(curr_prob, orig_target) else: @@ -76,12 +84,14 @@ def gather_target_probs(probs, target): probs = curr_prob.new(orig_target.numel()) step = curr_prob.size(0) * curr_prob.size(1) end = step + idx - tgt_probs = gather_target_probs(curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt) + tgt_probs = gather_target_probs( + curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt + ) probs[idx:end] = tgt_probs.view(-1) idx = end - sample['target'] = orig_target + sample["target"] = orig_target - probs = probs.view(sample['target'].shape) + probs = probs.view(sample["target"].shape) if avg_probs is None: avg_probs = probs @@ -104,21 +114,24 @@ def gather_target_probs(probs, target): bsz = avg_probs.size(0) hypos = [] - start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz + start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz for i in range(bsz): # remove padding from ref - ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \ - if sample['target'] is not None else None + ref = ( + utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) + if sample["target"] is not None + else None + ) tgt_len = ref.numel() - avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] + avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] if self.compute_alignment: alignment = utils.extract_hard_alignment( avg_attn_i, - sample['net_input']['src_tokens'][i], - sample['target'][i], + sample["net_input"]["src_tokens"][i], + sample["target"][i], self.pad, self.eos, ) @@ -126,11 +139,15 @@ def gather_target_probs(probs, target): alignment = None else: avg_attn_i = alignment = None - hypos.append([{ - 'tokens': ref, - 'score': score_i, - 'attention': avg_attn_i, - 'alignment': alignment, - 'positional_scores': avg_probs_i, - }]) + hypos.append( + [ + { + "tokens": ref, + "score": score_i, + "attention": avg_attn_i, + "alignment": alignment, + "positional_scores": avg_probs_i, + } + ] + ) return hypos diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 75bcfaa8db..ff2342afa9 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -8,7 +8,8 @@ import os import sys -from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset +from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset + from . import LegacyFairseqTask, register_task @@ -24,9 +25,7 @@ def __call__(self, label): @register_task("audio_pretraining") class AudioPretrainingTask(LegacyFairseqTask): - """ - - """ + """""" @staticmethod def add_args(parser): @@ -137,11 +136,11 @@ def max_positions(self): return (sys.maxsize, sys.maxsize) def filter_indices_by_size( - self, - indices, - dataset, - max_positions=None, - ignore_invalid_inputs=False, + self, + indices, + dataset, + max_positions=None, + ignore_invalid_inputs=False, ): # we do not need to filter by size in this task as dataloaders take care of this return indices diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index a7ce1f1ad5..8f8fe7e2de 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -3,31 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import OrderedDict import itertools import logging import os +from collections import OrderedDict import numpy as np - -from fairseq import tokenizer -from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary - -from fairseq.data import ( - Dictionary, - ConcatDataset, - data_utils, - TokenBlockDataset, -) +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset -from fairseq.tasks import register_task, LegacyFairseqTask -from fairseq import utils +from fairseq.tasks import LegacyFairseqTask, register_task + logger = logging.getLogger(__name__) -@register_task('cross_lingual_lm') +@register_task("cross_lingual_lm") class CrossLingualLMTask(LegacyFairseqTask): """ Task for training cross-lingual language models. @@ -41,17 +34,29 @@ class CrossLingualLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments' - ' per sample') - parser.add_argument('--monolingual-langs', default='en', type=str, - help='comma separated list of languages for which we' - ' want to train XLM on') - parser.add_argument('--shuffle', action='store_true', - help='shuffle each monolingual dataset while' - ' training') + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" " per sample", + ) + parser.add_argument( + "--monolingual-langs", + default="en", + type=str, + help="comma separated list of languages for which we" + " want to train XLM on", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle each monolingual dataset while" " training", + ) def __init__(self, args, dictionary): super().__init__(args) @@ -60,16 +65,13 @@ def __init__(self, args, dictionary): self.distributed_world_size = args.distributed_world_size self.langs2id = self._lang_to_id(args.monolingual_langs) - def _lang_to_id( - self, - languages: str - ): + def _lang_to_id(self, languages: str): """ Build a map from languages to ids. These ids are used as segment labels for cross-lingual LM training. """ lang2id = {} - langs = [l.strip() for l in languages.split(',')] + langs = [l.strip() for l in languages.split(",")] for id, lang in enumerate(langs): lang2id[lang] = id return lang2id @@ -79,10 +81,14 @@ def load_dictionary(cls, filename): return MaskedLMDictionary.load(filename) @classmethod - def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): d = MaskedLMDictionary() for filename in filenames: - Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d @@ -93,8 +99,8 @@ def target_dictionary(self): @classmethod def setup_task(cls, args, **kwargs): """Setup the task.""" - dictionary = MaskedLMDictionary.load(os.path.join(args.data, 'dict.txt')) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def _load_single_lang_dataset(self, split, epoch): @@ -105,27 +111,36 @@ def _load_single_lang_dataset(self, split, epoch): data_path = paths[(epoch - 1) % len(paths)] for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) - ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl) + ds = data_utils.load_indexed_dataset( + path, self.dictionary, self.args.dataset_impl + ) if ds is None: if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( - ds, ds.sizes, self.args.tokens_per_sample - 1, - pad=self.dictionary.pad(), eos=self.dictionary.eos(), + ds, + ds.sizes, + self.args.tokens_per_sample - 1, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), ) ) - logger.info('{} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] @@ -146,9 +161,11 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): for lang in self.langs2id.keys(): # Datasets are expected to be in "split.lang" format (Eg: train.en) - language_split = '{}.{}'.format(split, lang) + language_split = "{}.{}".format(split, lang) - block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch) + block_dataset, sizes = self._load_single_lang_dataset( + split=language_split, epoch=epoch + ) dataset_map[lang] = MaskedLMDataset( dataset=block_dataset, @@ -158,13 +175,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.eos(), sep_token_idx=self.dictionary.eos(), - shuffle=getattr(self.args, 'shuffle', False), + shuffle=getattr(self.args, "shuffle", False), has_pairs=False, segment_id=self.langs2id[lang], seed=self.seed, ) self.datasets[split] = MultiCorpusSampledDataset(dataset_map) - logger.info('{} {} {} examples'.format( - utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split])) + logger.info( + "{} {} {} examples".format( + utils.split_paths(self.args.data)[epoch - 1], + split, + len(self.datasets[split]), + ) ) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index ea6db45c75..3e88bf0ed0 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -6,24 +6,24 @@ import logging import os +from fairseq import utils from fairseq.data import ( - data_utils, - Dictionary, AppendTokenDataset, DenoisingDataset, + Dictionary, PrependTokenDataset, StripTokenDataset, TokenBlockDataset, + data_utils, ) from fairseq.data.encoders.utils import get_whole_word_mask -from fairseq.tasks import register_task, LegacyFairseqTask -from fairseq import utils +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('denoising') +@register_task("denoising") class DenoisingTask(LegacyFairseqTask): """ Denoising task for applying sequence to sequence denoising. (ie. BART) @@ -32,58 +32,88 @@ class DenoisingTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='path to data directory') - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments' - ' per sample for dataset') + parser.add_argument("data", help="path to data directory") parser.add_argument( - '--sample-break-mode', default="complete_doc", type=str, - help='mode for breaking sentence', + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" + " per sample for dataset", ) parser.add_argument( - '--mask', default=0.0, type=float, - help='fraction of words/subwords that will be masked', + "--sample-break-mode", + default="complete_doc", + type=str, + help="mode for breaking sentence", ) parser.add_argument( - '--mask-random', default=0.0, type=float, - help='instead of using [MASK], use random token this often' + "--mask", + default=0.0, + type=float, + help="fraction of words/subwords that will be masked", ) parser.add_argument( - '--insert', default=0.0, type=float, - help='insert this percentage of additional random tokens', + "--mask-random", + default=0.0, + type=float, + help="instead of using [MASK], use random token this often", ) parser.add_argument( - '--permute', default=0.0, type=float, - help='take this proportion of subwords and permute them', + "--insert", + default=0.0, + type=float, + help="insert this percentage of additional random tokens", ) parser.add_argument( - '--rotate', default=0.5, type=float, - help='rotate this proportion of inputs', + "--permute", + default=0.0, + type=float, + help="take this proportion of subwords and permute them", ) parser.add_argument( - '--poisson-lambda', default=3.0, type=float, - help='randomly shuffle sentences for this proportion of inputs' + "--rotate", + default=0.5, + type=float, + help="rotate this proportion of inputs", ) parser.add_argument( - '--permute-sentences', default=0.0, type=float, - help='shuffle this proportion of sentences in all inputs' + "--poisson-lambda", + default=3.0, + type=float, + help="randomly shuffle sentences for this proportion of inputs", ) parser.add_argument( - '--mask-length', default="subword", type=str, - choices=['subword', 'word', 'span-poisson'], - help='mask length to choose' + "--permute-sentences", + default=0.0, + type=float, + help="shuffle this proportion of sentences in all inputs", ) parser.add_argument( - '--replace-length', default=-1, type=int, - help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)' + "--mask-length", + default="subword", + type=str, + choices=["subword", "word", "span-poisson"], + help="mask length to choose", ) parser.add_argument( - '--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence' + "--replace-length", + default=-1, + type=int, + help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", ) parser.add_argument( - '--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence' + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", ) def __init__(self, args, dictionary): @@ -92,15 +122,14 @@ def __init__(self, args, dictionary): self.seed = args.seed # add mask token - self.mask_idx = self.dictionary.add_symbol('') + self.mask_idx = self.dictionary.add_symbol("") @classmethod def setup_task(cls, args, **kwargs): - """Setup the task. - """ - dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) - logger.info('dictionary: {} types'.format(len(dictionary))) - if not hasattr(args, 'shuffle_instance'): + """Setup the task.""" + dictionary = Dictionary.load(os.path.join(args.data, "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(args, "shuffle_instance"): args.shuffle_instance = False return cls(args, dictionary) @@ -122,32 +151,42 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, ) if dataset is None: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) dataset = StripTokenDataset(dataset, self.dictionary.eos()) # create continuous blocks of tokens dataset = TokenBlockDataset( - dataset, - dataset.sizes, - self.args.tokens_per_sample - 2, # one less for and one for - pad=self.dictionary.pad(), - eos=self.dictionary.eos(), - break_mode=self.args.sample_break_mode, - document_sep_len=0 + dataset, + dataset.sizes, + self.args.tokens_per_sample - 2, # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + document_sep_len=0, ) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) - mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ - if self.args.mask_length != 'subword' else None + mask_whole_words = ( + get_whole_word_mask(self.args, self.source_dictionary) + if self.args.mask_length != "subword" + else None + ) self.datasets[split] = DenoisingDataset( - dataset, dataset.sizes, self.dictionary, self.mask_idx, - mask_whole_words, shuffle=self.args.shuffle_instance, - seed=self.seed, args=self.args + dataset, + dataset.sizes, + self.dictionary, + self.mask_idx, + mask_whole_words, + shuffle=self.args.shuffle_instance, + seed=self.seed, + args=self.args, ) logger.info( "Split: {0}, Loaded {1} samples of denoising_dataset".format( diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index a8bfaa532d..0a96aeb1ea 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -10,9 +10,10 @@ import torch from fairseq import metrics, search, tokenizer, utils -from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators, encoders +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass.utils import gen_parser_from_dataclass + logger = logging.getLogger(__name__) @@ -358,7 +359,7 @@ def build_generator( ) elif prefix_allowed_tokens_fn: search_strategy = search.PrefixConstrainedBeamSearch( - self.target_dictionary, prefix_allowed_tokens_fn + self.target_dictionary, prefix_allowed_tokens_fn ) else: search_strategy = search.BeamSearch(self.target_dictionary) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 5477c28aa9..8792c6481c 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -27,7 +27,7 @@ ) from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.dataclass import FairseqDataclass, ChoiceEnum +from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.tasks import FairseqTask, register_task from omegaconf import II diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py index 4e0390cdca..9754976549 100644 --- a/fairseq/tasks/legacy_masked_lm.py +++ b/fairseq/tasks/legacy_masked_lm.py @@ -8,26 +8,18 @@ import os import numpy as np - -from fairseq import tokenizer -from fairseq.data import ( - ConcatDataset, - indexed_dataset, - data_utils, -) - -from fairseq.data import Dictionary +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset from fairseq.data.legacy.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dictionary import BertDictionary -from fairseq.tasks import register_task, LegacyFairseqTask -from fairseq import utils +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('legacy_masked_lm') +@register_task("legacy_masked_lm") class LegacyMaskedLMTask(LegacyFairseqTask): """ Task for training Masked LM (BERT) model. @@ -38,13 +30,22 @@ class LegacyMaskedLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments' - ' per sample for BERT dataset') - parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence') - parser.add_argument('--shuffle-dataset', action='store_true', default=False) + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" + " per sample for BERT dataset", + ) + parser.add_argument( + "--break-mode", default="doc", type=str, help="mode for breaking sentence" + ) + parser.add_argument("--shuffle-dataset", action="store_true", default=False) def __init__(self, args, dictionary): super().__init__(args) @@ -56,10 +57,14 @@ def load_dictionary(cls, filename): return BertDictionary.load(filename) @classmethod - def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): d = BertDictionary() for filename in filenames: - Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d @@ -69,12 +74,11 @@ def target_dictionary(self): @classmethod def setup_task(cls, args, **kwargs): - """Setup the task. - """ + """Setup the task.""" paths = utils.split_paths(args.data) assert len(paths) > 0 - dictionary = BertDictionary.load(os.path.join(paths[0], 'dict.txt')) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) @@ -92,7 +96,7 @@ def load_dataset(self, split, epoch=1, combine=False): logger.info("data_path", data_path) for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset( path, @@ -105,7 +109,9 @@ def load_dataset(self, split, epoch=1, combine=False): if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( @@ -119,7 +125,9 @@ def load_dataset(self, split, epoch=1, combine=False): ) ) - logger.info('{} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) if not combine: break diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 10b234a96b..56086f5e81 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -7,64 +7,99 @@ import os import numpy as np - +from fairseq import utils from fairseq.data import ( - data_utils, Dictionary, IdDataset, MaskTokensDataset, NestedDictionaryDataset, NumelDataset, NumSamplesDataset, - RightPadDataset, PrependTokenDataset, + RightPadDataset, SortDataset, TokenBlockDataset, + data_utils, ) -from fairseq.tasks import register_task, LegacyFairseqTask -from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.data.encoders.utils import get_whole_word_mask -from fairseq import utils +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('masked_lm') +@register_task("masked_lm") class MaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') - parser.add_argument('--sample-break-mode', default='complete', - choices=['none', 'complete', 'complete_doc', 'eos'], - help='If omitted or "none", fills each sample with tokens-per-sample ' - 'tokens. If set to "complete", splits samples only at the end ' - 'of sentence, but may include multiple sentences per sample. ' - '"complete_doc" is similar but respects doc boundaries. ' - 'If set to "eos", includes only one sentence per sample.') - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments ' - 'per sample for BERT dataset') - parser.add_argument('--mask-prob', default=0.15, type=float, - help='probability of replacing a token with mask') - parser.add_argument('--leave-unmasked-prob', default=0.1, type=float, - help='probability that a masked token is unmasked') - parser.add_argument('--random-token-prob', default=0.1, type=float, - help='probability of replacing a token with a random token') - parser.add_argument('--freq-weighted-replacement', default=False, action='store_true', - help='sample random replacement words based on word frequencies') - parser.add_argument('--mask-whole-words', default=False, action='store_true', - help='mask whole words; you may also want to set --bpe') - parser.add_argument('--shorten-method', default='none', - choices=['none', 'truncate', 'random_crop'], - help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-list', default='', - help='comma-separated list of dataset splits to apply shortening to, ' - 'e.g., "train,valid" (default: all dataset splits)') + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--sample-break-mode", + default="complete", + choices=["none", "complete", "complete_doc", "eos"], + help='If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.', + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.1, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.1, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--freq-weighted-replacement", + default=False, + action="store_true", + help="sample random replacement words based on word frequencies", + ) + parser.add_argument( + "--mask-whole-words", + default=False, + action="store_true", + help="mask whole words; you may also want to set --bpe", + ) + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) def __init__(self, args, dictionary): super().__init__(args) @@ -72,14 +107,14 @@ def __init__(self, args, dictionary): self.seed = args.seed # add mask token - self.mask_idx = dictionary.add_symbol('') + self.mask_idx = dictionary.add_symbol("") @classmethod def setup_task(cls, args, **kwargs): paths = utils.split_paths(args.data) assert len(paths) > 0 - dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): @@ -100,7 +135,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, ) if dataset is None: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) dataset = maybe_shorten_dataset( dataset, @@ -120,14 +157,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) - logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets - mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ - if self.args.mask_whole_words else None + mask_whole_words = ( + get_whole_word_mask(self.args, self.source_dictionary) + if self.args.mask_whole_words + else None + ) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, @@ -148,20 +188,20 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): self.datasets[split] = SortDataset( NestedDictionaryDataset( { - 'id': IdDataset(), - 'net_input': { - 'src_tokens': RightPadDataset( + "id": IdDataset(), + "net_input": { + "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), - 'src_lengths': NumelDataset(src_dataset, reduce=False), + "src_lengths": NumelDataset(src_dataset, reduce=False), }, - 'target': RightPadDataset( + "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(src_dataset, reduce=True), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), @@ -179,17 +219,17 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): self.args.tokens_per_sample - 1, # one less for pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), - break_mode='eos', + break_mode="eos", ), pad_idx=self.source_dictionary.pad(), ) src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( { - 'id': IdDataset(), - 'net_input': { - 'src_tokens': src_dataset, - 'src_lengths': NumelDataset(src_dataset, reduce=False), + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py index 18ee717fff..d1c914917f 100644 --- a/fairseq/tasks/multilingual_denoising.py +++ b/fairseq/tasks/multilingual_denoising.py @@ -7,62 +7,74 @@ import os import numpy as np - from fairseq.data import ( - data_utils, - Dictionary, AppendTokenDataset, ConcatDataset, DenoisingDataset, + Dictionary, PrependTokenDataset, ResamplingDataset, SortDataset, TokenBlockDataset, + data_utils, ) -from .denoising import DenoisingTask from fairseq.data.encoders.utils import get_whole_word_mask from fairseq.tasks import register_task +from .denoising import DenoisingTask + logger = logging.getLogger(__name__) -@register_task('multilingual_denoising') +@register_task("multilingual_denoising") class MultilingualDenoisingTask(DenoisingTask): - @staticmethod def add_args(parser): DenoisingTask.add_args(parser) - parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0, - help='smoothing alpha for sample ratios across multiple datasets') - parser.add_argument('--add-lang-token', default=False, action='store_true') - parser.add_argument('--langs', type=str, help="language ids we are considering", default=None) - parser.add_argument('--no-whole-word-mask-langs', type=str, default='', metavar='N', - help='languages without spacing between words dont support whole word masking') + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=1.0, + help="smoothing alpha for sample ratios across multiple datasets", + ) + parser.add_argument("--add-lang-token", default=False, action="store_true") + parser.add_argument( + "--langs", type=str, help="language ids we are considering", default=None + ) + parser.add_argument( + "--no-whole-word-mask-langs", + type=str, + default="", + metavar="N", + help="languages without spacing between words dont support whole word masking", + ) @classmethod def setup_task(cls, args, **kwargs): - """Setup the task. - """ - paths = args.data.split(':') + """Setup the task.""" + paths = args.data.split(":") assert len(paths) > 0 - dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) data_path = paths[0] if args.langs is None: - languages = sorted([ - name for name in os.listdir(data_path) - if os.path.isdir(os.path.join(data_path, name)) - ]) + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) else: - languages = args.langs.split(',') + languages = args.langs.split(",") if args.add_lang_token: for lang in languages: - dictionary.add_symbol('[{}]'.format(lang)) + dictionary.add_symbol("[{}]".format(lang)) logger.info("dictionary: {} types".format(len(dictionary))) - if not hasattr(args, 'shuffle_instance'): + if not hasattr(args, "shuffle_instance"): args.shuffle_instance = False return cls(args, dictionary) @@ -72,7 +84,7 @@ def __init__(self, args, dictionary): self.seed = args.seed # add mask token - self.mask_idx = self.dictionary.add_symbol('') + self.mask_idx = self.dictionary.add_symbol("") self.langs = args.langs self.args = args @@ -92,30 +104,32 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = self.args.data.split(':') + paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: - languages = sorted([ - name for name in os.listdir(data_path) - if os.path.isdir(os.path.join(data_path, name)) - ]) + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) else: - languages = self.langs.split(',') + languages = self.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) - logger.info("Language to id mapping: ", { - lang: id for id, lang in enumerate(languages) - } + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) - language_without_segmentations = self.args.no_whole_word_mask_langs.split(',') + language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) @@ -127,10 +141,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, ) if dataset is None: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) - end_token = self.source_dictionary.index('[{}]'.format(language)) \ - if self.args.add_lang_token else self.source_dictionary.eos() + end_token = ( + self.source_dictionary.index("[{}]".format(language)) + if self.args.add_lang_token + else self.source_dictionary.eos() + ) # create continuous blocks of tokens dataset = TokenBlockDataset( @@ -141,13 +160,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): eos=end_token, break_mode=self.args.sample_break_mode, ) - logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) - lang_mask_whole_words = mask_whole_words if language not in language_without_segmentations else None + lang_mask_whole_words = ( + mask_whole_words + if language not in language_without_segmentations + else None + ) lang_dataset = DenoisingDataset( dataset, dataset.sizes, @@ -157,7 +180,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, - eos=None if not self.args.add_lang_token else self.source_dictionary.index('[{}]'.format(language)), + eos=None + if not self.args.add_lang_token + else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) @@ -166,7 +191,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dtype=float, ) logger.info( - 'loaded total {} blocks for all languages'.format( + "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) @@ -174,17 +199,21 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( - "Sample probability by language: {}".format({ - lang: "{0:.4f}".format(sample_probs[id]) - for id, lang in enumerate(languages) - }) + "Sample probability by language: {}".format( + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + } + ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( - "Up/Down Sampling ratio by language: {}".format({ - lang: "{0:.2f}".format(size_ratio[id]) - for id, lang in enumerate(languages) - }) + "Up/Down Sampling ratio by language: {}".format( + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + } + ) ) resampled_lang_datasets = [ @@ -204,13 +233,13 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): - split_name = split + '_' + languages[lang_id] + split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( - split, ','.join(lang_splits) + split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py index 110e580a73..9e6ce4b8a2 100644 --- a/fairseq/tasks/multilingual_masked_lm.py +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -8,12 +8,10 @@ import numpy as np import torch - +from fairseq import utils from fairseq.data import ( - data_utils, - Dictionary, - encoders, ConcatDataset, + Dictionary, IdDataset, MaskTokensDataset, NestedDictionaryDataset, @@ -25,45 +23,79 @@ ResamplingDataset, SortDataset, TokenBlockDataset, + data_utils, + encoders, ) -from fairseq.tasks import register_task, LegacyFairseqTask -from fairseq import utils +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('multilingual_masked_lm') +@register_task("multilingual_masked_lm") class MultiLingualMaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner') - parser.add_argument('--sample-break-mode', default='complete', - choices=['none', 'complete', 'complete_doc', 'eos'], - help='If omitted or "none", fills each sample with tokens-per-sample ' - 'tokens. If set to "complete", splits samples only at the end ' - 'of sentence, but may include multiple sentences per sample. ' - '"complete_doc" is similar but respects doc boundaries. ' - 'If set to "eos", includes only one sentence per sample.') - parser.add_argument('--tokens-per-sample', default=512, type=int, - help='max number of total tokens over all segments ' - 'per sample for BERT dataset') - parser.add_argument('--mask-prob', default=0.15, type=float, - help='probability of replacing a token with mask') - parser.add_argument('--leave-unmasked-prob', default=0.1, type=float, - help='probability that a masked token is unmasked') - parser.add_argument('--random-token-prob', default=0.1, type=float, - help='probability of replacing a token with a random token') - parser.add_argument('--freq-weighted-replacement', action='store_true', - help='sample random replacement words based on word frequencies') - parser.add_argument('--mask-whole-words', default=False, action='store_true', - help='mask whole words; you may also want to set --bpe') - parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0, - help='smoothing alpha for sample rations across multiple datasets') + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--sample-break-mode", + default="complete", + choices=["none", "complete", "complete_doc", "eos"], + help='If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.', + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.1, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.1, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--freq-weighted-replacement", + action="store_true", + help="sample random replacement words based on word frequencies", + ) + parser.add_argument( + "--mask-whole-words", + default=False, + action="store_true", + help="mask whole words; you may also want to set --bpe", + ) + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=1.0, + help="smoothing alpha for sample rations across multiple datasets", + ) def __init__(self, args, dictionary): super().__init__(args) @@ -71,14 +103,14 @@ def __init__(self, args, dictionary): self.seed = args.seed # add mask token - self.mask_idx = dictionary.add_symbol('') + self.mask_idx = dictionary.add_symbol("") @classmethod def setup_task(cls, args, **kwargs): paths = utils.split_paths(args.data) assert len(paths) > 0 - dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) - logger.info('dictionary: {} types'.format(len(dictionary))) + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def _get_whole_word_mask(self): @@ -92,16 +124,16 @@ def is_beginning_of_word(i): # special elements are always considered beginnings return True tok = self.source_dictionary[i] - if tok.startswith('madeupword'): + if tok.startswith("madeupword"): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True - mask_whole_words = torch.ByteTensor(list( - map(is_beginning_of_word, range(len(self.source_dictionary))) - )) + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(self.source_dictionary)))) + ) else: mask_whole_words = None return mask_whole_words @@ -127,14 +159,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path = paths[(epoch - 1) % len(paths)] languages = sorted( - name for name in os.listdir(data_path) + name + for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) - logger.info("Language to id mapping: ", { - lang: id for id, lang in enumerate(languages) - } + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = self._get_whole_word_mask() @@ -149,7 +181,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, ) if dataset is None: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) # create continuous blocks of tokens dataset = TokenBlockDataset( @@ -160,7 +194,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) - logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) @@ -180,50 +214,53 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): lang_dataset = NestedDictionaryDataset( { - 'net_input': { - 'src_tokens': PadDataset( + "net_input": { + "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), - 'src_lengths': NumelDataset(src_dataset, reduce=False), + "src_lengths": NumelDataset(src_dataset, reduce=False), }, - 'target': PadDataset( + "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(src_dataset, reduce=True), - 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) - dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( - 'loaded total {} blocks for all languages'.format( + "loaded total {} blocks for all languages".format( dataset_lengths.sum(), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) - logger.info("Sample probability by language: ", { + logger.info( + "Sample probability by language: ", + { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) - } + }, ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths - logger.info("Up/Down Sampling ratio by language: ", { + logger.info( + "Up/Down Sampling ratio by language: ", + { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) - } + }, ) resampled_lang_datasets = [ @@ -241,7 +278,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): - split_name = split + '_' + languages[lang_id] + split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset @@ -250,7 +287,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( - split, ','.join(lang_splits) + split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): @@ -272,7 +309,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): self.args.tokens_per_sample - 1, # one less for pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), - break_mode='eos', + break_mode="eos", ), pad_idx=self.source_dictionary.pad(), left_pad=False, @@ -280,10 +317,10 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( { - 'id': IdDataset(), - 'net_input': { - 'src_tokens': src_dataset, - 'src_lengths': NumelDataset(src_dataset, reduce=False), + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 161eb436ec..f6cb17f12a 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import OrderedDict +import contextlib import logging import os -from fairseq import options -import contextlib -import torch +from collections import OrderedDict -from fairseq import metrics, utils +import torch +from fairseq import metrics, options, utils from fairseq.data import ( Dictionary, LanguagePairDataset, @@ -20,24 +19,24 @@ from fairseq.models import FairseqMultiModel from fairseq.tasks.translation import load_langpair_dataset -from . import register_task, LegacyFairseqTask +from . import LegacyFairseqTask, register_task + logger = logging.getLogger(__name__) def _lang_token(lang: str): - return '__{}__'.format(lang) + return "__{}__".format(lang) def _lang_token_index(dic: Dictionary, lang: str): """Return language token index.""" idx = dic.index(_lang_token(lang)) - assert idx != dic.unk_index, \ - 'cannot find language token for lang {}'.format(lang) + assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) return idx -@register_task('multilingual_translation') +@register_task("multilingual_translation") class MultilingualTranslationTask(LegacyFairseqTask): """A task for training multiple translation models simultaneously. @@ -99,7 +98,7 @@ def __init__(self, args, dicts, training): if training: self.lang_pairs = args.lang_pairs else: - self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] # eval_lang_pairs for multilingual translation is usually all of the # lang_pairs. However for other multitask settings or when we want to # optimize for certain languages we want to use a different subset. Thus @@ -123,10 +122,14 @@ def prepare(cls, args, **kargs): args.left_pad_target = utils.eval_bool(args.left_pad_target) if args.lang_pairs is None: - raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.') + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) if isinstance(args.lang_pairs, str): - args.lang_pairs = args.lang_pairs.split(',') - sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})) + args.lang_pairs = args.lang_pairs.split(",") + sorted_langs = sorted( + list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) + ) if args.source_lang is not None or args.target_lang is not None: training = False else: @@ -137,7 +140,9 @@ def prepare(cls, args, **kargs): for lang in sorted_langs: paths = utils.split_paths(args.data) assert len(paths) > 0 - dicts[lang] = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) + dicts[lang] = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) if len(dicts) > 0: assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() @@ -145,13 +150,13 @@ def prepare(cls, args, **kargs): if args.encoder_langtok is not None or args.decoder_langtok: for lang_to_add in sorted_langs: dicts[lang].add_symbol(_lang_token(lang_to_add)) - logger.info('[{}] dictionary: {} types'.format(lang, len(dicts[lang]))) + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) return dicts, training def get_encoder_langtok(self, src_lang, tgt_lang): if self.args.encoder_langtok is None: return self.dicts[src_lang].eos() - if self.args.encoder_langtok == 'src': + if self.args.encoder_langtok == "src": return _lang_token_index(self.dicts[src_lang], src_lang) else: return _lang_token_index(self.dicts[src_lang], tgt_lang) @@ -161,14 +166,24 @@ def get_decoder_langtok(self, tgt_lang): return self.dicts[tgt_lang].eos() return _lang_token_index(self.dicts[tgt_lang], tgt_lang) - def alter_dataset_langtok(self, lang_pair_dataset, - src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None): + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + ): if self.args.encoder_langtok is None and not self.args.decoder_langtok: return lang_pair_dataset new_src_eos = None - if self.args.encoder_langtok is not None and src_eos is not None \ - and src_lang is not None and tgt_lang is not None: + if ( + self.args.encoder_langtok is not None + and src_eos is not None + and src_lang is not None + and tgt_lang is not None + ): new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) else: src_eos = None @@ -194,10 +209,16 @@ def load_dataset(self, split, epoch=1, **kwargs): data_path = paths[(epoch - 1) % len(paths)] def language_pair_dataset(lang_pair): - src, tgt = lang_pair.split('-') + src, tgt = lang_pair.split("-") langpair_dataset = load_langpair_dataset( - data_path, split, src, self.dicts[src], tgt, self.dicts[tgt], - combine=True, dataset_impl=self.args.dataset_impl, + data_path, + split, + src, + self.dicts[src], + tgt, + self.dicts[tgt], + combine=True, + dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, @@ -213,68 +234,100 @@ def language_pair_dataset(lang_pair): ) self.datasets[split] = RoundRobinZipDatasets( - OrderedDict([ - (lang_pair, language_pair_dataset(lang_pair)) - for lang_pair in self.lang_pairs - ]), - eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in self.lang_pairs + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: - raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported") + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) return RoundRobinZipDatasets( - OrderedDict([( - lang_pair, - self.alter_dataset_langtok( - LanguagePairDataset( - src_tokens, src_lengths, - self.source_dictionary - ), - src_eos=self.source_dictionary.eos(), - src_lang=self.args.source_lang, - tgt_eos=self.target_dictionary.eos(), - tgt_lang=self.args.target_lang, - ), - )]), + OrderedDict( + [ + ( + lang_pair, + self.alter_dataset_langtok( + LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary + ), + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + ), + ) + ] + ), eval_key=lang_pair, ) def build_model(self, args): def check_args(): messages = [] - if len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) != 0: - messages.append('--lang-pairs should include all the language pairs {}.'.format(args.lang_pairs)) + if ( + len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) + != 0 + ): + messages.append( + "--lang-pairs should include all the language pairs {}.".format( + args.lang_pairs + ) + ) if self.args.encoder_langtok != args.encoder_langtok: - messages.append('--encoder-langtok should be {}.'.format(args.encoder_langtok)) + messages.append( + "--encoder-langtok should be {}.".format(args.encoder_langtok) + ) if self.args.decoder_langtok != args.decoder_langtok: - messages.append('--decoder-langtok should {} be set.'.format("" if args.decoder_langtok else "not")) + messages.append( + "--decoder-langtok should {} be set.".format( + "" if args.decoder_langtok else "not" + ) + ) if len(messages) > 0: - raise ValueError(' '.join(messages)) + raise ValueError(" ".join(messages)) # Check if task args are consistant with model args check_args() from fairseq import models + model = models.build_model(args, self) if not isinstance(model, FairseqMultiModel): - raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture') + raise ValueError( + "MultilingualTranslationTask requires a FairseqMultiModel architecture" + ) return model - def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad): - loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) + def _per_lang_pair_train_loss( + self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad + ): + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) if ignore_grad: loss *= 0 optimizer.backward(loss) return loss, sample_size, logging_output - def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): model.train() from collections import defaultdict - agg_loss, agg_sample_size, agg_logging_output = 0., 0., defaultdict(float) + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) curr_lang_pairs = [ lang_pair for lang_pair in self.model_lang_pairs @@ -282,18 +335,27 @@ def train_step(self, sample, model, criterion, optimizer, update_num, ignore_gra ] for idx, lang_pair in enumerate(curr_lang_pairs): + def maybe_no_sync(): if ( self.args.distributed_world_size > 1 - and hasattr(model, 'no_sync') + and hasattr(model, "no_sync") and idx < len(curr_lang_pairs) - 1 ): return model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager + with maybe_no_sync(): loss, sample_size, logging_output = self._per_lang_pair_train_loss( - lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad) + lang_pair, + model, + update_num, + criterion, + sample, + optimizer, + ignore_grad, + ) agg_loss += loss.detach().item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size @@ -309,11 +371,18 @@ def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): from collections import defaultdict - agg_loss, agg_sample_size, agg_logging_output = 0., 0., defaultdict(float) + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) for lang_pair in self.eval_lang_pairs: - if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0: + if ( + lang_pair not in sample + or sample[lang_pair] is None + or len(sample[lang_pair]) == 0 + ): continue - loss, sample_size, logging_output = self._per_lang_pair_valid_loss(lang_pair, model, criterion, sample) + loss, sample_size, logging_output = self._per_lang_pair_valid_loss( + lang_pair, model, criterion, sample + ) agg_loss += loss.data.item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size @@ -322,10 +391,14 @@ def valid_step(self, sample, model, criterion): agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] return agg_loss, agg_sample_size, agg_logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): with torch.no_grad(): if self.args.decoder_langtok: - bos_token = _lang_token_index(self.target_dictionary, self.args.target_lang) + bos_token = _lang_token_index( + self.target_dictionary, self.args.target_lang + ) else: bos_token = self.target_dictionary.eos() return generator.generate( @@ -340,7 +413,7 @@ def reduce_metrics(self, logging_outputs, criterion): with metrics.aggregate(): # pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task super().reduce_metrics(logging_outputs, criterion) - for k in ['sample_size', 'nsentences', 'ntokens']: + for k in ["sample_size", "nsentences", "ntokens"]: metrics.log_scalar(k, sum(l[k] for l in logging_outputs)) @property @@ -360,10 +433,17 @@ def target_dictionary(self): def max_positions(self): """Return the max sentence length allowed by the task.""" if len(self.datasets.values()) == 0: - return {'%s-%s' % (self.args.source_lang, self.args.target_lang): - (self.args.max_source_positions, self.args.max_target_positions)} - return OrderedDict([ - (key, (self.args.max_source_positions, self.args.max_target_positions)) - for split in self.datasets.keys() - for key in self.datasets[split].datasets.keys() - ]) + return { + "%s-%s" + % (self.args.source_lang, self.args.target_lang): ( + self.args.max_source_positions, + self.args.max_target_positions, + ) + } + return OrderedDict( + [ + (key, (self.args.max_source_positions, self.args.max_target_positions)) + for split in self.datasets.keys() + for key in self.datasets[split].datasets.keys() + ] + ) diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index c81d362886..b2f9bf9a73 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -3,27 +3,28 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import OrderedDict import logging import os +from collections import OrderedDict +from fairseq import utils from fairseq.data import ( BacktranslationDataset, - data_utils, - indexed_dataset, IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, LanguagePairDataset, NoisingDataset, RoundRobinZipDatasets, + data_utils, + indexed_dataset, ) from fairseq.models import FairseqMultiModel from fairseq.sequence_generator import SequenceGenerator -from .multilingual_translation import MultilingualTranslationTask from . import register_task -from fairseq import utils +from .multilingual_translation import MultilingualTranslationTask + logger = logging.getLogger(__name__) @@ -46,18 +47,20 @@ def parse_lambda_config(x): x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 # iterations, then will linearly increase to 1 until iteration 2000 """ - split = x.split(',') + split = x.split(",") if len(split) == 1: return float(x), None else: split = [s.split(os.pathsep) for s in split] assert all(len(s) == 2 for s in split) assert all(k.isdigit() for k, _ in split) - assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)) + assert all( + int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) + ) return float(split[0][1]), [(int(k), float(v)) for k, v in split] -@register_task('semisupervised_translation') +@register_task("semisupervised_translation") class SemisupervisedTranslationTask(MultilingualTranslationTask): """A task for training multiple translation models simultaneously. @@ -119,13 +122,19 @@ def add_args(parser): def __init__(self, args, dicts, training): super().__init__(args, dicts, training) - self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(args.lambda_parallel_config) - self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(args.lambda_otf_bt_config) - self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(args.lambda_denoising_config) - if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None): + self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( + args.lambda_parallel_config + ) + self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( + args.lambda_otf_bt_config + ) + self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( + args.lambda_denoising_config + ) + if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: denoising_lang_pairs = [ "%s-%s" % (tgt, tgt) - for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs} + for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} ] self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs self.backtranslate_datasets = {} @@ -144,39 +153,71 @@ def load_dataset(self, split, epoch=1, **kwargs): def split_exists(split, src, tgt, lang): if src is not None: - filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join( + data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) + ) else: - filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, src, tgt) + ) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): - return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl) + return data_utils.load_indexed_dataset( + path, dictionary, self.args.dataset_impl + ) # load parallel datasets src_datasets, tgt_datasets = {}, {} - if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): + if ( + self.lambda_parallel > 0.0 + or self.lambda_parallel_steps is not None + or not split.startswith("train") + ): for lang_pair in self.lang_pairs: - src, tgt = lang_pair.split('-') + src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, src, tgt) + ) elif split_exists(split, tgt, src, src): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, tgt, src) + ) else: continue - src_datasets[lang_pair] = load_indexed_dataset(prefix + src, self.dicts[src]) - tgt_datasets[lang_pair] = load_indexed_dataset(prefix + tgt, self.dicts[tgt]) - logger.info('parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) + src_datasets[lang_pair] = load_indexed_dataset( + prefix + src, self.dicts[src] + ) + tgt_datasets[lang_pair] = load_indexed_dataset( + prefix + tgt, self.dicts[tgt] + ) + logger.info( + "parallel-{} {} {} examples".format( + data_path, split, len(src_datasets[lang_pair]) + ) + ) if len(src_datasets) == 0: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) # back translation datasets backtranslate_datasets = {} - if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and split.startswith("train"): for lang_pair in self.lang_pairs: - src, tgt = lang_pair.split('-') + src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): - raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) - filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) + raise FileNotFoundError( + "Dataset not found: backtranslation {} ({})".format( + split, data_path + ) + ) + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, @@ -203,7 +244,8 @@ def load_indexed_dataset(path, dictionary): tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], - src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], + src_dict=self.dicts[src], + tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), @@ -212,19 +254,30 @@ def load_indexed_dataset(path, dictionary): tgt_lang=tgt, ).collater, ) - logger.info('backtranslate-{}: {} {} {} examples'.format( - tgt, data_path, split, len(backtranslate_datasets[lang_pair]), - )) - self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] + logger.info( + "backtranslate-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(backtranslate_datasets[lang_pair]), + ) + ) + self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ + lang_pair + ] # denoising autoencoder noising_datasets = {} - if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): + if ( + self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None + ) and split.startswith("train"): for lang_pair in self.lang_pairs: - _, tgt = lang_pair.split('-') + _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue - filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( @@ -251,17 +304,26 @@ def load_indexed_dataset(path, dictionary): tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) - logger.info('denoising-{}: {} {} {} examples'.format( - tgt, data_path, split, len(noising_datasets[lang_pair]), - )) + logger.info( + "denoising-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(noising_datasets[lang_pair]), + ) + ) def language_pair_dataset(lang_pair): - src, tgt = lang_pair.split('-') + src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( - src_dataset, src_dataset.sizes, self.dicts[src], - tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], + src_dataset, + src_dataset.sizes, + self.dicts[src], + tgt_dataset, + tgt_dataset.sizes, + self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), @@ -272,31 +334,42 @@ def language_pair_dataset(lang_pair): ) self.datasets[split] = RoundRobinZipDatasets( - OrderedDict([ - (lang_pair, language_pair_dataset(lang_pair)) - for lang_pair in src_datasets.keys() - ] + [ - (_get_bt_dataset_key(lang_pair), dataset) - for lang_pair, dataset in backtranslate_datasets.items() - ] + [ - (_get_denoising_dataset_key(lang_pair), dataset) - for lang_pair, dataset in noising_datasets.items() - ]), - eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in src_datasets.keys() + ] + + [ + (_get_bt_dataset_key(lang_pair), dataset) + for lang_pair, dataset in backtranslate_datasets.items() + ] + + [ + (_get_denoising_dataset_key(lang_pair), dataset) + for lang_pair, dataset in noising_datasets.items() + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), ) def build_model(self, args): from fairseq import models + model = models.build_model(args, self) if not isinstance(model, FairseqMultiModel): - raise ValueError('SemisupervisedTranslationTask requires a FairseqMultiModel architecture') + raise ValueError( + "SemisupervisedTranslationTask requires a FairseqMultiModel architecture" + ) # create SequenceGenerator for each model that has backtranslation dependency on it self.sequence_generators = {} - if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training: + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and self.training: for lang_pair in self.lang_pairs: - src, tgt = lang_pair.split('-') - key = '{}-{}'.format(tgt, src) + src, tgt = lang_pair.split("-") + key = "{}-{}".format(tgt, src) self.sequence_generators[key] = SequenceGenerator( [model.models[key]], tgt_dict=self.dicts[src], @@ -307,7 +380,8 @@ def build_model(self, args): decoder_lang_tok_idx = self.get_decoder_langtok(src) def backtranslate_fn( - sample, model=model.models[key], + sample, + model=model.models[key], bos_token=decoder_lang_tok_idx, sequence_generator=self.sequence_generators[key], ): @@ -316,17 +390,20 @@ def backtranslate_fn( sample, bos_token=bos_token, ) + self.backtranslators[lang_pair] = backtranslate_fn return model - def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): model.train() if update_num > 0: self.update_step(update_num) - agg_loss, agg_sample_size, agg_logging_output = 0., 0., {} + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} def forward_backward(model, samples, logging_output_key, weight): nonlocal agg_loss, agg_sample_size, agg_logging_output @@ -347,18 +424,33 @@ def forward_backward(model, samples, logging_output_key, weight): if self.lambda_parallel > 0.0: for lang_pair in self.lang_pairs: - forward_backward(model.models[lang_pair], sample[lang_pair], lang_pair, self.lambda_parallel) + forward_backward( + model.models[lang_pair], + sample[lang_pair], + lang_pair, + self.lambda_parallel, + ) if self.lambda_otf_bt > 0.0: for lang_pair in self.lang_pairs: sample_key = _get_bt_dataset_key(lang_pair) - forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt) + forward_backward( + model.models[lang_pair], + sample[sample_key], + sample_key, + self.lambda_otf_bt, + ) if self.lambda_denoising > 0.0: for lang_pair in self.lang_pairs: - _, tgt = lang_pair.split('-') + _, tgt = lang_pair.split("-") sample_key = _get_denoising_dataset_key(lang_pair) - forward_backward(model.models['{0}-{0}'.format(tgt)], sample[sample_key], sample_key, self.lambda_denoising) + forward_backward( + model.models["{0}-{0}".format(tgt)], + sample[sample_key], + sample_key, + self.lambda_denoising, + ) return agg_loss, agg_sample_size, agg_logging_output @@ -367,7 +459,11 @@ def lambda_step_func(config, n_iter): """ Update a lambda value according to its schedule configuration. """ - ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]] + ranges = [ + i + for i in range(len(config) - 1) + if config[i][0] <= n_iter < config[i + 1][0] + ] if len(ranges) == 0: assert n_iter >= config[-1][0] return config[-1][1] @@ -378,8 +474,12 @@ def lambda_step_func(config, n_iter): return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) if self.lambda_parallel_steps is not None: - self.lambda_parallel = lambda_step_func(self.lambda_parallel_steps, num_updates) + self.lambda_parallel = lambda_step_func( + self.lambda_parallel_steps, num_updates + ) if self.lambda_denoising_steps is not None: - self.lambda_denoising = lambda_step_func(self.lambda_denoising_steps, num_updates) + self.lambda_denoising = lambda_step_func( + self.lambda_denoising_steps, num_updates + ) if self.lambda_otf_bt_steps is not None: self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index d9a82faddd..69dc996e6a 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -7,16 +7,14 @@ import os import numpy as np - from fairseq import utils from fairseq.data import ( ConcatSentencesDataset, - data_utils, Dictionary, IdDataset, NestedDictionaryDataset, - NumSamplesDataset, NumelDataset, + NumSamplesDataset, OffsetTokensDataset, PrependTokenDataset, RawLabelDataset, @@ -24,15 +22,16 @@ RollDataset, SortDataset, StripTokenDataset, + data_utils, ) -from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('sentence_prediction') +@register_task("sentence_prediction") class SentencePredictionTask(LegacyFairseqTask): """ Sentence (or sentence pair) prediction (classification or regression) task. @@ -44,30 +43,51 @@ class SentencePredictionTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', metavar='FILE', - help='file prefix for data') - parser.add_argument('--num-classes', type=int, default=-1, - help='number of classes or regression targets') - parser.add_argument('--init-token', type=int, default=None, - help='add token at the beginning of each batch item') - parser.add_argument('--separator-token', type=int, default=None, - help='add separator token between inputs') - parser.add_argument('--regression-target', action='store_true', default=False) - parser.add_argument('--no-shuffle', action='store_true', default=False) - parser.add_argument('--shorten-method', default='none', - choices=['none', 'truncate', 'random_crop'], - help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-list', default='', - help='comma-separated list of dataset splits to apply shortening to, ' - 'e.g., "train,valid" (default: all dataset splits)') - parser.add_argument('--add-prev-output-tokens', action='store_true', default=False, - help='add prev_output_tokens to sample, used for encoder-decoder arch') + parser.add_argument("data", metavar="FILE", help="file prefix for data") + parser.add_argument( + "--num-classes", + type=int, + default=-1, + help="number of classes or regression targets", + ) + parser.add_argument( + "--init-token", + type=int, + default=None, + help="add token at the beginning of each batch item", + ) + parser.add_argument( + "--separator-token", + type=int, + default=None, + help="add separator token between inputs", + ) + parser.add_argument("--regression-target", action="store_true", default=False) + parser.add_argument("--no-shuffle", action="store_true", default=False) + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + parser.add_argument( + "--add-prev-output-tokens", + action="store_true", + default=False, + help="add prev_output_tokens to sample, used for encoder-decoder arch", + ) def __init__(self, args, data_dictionary, label_dictionary): super().__init__(args) self.dictionary = data_dictionary self._label_dictionary = label_dictionary - if not hasattr(args, 'max_positions'): + if not hasattr(args, "max_positions"): self._max_positions = ( args.max_source_positions, args.max_target_positions, @@ -84,36 +104,37 @@ def load_dictionary(cls, args, filename, source=True): filename (str): the filename """ dictionary = Dictionary.load(filename) - dictionary.add_symbol('') + dictionary.add_symbol("") return dictionary @classmethod def setup_task(cls, args, **kwargs): - assert args.num_classes > 0, 'Must set --num-classes' + assert args.num_classes > 0, "Must set --num-classes" # load data dictionary data_dict = cls.load_dictionary( args, - os.path.join(args.data, 'input0', 'dict.txt'), + os.path.join(args.data, "input0", "dict.txt"), source=True, ) - logger.info('[input] dictionary: {} types'.format(len(data_dict))) + logger.info("[input] dictionary: {} types".format(len(data_dict))) label_dict = None if not args.regression_target: # load label dictionary label_dict = cls.load_dictionary( args, - os.path.join(args.data, 'label', 'dict.txt'), + os.path.join(args.data, "label", "dict.txt"), source=False, ) - logger.info('[label] dictionary: {} types'.format(len(label_dict))) + logger.info("[label] dictionary: {} types".format(len(label_dict))) else: label_dict = data_dict return cls(args, data_dict, label_dict) def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" + def get_path(type, split): return os.path.join(self.args.data, type, split) @@ -128,9 +149,11 @@ def make_dataset(type, dictionary): ) return dataset - input0 = make_dataset('input0', self.source_dictionary) - assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split)) - input1 = make_dataset('input1', self.source_dictionary) + input0 = make_dataset("input0", self.source_dictionary) + assert input0 is not None, "could not find dataset: {}".format( + get_path(type, split) + ) + input1 = make_dataset("input1", self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) @@ -156,16 +179,16 @@ def make_dataset(type, dictionary): ) dataset = { - 'id': IdDataset(), - 'net_input': { - 'src_tokens': RightPadDataset( + "id": IdDataset(), + "net_input": { + "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), - 'src_lengths': NumelDataset(src_tokens, reduce=False), + "src_lengths": NumelDataset(src_tokens, reduce=False), }, - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(src_tokens, reduce=True), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: @@ -173,12 +196,12 @@ def make_dataset(type, dictionary): RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) - dataset['net_input'].update( + dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: - label_dataset = make_dataset('label', self.label_dictionary) + label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: dataset.update( target=OffsetTokensDataset( @@ -190,21 +213,24 @@ def make_dataset(type, dictionary): ) ) else: - label_path = "{0}.label".format(get_path('label', split)) + label_path = "{0}.label".format(get_path("label", split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() - assert len(values) == self.args.num_classes, \ - f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' + assert ( + len(values) == self.args.num_classes + ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] with open(label_path) as h: dataset.update( - target=RawLabelDataset([ - parse_regression_target(i, line.strip()) - for i, line in enumerate(h.readlines()) - ]) + target=RawLabelDataset( + [ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ] + ) ) nested_dataset = NestedDictionaryDataset( @@ -228,10 +254,11 @@ def parse_regression_target(i, line): def build_model(self, args): from fairseq import models + model = models.build_model(args, self) model.register_classification_head( - getattr(args, 'classification_head_name', 'sentence_classification_head'), + getattr(args, "classification_head_name", "sentence_classification_head"), num_classes=self.args.num_classes, ) diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index a1d332a3ca..bed44f34e5 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -7,30 +7,29 @@ import os import numpy as np - from fairseq import utils from fairseq.data import ( ConcatSentencesDataset, - data_utils, Dictionary, IdDataset, NestedDictionaryDataset, - NumSamplesDataset, NumelDataset, + NumSamplesDataset, PrependTokenDataset, RawLabelDataset, RightPadDataset, SortDataset, - TruncateDataset + TruncateDataset, + data_utils, ) -from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) -@register_task('sentence_ranking') +@register_task("sentence_ranking") class SentenceRankingTask(LegacyFairseqTask): """ Ranking task on multiple sentences. @@ -42,23 +41,34 @@ class SentenceRankingTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', metavar='FILE', - help='file prefix for data') - parser.add_argument('--num-classes', type=int, - help='number of sentences to be ranked') - parser.add_argument('--init-token', type=int, - help='add token at the beginning of each batch item') - parser.add_argument('--separator-token', type=int, - help='add separator token between inputs') - parser.add_argument('--no-shuffle', action='store_true') - parser.add_argument('--shorten-method', default='none', - choices=['none', 'truncate', 'random_crop'], - help='if not none, shorten sequences that exceed --tokens-per-sample') - parser.add_argument('--shorten-data-split-list', default='', - help='comma-separated list of dataset splits to apply shortening to, ' - 'e.g., "train,valid" (default: all dataset splits)') - parser.add_argument('--max-option-length', type=int, - help='max length for each option') + parser.add_argument("data", metavar="FILE", help="file prefix for data") + parser.add_argument( + "--num-classes", type=int, help="number of sentences to be ranked" + ) + parser.add_argument( + "--init-token", + type=int, + help="add token at the beginning of each batch item", + ) + parser.add_argument( + "--separator-token", type=int, help="add separator token between inputs" + ) + parser.add_argument("--no-shuffle", action="store_true") + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + parser.add_argument( + "--max-option-length", type=int, help="max length for each option" + ) def __init__(self, args, dictionary): super().__init__(args) @@ -72,21 +82,22 @@ def load_dictionary(cls, args, filename, source=True): filename (str): the filename """ dictionary = Dictionary.load(filename) - dictionary.add_symbol('') + dictionary.add_symbol("") return dictionary @classmethod def setup_task(cls, args, **kwargs): - assert args.criterion == 'sentence_ranking', \ - 'Must set --criterion=sentence_ranking' + assert ( + args.criterion == "sentence_ranking" + ), "Must set --criterion=sentence_ranking" # load data dictionary data_dict = cls.load_dictionary( args, - os.path.join(args.data, 'input0', 'dict.txt'), + os.path.join(args.data, "input0", "dict.txt"), source=True, ) - logger.info('[input] dictionary: {} types'.format(len(data_dict))) + logger.info("[input] dictionary: {} types".format(len(data_dict))) return SentenceRankingTask(args, data_dict) def load_dataset(self, split, combine=False, **kwargs): @@ -106,12 +117,9 @@ def make_dataset(type, dictionary): ) return dataset - input0 = make_dataset('input0', self.source_dictionary) + input0 = make_dataset("input0", self.source_dictionary) input_options = [ - make_dataset( - 'input{idx}'.format(idx=idx + 1), - self.source_dictionary - ) + make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary) for idx in range(self.args.num_classes) ] @@ -123,7 +131,9 @@ def make_dataset(type, dictionary): if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: - input_option = TruncateDataset(input_option, self.args.max_option_length) + input_option = TruncateDataset( + input_option, self.args.max_option_length + ) src_token = ConcatSentencesDataset(input_option, input0) src_token = maybe_shorten_dataset( src_token, @@ -139,31 +149,31 @@ def make_dataset(type, dictionary): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { - 'id': IdDataset(), - 'nsentences': NumSamplesDataset(), - 'ntokens': NumelDataset(src_tokens[0], reduce=True), + "id": IdDataset(), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update( { - 'net_input{idx}'.format(idx=src_token_idx+1): { - 'src_tokens': RightPadDataset( + "net_input{idx}".format(idx=src_token_idx + 1): { + "src_tokens": RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), - 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), + "src_lengths": NumelDataset( + src_tokens[src_token_idx], reduce=False + ), } } ) - label_path = '{}.label'.format(get_path('label', split)) + label_path = "{}.label".format(get_path("label", split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update( - target=RawLabelDataset([ - int(x.strip()) for x in h.readlines() - ]) + target=RawLabelDataset([int(x.strip()) for x in h.readlines()]) ) nested_dataset = NestedDictionaryDataset( @@ -187,10 +197,11 @@ def make_dataset(type, dictionary): def build_model(self, args): from fairseq import models + model = models.build_model(args, self) model.register_classification_head( - getattr(args, 'ranking_head_name', 'sentence_classification_head'), + getattr(args, "ranking_head_name", "sentence_classification_head"), num_classes=1, ) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index b17ad22602..6d222f0de3 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -4,38 +4,51 @@ # LICENSE file in the root directory of this source tree. import logging -from argparse import Namespace import os.path as op +from argparse import Namespace -from fairseq.data import encoders, Dictionary +from fairseq.data import Dictionary, encoders from fairseq.data.audio.speech_to_text_dataset import ( - SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig + S2TDataConfig, + SpeechToTextDataset, + SpeechToTextDatasetCreator, ) from fairseq.tasks import FairseqTask, register_task + logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO, - ) + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) logger = logging.getLogger(__name__) -@register_task('speech_to_text') +@register_task("speech_to_text") class SpeechToTextTask(FairseqTask): @staticmethod def add_args(parser): - parser.add_argument('data', help='manifest root path') + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=6000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) parser.add_argument( - '--config-yaml', type=str, default='config.yaml', - help='Configuration YAML filename (under manifest root)' + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", ) - parser.add_argument('--max-source-positions', default=6000, type=int, - metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, - metavar='N', - help='max number of tokens in the target sequence') def __init__(self, args, tgt_dict): super().__init__(args) @@ -47,31 +60,41 @@ def setup_task(cls, args, **kwargs): data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) dict_path = op.join(args.data, data_cfg.vocab_filename) if not op.isfile(dict_path): - raise FileNotFoundError(f'Dict not found: {dict_path}') + raise FileNotFoundError(f"Dict not found: {dict_path}") tgt_dict = Dictionary.load(dict_path) - logger.info(f'dictionary size ({data_cfg.vocab_filename}): ' - f'{len(tgt_dict):,}') + logger.info( + f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" + ) - if getattr(args, 'train_subset', None) is not None: - if not all(s.startswith('train') for s in args.train_subset.split(',')): + if getattr(args, "train_subset", None) is not None: + if not all(s.startswith("train") for s in args.train_subset.split(",")): raise ValueError('Train splits should be named like "train*".') return cls(args, tgt_dict) def build_criterion(self, args): from fairseq import criterions + if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: - raise ValueError('Please set "--ignore-prefix-size 1" since ' - 'target language ID token is prepended as BOS.') + raise ValueError( + 'Please set "--ignore-prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) return criterions.build_criterion(args, self) def load_dataset(self, split, epoch=1, combine=False, **kwargs): - is_train_split = split.startswith('train') + is_train_split = split.startswith("train") pre_tokenizer = self.build_tokenizer(self.args) bpe_tokenizer = self.build_bpe(self.args) self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( - self.args.data, self.data_cfg, split, self.tgt_dict, - pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split, - epoch=epoch, seed=self.args.seed + self.args.data, + self.data_cfg, + split, + self.tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, ) @property @@ -91,30 +114,35 @@ def build_model(self, args): return super(SpeechToTextTask, self).build_model(args) def build_generator( - self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, ): if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: - raise ValueError('Please set "--prefix-size 1" since ' - 'target language ID token is prepended as BOS.') + raise ValueError( + 'Please set "--prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) lang_token_ids = { - i for s, i in self.tgt_dict.indices.items() + i + for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } - extra_gen_cls_kwargs = {'symbols_to_strip_from_output': lang_token_ids} + extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} return super().build_generator( - models, args, seq_gen_cls=None, - extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_tokenizer(self, args): - logger.info(f'pre-tokenizer: {self.data_cfg.pre_tokenizer}') + logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) def build_bpe(self, args): - logger.info(f'tokenizer: {self.data_cfg.bpe_tokenizer}') + logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) @classmethod def build_dataset_for_inference(cls, audio_paths, n_frames): - return SpeechToTextDataset('interactive', False, {}, audio_paths, - n_frames) + return SpeechToTextDataset("interactive", False, {}, audio_paths, n_frames) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index a04924605c..79007a6d9f 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -3,28 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from argparse import Namespace -import json import itertools +import json import logging import os -from fairseq import options -import numpy as np +from argparse import Namespace -from fairseq import metrics, utils +import numpy as np +from fairseq import metrics, options, utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, - data_utils, - encoders, - indexed_dataset, LanguagePairDataset, PrependTokenDataset, StripTokenDataset, TruncateDataset, + data_utils, + encoders, + indexed_dataset, ) +from fairseq.tasks import LegacyFairseqTask, register_task -from fairseq.tasks import register_task, LegacyFairseqTask EVAL_BLEU_ORDER = 4 @@ -33,40 +32,53 @@ def load_langpair_dataset( - data_path, split, - src, src_dict, - tgt, tgt_dict, - combine, dataset_impl, upsample_primary, - left_pad_source, left_pad_target, max_source_positions, - max_target_positions, prepend_bos=False, load_alignments=False, - truncate_source=False, append_source_id=False, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, ): - def split_exists(split, src, tgt, lang, data_path): - filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) - src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) + src_dataset = data_utils.load_indexed_dataset( + prefix + src, src_dict, dataset_impl + ) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( @@ -77,13 +89,17 @@ def split_exists(split, src, tgt, lang, data_path): ) src_datasets.append(src_dataset) - tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) + tgt_dataset = data_utils.load_indexed_dataset( + prefix + tgt, tgt_dict, dataset_impl + ) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) - logger.info('{} {} {}-{} {} examples'.format( - data_path, split_k, src, tgt, len(src_datasets[-1]) - )) + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) if not combine: break @@ -110,31 +126,42 @@ def split_exists(split, src, tgt, lang, data_path): eos = None if append_source_id: - src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) + src_dataset = AppendTokenDataset( + src_dataset, src_dict.index("[{}]".format(src)) + ) if tgt_dataset is not None: - tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) - eos = tgt_dict.index('[{}]'.format(tgt)) + tgt_dataset = AppendTokenDataset( + tgt_dataset, tgt_dict.index("[{}]".format(tgt)) + ) + eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: - align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) + align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): - align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( - src_dataset, src_dataset.sizes, src_dict, - tgt_dataset, tgt_dataset_sizes, tgt_dict, + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset_sizes, + tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, - align_dataset=align_dataset, eos=eos, + align_dataset=align_dataset, + eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, ) -@register_task('translation') +@register_task("translation") class TranslationTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. @@ -227,18 +254,26 @@ def setup_task(cls, args, **kwargs): assert len(paths) > 0 # find language pair automatically if args.source_lang is None or args.target_lang is None: - args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0]) + args.source_lang, args.target_lang = data_utils.infer_language_pair( + paths[0] + ) if args.source_lang is None or args.target_lang is None: - raise Exception('Could not infer language pair, please provide it explicitly') + raise Exception( + "Could not infer language pair, please provide it explicitly" + ) # load dictionaries - src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang))) - tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang))) + src_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(args.source_lang)) + ) + tgt_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(args.target_lang)) + ) assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() - logger.info('[{}] dictionary: {} types'.format(args.source_lang, len(src_dict))) - logger.info('[{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict))) + logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict))) + logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict))) return cls(args, src_dict, tgt_dict) @@ -259,8 +294,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( - data_path, split, src, self.src_dict, tgt, self.tgt_dict, - combine=combine, dataset_impl=self.args.dataset_impl, + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, @@ -269,45 +310,52 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, num_buckets=self.args.num_batch_buckets, - shuffle=(split != 'test'), + shuffle=(split != "test"), pad_to_multiple=self.args.required_seq_len_multiple, ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): - return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, - tgt_dict=self.target_dictionary, - constraints=constraints) + return LanguagePairDataset( + src_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) def build_model(self, args): model = super().build_model(args) - if getattr(args, 'eval_bleu', False): - assert getattr(args, 'eval_bleu_detok', None) is not None, ( - '--eval-bleu-detok is required if using --eval-bleu; ' - 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' - 'to disable detokenization, e.g., when using sentencepiece)' + if getattr(args, "eval_bleu", False): + assert getattr(args, "eval_bleu_detok", None) is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") + self.tokenizer = encoders.build_tokenizer( + Namespace( + tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args + ) + ) + + gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") + self.sequence_generator = self.build_generator( + [model], Namespace(**gen_args) ) - detok_args = json.loads(getattr(args, 'eval_bleu_detok_args', '{}') or '{}') - self.tokenizer = encoders.build_tokenizer(Namespace( - tokenizer=getattr(args, 'eval_bleu_detok', None), - **detok_args - )) - - gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}') - self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) return model def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) if self.args.eval_bleu: bleu = self._inference_with_bleu(self.sequence_generator, sample, model) - logging_output['_bleu_sys_len'] = bleu.sys_len - logging_output['_bleu_ref_len'] = bleu.ref_len + logging_output["_bleu_sys_len"] = bleu.sys_len + logging_output["_bleu_ref_len"] = bleu.ref_len # we split counts into separate entries so that they can be # summed efficiently across workers using fast-stat-sync assert len(bleu.counts) == EVAL_BLEU_ORDER for i in range(EVAL_BLEU_ORDER): - logging_output['_bleu_counts_' + str(i)] = bleu.counts[i] - logging_output['_bleu_totals_' + str(i)] = bleu.totals[i] + logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] + logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] return loss, sample_size, logging_output def reduce_metrics(self, logging_outputs, criterion): @@ -319,34 +367,35 @@ def sum_logs(key): counts, totals = [], [] for i in range(EVAL_BLEU_ORDER): - counts.append(sum_logs('_bleu_counts_' + str(i))) - totals.append(sum_logs('_bleu_totals_' + str(i))) + counts.append(sum_logs("_bleu_counts_" + str(i))) + totals.append(sum_logs("_bleu_totals_" + str(i))) if max(totals) > 0: # log counts as numpy arrays -- log_scalar will sum them correctly - metrics.log_scalar('_bleu_counts', np.array(counts)) - metrics.log_scalar('_bleu_totals', np.array(totals)) - metrics.log_scalar('_bleu_sys_len', sum_logs('_bleu_sys_len')) - metrics.log_scalar('_bleu_ref_len', sum_logs('_bleu_ref_len')) + metrics.log_scalar("_bleu_counts", np.array(counts)) + metrics.log_scalar("_bleu_totals", np.array(totals)) + metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) + metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) def compute_bleu(meters): import inspect import sacrebleu + fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] - if 'smooth_method' in fn_sig: - smooth = {'smooth_method': 'exp'} + if "smooth_method" in fn_sig: + smooth = {"smooth_method": "exp"} else: - smooth = {'smooth': 'exp'} + smooth = {"smooth": "exp"} bleu = sacrebleu.compute_bleu( - correct=meters['_bleu_counts'].sum, - total=meters['_bleu_totals'].sum, - sys_len=meters['_bleu_sys_len'].sum, - ref_len=meters['_bleu_ref_len'].sum, + correct=meters["_bleu_counts"].sum, + total=meters["_bleu_totals"].sum, + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, **smooth ) return round(bleu.score, 2) - metrics.log_derived('bleu', compute_bleu) + metrics.log_derived("bleu", compute_bleu) def max_positions(self): """Return the max sentence length allowed by the task.""" @@ -374,9 +423,7 @@ def decode(toks, escape_unk=False): # BLEU scores. Instead, we use a somewhat more verbose # alternative that is unlikely to appear in the real # reference, but doesn't get split into multiple tokens. - unk_string=( - "UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP" - ), + unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), ) if self.tokenizer: s = self.tokenizer.decode(s) @@ -385,15 +432,17 @@ def decode(toks, escape_unk=False): gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) hyps, refs = [], [] for i in range(len(gen_out)): - hyps.append(decode(gen_out[i][0]['tokens'])) - refs.append(decode( - utils.strip_pad(sample['target'][i], self.tgt_dict.pad()), - escape_unk=True, # don't count as matches to the hypo - )) + hyps.append(decode(gen_out[i][0]["tokens"])) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), + escape_unk=True, # don't count as matches to the hypo + ) + ) if self.args.eval_bleu_print_samples: - logger.info('example hypothesis: ' + hyps[0]) - logger.info('example reference: ' + refs[0]) + logger.info("example hypothesis: " + hyps[0]) + logger.info("example reference: " + refs[0]) if self.args.eval_tokenized_bleu: - return sacrebleu.corpus_bleu(hyps, [refs], tokenize='none') + return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") else: return sacrebleu.corpus_bleu(hyps, [refs]) diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 4d574ffc82..8710b7fe7d 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -4,15 +4,14 @@ # LICENSE file in the root directory of this source tree. import torch - -from fairseq.data import LanguagePairDataset from fairseq import utils +from fairseq.data import LanguagePairDataset -from .translation import load_langpair_dataset, TranslationTask from . import register_task +from .translation import TranslationTask, load_langpair_dataset -@register_task('translation_from_pretrained_bart') +@register_task("translation_from_pretrained_bart") class TranslationFromPretrainedBARTTask(TranslationTask): """ Translate from source language to target language with a model initialized with a multilingual pretrain. @@ -52,11 +51,11 @@ def add_args(parser): def __init__(self, args, src_dict, tgt_dict): super().__init__(args, src_dict, tgt_dict) - self.langs = args.langs.split(',') + self.langs = args.langs.split(",") for d in [src_dict, tgt_dict]: for l in self.langs: - d.add_symbol('[{}]'.format(l)) - d.add_symbol('') + d.add_symbol("[{}]".format(l)) + d.add_symbol("") def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -72,50 +71,62 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( - data_path, split, src, self.src_dict, tgt, self.tgt_dict, - combine=combine, dataset_impl=self.args.dataset_impl, + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, - max_source_positions=getattr(self.args, 'max_source_positions', 1024), - max_target_positions=getattr(self.args, 'max_target_positions', 1024), + max_source_positions=getattr(self.args, "max_source_positions", 1024), + max_target_positions=getattr(self.args, "max_target_positions", 1024), load_alignments=self.args.load_alignments, - prepend_bos=getattr(self.args, 'prepend_bos', False), - append_source_id=True - ) + prepend_bos=getattr(self.args, "prepend_bos", False), + append_source_id=True, + ) def build_generator(self, models, args, **unused): - if getattr(args, 'score_reference', False): + if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer + return SequenceScorer( self.target_dictionary, - eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang)) + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), ) else: from fairseq.sequence_generator import SequenceGenerator + return SequenceGenerator( models, self.target_dictionary, - beam_size=getattr(args, 'beam', 5), - max_len_a=getattr(args, 'max_len_a', 0), - max_len_b=getattr(args, 'max_len_b', 200), - min_len=getattr(args, 'min_len', 1), - normalize_scores=(not getattr(args, 'unnormalized', False)), - len_penalty=getattr(args, 'lenpen', 1), - unk_penalty=getattr(args, 'unkpen', 0), - temperature=getattr(args, 'temperature', 1.), - match_source_len=getattr(args, 'match_source_len', False), - no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), - eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang)) + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): - src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.source_lang)) + src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) source_tokens = [] for s_t in src_tokens: s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) source_tokens.append(s_t) - dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary, - tgt_dict=self.target_dictionary, - constraints=constraints) + dataset = LanguagePairDataset( + source_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) return dataset diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 3af9bb2532..4678774922 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -6,15 +6,14 @@ import os import torch - +from fairseq import utils from fairseq.data import LanguagePairDataset - -from fairseq.utils import new_arange from fairseq.tasks import register_task from fairseq.tasks.translation import TranslationTask, load_langpair_dataset -from fairseq import utils +from fairseq.utils import new_arange + -@register_task('translation_lev') +@register_task("translation_lev") class TranslationLevenshteinTask(TranslationTask): """ Translation (Sequence Generation) task for Levenshtein Transformer @@ -46,8 +45,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( - data_path, split, src, self.src_dict, tgt, self.tgt_dict, - combine=combine, dataset_impl=self.args.dataset_impl, + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, @@ -66,24 +71,32 @@ def _random_delete(target_tokens): target_mask = target_tokens.eq(pad) target_score = target_tokens.clone().float().uniform_() target_score.masked_fill_( - target_tokens.eq(bos) | target_tokens.eq(eos), 0.0) + target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 + ) target_score.masked_fill_(target_mask, 1) target_score, target_rank = target_score.sort(1) target_length = target_mask.size(1) - target_mask.float().sum( - 1, keepdim=True) + 1, keepdim=True + ) # do not delete and (we assign 0 score for them) - target_cutoff = 2 + ((target_length - 2) * target_score.new_zeros( - target_score.size(0), 1).uniform_()).long() + target_cutoff = ( + 2 + + ( + (target_length - 2) + * target_score.new_zeros(target_score.size(0), 1).uniform_() + ).long() + ) target_cutoff = target_score.sort(1)[1] >= target_cutoff - prev_target_tokens = target_tokens.gather( - 1, target_rank).masked_fill_(target_cutoff, pad).gather( - 1, - target_rank.masked_fill_(target_cutoff, - max_len).sort(1)[1]) - prev_target_tokens = prev_target_tokens[:, :prev_target_tokens. - ne(pad).sum(1).max()] + prev_target_tokens = ( + target_tokens.gather(1, target_rank) + .masked_fill_(target_cutoff, pad) + .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) + ) + prev_target_tokens = prev_target_tokens[ + :, : prev_target_tokens.ne(pad).sum(1).max() + ] return prev_target_tokens @@ -93,9 +106,9 @@ def _random_mask(target_tokens): eos = self.tgt_dict.eos() unk = self.tgt_dict.unk() - target_masks = target_tokens.ne(pad) & \ - target_tokens.ne(bos) & \ - target_tokens.ne(eos) + target_masks = ( + target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) + ) target_score = target_tokens.clone().float().uniform_() target_score.masked_fill_(~target_masks, 2.0) target_length = target_masks.sum(1).float() @@ -105,7 +118,8 @@ def _random_mask(target_tokens): _, target_rank = target_score.sort(1) target_cutoff = new_arange(target_rank) < target_length[:, None].long() prev_target_tokens = target_tokens.masked_fill( - target_cutoff.scatter(1, target_rank, target_cutoff), unk) + target_cutoff.scatter(1, target_rank, target_cutoff), unk + ) return prev_target_tokens def _full_mask(target_tokens): @@ -114,17 +128,18 @@ def _full_mask(target_tokens): eos = self.tgt_dict.eos() unk = self.tgt_dict.unk() - target_mask = target_tokens.eq(bos) | target_tokens.eq( - eos) | target_tokens.eq(pad) + target_mask = ( + target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) + ) return target_tokens.masked_fill(~target_mask, unk) - if self.args.noise == 'random_delete': + if self.args.noise == "random_delete": return _random_delete(target_tokens) - elif self.args.noise == 'random_mask': + elif self.args.noise == "random_mask": return _random_mask(target_tokens) - elif self.args.noise == 'full_mask': + elif self.args.noise == "full_mask": return _full_mask(target_tokens) - elif self.args.noise == 'no_noise': + elif self.args.noise == "no_noise": return target_tokens else: raise NotImplementedError @@ -132,34 +147,34 @@ def _full_mask(target_tokens): def build_generator(self, models, args, **unused): # add models input to match the API for SequenceGenerator from fairseq.iterative_refinement_generator import IterativeRefinementGenerator + return IterativeRefinementGenerator( self.target_dictionary, - eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0), - max_iter=getattr(args, 'iter_decode_max_iter', 10), - beam_size=getattr(args, 'iter_decode_with_beam', 1), - reranking=getattr(args, 'iter_decode_with_external_reranker', False), - decoding_format=getattr(args, 'decoding_format', None), - adaptive=not getattr(args, 'iter_decode_force_max_iter', False), - retain_history=getattr(args, 'retain_iter_history', False)) + eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), + max_iter=getattr(args, "iter_decode_max_iter", 10), + beam_size=getattr(args, "iter_decode_with_beam", 1), + reranking=getattr(args, "iter_decode_with_external_reranker", False), + decoding_format=getattr(args, "decoding_format", None), + adaptive=not getattr(args, "iter_decode_force_max_iter", False), + retain_history=getattr(args, "retain_iter_history", False), + ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ - raise NotImplementedError("Constrained decoding with the translation_lev task is not supported") + raise NotImplementedError( + "Constrained decoding with the translation_lev task is not supported" + ) return LanguagePairDataset( src_tokens, src_lengths, self.source_dictionary, append_bos=True ) - def train_step(self, - sample, - model, - criterion, - optimizer, - update_num, - ignore_grad=False): + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): model.train() - sample['prev_target'] = self.inject_noise(sample['target']) + sample["prev_target"] = self.inject_noise(sample["target"]) loss, sample_size, logging_output = criterion(model, sample) if ignore_grad: loss *= 0 @@ -169,6 +184,6 @@ def train_step(self, def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): - sample['prev_target'] = self.inject_noise(sample['target']) + sample["prev_target"] = self.inject_noise(sample["target"]) loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 960b82e1e8..95a2d162c0 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -3,34 +3,40 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import datetime +import logging import time import torch from fairseq.data import ( - data_utils, FairseqDataset, - iterators, LanguagePairDataset, ListDataset, + data_utils, + iterators, +) +from fairseq.data.multilingual.multilingual_data_manager import ( + MultilingualDatasetManager, ) - -from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.multilingual.sampling_method import SamplingMethod -from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager +from fairseq.tasks import LegacyFairseqTask, register_task from fairseq.utils import FileContentsAction + ### def get_time_gap(s, e): - return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__() + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() + + ### logger = logging.getLogger(__name__) -@register_task('translation_multi_simple_epoch') +@register_task("translation_multi_simple_epoch") class TranslationMultiSimpleEpochTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. @@ -79,7 +85,7 @@ def __init__(self, args, langs, dicts, training): if training: self.lang_pairs = args.lang_pairs else: - self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] # eval_lang_pairs for multilingual translation is usually all of the # lang_pairs. However for other multitask settings or when we want to # optimize for certain languages we want to use a different subset. Thus @@ -92,7 +98,8 @@ def __init__(self, args, langs, dicts, training): self.model_lang_pairs = self.lang_pairs self.sampling_method = SamplingMethod.build_sampler(args, self) self.data_manager = MultilingualDatasetManager.setup_data_manager( - args, self.lang_pairs, langs, dicts, self.sampling_method) + args, self.lang_pairs, langs, dicts, self.sampling_method + ) @classmethod def setup_task(cls, args, **kwargs): @@ -130,59 +137,67 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): else: # estimate the shard epoch from virtual data size and virtual epoch size shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) - logger.info(f'loading data for {split} epoch={epoch}/{shard_epoch}') + logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") logger.info(f"mem usage: {data_utils.get_mem_usage()}") if split in self.datasets: del self.datasets[split] - logger.info('old dataset deleted manually') + logger.info("old dataset deleted manually") logger.info(f"mem usage: {data_utils.get_mem_usage()}") self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( split, self.training, - epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs + epoch=epoch, + combine=combine, + shard_epoch=shard_epoch, + **kwargs, ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: - raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported") + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) src_data = ListDataset(src_tokens, src_lengths) dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) - src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main'] + src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] if self.args.lang_tok_replacing_bos_eos: dataset = self.data_manager.alter_dataset_langtok( - dataset, - src_eos=self.source_dictionary.eos(), - src_lang=self.args.source_lang, - tgt_eos=self.target_dictionary.eos(), - tgt_lang=self.args.target_lang, - src_langtok_spec=src_langtok_spec, - tgt_langtok_spec=tgt_langtok_spec, - ) + dataset, + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) else: dataset.src = self.data_manager.src_dataset_tranform_func( self.args.source_lang, self.args.target_lang, dataset=dataset.src, spec=src_langtok_spec, - ) + ) return dataset def build_generator( - self, models, args, - seq_gen_cls=None, extra_gen_cls_kwargs=None, + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, ): - if not getattr(args, 'keep_inference_langtok', False): - _, tgt_langtok_spec = self.args.langtoks['main'] + if not getattr(args, "keep_inference_langtok", False): + _, tgt_langtok_spec = self.args.langtoks["main"] if tgt_langtok_spec: - tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} - extra_gen_cls_kwargs['symbols_to_strip_from_output'] = {tgt_lang_tok} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} return super().build_generator( - models, args, - seq_gen_cls=None, - extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_model(self, args): @@ -192,30 +207,37 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): with torch.no_grad(): - _, tgt_langtok_spec = self.args.langtoks['main'] + _, tgt_langtok_spec = self.args.langtoks["main"] if not self.args.lang_tok_replacing_bos_eos: if prefix_tokens is None and tgt_langtok_spec: - tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) - src_tokens = sample['net_input']['src_tokens'] + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + src_tokens = sample["net_input"]["src_tokens"] bsz = src_tokens.size(0) - prefix_tokens = torch.LongTensor( - [[tgt_lang_tok]] - ).expand(bsz, 1).to(src_tokens) + prefix_tokens = ( + torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) + ) return generator.generate( - models, - sample, - prefix_tokens=prefix_tokens, - constraints=constraints, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, ) else: return generator.generate( - models, - sample, - prefix_tokens=prefix_tokens, - bos_token=self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec) - if tgt_langtok_spec else self.target_dictionary.eos(), + models, + sample, + prefix_tokens=prefix_tokens, + bos_token=self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + if tgt_langtok_spec + else self.target_dictionary.eos(), ) def reduce_metrics(self, logging_outputs, criterion): @@ -234,15 +256,18 @@ def target_dictionary(self): return next(iter(self.dicts.values())) def create_batch_sampler_func( - self, max_positions, ignore_invalid_inputs, - max_tokens, max_sentences, + self, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, required_batch_size_multiple=1, seed=1, ): - def construct_batch_sampler( - dataset, epoch - ): - splits = [s for s, _ in self.datasets.items() if self.datasets[s] == dataset] + def construct_batch_sampler(dataset, epoch): + splits = [ + s for s, _ in self.datasets.items() if self.datasets[s] == dataset + ] split = splits[0] if len(splits) > 0 else None # NEW implementation if epoch is not None: @@ -255,7 +280,9 @@ def construct_batch_sampler( with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() - logger.info(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}') + logger.info( + f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" + ) logger.info(f"mem usage: {data_utils.get_mem_usage()}") # filter examples that are too large @@ -264,7 +291,9 @@ def construct_batch_sampler( indices = self.filter_indices_by_size( indices, dataset, max_positions, ignore_invalid_inputs ) - logger.info(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}') + logger.info( + f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" + ) logger.info(f"mem usage: {data_utils.get_mem_usage()}") # create mini-batches with given size constraints @@ -276,19 +305,34 @@ def construct_batch_sampler( required_batch_size_multiple=required_batch_size_multiple, ) - logger.info(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}') - logger.info(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}') + logger.info( + f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info( + f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" + ) logger.info(f"mem usage: {data_utils.get_mem_usage()}") return batch_sampler + return construct_batch_sampler # we need to override get_batch_iterator because we want to reset the epoch iterator each time def get_batch_iterator( - self, dataset, max_tokens=None, max_sentences=None, max_positions=None, - ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, - data_buffer_size=0, disable_iterator_cache=False, + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -329,9 +373,7 @@ def get_batch_iterator( assert isinstance(dataset, FairseqDataset) if dataset in self.dataset_to_epoch_iter: return self.dataset_to_epoch_iter[dataset] - if ( - self.args.sampling_method == 'RoundRobin' - ): + if self.args.sampling_method == "RoundRobin": batch_iter = super().get_batch_iterator( dataset, max_tokens=max_tokens, @@ -351,8 +393,10 @@ def get_batch_iterator( return batch_iter construct_batch_sampler = self.create_batch_sampler_func( - max_positions, ignore_invalid_inputs, - max_tokens, max_sentences, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, required_batch_size_multiple=required_batch_size_multiple, seed=seed, ) diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py index 7077199fd9..e708dc51bc 100644 --- a/fairseq/token_generation_constraints.py +++ b/fairseq/token_generation_constraints.py @@ -27,10 +27,11 @@ that many times in the output. """ +from collections import Counter +from typing import List, Optional, Set, Tuple + import torch -from collections import Counter -from typing import Tuple, List, Optional, Set class ConstraintState: def __init__(self): @@ -70,7 +71,11 @@ def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tenso for sentence_constraints in batch_constraints: if len(sentence_constraints): # number of constraints, plus sum of constrain lens, plus a zero after each - constraints_len = 1 + sum([c.size(0) for c in sentence_constraints]) + len(sentence_constraints) + constraints_len = ( + 1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints) + ) max_constraints_len = max(max_constraints_len, constraints_len) batch_size = len(batch_constraints) @@ -80,7 +85,7 @@ def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tenso offset = 1 for j, constraint in enumerate(sentence_constraints): this_len = constraint.size(0) - constraints_tensor[i, offset:offset+this_len] = constraint + constraints_tensor[i, offset : offset + this_len] = constraint offset += this_len + 1 return constraints_tensor.long() @@ -107,6 +112,7 @@ class ConstraintNode: """ Represents a node in a trie managing unordered constraints. """ + def __init__(self, token: int = None, parent=None): # The token associate with this node (None for the root) self.token = int(token) if token is not None else None @@ -198,9 +204,8 @@ class UnorderedConstraintState(ConstraintState): Records progress through the set of constraints for each item in the beam using a trie. """ - def __init__(self, - node: ConstraintNode, - copy_from: "ConstraintState" = None): + + def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): self.node = node if copy_from is None: @@ -383,9 +388,8 @@ class OrderedConstraintState(ConstraintState): """ Records progress through the set of linear nonbranching constraints with gaps. """ - def __init__(self, - sequence: ConstraintSequence, - state: int = -1): + + def __init__(self, sequence: ConstraintSequence, state: int = -1): self.sequence = sequence self.state = state @@ -407,7 +411,9 @@ def copy(self): def num_completed(self): if self.state == -1: return 0 - count = len(list(filter(lambda x: x, self.sequence.endpoints[0:self.state+1]))) + count = len( + list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) + ) return count @property diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py index 8c4d694aa0..42131f7b1d 100644 --- a/fairseq/tokenizer.py +++ b/fairseq/tokenizer.py @@ -5,6 +5,7 @@ import re + SPACE_NORMALIZER = re.compile(r"\s+") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5d68783bfb..0069b79425 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -8,14 +8,13 @@ """ import contextlib -from itertools import chain import logging import sys import time +from itertools import chain from typing import Any, Dict, List import torch - from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics @@ -43,20 +42,21 @@ def __init__(self, args, task, model, criterion, quantizer=None): # catalog shared parameters shared_params = _catalog_shared_params(model) - self.tpu = getattr(args, 'tpu', False) + self.tpu = getattr(args, "tpu", False) self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu if self.cuda: - self.device = torch.device('cuda') + self.device = torch.device("cuda") elif self.tpu: self.device = utils.get_tpu_device(args) else: - self.device = torch.device('cpu') + self.device = torch.device("cpu") # copy model and criterion to current device/dtype self._criterion = criterion self._model = model if self.tpu: import torch_xla.core.xla_model as xm + self._model = xm.send_cpu_data_to_device(self._model, self.device) if args.fp16: self._criterion = self._criterion.half() @@ -77,7 +77,7 @@ def __init__(self, args, task, model, criterion, quantizer=None): ref = _get_module_by_path(self._model, shared_param[0]) for path in shared_param[1:]: logger.info( - 'detected shared parameter: {} <- {}'.format(shared_param[0], path) + "detected shared parameter: {} <- {}".format(shared_param[0], path) ) _set_module_by_path(self._model, path, ref) @@ -134,7 +134,7 @@ def data_parallel_world_size(self): @property def data_parallel_process_group(self): if self.tpu: - return ('tpu', None) + return ("tpu", None) else: return None @@ -156,8 +156,9 @@ def criterion(self): and not self.tpu ): self._wrapped_criterion = models.DistributedFairseqModel( - self.args, self._criterion, - process_group=self.data_parallel_process_group + self.args, + self._criterion, + process_group=self.data_parallel_process_group, ) else: self._wrapped_criterion = self._criterion @@ -172,8 +173,9 @@ def model(self): and not self.tpu ): self._wrapped_model = models.DistributedFairseqModel( - self.args, self._model, - process_group=self.data_parallel_process_group + self.args, + self._model, + process_group=self.data_parallel_process_group, ) else: self._wrapped_model = self._model @@ -219,17 +221,20 @@ def _build_optimizer(self): if self.args.use_bmuf: self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) - if self.args.zero_sharding == 'os': - if (self.args.fp16 - and not self.args.memory_efficient_fp16 - and not self.args.memory_efficient_bf16 + if self.args.zero_sharding == "os": + if ( + self.args.fp16 + and not self.args.memory_efficient_fp16 + and not self.args.memory_efficient_bf16 ) and not self.args.fp16_no_flatten_grads: raise ValueError( - "ZeRO is incomptabile with fp16 and flattened grads. " - "Please use --fp16-no-flatten-grads" + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" ) else: - optim.shard_(self.args, self._optimizer, self.data_parallel_process_group) + optim.shard_( + self.args, self._optimizer, self.data_parallel_process_group + ) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. @@ -416,7 +421,7 @@ def begin_epoch(self, epoch): if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous('begin_epoch') # wait for all workers + xm.rendezvous("begin_epoch") # wait for all workers xm.mark_step() def begin_valid_epoch(self, epoch): @@ -511,13 +516,14 @@ def maybe_no_sync(): # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm + xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: - sample_size *= 0. + sample_size *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() @@ -527,27 +533,42 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() - logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( - logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, + logging_outputs, ( + sample_size, + ooms, + total_train_time, + ) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ooms, + train_time, + ignore=is_dummy_batch, + ) + self._cumulative_training_time = ( + total_train_time / self.data_parallel_world_size ) - self._cumulative_training_time = total_train_time / self.data_parallel_world_size - if hasattr(self.model, 'all_reduce'): + if hasattr(self.model, "all_reduce"): self.model.all_reduce() overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm + gradients = xm._fetch_gradients(self.optimizer.optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) + xm.all_reduce( + "sum", gradients, scale=1.0 / self.data_parallel_world_size + ) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: - self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) + self.optimizer.multiply_grads( + self.data_parallel_world_size / sample_size + ) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) @@ -559,7 +580,7 @@ def maybe_no_sync(): # check that grad norms are consistent across workers if ( not self.args.use_bmuf - and self.args.distributed_wrapper != 'SlowMo' + and self.args.distributed_wrapper != "SlowMo" and not self.tpu ): self._check_grad_norms(grad_norm) @@ -573,14 +594,18 @@ def maybe_no_sync(): # out where it fails with NanDetector(self.get_model()): self.task.train_step( - sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), - ignore_grad=False + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) - grad_norm = torch.tensor(0.).cuda() + grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): @@ -589,18 +614,23 @@ def maybe_no_sync(): raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step - if hasattr(self.model, 'perform_additional_optimizer_actions'): - if hasattr(self.optimizer, 'fp32_params'): - self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) + if hasattr(self.model, "perform_additional_optimizer_actions"): + if hasattr(self.optimizer, "fp32_params"): + self.model.perform_additional_optimizer_actions( + self.optimizer.optimizer, self.optimizer.fp32_params + ) else: - self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) + self.model.perform_additional_optimizer_actions( + self.optimizer.optimizer + ) - if not overflow or self.args.distributed_wrapper == 'SlowMo': + if not overflow or self.args.distributed_wrapper == "SlowMo": self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm + xm.mark_step() # only log stats every log_interval steps @@ -609,17 +639,27 @@ def maybe_no_sync(): if self.get_num_updates() % self.args.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) - gb_free = mem_info['kb_free'] / 1024 / 1024 - gb_total = mem_info['kb_total'] / 1024 / 1024 + gb_free = mem_info["kb_free"] / 1024 / 1024 + gb_total = mem_info["kb_total"] / 1024 / 1024 metrics.log_scalar( - 'gb_free', gb_free, priority=1500, round=1, weight=0, + "gb_free", + gb_free, + priority=1500, + round=1, + weight=0, ) metrics.log_scalar( - 'gb_total', gb_total, priority=1600, round=1, weight=0, + "gb_total", + gb_total, + priority=1600, + round=1, + weight=0, ) logging_output = self._reduce_and_log_stats( - logging_outputs, sample_size, grad_norm, + logging_outputs, + sample_size, + grad_norm, ) # log whenever there's an XLA compilation, since these @@ -629,7 +669,9 @@ def maybe_no_sync(): else: # log stats logging_output = self._reduce_and_log_stats( - logging_outputs, sample_size, grad_norm, + logging_outputs, + sample_size, + grad_norm, ) # clear CUDA cache to reduce memory fragmentation @@ -639,7 +681,8 @@ def maybe_no_sync(): and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq - ) == 0 + ) + == 0 ): torch.cuda.empty_cache() @@ -660,7 +703,8 @@ def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous('valid_step') # wait for all workers + + xm.rendezvous("valid_step") # wait for all workers xm.mark_step() with torch.no_grad(): @@ -700,12 +744,14 @@ def valid_step(self, sample, raise_oom=False): if torch.is_tensor(sample_size): sample_size.zero_() else: - sample_size *= 0. + sample_size *= 0.0 # gather logging outputs from all replicas if self.data_parallel_world_size > 1: - logging_outputs, (sample_size, ) = self._aggregate_logging_outputs( - logging_outputs, sample_size, ignore=is_dummy_batch, + logging_outputs, (sample_size,) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ignore=is_dummy_batch, ) # log validation stats @@ -744,10 +790,10 @@ def get_meter(self, name): """[deprecated] Get a specific meter by name.""" from fairseq import meters - if 'get_meter' not in self._warn_once: - self._warn_once.add('get_meter') + if "get_meter" not in self._warn_once: + self._warn_once.add("get_meter") utils.deprecation_warning( - 'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.' + "Trainer.get_meter is deprecated. Please use fairseq.metrics instead." ) train_meters = metrics.get_meters("train") @@ -772,7 +818,7 @@ def get_meter(self, name): elif name in {"valid_loss", "valid_nll_loss"}: # support for legacy train.py, which assumed these meters # are always initialized - k = name[len("valid_"):] + k = name[len("valid_") :] m = metrics.get_meter("valid", k) return m or meters.AverageMeter() elif name == "oom": @@ -820,8 +866,10 @@ def _prepare_sample(self, sample): if self.cuda: if self.pipeline_model_parallel: - if 'target' in sample: - sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device) + if "target" in sample: + sample["target"] = utils.move_to_cuda( + sample["target"], device=self.last_device + ) else: sample = utils.move_to_cuda(sample) @@ -855,10 +903,9 @@ def _sync_stats(self): if self.data_parallel_world_size == 1: return False elif self.args.use_bmuf: - return ( - (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 - and (self.get_num_updates() + 1) > self.args.warmup_iterations - ) + return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.args.warmup_iterations else: return True @@ -899,13 +946,15 @@ def _all_gather_list_sync( raise NotImplementedError if ignore: logging_outputs = [] - results = list(zip( - *distributed_utils.all_gather_list( - [logging_outputs] + list(extra_stats_to_sum), - max_size=getattr(self.args, 'all_gather_list_size', 16384), - group=self.data_parallel_process_group, + results = list( + zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.args, "all_gather_list_size", 16384), + group=self.data_parallel_process_group, + ) ) - )) + ) logging_outputs, extra_stats_to_sum = results[0], results[1:] logging_outputs = list(chain.from_iterable(logging_outputs)) extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] @@ -925,7 +974,7 @@ def _fast_stat_sync_sum( """ data = {} for i, stat in enumerate(extra_stats_to_sum): - data['extra_stats_' + str(i)] = stat + data["extra_stats_" + str(i)] = stat if len(logging_outputs) > 0: log_keys = list(logging_outputs[0].keys()) for k in log_keys: @@ -934,21 +983,19 @@ def _fast_stat_sync_sum( else: v = logging_outputs[0][k] v = torch.zeros_like(v) if torch.is_tensor(v) else 0 - data['logging_outputs_' + k] = v + data["logging_outputs_" + k] = v else: log_keys = None data = distributed_utils.all_reduce_dict( - data, - device=self.device, - group=self.data_parallel_process_group + data, device=self.device, group=self.data_parallel_process_group ) extra_stats_to_sum = [ - data['extra_stats_' + str(i)] for i in range(len(extra_stats_to_sum)) + data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) ] if log_keys is not None: - logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}] + logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] else: logging_outputs = [] return logging_outputs, extra_stats_to_sum @@ -959,8 +1006,7 @@ def _check_grad_norms(self, grad_norm): self._grad_norm_buf.zero_() self._grad_norm_buf[self.data_parallel_rank] = grad_norm distributed_utils.all_reduce( - self._grad_norm_buf, - group=self.data_parallel_process_group + self._grad_norm_buf, group=self.data_parallel_process_group ) def is_consistent(tensor): @@ -975,7 +1021,9 @@ def is_consistent(tensor): "rank {:3d} = {:.8f}".format(r, n) for r, n in enumerate(self._grad_norm_buf.tolist()) ) - error_detail = "grad_norm across the workers:\n{}\n".format(pretty_detail) + error_detail = "grad_norm across the workers:\n{}\n".format( + pretty_detail + ) raise RuntimeError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d. " @@ -988,7 +1036,7 @@ def is_consistent(tensor): def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: - metrics.log_speed("ups", 1., priority=100, round=2) + metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: metrics.log_scalar( @@ -1030,6 +1078,7 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): def _check_xla_compilation(self): import torch_xla.debug.metrics as met + compile_stats = met.metric_data("CompileTime") if compile_stats is None: return @@ -1037,41 +1086,42 @@ def _check_xla_compilation(self): if num_xla_compiles > self._num_xla_compiles: logger.warning( "XLA compilation detected on device #{}; too many of these can lead " - "to slow training, but we expect a few in the beginning" - .format(self.args.distributed_rank) + "to slow training, but we expect a few in the beginning".format( + self.args.distributed_rank + ) ) self._num_xla_compiles = num_xla_compiles -def _catalog_shared_params(module, memo=None, prefix=''): +def _catalog_shared_params(module, memo=None, prefix=""): if memo is None: first_call = True memo = {} else: first_call = False for name, param in module._parameters.items(): - param_prefix = prefix + ('.' if prefix else '') + name + param_prefix = prefix + ("." if prefix else "") + name if param not in memo: memo[param] = [] memo[param].append(param_prefix) for name, m in module._modules.items(): if m is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name + submodule_prefix = prefix + ("." if prefix else "") + name _catalog_shared_params(m, memo, submodule_prefix) if first_call: return [x for x in memo.values() if len(x) > 1] def _get_module_by_path(module, path): - path = path.split('.') + path = path.split(".") for name in path: module = getattr(module, name) return module def _set_module_by_path(module, path, value): - path = path.split('.') + path = path.split(".") for name in path[:-1]: module = getattr(module, name) setattr(module, path[-1], value) diff --git a/fairseq/utils.py b/fairseq/utils.py index 1a18bf5e6c..fdbf66cf3f 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -445,7 +445,7 @@ def import_user_module(args): # temporary directory and symlink the user_dir under a new name, which is # a deterministic hash of the original module_path. with tempfile.TemporaryDirectory() as tmpdirname: - unique_mod_name = 'fairseq_user_dir_{}'.format(hash(module_path) % 100000) + unique_mod_name = "fairseq_user_dir_{}".format(hash(module_path) % 100000) os.symlink(module_path, os.path.join(tmpdirname, unique_mod_name)) sys.path.insert(0, tmpdirname) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 64c83673e6..9a4ff8ee39 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -13,21 +13,19 @@ import os import torch - -from fairseq import checkpoint_utils, options, tasks, utils +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer -from fairseq import distributed_utils logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=os.environ.get('LOGLEVEL', 'INFO').upper(), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), ) -logger = logging.getLogger('fairseq_cli.eval_lm') +logger = logging.getLogger("fairseq_cli.eval_lm") class WordStat(object): @@ -40,10 +38,10 @@ def __init__(self, word, is_bpe): self.missing_next_words = 0 def add(self, log_prob, next_word_prob): - """ increments counters for the sum of log probs of current word and next - word (given context ending at current word). Since the next word might be at the end of the example, - or it might be not counted because it is not an ending subword unit, - also keeps track of how many of those we have seen """ + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" if next_word_prob is not None: self.next_word_prob += next_word_prob else: @@ -52,12 +50,18 @@ def add(self, log_prob, next_word_prob): self.count += 1 def __str__(self): - return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe, - self.next_word_prob, self.count - self.missing_next_words) + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, + ) def main(parsed_args, **unused_kwargs): - assert parsed_args.path is not None, '--path required for evaluation!' + assert parsed_args.path is not None, "--path required for evaluation!" if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) @@ -71,7 +75,7 @@ def main(parsed_args, **unused_kwargs): task = tasks.setup_task(parsed_args) # Load ensemble - logger.info('loading model(s) from {}'.format(parsed_args.path)) + logger.info("loading model(s) from {}".format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), @@ -83,8 +87,12 @@ def main(parsed_args, **unused_kwargs): for arg in vars(parsed_args).keys(): if arg not in { - 'self_target', 'future_target', 'past_target', 'tokens_per_sample', - 'output_size_dictionary', 'add_bos_token', + "self_target", + "future_target", + "past_target", + "tokens_per_sample", + "output_size_dictionary", + "add_bos_token", }: setattr(args, arg, getattr(parsed_args, arg)) @@ -102,7 +110,7 @@ def main(parsed_args, **unused_kwargs): context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) - logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) + logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: @@ -114,15 +122,17 @@ def main(parsed_args, **unused_kwargs): assert len(models) > 0 - logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) + logger.info( + "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) + ) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.batch_size, - max_positions=utils.resolve_max_positions(*[ - model.max_positions() for model in models - ]), + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, @@ -133,17 +143,17 @@ def main(parsed_args, **unused_kwargs): itr, log_format=args.log_format, log_interval=args.log_interval, - default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) - score_sum = 0. + score_sum = 0.0 count = 0 if args.remove_bpe is not None: - if args.remove_bpe == 'sentencepiece': + if args.remove_bpe == "sentencepiece": raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() @@ -162,25 +172,25 @@ def main(parsed_args, **unused_kwargs): wps_meter = TimeMeter() for sample in progress: - if 'net_input' not in sample: + if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) - gen_timer.stop(sample['ntokens']) + gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] - sample_id = sample['id'][i] + sample_id = sample["id"][i] - tokens = hypo['tokens'] + tokens = hypo["tokens"] tgt_len = tokens.numel() - pos_scores = hypo['positional_scores'].float() + pos_scores = hypo["positional_scores"].float() - if getattr(args, 'add_bos_token', False): - assert hypo['tokens'][0].item() == task.target_dictionary.bos() + if getattr(args, "add_bos_token", False): + assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -192,18 +202,18 @@ def main(parsed_args, **unused_kwargs): pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 - inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) + inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) if inf_scores.any(): logger.info( - 'skipping tokens with inf scores:', - task.target_dictionary.string(tokens[inf_scores.nonzero()]) + "skipping tokens with inf scores:", + task.target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: - w = '' + w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): @@ -223,25 +233,36 @@ def main(parsed_args, **unused_kwargs): break ind += 1 - word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) + word_stats.setdefault(w, WordStat(w, is_bpe)).add( + pos_scores[i].item(), next_prob + ) is_bpe = False - w = '' + w = "" if args.output_word_probs: logger.info( - str(int(sample_id)) + " " - + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) + str(int(sample_id)) + + " " + + ( + "\t".join( + "{} [{:2f}]".format(x[0], x[1]) for x in word_prob + ) + ) ) - wps_meter.update(sample['ntokens']) - progress.log({'wps': round(wps_meter.avg)}) + wps_meter.update(sample["ntokens"]) + progress.log({"wps": round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 - logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( - gen_timer.n, gen_timer.sum, 1. / gen_timer.avg - )) - logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( - avg_nll_loss, 2**avg_nll_loss - )) + logger.info( + "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( + gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg + ) + ) + logger.info( + "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( + avg_nll_loss, 2 ** avg_nll_loss + ) + ) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): @@ -254,5 +275,5 @@ def cli_main(): distributed_utils.call_main(args, main) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 0064b88a95..8ddf981cc3 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -8,39 +8,41 @@ """ import ast -from itertools import chain import logging import math import os import sys +from itertools import chain import numpy as np - import torch - from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter def main(args): - assert args.path is not None, '--path required for generation!' - assert not args.sampling or args.nbest == args.beam, \ - '--sampling requires --nbest to be equal to --beam' - assert args.replace_unk is None or args.dataset_impl == 'raw', \ - '--replace-unk requires a raw text dataset (--dataset-impl=raw)' + assert args.path is not None, "--path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.dataset_impl == "raw" + ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" if args.results_path is not None: os.makedirs(args.results_path, exist_ok=True) - output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) - with open(output_path, 'w', buffering=1, encoding='utf-8') as h: + output_path = os.path.join( + args.results_path, "generate-{}.txt".format(args.gen_subset) + ) + with open(output_path, "w", buffering=1, encoding="utf-8") as h: return _main(args, h) else: return _main(args, sys.stdout) def get_symbols_to_strip_from_output(generator): - if hasattr(generator, 'symbols_to_strip_from_output'): + if hasattr(generator, "symbols_to_strip_from_output"): return generator.symbols_to_strip_from_output else: return {generator.eos} @@ -48,12 +50,12 @@ def get_symbols_to_strip_from_output(generator): def _main(args, output_file): logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=os.environ.get('LOGLEVEL', 'INFO').upper(), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) - logger = logging.getLogger('fairseq_cli.generate') + logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(args) @@ -74,7 +76,7 @@ def _main(args, output_file): # Set dictionaries try: - src_dict = getattr(task, 'source_dictionary', None) + src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary @@ -82,7 +84,7 @@ def _main(args, output_file): overrides = ast.literal_eval(args.model_overrides) # Load ensemble - logger.info('loading model(s) from {}'.format(args.path)) + logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=overrides, @@ -93,7 +95,7 @@ def _main(args, output_file): ) if args.lm_path is not None: - overrides['data'] = args.data + overrides["data"] = args.data try: lms, _ = checkpoint_utils.load_model_ensemble( @@ -102,8 +104,10 @@ def _main(args, output_file): task=None, ) except: - logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})") + logger.warning( + f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({args.data})" + ) raise assert len(lms) == 1 @@ -130,8 +134,7 @@ def _main(args, output_file): max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() for model in models] + task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, @@ -144,17 +147,16 @@ def _main(args, output_file): itr, log_format=args.log_format, log_interval=args.log_interval, - default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) # Initialize generator gen_timer = StopwatchMeter() - extra_gen_cls_kwargs = { - 'lm_model': lms[0], - 'lm_weight': args.lm_weight - } - generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight} + generator = task.build_generator( + models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) # Handle tokenization and BPE tokenizer = task.build_tokenizer(args) @@ -174,39 +176,51 @@ def decode_fn(x): wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample - if 'net_input' not in sample: + if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: - prefix_tokens = sample['target'][:, :args.prefix_size] + prefix_tokens = sample["target"][:, : args.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) - num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) - for i, sample_id in enumerate(sample['id'].tolist()): - has_target = sample['target'] is not None + for i, sample_id in enumerate(sample["id"].tolist()): + has_target = sample["target"] is not None # Remove padding - if 'src_tokens' in sample['net_input']: - src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + if "src_tokens" in sample["net_input"]: + src_tokens = utils.strip_pad( + sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() + ) else: src_tokens = None target_tokens = None if has_target: - target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() + target_tokens = ( + utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() + ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) - target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) + target_str = task.dataset(args.gen_subset).tgt.get_original_text( + sample_id + ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) @@ -217,7 +231,9 @@ def decode_fn(x): target_tokens, args.remove_bpe, escape_unk=True, - extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), + extra_symbols_to_ignore=get_symbols_to_strip_from_output( + generator + ), ) src_str = decode_fn(src_str) @@ -226,16 +242,16 @@ def decode_fn(x): if not args.quiet: if src_dict is not None: - print('S-{}\t{}'.format(sample_id, src_str), file=output_file) + print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: - print('T-{}\t{}'.format(sample_id, target_str), file=output_file) + print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][:args.nbest]): + for j, hypo in enumerate(hypos[i][: args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( - hypo_tokens=hypo['tokens'].int().cpu(), + hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, - alignment=hypo['alignment'], + alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, @@ -243,71 +259,116 @@ def decode_fn(x): ) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: - score = hypo['score'] / math.log(2) # convert to base 2 + score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) - print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) + print( + "H-{}\t{}\t{}".format(sample_id, score, hypo_str), + file=output_file, + ) # detokenized hypothesis - print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file) - print('P-{}\t{}'.format( - sample_id, - ' '.join(map( - lambda x: '{:.4f}'.format(x), - # convert from base e to base 2 - hypo['positional_scores'].div_(math.log(2)).tolist(), - )) - ), file=output_file) + print( + "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), + file=output_file, + ) + print( + "P-{}\t{}".format( + sample_id, + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"] + .div_(math.log(2)) + .tolist(), + ) + ), + ), + file=output_file, + ) if args.print_alignment: - print('A-{}\t{}'.format( - sample_id, - ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) - ), file=output_file) + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in alignment + ] + ), + ), + file=output_file, + ) if args.print_step: - print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) + print( + "I-{}\t{}".format(sample_id, hypo["steps"]), + file=output_file, + ) - if getattr(args, 'retain_iter_history', False): - for step, h in enumerate(hypo['history']): + if getattr(args, "retain_iter_history", False): + for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( - hypo_tokens=h['tokens'].int().cpu(), + hypo_tokens=h["tokens"].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) - print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) + print( + "E-{}_{}\t{}".format(sample_id, step, h_str), + file=output_file, + ) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE - target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) - hypo_tokens = tgt_dict.encode_line(detok_hypo_str, add_if_not_exist=True) - if hasattr(scorer, 'add_string'): + target_tokens = tgt_dict.encode_line( + target_str, add_if_not_exist=True + ) + hypo_tokens = tgt_dict.encode_line( + detok_hypo_str, add_if_not_exist=True + ) + if hasattr(scorer, "add_string"): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) - progress.log({'wps': round(wps_meter.avg)}) - num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() - - logger.info('NOTE: hypothesis and token scores are output in base 2') - logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( - num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info( + "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) + ) if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: logger.warning( - "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" + ) else: logger.warning( - "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" + ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( - 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), - file=output_file) + "Generate {} with beam={}: {}".format( + args.gen_subset, args.beam, scorer.result_string() + ), + file=output_file, + ) return scorer @@ -318,5 +379,5 @@ def cli_main(): main(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index fc4b46e39d..de3893a385 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -7,34 +7,33 @@ Translate raw text with a trained model. Batches data on-the-fly. """ -from collections import namedtuple import fileinput import logging import math +import os import sys import time -import os +from collections import namedtuple import numpy as np - import torch - from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output + logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=os.environ.get('LOGLEVEL', 'INFO').upper(), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger('fairseq_cli.interactive') +logger = logging.getLogger("fairseq_cli.interactive") -Batch = namedtuple('Batch', 'ids src_tokens src_lengths constraints') -Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') +Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") +Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") def buffered_read(input, buffer_size): @@ -64,11 +63,14 @@ def encode_fn_target(x): # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): - batch_constraints[i] = [task.target_dictionary.encode_line( - encode_fn_target(constraint), - append_eos=False, - add_if_not_exist=False, - ) for constraint in constraint_list] + batch_constraints[i] = [ + task.target_dictionary.encode_line( + encode_fn_target(constraint), + append_eos=False, + add_if_not_exist=False, + ) + for constraint in constraint_list + ] tokens = [ task.source_dictionary.encode_line( @@ -84,16 +86,18 @@ def encode_fn_target(x): lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( - dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor), + dataset=task.build_dataset_for_inference( + tokens, lengths, constraints=constraints_tensor + ), max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=max_positions, - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: - ids = batch['id'] - src_tokens = batch['net_input']['src_tokens'] - src_lengths = batch['net_input']['src_lengths'] + ids = batch["id"] + src_tokens = batch["net_input"]["src_tokens"] + src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( @@ -115,10 +119,12 @@ def main(args): if args.max_tokens is None and args.batch_size is None: args.batch_size = 1 - assert not args.sampling or args.nbest == args.beam, \ - '--sampling requires --nbest to be equal to --beam' - assert not args.batch_size or args.batch_size <= args.buffer_size, \ - '--batch-size cannot be larger than --buffer-size' + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + not args.batch_size or args.batch_size <= args.buffer_size + ), "--batch-size cannot be larger than --buffer-size" logger.info(args) @@ -133,7 +139,7 @@ def main(args): task = tasks.setup_task(args) # Load ensemble - logger.info('loading model(s) from {}'.format(args.path)) + logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), @@ -181,17 +187,18 @@ def decode_fn(x): align_dict = utils.load_align_dict(args.replace_unk) max_positions = utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() for model in models] + task.max_positions(), *[model.max_positions() for model in models] ) if args.constraints: - logger.warning("NOTE: Constrained decoding currently assumes a shared subword vocabulary.") + logger.warning( + "NOTE: Constrained decoding currently assumes a shared subword vocabulary." + ) if args.buffer_size > 1: - logger.info('Sentence buffer size: %s', args.buffer_size) - logger.info('NOTE: hypothesis and token scores are output in base 2') - logger.info('Type the input sentence and press return:') + logger.info("Sentence buffer size: %s", args.buffer_size) + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Type the input sentence and press return:") start_id = 0 for inputs in buffered_read(args.input, args.buffer_size): results = [] @@ -207,13 +214,15 @@ def decode_fn(x): constraints = constraints.cuda() sample = { - 'net_input': { - 'src_tokens': src_tokens, - 'src_lengths': src_lengths, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, }, } translate_start_time = time.time() - translations = task.inference_step(generator, models, sample, constraints=constraints) + translations = task.inference_step( + generator, models, sample, constraints=constraints + ) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] @@ -222,56 +231,75 @@ def decode_fn(x): for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] - results.append((start_id + id, src_tokens_i, hypos, - { "constraints": constraints, - "time": translate_time / len(translations) } - )) + results.append( + ( + start_id + id, + src_tokens_i, + hypos, + { + "constraints": constraints, + "time": translate_time / len(translations), + }, + ) + ) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) - print('S-{}\t{}'.format(id_, src_str)) + print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: - print("C-{}\t{}".format(id_, tgt_dict.string(constraint, args.remove_bpe))) + print( + "C-{}\t{}".format( + id_, tgt_dict.string(constraint, args.remove_bpe) + ) + ) # Process top predictions - for hypo in hypos[:min(len(hypos), args.nbest)]: + for hypo in hypos[: min(len(hypos), args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( - hypo_tokens=hypo['tokens'].int().cpu(), + hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, - alignment=hypo['alignment'], + alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) - score = hypo['score'] / math.log(2) # convert to base 2 + score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) - print('H-{}\t{}\t{}'.format(id_, score, hypo_str)) + print("H-{}\t{}\t{}".format(id_, score, hypo_str)) # detokenized hypothesis - print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str)) - print('P-{}\t{}'.format( - id_, - ' '.join(map( - lambda x: '{:.4f}'.format(x), - # convert from base e to base 2 - hypo['positional_scores'].div_(math.log(2)).tolist(), - )) - )) - if args.print_alignment: - alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment]) - print('A-{}\t{}'.format( + print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) + print( + "P-{}\t{}".format( id_, - alignment_str - )) + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"].div_(math.log(2)).tolist(), + ) + ), + ) + ) + if args.print_alignment: + alignment_str = " ".join( + ["{}-{}".format(src, tgt) for src, tgt in alignment] + ) + print("A-{}\t{}".format(id_, alignment_str)) # update running id_ counter start_id += len(inputs) - logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(time.time() - start_time, total_translate_time)) + logger.info( + "Total time: {:.3f} seconds; translation time: {:.3f}".format( + time.time() - start_time, total_translate_time + ) + ) + def cli_main(): parser = options.get_interactive_generation_parser() @@ -279,5 +307,5 @@ def cli_main(): distributed_utils.call_main(args, main) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index 3fe5131324..fa77da8dba 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -7,26 +7,26 @@ Data pre-processing: build vocabularies and binarize training data. """ -from collections import Counter -from itertools import zip_longest import logging -from multiprocessing import Pool import os import shutil import sys +from collections import Counter +from itertools import zip_longest +from multiprocessing import Pool from fairseq import options, tasks, utils -from fairseq.data import indexed_dataset from fairseq.binarizer import Binarizer +from fairseq.data import indexed_dataset logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=os.environ.get('LOGLEVEL', 'INFO').upper(), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger('fairseq_cli.preprocess') +logger = logging.getLogger("fairseq_cli.preprocess") def main(args): @@ -34,9 +34,11 @@ def main(args): os.makedirs(args.destdir, exist_ok=True) - logger.addHandler(logging.FileHandler( - filename=os.path.join(args.destdir, 'preprocess.log'), - )) + logger.addHandler( + logging.FileHandler( + filename=os.path.join(args.destdir, "preprocess.log"), + ) + ) logger.info(args) task = tasks.get_task(args.task) @@ -74,31 +76,39 @@ def build_dictionary(filenames, src=False, tgt=False): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: - assert not args.srcdict or not args.tgtdict, \ - "cannot use both --srcdict and --tgtdict with --joined-dictionary" + assert ( + not args.srcdict or not args.tgtdict + ), "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: - assert args.trainpref, "--trainpref must be set if --srcdict is not specified" + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( - {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True + {train_path(lang) for lang in [args.source_lang, args.target_lang]}, + src=True, ) tgt_dict = src_dict else: if args.srcdict: src_dict = task.load_dictionary(args.srcdict) else: - assert args.trainpref, "--trainpref must be set if --srcdict is not specified" + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = task.load_dictionary(args.tgtdict) else: - assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" + assert ( + args.trainpref + ), "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None @@ -135,18 +145,20 @@ def merge_result(worker_result): prefix, lang, offsets[worker_id], - offsets[worker_id + 1] + offsets[worker_id + 1], ), - callback=merge_result + callback=merge_result, ) pool.close() - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), - impl=args.dataset_impl, vocab_size=len(vocab)) + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, + vocab_size=len(vocab), + ) merge_result( Binarizer.binarize( - input_file, vocab, lambda t: ds.add_item(t), - offset=0, end=offsets[1] + input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1] ) ) if num_workers > 1: @@ -175,7 +187,7 @@ def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): nseq = [0] def merge_result(worker_result): - nseq[0] += worker_result['nseq'] + nseq[0] += worker_result["nseq"] input_file = input_prefix offsets = Binarizer.find_offsets(input_file, num_workers) @@ -192,19 +204,23 @@ def merge_result(worker_result): utils.parse_alignment, prefix, offsets[worker_id], - offsets[worker_id + 1] + offsets[worker_id + 1], ), - callback=merge_result + callback=merge_result, ) pool.close() - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), - impl=args.dataset_impl) + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl + ) merge_result( Binarizer.binarize_alignments( - input_file, utils.parse_alignment, lambda t: ds.add_item(t), - offset=0, end=offsets[1] + input_file, + utils.parse_alignment, + lambda t: ds.add_item(t), + offset=0, + end=offsets[1], ) ) if num_workers > 1: @@ -218,12 +234,7 @@ def merge_result(worker_result): ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) - logger.info( - "[alignments] {}: parsed {} alignments".format( - input_file, - nseq[0] - ) - ) + logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0])) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.dataset_impl == "raw": @@ -242,7 +253,9 @@ def make_all(lang, vocab): if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" - make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) + make_dataset( + vocab, validpref, outprefix, lang, num_workers=args.workers + ) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" @@ -250,11 +263,23 @@ def make_all(lang, vocab): def make_all_alignments(): if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): - make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers) + make_binary_alignment_dataset( + args.trainpref + "." + args.align_suffix, + "train.align", + num_workers=args.workers, + ) if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): - make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers) + make_binary_alignment_dataset( + args.validpref + "." + args.align_suffix, + "valid.align", + num_workers=args.workers, + ) if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): - make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers) + make_binary_alignment_dataset( + args.testpref + "." + args.align_suffix, + "test.align", + num_workers=args.workers, + ) make_all(args.source_lang, src_dict) if target: @@ -269,9 +294,9 @@ def make_all_alignments(): src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} - with open(args.alignfile, "r", encoding='utf-8') as align_file: - with open(src_file_name, "r", encoding='utf-8') as src_file: - with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: + with open(args.alignfile, "r", encoding="utf-8") as align_file: + with open(src_file_name, "r", encoding="utf-8") as src_file: + with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) @@ -297,38 +322,47 @@ def make_all_alignments(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open( - os.path.join( - args.destdir, - "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), - ), - "w", encoding='utf-8' + os.path.join( + args.destdir, + "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), + ), + "w", + encoding="utf-8", ) as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), - impl=args.dataset_impl, vocab_size=len(vocab)) + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, + vocab_size=len(vocab), + ) def consumer(tensor): ds.add_item(tensor) - res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos, - offset=offset, end=end) + res = Binarizer.binarize( + filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end + ) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) return res def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end): - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), - impl=args.dataset_impl, vocab_size=None) + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl, + vocab_size=None, + ) def consumer(tensor): ds.add_item(tensor) - res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset, - end=end) + res = Binarizer.binarize_alignments( + filename, parse_alignment, consumer, offset=offset, end=end + ) ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) return res diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index 59631c2d65..b8354eb95a 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -11,12 +11,14 @@ import os import sys -from fairseq.scoring import bleu from fairseq.data import dictionary +from fairseq.scoring import bleu def get_parser(): - parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') + parser = argparse.ArgumentParser( + description="Command-line script for BLEU scoring." + ) # fmt: off parser.add_argument('-s', '--sys', default='-', help='system output') parser.add_argument('-r', '--ref', required=True, help='references') @@ -37,10 +39,10 @@ def cli_main(): args = parser.parse_args() print(args) - assert args.sys == '-' or os.path.exists(args.sys), \ - "System output file {} does not exist".format(args.sys) - assert os.path.exists(args.ref), \ - "Reference file {} does not exist".format(args.ref) + assert args.sys == "-" or os.path.exists( + args.sys + ), "System output file {} does not exist".format(args.sys) + assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref) dict = dictionary.Dictionary() @@ -57,17 +59,23 @@ def readlines(fd): def score(fdsys): with open(args.ref) as fdref: print(sacrebleu.corpus_bleu(fdsys, [fdref])) + elif args.sentence_bleu: + def score(fdsys): with open(args.ref) as fdref: scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) - for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))): + for i, (sys_tok, ref_tok) in enumerate( + zip(readlines(fdsys), readlines(fdref)) + ): scorer.reset(one_init=True) sys_tok = dict.encode_line(sys_tok) ref_tok = dict.encode_line(ref_tok) scorer.add(ref_tok, sys_tok) print(i, scorer.result_string(args.order)) + else: + def score(fdsys): with open(args.ref) as fdref: scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) @@ -77,12 +85,12 @@ def score(fdsys): scorer.add(ref_tok, sys_tok) print(scorer.result_string(args.order)) - if args.sys == '-': + if args.sys == "-": score(sys.stdin) else: - with open(args.sys, 'r') as f: + with open(args.sys, "r") as f: score(f) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 717a776c8f..df857550d1 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -5,31 +5,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from itertools import chain import logging import os import sys +from itertools import chain import torch - from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq.logging import metrics, progress_bar logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=os.environ.get('LOGLEVEL', 'INFO').upper(), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger('fairseq_cli.validate') +logger = logging.getLogger("fairseq_cli.validate") def main(args, override_args=None): utils.import_user_module(args) - assert args.max_tokens is not None or args.batch_size is not None, \ - 'Must specify batch size either with --max-tokens or --batch-size' + assert ( + args.max_tokens is not None or args.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu @@ -39,12 +39,12 @@ def main(args, override_args=None): if override_args is not None: overrides = vars(override_args) - overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None # Load ensemble - logger.info('loading model(s) from {}'.format(args.path)) + logger.info("loading model(s) from {}".format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, @@ -66,12 +66,12 @@ def main(args, override_args=None): criterion = task.build_criterion(model_args) criterion.eval() - for subset in args.valid_subset.split(','): + for subset in args.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: - raise Exception('Cannot find dataset: ' + subset) + raise Exception("Cannot find dataset: " + subset) # Initialize data iterator itr = task.get_batch_iterator( @@ -95,7 +95,7 @@ def main(args, override_args=None): log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", - default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) log_outputs = [] @@ -108,7 +108,7 @@ def main(args, override_args=None): if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, - max_size=getattr(args, 'all_gather_list_size', 16384), + max_size=getattr(args, "all_gather_list_size", 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) @@ -130,5 +130,5 @@ def cli_main(): distributed_utils.call_main(args, main, override_args=override_args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/hubconf.py b/hubconf.py index c63fa8ae89..ce7d76cfe1 100644 --- a/hubconf.py +++ b/hubconf.py @@ -6,14 +6,20 @@ import functools import importlib +from fairseq.hub_utils import ( # noqa; noqa + BPEHubInterface as bpe, + TokenizerHubInterface as tokenizer, +) +from fairseq.models import MODEL_REGISTRY # noqa + dependencies = [ - 'dataclasses', - 'hydra', - 'numpy', - 'regex', - 'requests', - 'torch', + "dataclasses", + "hydra", + "numpy", + "regex", + "requests", + "torch", ] @@ -26,11 +32,11 @@ # Hack: the hydra package is provided under the "hydra-core" name in # pypi. We don't want the user mistakenly calling `pip install hydra` # since that will install an unrelated package. - if dep == 'hydra': - dep = 'hydra-core' + if dep == "hydra": + dep = "hydra-core" missing_deps.append(dep) if len(missing_deps) > 0: - raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) + raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) # torch.hub doesn't build Cython components, so if they are not found then try @@ -42,22 +48,18 @@ import cython # noqa import os from setuptools import sandbox + sandbox.run_setup( - os.path.join(os.path.dirname(__file__), 'setup.py'), - ['build_ext', '--inplace'], + os.path.join(os.path.dirname(__file__), "setup.py"), + ["build_ext", "--inplace"], ) except ImportError: print( - 'Unable to build Cython components. Please make sure Cython is ' - 'installed if the torch.hub model you are loading depends on it.' + "Unable to build Cython components. Please make sure Cython is " + "installed if the torch.hub model you are loading depends on it." ) -from fairseq.hub_utils import BPEHubInterface as bpe # noqa -from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa -from fairseq.models import MODEL_REGISTRY # noqa - - # automatically expose models defined in FairseqModel::hub_models for _model_type, _cls in MODEL_REGISTRY.items(): for model_name in _cls.hub_models().keys(): diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index 9d69671e7e..c512f802bc 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -6,10 +6,10 @@ import argparse import collections -import torch import os import re +import torch from fairseq.file_io import PathManager @@ -30,26 +30,26 @@ def average_checkpoints(inputs): num_models = len(inputs) for fpath in inputs: - with PathManager.open(fpath, 'rb') as f: + with PathManager.open(fpath, "rb") as f: state = torch.load( f, map_location=( - lambda s, _: torch.serialization.default_restore_location(s, 'cpu') + lambda s, _: torch.serialization.default_restore_location(s, "cpu") ), ) # Copies over the settings from the first checkpoint if new_state is None: new_state = state - model_params = state['model'] + model_params = state["model"] model_params_keys = list(model_params.keys()) if params_keys is None: params_keys = model_params_keys elif params_keys != model_params_keys: raise KeyError( - 'For checkpoint {}, expected list of params: {}, ' - 'but found: {}'.format(f, params_keys, model_params_keys) + "For checkpoint {}, expected list of params: {}, " + "but found: {}".format(f, params_keys, model_params_keys) ) for k in params_keys: @@ -69,7 +69,7 @@ def average_checkpoints(inputs): averaged_params[k].div_(num_models) else: averaged_params[k] //= num_models - new_state['model'] = averaged_params + new_state["model"] = averaged_params return new_state @@ -77,9 +77,9 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None): assert len(paths) == 1 path = paths[0] if update_based: - pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') + pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") else: - pt_regexp = re.compile(r'checkpoint(\d+)\.pt') + pt_regexp = re.compile(r"checkpoint(\d+)\.pt") files = PathManager.ls(path) entries = [] @@ -90,14 +90,16 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None): if upper_bound is None or sort_key <= upper_bound: entries.append((sort_key, m.group(0))) if len(entries) < n: - raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) + raise Exception( + "Found {} checkpoint files but need at least {}", len(entries), n + ) return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] def main(): parser = argparse.ArgumentParser( - description='Tool to average the params of input checkpoints to ' - 'produce a new checkpoint', + description="Tool to average the params of input checkpoints to " + "produce a new checkpoint", ) # fmt: off parser.add_argument('--inputs', required=True, nargs='+', @@ -129,22 +131,28 @@ def main(): elif args.num_epoch_checkpoints is not None: num = args.num_epoch_checkpoints - assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \ - '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints' - assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ - 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' + assert args.checkpoint_upper_bound is None or ( + args.num_epoch_checkpoints is not None + or args.num_update_checkpoints is not None + ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" + assert ( + args.num_epoch_checkpoints is None or args.num_update_checkpoints is None + ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" if num is not None: args.inputs = last_n_checkpoints( - args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, + args.inputs, + num, + is_update_based, + upper_bound=args.checkpoint_upper_bound, ) - print('averaging checkpoints: ', args.inputs) + print("averaging checkpoints: ", args.inputs) new_state = average_checkpoints(args.inputs) - with PathManager.open(args.output, 'wb') as f: + with PathManager.open(args.output, "wb") as f: torch.save(new_state, f) - print('Finished writing averaged checkpoint to {}'.format(args.output)) + print("Finished writing averaged checkpoint to {}".format(args.output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/build_sym_alignment.py b/scripts/build_sym_alignment.py index bb0cac09dd..0ca5c18f7b 100644 --- a/scripts/build_sym_alignment.py +++ b/scripts/build_sym_alignment.py @@ -27,7 +27,7 @@ def main(): - parser = argparse.ArgumentParser(description='symmetric alignment builer') + parser = argparse.ArgumentParser(description="symmetric alignment builer") # fmt: off parser.add_argument('--fast_align_dir', help='path to fast_align build directory') @@ -47,40 +47,40 @@ def main(): # fmt: on args = parser.parse_args() - fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') - symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') + fast_align_bin = os.path.join(args.fast_align_dir, "fast_align") + symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal") sym_fast_align_bin = os.path.join( - args.mosesdecoder_dir, 'scripts', 'ems', - 'support', 'symmetrize-fast-align.perl') + args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl" + ) # create joined file - joined_file = os.path.join(args.output_dir, 'text.joined') - with open(args.source_file, 'r', encoding='utf-8') as src, open(args.target_file, 'r', encoding='utf-8') as tgt: - with open(joined_file, 'w', encoding='utf-8') as joined: + joined_file = os.path.join(args.output_dir, "text.joined") + with open(args.source_file, "r", encoding="utf-8") as src, open( + args.target_file, "r", encoding="utf-8" + ) as tgt: + with open(joined_file, "w", encoding="utf-8") as joined: for s, t in zip_longest(src, tgt): - print('{} ||| {}'.format(s.strip(), t.strip()), file=joined) + print("{} ||| {}".format(s.strip(), t.strip()), file=joined) - bwd_align_file = os.path.join(args.output_dir, 'align.backward') + bwd_align_file = os.path.join(args.output_dir, "align.backward") # run forward alignment - fwd_align_file = os.path.join(args.output_dir, 'align.forward') - fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( - FASTALIGN=fast_align_bin, - JOINED=joined_file, - FWD=fwd_align_file) + fwd_align_file = os.path.join(args.output_dir, "align.forward") + fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file + ) assert os.system(fwd_fast_align_cmd) == 0 # run backward alignment - bwd_align_file = os.path.join(args.output_dir, 'align.backward') - bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( - FASTALIGN=fast_align_bin, - JOINED=joined_file, - BWD=bwd_align_file) + bwd_align_file = os.path.join(args.output_dir, "align.backward") + bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file + ) assert os.system(bwd_fast_align_cmd) == 0 # run symmetrization - sym_out_file = os.path.join(args.output_dir, 'aligned') - sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format( + sym_out_file = os.path.join(args.output_dir, "aligned") + sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format( SYMFASTALIGN=sym_fast_align_bin, FWD=fwd_align_file, BWD=bwd_align_file, @@ -88,10 +88,10 @@ def main(): TGT=args.target_file, OUT=sym_out_file, HEURISTIC=args.sym_heuristic, - SYMAL=symal_bin + SYMAL=symal_bin, ) assert os.system(sym_cmd) == 0 -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/compare_namespaces.py b/scripts/compare_namespaces.py index db5121189a..bc24db624f 100644 --- a/scripts/compare_namespaces.py +++ b/scripts/compare_namespaces.py @@ -6,13 +6,13 @@ def main(): - ns1 = eval(input('Namespace 1: ')) - ns2 = eval(input('Namespace 2: ')) + ns1 = eval(input("Namespace 1: ")) + ns2 = eval(input("Namespace 2: ")) def keys(ns): ks = set() for k in dir(ns): - if not k.startswith('_'): + if not k.startswith("_"): ks.add(k) return ks @@ -22,23 +22,25 @@ def keys(ns): def print_keys(ks, ns1, ns2=None): for k in ks: if ns2 is None: - print('{}\t{}'.format(k, getattr(ns1, k, None))) + print("{}\t{}".format(k, getattr(ns1, k, None))) else: - print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) + print( + "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None)) + ) - print('Keys unique to namespace 1:') + print("Keys unique to namespace 1:") print_keys(k1 - k2, ns1) print() - print('Keys unique to namespace 2:') + print("Keys unique to namespace 2:") print_keys(k2 - k1, ns2) print() - print('Overlapping keys with different values:') - ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] + print("Overlapping keys with different values:") + ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")] print_keys(ks, ns1, ns2) print() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/constraints/extract.py b/scripts/constraints/extract.py index 8f9bc4ad14..f6155d0a05 100755 --- a/scripts/constraints/extract.py +++ b/scripts/constraints/extract.py @@ -10,12 +10,13 @@ import argparse import random import sys + from sacrebleu import extract_ngrams def get_phrase(words, index, length): - assert(index < len(words) - length + 1) - phr = ' '.join(words[index:index+length]) + assert index < len(words) - length + 1 + phr = " ".join(words[index : index + length]) for i in range(index, index + length): words.pop(index) return phr @@ -33,8 +34,8 @@ def add_constraint(constraint): constraints.append(constraint) source = line.rstrip() - if '\t' in line: - source, target = line.split('\t') + if "\t" in line: + source, target = line.split("\t") if args.add_sos: target = f" {target}" if args.add_eos: @@ -53,8 +54,12 @@ def add_constraint(constraint): segment = words.pop(segmentno) tokens = segment.split() phrase_index = random.choice(range(len(tokens))) - choice = " ".join(tokens[phrase_index:min(len(tokens), phrase_index + args.len)]) - for j in range(phrase_index, min(len(tokens), phrase_index + args.len)): + choice = " ".join( + tokens[phrase_index : min(len(tokens), phrase_index + args.len)] + ) + for j in range( + phrase_index, min(len(tokens), phrase_index + args.len) + ): tokens.pop(phrase_index) if phrase_index > 0: words.append(" ".join(tokens[0:phrase_index])) @@ -73,11 +78,15 @@ def add_constraint(constraint): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--number', '-n', type=int, default=1, help="number of phrases") - parser.add_argument('--len', '-l', type=int, default=1, help="phrase length") - parser.add_argument('--add-sos', default=False, action='store_true', help='add token') - parser.add_argument('--add-eos', default=False, action='store_true', help='add token') - parser.add_argument('--seed', "-s", default=0, type=int) + parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases") + parser.add_argument("--len", "-l", type=int, default=1, help="phrase length") + parser.add_argument( + "--add-sos", default=False, action="store_true", help="add token" + ) + parser.add_argument( + "--add-eos", default=False, action="store_true", help="add token" + ) + parser.add_argument("--seed", "-s", default=0, type=int) args = parser.parse_args() main(args) diff --git a/scripts/constraints/validate.py b/scripts/constraints/validate.py index 6d1a4a0885..d531ad9f39 100755 --- a/scripts/constraints/validate.py +++ b/scripts/constraints/validate.py @@ -7,6 +7,7 @@ import sys + """Reads in a fairseq output file, and verifies that the constraints (C- lines) are present in the output (the first H- line). Assumes that constraints are listed prior to the first hypothesis. diff --git a/scripts/count_docs.py b/scripts/count_docs.py index 8d185398a7..58d85af85e 100644 --- a/scripts/count_docs.py +++ b/scripts/count_docs.py @@ -17,15 +17,15 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('input') - parser.add_argument('--gzip', action='store_true') + parser.add_argument("input") + parser.add_argument("--gzip", action="store_true") args = parser.parse_args() def gopen(): if args.gzip: - return gzip.open(args.input, 'r') + return gzip.open(args.input, "r") else: - return open(args.input, 'r', encoding='utf-8') + return open(args.input, "r", encoding="utf-8") num_lines = [] num_toks = [] @@ -54,5 +54,5 @@ def gopen(): print("average num toks per doc: {}".format(np.mean(num_toks))) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/read_binarized.py b/scripts/read_binarized.py index f48409beb4..a414095d03 100644 --- a/scripts/read_binarized.py +++ b/scripts/read_binarized.py @@ -6,12 +6,13 @@ import argparse -from fairseq.data import data_utils, Dictionary, indexed_dataset +from fairseq.data import Dictionary, data_utils, indexed_dataset def get_parser(): parser = argparse.ArgumentParser( - description='writes text from binarized file to stdout') + description="writes text from binarized file to stdout" + ) # fmt: off parser.add_argument('--dataset-impl', help='dataset implementation', choices=indexed_dataset.get_available_dataset_impl()) @@ -31,17 +32,17 @@ def main(): args.input, dictionary, dataset_impl=args.dataset_impl, - default='lazy', + default="lazy", ) for tensor_line in dataset: if dictionary is None: - line = ' '.join([str(int(x)) for x in tensor_line]) + line = " ".join([str(int(x)) for x in tensor_line]) else: line = dictionary.string(tensor_line) print(line) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/rm_pt.py b/scripts/rm_pt.py index 21976cee4f..6cd063d21f 100644 --- a/scripts/rm_pt.py +++ b/scripts/rm_pt.py @@ -11,9 +11,9 @@ import sys -pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt') -pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt') -pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt') +pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt") +pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt") +pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt") def parse_checkpoints(files): @@ -42,18 +42,31 @@ def every_n_checkpoints(files, n): def main(): parser = argparse.ArgumentParser( description=( - 'Recursively delete checkpoint files from `root_dir`, ' - 'but preserve checkpoint_best.pt and checkpoint_last.pt' + "Recursively delete checkpoint files from `root_dir`, " + "but preserve checkpoint_best.pt and checkpoint_last.pt" ) ) - parser.add_argument('root_dirs', nargs='*') - parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save') - parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save') - parser.add_argument('--preserve-test', action='store_true', - help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)') - parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt') - parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt') - parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks') + parser.add_argument("root_dirs", nargs="*") + parser.add_argument( + "--save-last", type=int, default=0, help="number of last checkpoints to save" + ) + parser.add_argument( + "--save-every", type=int, default=0, help="interval of checkpoints to save" + ) + parser.add_argument( + "--preserve-test", + action="store_true", + help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)", + ) + parser.add_argument( + "--delete-best", action="store_true", help="delete checkpoint_best.pt" + ) + parser.add_argument( + "--delete-last", action="store_true", help="delete checkpoint_last.pt" + ) + parser.add_argument( + "--no-dereference", action="store_true", help="don't dereference symlinks" + ) args = parser.parse_args() files_to_desymlink = [] @@ -72,15 +85,11 @@ def main(): continue full_path = os.path.join(root, file) if ( - ( - not os.path.basename(root).startswith('test_') - or args.preserve_test - ) - and ( - (file == 'checkpoint_last.pt' and not args.delete_last) - or (file == 'checkpoint_best.pt' and not args.delete_best) - or file in to_save - ) + not os.path.basename(root).startswith("test_") or args.preserve_test + ) and ( + (file == "checkpoint_last.pt" and not args.delete_last) + or (file == "checkpoint_best.pt" and not args.delete_best) + or file in to_save ): if os.path.islink(full_path) and not args.no_dereference: files_to_desymlink.append(full_path) @@ -90,43 +99,43 @@ def main(): files_to_delete.append(full_path) if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: - print('Nothing to do.') + print("Nothing to do.") sys.exit(0) files_to_desymlink = sorted(files_to_desymlink) files_to_preserve = sorted(files_to_preserve) files_to_delete = sorted(files_to_delete) - print('Operations to perform (in order):') + print("Operations to perform (in order):") if len(files_to_desymlink) > 0: for file in files_to_desymlink: - print(' - preserve (and dereference symlink): ' + file) + print(" - preserve (and dereference symlink): " + file) if len(files_to_preserve) > 0: for file in files_to_preserve: - print(' - preserve: ' + file) + print(" - preserve: " + file) if len(files_to_delete) > 0: for file in files_to_delete: - print(' - delete: ' + file) + print(" - delete: " + file) while True: - resp = input('Continue? (Y/N): ') - if resp.strip().lower() == 'y': + resp = input("Continue? (Y/N): ") + if resp.strip().lower() == "y": break - elif resp.strip().lower() == 'n': + elif resp.strip().lower() == "n": sys.exit(0) - print('Executing...') + print("Executing...") if len(files_to_desymlink) > 0: for file in files_to_desymlink: realpath = os.path.realpath(file) - print('rm ' + file) + print("rm " + file) os.remove(file) - print('cp {} {}'.format(realpath, file)) + print("cp {} {}".format(realpath, file)) shutil.copyfile(realpath, file) if len(files_to_delete) > 0: for file in files_to_delete: - print('rm ' + file) + print("rm " + file) os.remove(file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/shard_docs.py b/scripts/shard_docs.py index 87d7c22d4f..97232c3c84 100644 --- a/scripts/shard_docs.py +++ b/scripts/shard_docs.py @@ -14,21 +14,23 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('input') - parser.add_argument('--num-shards', type=int) + parser.add_argument("input") + parser.add_argument("--num-shards", type=int) args = parser.parse_args() assert args.num_shards is not None and args.num_shards > 1 - with open(args.input, 'r', encoding='utf-8') as h: + with open(args.input, "r", encoding="utf-8") as h: with contextlib.ExitStack() as stack: outputs = [ - stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8")) + stack.enter_context( + open(args.input + ".shard" + str(i), "w", encoding="utf-8") + ) for i in range(args.num_shards) ] doc = [] - first_doc = [True]*args.num_shards + first_doc = [True] * args.num_shards def output_doc(i): if not first_doc[i]: @@ -48,5 +50,5 @@ def output_doc(i): output_doc(num_docs % args.num_shards) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/split_train_valid_docs.py b/scripts/split_train_valid_docs.py index 9adf99634c..ff15978528 100644 --- a/scripts/split_train_valid_docs.py +++ b/scripts/split_train_valid_docs.py @@ -15,12 +15,13 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('input') - parser.add_argument('sample_output', help='train output file') - parser.add_argument('remainder_output', help='valid output file') - parser.add_argument('-k', type=int, help="remainder size") - parser.add_argument('--lines', action='store_true', - help='split lines instead of docs') + parser.add_argument("input") + parser.add_argument("sample_output", help="train output file") + parser.add_argument("remainder_output", help="valid output file") + parser.add_argument("-k", type=int, help="remainder size") + parser.add_argument( + "--lines", action="store_true", help="split lines instead of docs" + ) args = parser.parse_args() assert args.k is not None @@ -43,7 +44,7 @@ def update_sample(doc): num_docs[0] += 1 doc.clear() - with open(args.input, 'r', encoding='utf-8') as h: + with open(args.input, "r", encoding="utf-8") as h: doc = [] for i, line in enumerate(h): if line.strip() == "": # empty line indicates new document @@ -62,7 +63,7 @@ def update_sample(doc): assert len(sample) == args.k - with open(args.sample_output, 'w', encoding='utf-8') as out: + with open(args.sample_output, "w", encoding="utf-8") as out: first = True for doc in sample: if not first and not args.lines: @@ -71,7 +72,7 @@ def update_sample(doc): for line in doc: out.write(line) - with open(args.remainder_output, 'w', encoding='utf-8') as out: + with open(args.remainder_output, "w", encoding="utf-8") as out: first = True for doc in remainder: if not first and not args.lines: @@ -81,5 +82,5 @@ def update_sample(doc): out.write(line) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/spm_decode.py b/scripts/spm_decode.py index bd3961ab97..1c18b1d2a7 100644 --- a/scripts/spm_decode.py +++ b/scripts/spm_decode.py @@ -14,8 +14,9 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True, - help="sentencepiece model to use for decoding") + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for decoding" + ) parser.add_argument("--input", required=True, help="input file to decode") parser.add_argument("--input_format", choices=["piece", "id"], default="piece") args = parser.parse_args() @@ -24,11 +25,15 @@ def main(): sp.Load(args.model) if args.input_format == "piece": + def decode(l): return "".join(sp.DecodePieces(l)) + elif args.input_format == "id": + def decode(l): return "".join(sp.DecodeIds(l)) + else: raise NotImplementedError @@ -43,5 +48,6 @@ def tok2int(tok): elif args.input_format == "piece": print(decode(line.rstrip().split())) + if __name__ == "__main__": main() diff --git a/scripts/spm_encode.py b/scripts/spm_encode.py index e1cb54192a..83facfb3b1 100644 --- a/scripts/spm_encode.py +++ b/scripts/spm_encode.py @@ -16,53 +16,73 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True, - help="sentencepiece model to use for encoding") - parser.add_argument("--inputs", nargs="+", default=['-'], - help="input files to filter/encode") - parser.add_argument("--outputs", nargs="+", default=['-'], - help="path to save encoded outputs") + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for encoding" + ) + parser.add_argument( + "--inputs", nargs="+", default=["-"], help="input files to filter/encode" + ) + parser.add_argument( + "--outputs", nargs="+", default=["-"], help="path to save encoded outputs" + ) parser.add_argument("--output_format", choices=["piece", "id"], default="piece") - parser.add_argument("--min-len", type=int, metavar="N", - help="filter sentence pairs with fewer than N tokens") - parser.add_argument("--max-len", type=int, metavar="N", - help="filter sentence pairs with more than N tokens") + parser.add_argument( + "--min-len", + type=int, + metavar="N", + help="filter sentence pairs with fewer than N tokens", + ) + parser.add_argument( + "--max-len", + type=int, + metavar="N", + help="filter sentence pairs with more than N tokens", + ) args = parser.parse_args() - assert len(args.inputs) == len(args.outputs), \ - "number of input and output paths should match" + assert len(args.inputs) == len( + args.outputs + ), "number of input and output paths should match" sp = spm.SentencePieceProcessor() sp.Load(args.model) if args.output_format == "piece": + def encode(l): return sp.EncodeAsPieces(l) + elif args.output_format == "id": + def encode(l): return list(map(str, sp.EncodeAsIds(l))) + else: raise NotImplementedError if args.min_len is not None or args.max_len is not None: + def valid(line): - return ( - (args.min_len is None or len(line) >= args.min_len) - and (args.max_len is None or len(line) <= args.max_len) + return (args.min_len is None or len(line) >= args.min_len) and ( + args.max_len is None or len(line) <= args.max_len ) + else: + def valid(lines): return True with contextlib.ExitStack() as stack: inputs = [ - stack.enter_context(open(input, "r", encoding="utf-8")) \ - if input != "-" else sys.stdin + stack.enter_context(open(input, "r", encoding="utf-8")) + if input != "-" + else sys.stdin for input in args.inputs ] outputs = [ - stack.enter_context(open(output, "w", encoding="utf-8")) \ - if output != "-" else sys.stdout + stack.enter_context(open(output, "w", encoding="utf-8")) + if output != "-" + else sys.stdout for output in args.outputs ] diff --git a/setup.py b/setup.py index 21e05d8da6..ad2ea2088b 100644 --- a/setup.py +++ b/setup.py @@ -5,22 +5,23 @@ # LICENSE file in the root directory of this source tree. import os -from setuptools import setup, find_packages, Extension import sys +from setuptools import Extension, find_packages, setup + if sys.version_info < (3, 6): - sys.exit('Sorry, Python >= 3.6 is required for fairseq.') + sys.exit("Sorry, Python >= 3.6 is required for fairseq.") -with open('README.md') as f: +with open("README.md") as f: readme = f.read() -if sys.platform == 'darwin': - extra_compile_args = ['-stdlib=libc++', '-O3'] +if sys.platform == "darwin": + extra_compile_args = ["-stdlib=libc++", "-O3"] else: - extra_compile_args = ['-std=c++11', '-O3'] + extra_compile_args = ["-std=c++11", "-O3"] class NumpyExtension(Extension): @@ -33,6 +34,7 @@ def __init__(self, *args, **kwargs): @property def include_dirs(self): import numpy + return self.__include_dirs + [numpy.get_include()] @include_dirs.setter @@ -42,23 +44,23 @@ def include_dirs(self, dirs): extensions = [ Extension( - 'fairseq.libbleu', + "fairseq.libbleu", sources=[ - 'fairseq/clib/libbleu/libbleu.cpp', - 'fairseq/clib/libbleu/module.cpp', + "fairseq/clib/libbleu/libbleu.cpp", + "fairseq/clib/libbleu/module.cpp", ], extra_compile_args=extra_compile_args, ), NumpyExtension( - 'fairseq.data.data_utils_fast', - sources=['fairseq/data/data_utils_fast.pyx'], - language='c++', + "fairseq.data.data_utils_fast", + sources=["fairseq/data/data_utils_fast.pyx"], + language="c++", extra_compile_args=extra_compile_args, ), NumpyExtension( - 'fairseq.data.token_block_utils_fast', - sources=['fairseq/data/token_block_utils_fast.pyx'], - language='c++', + "fairseq.data.token_block_utils_fast", + sources=["fairseq/data/token_block_utils_fast.pyx"], + language="c++", extra_compile_args=extra_compile_args, ), ] @@ -70,94 +72,104 @@ def include_dirs(self, dirs): try: # torch is not available when generating docs from torch.utils import cpp_extension - extensions.extend([ - cpp_extension.CppExtension( - 'fairseq.libnat', - sources=[ - 'fairseq/clib/libnat/edit_dist.cpp', - ], - ) - ]) - if 'CUDA_HOME' in os.environ: - extensions.extend([ + extensions.extend( + [ cpp_extension.CppExtension( - 'fairseq.libnat_cuda', + "fairseq.libnat", sources=[ - 'fairseq/clib/libnat_cuda/edit_dist.cu', - 'fairseq/clib/libnat_cuda/binding.cpp' + "fairseq/clib/libnat/edit_dist.cpp", ], - )]) - cmdclass['build_ext'] = cpp_extension.BuildExtension + ) + ] + ) + + if "CUDA_HOME" in os.environ: + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libnat_cuda", + sources=[ + "fairseq/clib/libnat_cuda/edit_dist.cu", + "fairseq/clib/libnat_cuda/binding.cpp", + ], + ) + ] + ) + cmdclass["build_ext"] = cpp_extension.BuildExtension except ImportError: pass -if 'READTHEDOCS' in os.environ: +if "READTHEDOCS" in os.environ: # don't build extensions when generating docs extensions = [] - if 'build_ext' in cmdclass: - del cmdclass['build_ext'] + if "build_ext" in cmdclass: + del cmdclass["build_ext"] # use CPU build of PyTorch dependency_links = [ - 'https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl' + "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" ] else: dependency_links = [] -if 'clean' in sys.argv[1:]: +if "clean" in sys.argv[1:]: # Source: https://bit.ly/2NLVsgE print("deleting Cython files...") import subprocess - subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd'], shell=True) + + subprocess.run( + ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], + shell=True, + ) setup( - name='fairseq', - version='0.9.0', - description='Facebook AI Research Sequence-to-Sequence Toolkit', - url='https://github.com/pytorch/fairseq', + name="fairseq", + version="0.9.0", + description="Facebook AI Research Sequence-to-Sequence Toolkit", + url="https://github.com/pytorch/fairseq", classifiers=[ - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], long_description=readme, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", setup_requires=[ - 'cython', - 'numpy', - 'setuptools>=18.0', + "cython", + "numpy", + "setuptools>=18.0", ], install_requires=[ - 'cffi', - 'cython', - 'dataclasses', - 'editdistance', - 'hydra-core', - 'numpy', - 'regex', - 'sacrebleu>=1.4.12', - 'torch', - 'tqdm', + "cffi", + "cython", + "dataclasses", + "editdistance", + "hydra-core", + "numpy", + "regex", + "sacrebleu>=1.4.12", + "torch", + "tqdm", ], dependency_links=dependency_links, - packages=find_packages(exclude=['scripts', 'tests']), + packages=find_packages(exclude=["scripts", "tests"]), ext_modules=extensions, - test_suite='tests', + test_suite="tests", entry_points={ - 'console_scripts': [ - 'fairseq-eval-lm = fairseq_cli.eval_lm:cli_main', - 'fairseq-generate = fairseq_cli.generate:cli_main', - 'fairseq-interactive = fairseq_cli.interactive:cli_main', - 'fairseq-preprocess = fairseq_cli.preprocess:cli_main', - 'fairseq-score = fairseq_cli.score:cli_main', - 'fairseq-train = fairseq_cli.train:cli_main', - 'fairseq-validate = fairseq_cli.validate:cli_main', + "console_scripts": [ + "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", + "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-interactive = fairseq_cli.interactive:cli_main", + "fairseq-preprocess = fairseq_cli.preprocess:cli_main", + "fairseq-score = fairseq_cli.score:cli_main", + "fairseq-train = fairseq_cli.train:cli_main", + "fairseq-validate = fairseq_cli.validate:cli_main", ], }, cmdclass=cmdclass, diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 4f3d3fceb7..0341031394 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -7,6 +7,7 @@ import numpy as np import torch +from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.dictionary import Dictionary from fairseq.models import ( @@ -18,7 +19,6 @@ FairseqModel, ) from fairseq.tasks.fairseq_task import LegacyFairseqTask -from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask DEFAULT_TEST_VOCAB_SIZE = 100 @@ -172,9 +172,8 @@ def check_encoder_output(encoder_output, batch_size=None): "encoder_padding_mask must be a torch.Tensor" + _current_postion_info() ) return False, msg - if ( - mask.dtype != torch.uint8 - and (not hasattr(torch, 'bool') or mask.dtype != torch.bool) + if mask.dtype != torch.uint8 and ( + not hasattr(torch, "bool") or mask.dtype != torch.bool ): msg = ( "encoder_padding_mask must have dtype of uint8" @@ -516,14 +515,16 @@ def setUpArgs(self): def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) - self.criterion = self.criterion_cls.build_criterion(args=args, task=DummyTask(args)) + self.criterion = self.criterion_cls.build_criterion( + args=args, task=DummyTask(args) + ) def get_src_tokens(self, correct_prediction, aggregate): """ - correct_prediction: True if the net_output (src_tokens) should - predict the correct target - aggregate: True if the criterion expects net_output (src_tokens) - aggregated across time axis + correct_prediction: True if the net_output (src_tokens) should + predict the correct target + aggregate: True if the criterion expects net_output (src_tokens) + aggregated across time axis """ predicted_idx = 0 if correct_prediction else 1 if aggregate: diff --git a/tests/speech_recognition/test_cross_entropy.py b/tests/speech_recognition/test_cross_entropy.py index 508d490e01..b05400ed95 100644 --- a/tests/speech_recognition/test_cross_entropy.py +++ b/tests/speech_recognition/test_cross_entropy.py @@ -4,7 +4,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion +from examples.speech_recognition.criterions.cross_entropy_acc import ( + CrossEntropyWithAccCriterion, +) + from .asr_test_base import CrossEntropyCriterionTestBase diff --git a/tests/speech_recognition/test_data_utils.py b/tests/speech_recognition/test_data_utils.py index 5ca7c5c2a1..a72e0b6694 100644 --- a/tests/speech_recognition/test_data_utils.py +++ b/tests/speech_recognition/test_data_utils.py @@ -6,18 +6,57 @@ import unittest import torch - from examples.speech_recognition.data import data_utils class DataUtilsTest(unittest.TestCase): - def test_normalization(self): - sample_len1 = torch.tensor([[-0.7661, -1.3889, -2.0972, -0.9134, -0.7071, -0.9765, -0.8700, -0.8283, - 0.7512, 1.3211, 2.1532, 2.1174, 1.2800, 1.2633, 1.6147, 1.6322, - 2.0723, 3.1522, 3.2852, 2.2309, 2.5569, 2.2183, 2.2862, 1.5886, - 0.8773, 0.8725, 1.2662, 0.9899, 1.1069, 1.3926, 1.2795, 1.1199, - 1.1477, 1.2687, 1.3843, 1.1903, 0.8355, 1.1367, 1.2639, 1.4707]]) + sample_len1 = torch.tensor( + [ + [ + -0.7661, + -1.3889, + -2.0972, + -0.9134, + -0.7071, + -0.9765, + -0.8700, + -0.8283, + 0.7512, + 1.3211, + 2.1532, + 2.1174, + 1.2800, + 1.2633, + 1.6147, + 1.6322, + 2.0723, + 3.1522, + 3.2852, + 2.2309, + 2.5569, + 2.2183, + 2.2862, + 1.5886, + 0.8773, + 0.8725, + 1.2662, + 0.9899, + 1.1069, + 1.3926, + 1.2795, + 1.1199, + 1.1477, + 1.2687, + 1.3843, + 1.1903, + 0.8355, + 1.1367, + 1.2639, + 1.4707, + ] + ] + ) out = data_utils.apply_mv_norm(sample_len1) assert not torch.isnan(out).any() assert (out == sample_len1).all() diff --git a/tests/test_average_checkpoints.py b/tests/test_average_checkpoints.py index 8ed298c3c9..f348b56b86 100644 --- a/tests/test_average_checkpoints.py +++ b/tests/test_average_checkpoints.py @@ -5,16 +5,14 @@ import collections import os +import shutil import tempfile import unittest -import shutil import numpy as np import torch -from torch import nn - - from scripts.average_checkpoints import average_checkpoints +from torch import nn class ModelWithSharedParameter(nn.Module): @@ -37,33 +35,33 @@ class TestAverageCheckpoints(unittest.TestCase): def test_average_checkpoints(self): params_0 = collections.OrderedDict( [ - ('a', torch.DoubleTensor([100.0])), - ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), - ('c', torch.IntTensor([7, 8, 9])), + ("a", torch.DoubleTensor([100.0])), + ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), + ("c", torch.IntTensor([7, 8, 9])), ] ) params_1 = collections.OrderedDict( [ - ('a', torch.DoubleTensor([1.0])), - ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), - ('c', torch.IntTensor([2, 2, 2])), + ("a", torch.DoubleTensor([1.0])), + ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), + ("c", torch.IntTensor([2, 2, 2])), ] ) params_avg = collections.OrderedDict( [ - ('a', torch.DoubleTensor([50.5])), - ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), + ("a", torch.DoubleTensor([50.5])), + ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), # We expect truncation for integer division - ('c', torch.IntTensor([4, 5, 5])), + ("c", torch.IntTensor([4, 5, 5])), ] ) fd_0, path_0 = tempfile.mkstemp() fd_1, path_1 = tempfile.mkstemp() - torch.save(collections.OrderedDict([('model', params_0)]), path_0) - torch.save(collections.OrderedDict([('model', params_1)]), path_1) + torch.save(collections.OrderedDict([("model", params_0)]), path_0) + torch.save(collections.OrderedDict([("model", params_1)]), path_1) - output = average_checkpoints([path_0, path_1])['model'] + output = average_checkpoints([path_0, path_1])["model"] os.close(fd_0) os.remove(path_0) @@ -71,28 +69,27 @@ def test_average_checkpoints(self): os.remove(path_1) for (k_expected, v_expected), (k_out, v_out) in zip( - params_avg.items(), output.items()): + params_avg.items(), output.items() + ): self.assertEqual( - k_expected, k_out, 'Key mismatch - expected {} but found {}. ' - '(Expected list of keys: {} vs actual list of keys: {})'.format( + k_expected, + k_out, + "Key mismatch - expected {} but found {}. " + "(Expected list of keys: {} vs actual list of keys: {})".format( k_expected, k_out, params_avg.keys(), output.keys() - ) + ), ) np.testing.assert_allclose( v_expected.numpy(), v_out.numpy(), - err_msg='Tensor value mismatch for key {}'.format(k_expected) + err_msg="Tensor value mismatch for key {}".format(k_expected), ) def test_average_checkpoints_with_shared_parameters(self): - def _construct_model_with_shared_parameters(path, value): m = ModelWithSharedParameter() nn.init.constant_(m.FC1.weight, value) - torch.save( - {'model': m.state_dict()}, - path - ) + torch.save({"model": m.state_dict()}, path) return m tmpdir = tempfile.mkdtemp() @@ -112,32 +109,26 @@ def _construct_model_with_shared_parameters(path, value): new_model = average_checkpoints(paths) self.assertTrue( torch.equal( - new_model['model']['embedding.weight'], - (m1.embedding.weight + - m2.embedding.weight + - m3.embedding.weight) / 3.0 + new_model["model"]["embedding.weight"], + (m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0, ) ) self.assertTrue( torch.equal( - new_model['model']['FC1.weight'], - (m1.FC1.weight + - m2.FC1.weight + - m3.FC1.weight) / 3.0 + new_model["model"]["FC1.weight"], + (m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0, ) ) self.assertTrue( torch.equal( - new_model['model']['FC2.weight'], - (m1.FC2.weight + - m2.FC2.weight + - m3.FC2.weight) / 3.0 + new_model["model"]["FC2.weight"], + (m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0, ) ) shutil.rmtree(tmpdir) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_backtranslation_dataset.py b/tests/test_backtranslation_dataset.py index 23ae333761..dffc3b4938 100644 --- a/tests/test_backtranslation_dataset.py +++ b/tests/test_backtranslation_dataset.py @@ -5,8 +5,8 @@ import unittest +import tests.utils as test_utils import torch - from fairseq.data import ( BacktranslationDataset, LanguagePairDataset, @@ -14,15 +14,17 @@ ) from fairseq.sequence_generator import SequenceGenerator -import tests.utils as test_utils - class TestBacktranslationDataset(unittest.TestCase): - def setUp(self): - self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = ( - test_utils.sequence_generator_setup() - ) + ( + self.tgt_dict, + self.w1, + self.w2, + self.src_tokens, + self.src_lengths, + self.model, + ) = test_utils.sequence_generator_setup() dummy_src_samples = self.src_tokens @@ -30,7 +32,9 @@ def setUp(self): self.cuda = torch.cuda.is_available() def _backtranslation_dataset_helper( - self, remove_eos_from_input_src, remove_eos_from_output_src, + self, + remove_eos_from_input_src, + remove_eos_from_output_src, ): tgt_dataset = LanguagePairDataset( src=self.tgt_dataset, @@ -94,17 +98,20 @@ def _backtranslation_dataset_helper( def test_backtranslation_dataset_no_eos_in_output_src(self): self._backtranslation_dataset_helper( - remove_eos_from_input_src=False, remove_eos_from_output_src=True, + remove_eos_from_input_src=False, + remove_eos_from_output_src=True, ) def test_backtranslation_dataset_with_eos_in_output_src(self): self._backtranslation_dataset_helper( - remove_eos_from_input_src=False, remove_eos_from_output_src=False, + remove_eos_from_input_src=False, + remove_eos_from_output_src=False, ) def test_backtranslation_dataset_no_eos_in_input_src(self): self._backtranslation_dataset_helper( - remove_eos_from_input_src=True, remove_eos_from_output_src=False, + remove_eos_from_input_src=True, + remove_eos_from_output_src=False, ) def assertTensorEqual(self, t1, t2): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index aa5a6c69d1..4b87afea55 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -4,31 +4,27 @@ # LICENSE file in the root directory of this source tree. import contextlib -from io import StringIO import logging import os import random import tempfile import unittest +from io import StringIO import torch - from fairseq import options -from fairseq_cli import train -from fairseq_cli import eval_lm -from fairseq_cli import validate +from fairseq_cli import eval_lm, train, validate from tests.utils import ( create_dummy_data, + generate_main, preprocess_lm_data, - preprocess_translation_data, preprocess_summarization_data, + preprocess_translation_data, train_translation_model, - generate_main, ) class TestTranslation(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) @@ -37,180 +33,271 @@ def tearDown(self): def test_fconv(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fconv') as data_dir: + with tempfile.TemporaryDirectory("test_fconv") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en') + train_translation_model(data_dir, "fconv_iwslt_de_en") generate_main(data_dir) def test_raw(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir: + with tempfile.TemporaryDirectory("test_fconv_raw") as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--dataset-impl', 'raw']) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw']) - generate_main(data_dir, ['--dataset-impl', 'raw']) + preprocess_translation_data(data_dir, ["--dataset-impl", "raw"]) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--dataset-impl", "raw"] + ) + generate_main(data_dir, ["--dataset-impl", "raw"]) def test_update_freq(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_update_freq') as data_dir: + with tempfile.TemporaryDirectory("test_update_freq") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3']) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--update-freq", "3"] + ) generate_main(data_dir) def test_max_positions(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_max_positions') as data_dir: + with tempfile.TemporaryDirectory("test_max_positions") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) with self.assertRaises(Exception) as context: train_translation_model( - data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'], + data_dir, + "fconv_iwslt_de_en", + ["--max-target-positions", "5"], ) self.assertTrue( - 'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception) + "skip this example with --skip-invalid-size-inputs-valid-test" + in str(context.exception) ) train_translation_model( - data_dir, 'fconv_iwslt_de_en', - ['--max-target-positions', '5', '--skip-invalid-size-inputs-valid-test'], + data_dir, + "fconv_iwslt_de_en", + [ + "--max-target-positions", + "5", + "--skip-invalid-size-inputs-valid-test", + ], ) with self.assertRaises(Exception) as context: generate_main(data_dir) - generate_main(data_dir, ['--skip-invalid-size-inputs-valid-test']) + generate_main(data_dir, ["--skip-invalid-size-inputs-valid-test"]) def test_generation(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_sampling') as data_dir: + with tempfile.TemporaryDirectory("test_sampling") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en') - generate_main(data_dir, [ - '--sampling', - '--temperature', '2', - '--beam', '2', - '--nbest', '2', - ]) - generate_main(data_dir, [ - '--sampling', - '--sampling-topk', '3', - '--beam', '2', - '--nbest', '2', - ]) - generate_main(data_dir, [ - '--sampling', - '--sampling-topp', '0.2', - '--beam', '2', - '--nbest', '2', - ]) - generate_main(data_dir, [ - '--diversity-rate', '0.5', - '--beam', '6', - ]) + train_translation_model(data_dir, "fconv_iwslt_de_en") + generate_main( + data_dir, + [ + "--sampling", + "--temperature", + "2", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--sampling", + "--sampling-topk", + "3", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--sampling", + "--sampling-topp", + "0.2", + "--beam", + "2", + "--nbest", + "2", + ], + ) + generate_main( + data_dir, + [ + "--diversity-rate", + "0.5", + "--beam", + "6", + ], + ) with self.assertRaises(ValueError): - generate_main(data_dir, [ - '--diverse-beam-groups', '4', - '--match-source-len', - ]) - generate_main(data_dir, ['--prefix-size', '2']) - generate_main(data_dir, ['--retain-dropout']) + generate_main( + data_dir, + [ + "--diverse-beam-groups", + "4", + "--match-source-len", + ], + ) + generate_main(data_dir, ["--prefix-size", "2"]) + generate_main(data_dir, ["--retain-dropout"]) def test_eval_bleu(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir: + with tempfile.TemporaryDirectory("test_eval_bleu") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', [ - '--eval-bleu', - '--eval-bleu-print-samples', - '--eval-bleu-remove-bpe', - '--eval-bleu-detok', 'space', - '--eval-bleu-args', '{"beam": 4, "min_len": 10}', - ]) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + [ + "--eval-bleu", + "--eval-bleu-print-samples", + "--eval-bleu-remove-bpe", + "--eval-bleu-detok", + "space", + "--eval-bleu-args", + '{"beam": 4, "min_len": 10}', + ], + ) def test_lstm(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lstm') as data_dir: + with tempfile.TemporaryDirectory("test_lstm") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--decoder-out-embed-dim', '8', - ]) + train_translation_model( + data_dir, + "lstm_wiseman_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + ], + ) generate_main(data_dir) def test_lstm_bidirectional(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir: + with tempfile.TemporaryDirectory("test_lstm_bidirectional") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'lstm', [ - '--encoder-layers', '2', - '--encoder-bidirectional', - '--encoder-hidden-size', '16', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--decoder-out-embed-dim', '8', - '--decoder-layers', '2', - ]) + train_translation_model( + data_dir, + "lstm", + [ + "--encoder-layers", + "2", + "--encoder-bidirectional", + "--encoder-hidden-size", + "16", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + "--decoder-layers", + "2", + ], + ) generate_main(data_dir) def test_transformer(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer') as data_dir: + with tempfile.TemporaryDirectory("test_transformer") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'transformer_iwslt_de_en', [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - ], run_validation=True) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + run_validation=True, + ) generate_main(data_dir) def test_multilingual_transformer(self): # test with all combinations of encoder/decoder lang tokens - encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']] - decoder_langtok_flags = [[], ['--decoder-langtok']] + encoder_langtok_flags = [ + [], + ["--encoder-langtok", "src"], + ["--encoder-langtok", "tgt"], + ] + decoder_langtok_flags = [[], ["--decoder-langtok"]] with contextlib.redirect_stdout(StringIO()): for i in range(len(encoder_langtok_flags)): for j in range(len(decoder_langtok_flags)): enc_ltok_flag = encoder_langtok_flags[i] dec_ltok_flag = decoder_langtok_flags[j] - with tempfile.TemporaryDirectory(f'test_multilingual_transformer_{i}_{j}') as data_dir: + with tempfile.TemporaryDirectory( + f"test_multilingual_transformer_{i}_{j}" + ) as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) train_translation_model( data_dir, - arch='multilingual_transformer', - task='multilingual_translation', + arch="multilingual_transformer", + task="multilingual_translation", extra_flags=[ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - ] + enc_ltok_flag + dec_ltok_flag, - lang_flags=['--lang-pairs', 'in-out,out-in'], + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], run_validation=True, extra_valid_flags=enc_ltok_flag + dec_ltok_flag, ) generate_main( data_dir, extra_flags=[ - '--task', 'multilingual_translation', - '--lang-pairs', 'in-out,out-in', - '--source-lang', 'in', - '--target-lang', 'out', - ] + enc_ltok_flag + dec_ltok_flag, + "--task", + "multilingual_translation", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, ) def test_multilingual_translation_latent_depth(self): # test with latent depth in encoder, decoder, or both - encoder_latent_layer = [[], ['--encoder-latent-layer']] - decoder_latent_layer = [[], ['--decoder-latent-layer']] + encoder_latent_layer = [[], ["--encoder-latent-layer"]] + decoder_latent_layer = [[], ["--decoder-latent-layer"]] with contextlib.redirect_stdout(StringIO()): for i in range(len(encoder_latent_layer)): for j in range(len(decoder_latent_layer)): @@ -218,186 +305,298 @@ def test_multilingual_translation_latent_depth(self): continue enc_ll_flag = encoder_latent_layer[i] dec_ll_flag = decoder_latent_layer[j] - with tempfile.TemporaryDirectory(f'test_multilingual_translation_latent_depth_{i}_{j}') as data_dir: + with tempfile.TemporaryDirectory( + f"test_multilingual_translation_latent_depth_{i}_{j}" + ) as data_dir: create_dummy_data(data_dir) preprocess_translation_data( - data_dir, - extra_flags=['--joined-dictionary'] + data_dir, extra_flags=["--joined-dictionary"] ) train_translation_model( data_dir, - arch='latent_multilingual_transformer', - task='multilingual_translation_latent_depth', + arch="latent_multilingual_transformer", + task="multilingual_translation_latent_depth", extra_flags=[ - '--user-dir', 'examples/latent_depth/src', - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--share-encoders', - '--share-decoders', - '--sparsity-weight', '0.1', - ] + enc_ll_flag + dec_ll_flag, - lang_flags=['--lang-pairs', 'in-out,out-in'], + "--user-dir", + "examples/latent_depth/src", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--share-encoders", + "--share-decoders", + "--sparsity-weight", + "0.1", + ] + + enc_ll_flag + + dec_ll_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], run_validation=True, - extra_valid_flags=['--user-dir', 'examples/latent_depth/src'] + enc_ll_flag + dec_ll_flag, + extra_valid_flags=[ + "--user-dir", + "examples/latent_depth/src", + ] + + enc_ll_flag + + dec_ll_flag, ) generate_main( data_dir, extra_flags=[ - '--user-dir', 'examples/latent_depth/src', - '--task', 'multilingual_translation_latent_depth', - '--lang-pairs', 'in-out,out-in', - '--source-lang', 'in', - '--target-lang', 'out', - ] + enc_ll_flag + dec_ll_flag, + "--user-dir", + "examples/latent_depth/src", + "--task", + "multilingual_translation_latent_depth", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ll_flag + + dec_ll_flag, ) def test_translation_multi_simple_epoch(self): # test with all combinations of encoder/decoder lang tokens - encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']] - decoder_langtok_flags = [[], ['--decoder-langtok']] + encoder_langtok_flags = [ + [], + ["--encoder-langtok", "src"], + ["--encoder-langtok", "tgt"], + ] + decoder_langtok_flags = [[], ["--decoder-langtok"]] with contextlib.redirect_stdout(StringIO()): for i in range(len(encoder_langtok_flags)): for j in range(len(decoder_langtok_flags)): enc_ltok_flag = encoder_langtok_flags[i] dec_ltok_flag = decoder_langtok_flags[j] - with tempfile.TemporaryDirectory(f'test_translation_multi_simple_epoch_{i}_{j}') as data_dir: + with tempfile.TemporaryDirectory( + f"test_translation_multi_simple_epoch_{i}_{j}" + ) as data_dir: create_dummy_data(data_dir) preprocess_translation_data( - data_dir, - extra_flags=['--joined-dictionary'] + data_dir, extra_flags=["--joined-dictionary"] ) train_translation_model( data_dir, - arch='transformer', - task='translation_multi_simple_epoch', + arch="transformer", + task="translation_multi_simple_epoch", extra_flags=[ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--sampling-method', 'temperature', - '--sampling-temperature', '1.5', - '--virtual-epoch-size', '1000', - ] + enc_ltok_flag + dec_ltok_flag, - lang_flags=['--lang-pairs', 'in-out,out-in'], + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out,out-in"], run_validation=True, extra_valid_flags=enc_ltok_flag + dec_ltok_flag, ) generate_main( data_dir, extra_flags=[ - '--task', 'translation_multi_simple_epoch', - '--lang-pairs', 'in-out,out-in', - '--source-lang', 'in', - '--target-lang', 'out', - ] + enc_ltok_flag + dec_ltok_flag, + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out,out-in", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, ) def test_transformer_cross_self_attention(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir: + with tempfile.TemporaryDirectory( + "test_transformer_cross_self_attention" + ) as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'transformer_iwslt_de_en', [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--no-cross-attention', - '--cross-self-attention', - ], run_validation=True) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--no-cross-attention", + "--cross-self-attention", + ], + run_validation=True, + ) generate_main(data_dir, extra_flags=[]) def test_transformer_pointer_generator(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer_pointer_generator') as data_dir: + with tempfile.TemporaryDirectory( + "test_transformer_pointer_generator" + ) as data_dir: create_dummy_data(data_dir) preprocess_summarization_data(data_dir) train_translation_model( data_dir, - 'transformer_pointer_generator', + "transformer_pointer_generator", extra_flags=[ - '--user-dir', 'examples/pointer_generator/src', - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--alignment-layer', '-1', - '--alignment-heads', '1', - '--source-position-markers', '0', + "--user-dir", + "examples/pointer_generator/src", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--alignment-layer", + "-1", + "--alignment-heads", + "1", + "--source-position-markers", + "0", ], run_validation=True, - extra_valid_flags=['--user-dir', 'examples/pointer_generator/src'], + extra_valid_flags=["--user-dir", "examples/pointer_generator/src"], ) generate_main( data_dir, - extra_flags=['--user-dir', 'examples/pointer_generator/src'], + extra_flags=["--user-dir", "examples/pointer_generator/src"], ) def test_lightconv(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lightconv') as data_dir: + with tempfile.TemporaryDirectory("test_lightconv") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'lightconv_iwslt_de_en', [ - '--encoder-conv-type', 'lightweight', - '--decoder-conv-type', 'lightweight', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - ]) + train_translation_model( + data_dir, + "lightconv_iwslt_de_en", + [ + "--encoder-conv-type", + "lightweight", + "--decoder-conv-type", + "lightweight", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) generate_main(data_dir) def test_dynamicconv(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_dynamicconv') as data_dir: + with tempfile.TemporaryDirectory("test_dynamicconv") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'lightconv_iwslt_de_en', [ - '--encoder-conv-type', 'dynamic', - '--decoder-conv-type', 'dynamic', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - ]) + train_translation_model( + data_dir, + "lightconv_iwslt_de_en", + [ + "--encoder-conv-type", + "dynamic", + "--decoder-conv-type", + "dynamic", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) generate_main(data_dir) def test_cmlm_transformer(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir: + with tempfile.TemporaryDirectory("test_cmlm_transformer") as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'cmlm_transformer', [ - '--apply-bert-init', - '--criterion', 'nat_loss', - '--noise', 'full_mask', - '--pred-length-offset', - '--length-loss-factor', '0.1' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '9', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "cmlm_transformer", + [ + "--apply-bert-init", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--pred-length-offset", + "--length-loss-factor", + "0.1", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) def test_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: + with tempfile.TemporaryDirectory( + "test_nonautoregressive_transformer" + ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'nonautoregressive_transformer', [ - '--apply-bert-init', '--src-embedding-copy', '--criterion', - 'nat_loss', '--noise', 'full_mask', '--pred-length-offset', - '--length-loss-factor', '0.1' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '0', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "nonautoregressive_transformer", + [ + "--apply-bert-init", + "--src-embedding-copy", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--pred-length-offset", + "--length-loss-factor", + "0.1", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "0", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) # def test_nat_crf_transformer(self): # with contextlib.redirect_stdout(StringIO()): @@ -421,78 +620,139 @@ def test_nonautoregressive_transformer(self): def test_iterative_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir: + with tempfile.TemporaryDirectory( + "test_iterative_nonautoregressive_transformer" + ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [ - '--apply-bert-init', '--src-embedding-copy', '--criterion', - 'nat_loss', '--noise', 'full_mask', '--stochastic-approx', - '--dae-ratio', '0.5', '--train-step', '3' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '9', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "iterative_nonautoregressive_transformer", + [ + "--apply-bert-init", + "--src-embedding-copy", + "--criterion", + "nat_loss", + "--noise", + "full_mask", + "--stochastic-approx", + "--dae-ratio", + "0.5", + "--train-step", + "3", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) def test_insertion_transformer(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir: + with tempfile.TemporaryDirectory("test_insertion_transformer") as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'insertion_transformer', [ - '--apply-bert-init', '--criterion', 'nat_loss', '--noise', - 'random_mask' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '9', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "insertion_transformer", + [ + "--apply-bert-init", + "--criterion", + "nat_loss", + "--noise", + "random_mask", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) def test_mixture_of_experts(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_moe') as data_dir: + with tempfile.TemporaryDirectory("test_moe") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'transformer_iwslt_de_en', [ - '--task', 'translation_moe', - '--user-dir', 'examples/translation_moe/src', - '--method', 'hMoElp', - '--mean-pool-gating-network', - '--num-experts', '3', - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - ]) - generate_main(data_dir, [ - '--task', 'translation_moe', - '--user-dir', 'examples/translation_moe/src', - '--method', 'hMoElp', - '--mean-pool-gating-network', - '--num-experts', '3', - '--gen-expert', '0' - ]) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--task", + "translation_moe", + "--user-dir", + "examples/translation_moe/src", + "--method", + "hMoElp", + "--mean-pool-gating-network", + "--num-experts", + "3", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + ], + ) + generate_main( + data_dir, + [ + "--task", + "translation_moe", + "--user-dir", + "examples/translation_moe/src", + "--method", + "hMoElp", + "--mean-pool-gating-network", + "--num-experts", + "3", + "--gen-expert", + "0", + ], + ) def test_alignment(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_alignment') as data_dir: + with tempfile.TemporaryDirectory("test_alignment") as data_dir: create_dummy_data(data_dir, alignment=True) - preprocess_translation_data(data_dir, ['--align-suffix', 'align']) + preprocess_translation_data(data_dir, ["--align-suffix", "align"]) train_translation_model( data_dir, - 'transformer_align', + "transformer_align", [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--load-alignments', - '--alignment-layer', '1', - '--criterion', 'label_smoothed_cross_entropy_with_alignment', + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--load-alignments", + "--alignment-layer", + "1", + "--criterion", + "label_smoothed_cross_entropy_with_alignment", ], run_validation=True, ) @@ -500,21 +760,27 @@ def test_alignment(self): def test_alignment_full_context(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_alignment') as data_dir: + with tempfile.TemporaryDirectory("test_alignment") as data_dir: create_dummy_data(data_dir, alignment=True) - preprocess_translation_data(data_dir, ['--align-suffix', 'align']) + preprocess_translation_data(data_dir, ["--align-suffix", "align"]) train_translation_model( data_dir, - 'transformer_align', + "transformer_align", [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--load-alignments', - '--alignment-layer', '1', - '--criterion', 'label_smoothed_cross_entropy_with_alignment', - '--full-context-alignment', + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--load-alignments", + "--alignment-layer", + "1", + "--criterion", + "label_smoothed_cross_entropy_with_alignment", + "--full-context-alignment", ], run_validation=True, ) @@ -522,7 +788,6 @@ def test_alignment_full_context(self): class TestStories(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) @@ -531,37 +796,55 @@ def tearDown(self): def test_fconv_self_att_wp(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir: + with tempfile.TemporaryDirectory("test_fconv_self_att_wp") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) config = [ - '--encoder-layers', '[(128, 3)] * 2', - '--decoder-layers', '[(128, 3)] * 2', - '--decoder-attention', 'True', - '--encoder-attention', 'False', - '--gated-attention', 'True', - '--self-attention', 'True', - '--project-input', 'True', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--decoder-out-embed-dim', '8', - '--multihead-self-attention-nheads', '2' + "--encoder-layers", + "[(128, 3)] * 2", + "--decoder-layers", + "[(128, 3)] * 2", + "--decoder-attention", + "True", + "--encoder-attention", + "False", + "--gated-attention", + "True", + "--self-attention", + "True", + "--project-input", + "True", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--decoder-out-embed-dim", + "8", + "--multihead-self-attention-nheads", + "2", ] - train_translation_model(data_dir, 'fconv_self_att_wp', config) + train_translation_model(data_dir, "fconv_self_att_wp", config) generate_main(data_dir) # fusion model - os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt')) - config.extend([ - '--pretrained', 'True', - '--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'), - '--save-dir', os.path.join(data_dir, 'fusion_model'), - ]) - train_translation_model(data_dir, 'fconv_self_att_wp', config) + os.rename( + os.path.join(data_dir, "checkpoint_last.pt"), + os.path.join(data_dir, "pretrained.pt"), + ) + config.extend( + [ + "--pretrained", + "True", + "--pretrained-checkpoint", + os.path.join(data_dir, "pretrained.pt"), + "--save-dir", + os.path.join(data_dir, "fusion_model"), + ] + ) + train_translation_model(data_dir, "fconv_self_att_wp", config) class TestLanguageModeling(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) @@ -570,84 +853,134 @@ def tearDown(self): def test_fconv_lm(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir: + with tempfile.TemporaryDirectory("test_fconv_lm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) - train_language_model(data_dir, 'fconv_lm', [ - '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', - '--decoder-embed-dim', '280', - '--optimizer', 'nag', - '--lr', '0.1', - ]) + train_language_model( + data_dir, + "fconv_lm", + [ + "--decoder-layers", + "[(850, 3)] * 2 + [(1024,4)]", + "--decoder-embed-dim", + "280", + "--optimizer", + "nag", + "--lr", + "0.1", + ], + ) eval_lm_main(data_dir) - generate_main(data_dir, [ - '--task', 'language_modeling', - '--sample-break-mode', 'eos', - '--tokens-per-sample', '500', - ]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) def test_transformer_lm(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir: + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) train_language_model( - data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True, + data_dir, + "transformer_lm", + ["--add-bos-token"], + run_validation=True, ) eval_lm_main(data_dir) - generate_main(data_dir, [ - '--task', 'language_modeling', - '--sample-break-mode', 'eos', - '--tokens-per-sample', '500', - ]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) def test_lightconv_lm(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lightconv_lm') as data_dir: + with tempfile.TemporaryDirectory("test_lightconv_lm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) train_language_model( - data_dir, 'lightconv_lm', ['--add-bos-token'], run_validation=True, + data_dir, + "lightconv_lm", + ["--add-bos-token"], + run_validation=True, ) eval_lm_main(data_dir) - generate_main(data_dir, [ - '--task', 'language_modeling', - '--sample-break-mode', 'eos', - '--tokens-per-sample', '500', - ]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) def test_lstm_lm(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lstm_lm') as data_dir: + with tempfile.TemporaryDirectory("test_lstm_lm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) train_language_model( - data_dir, 'lstm_lm', ['--add-bos-token'], run_validation=True, + data_dir, + "lstm_lm", + ["--add-bos-token"], + run_validation=True, ) eval_lm_main(data_dir) - generate_main(data_dir, [ - '--task', 'language_modeling', - '--sample-break-mode', 'eos', - '--tokens-per-sample', '500', - ]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) def test_lstm_lm_residuals(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_lstm_lm_residuals') as data_dir: + with tempfile.TemporaryDirectory("test_lstm_lm_residuals") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) train_language_model( - data_dir, 'lstm_lm', ['--add-bos-token', '--residuals'], run_validation=True, + data_dir, + "lstm_lm", + ["--add-bos-token", "--residuals"], + run_validation=True, ) eval_lm_main(data_dir) - generate_main(data_dir, [ - '--task', 'language_modeling', - '--sample-break-mode', 'eos', - '--tokens-per-sample', '500', - ]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) -class TestMaskedLanguageModel(unittest.TestCase): +class TestMaskedLanguageModel(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -666,32 +999,52 @@ def test_roberta_masked_lm(self): with tempfile.TemporaryDirectory("test_roberta_mlm") as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) - train_masked_lm(data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"]) + train_masked_lm( + data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"] + ) def test_roberta_sentence_prediction(self): num_classes = 3 with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_roberta_head") as data_dir: create_dummy_roberta_head_data(data_dir, num_classes=num_classes) - preprocess_lm_data(os.path.join(data_dir, 'input0')) - preprocess_lm_data(os.path.join(data_dir, 'label')) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) train_roberta_head(data_dir, "roberta_base", num_classes=num_classes) def test_roberta_regression_single(self): num_classes = 1 with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_roberta_regression_single") as data_dir: - create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) - preprocess_lm_data(os.path.join(data_dir, 'input0')) - train_roberta_head(data_dir, "roberta_base", num_classes=num_classes, extra_flags=['--regression-target']) + with tempfile.TemporaryDirectory( + "test_roberta_regression_single" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target"], + ) def test_roberta_regression_multiple(self): num_classes = 3 with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_roberta_regression_multiple") as data_dir: - create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) - preprocess_lm_data(os.path.join(data_dir, 'input0')) - train_roberta_head(data_dir, "roberta_base", num_classes=num_classes, extra_flags=['--regression-target']) + with tempfile.TemporaryDirectory( + "test_roberta_regression_multiple" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) + train_roberta_head( + data_dir, + "roberta_base", + num_classes=num_classes, + extra_flags=["--regression-target"], + ) def test_linformer_roberta_masked_lm(self): with contextlib.redirect_stdout(StringIO()): @@ -702,8 +1055,10 @@ def test_linformer_roberta_masked_lm(self): data_dir, "linformer_roberta_base", extra_flags=[ - "--user-dir", "examples/linformer/src", - "--encoder-layers", "2", + "--user-dir", + "examples/linformer/src", + "--encoder-layers", + "2", ], ) @@ -712,8 +1067,8 @@ def test_linformer_roberta_sentence_prediction(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_linformer_roberta_head") as data_dir: create_dummy_roberta_head_data(data_dir, num_classes=num_classes) - preprocess_lm_data(os.path.join(data_dir, 'input0')) - preprocess_lm_data(os.path.join(data_dir, 'label')) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) train_roberta_head( data_dir, "linformer_roberta_base", @@ -724,27 +1079,43 @@ def test_linformer_roberta_sentence_prediction(self): def test_linformer_roberta_regression_single(self): num_classes = 1 with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_linformer_roberta_regression_single") as data_dir: - create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) - preprocess_lm_data(os.path.join(data_dir, 'input0')) + with tempfile.TemporaryDirectory( + "test_linformer_roberta_regression_single" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) train_roberta_head( data_dir, "linformer_roberta_base", num_classes=num_classes, - extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"], + extra_flags=[ + "--regression-target", + "--user-dir", + "examples/linformer/src", + ], ) def test_linformer_roberta_regression_multiple(self): num_classes = 3 with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_linformer_roberta_regression_multiple") as data_dir: - create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True) - preprocess_lm_data(os.path.join(data_dir, 'input0')) + with tempfile.TemporaryDirectory( + "test_linformer_roberta_regression_multiple" + ) as data_dir: + create_dummy_roberta_head_data( + data_dir, num_classes=num_classes, regression=True + ) + preprocess_lm_data(os.path.join(data_dir, "input0")) train_roberta_head( data_dir, "linformer_roberta_base", num_classes=num_classes, - extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"], + extra_flags=[ + "--regression-target", + "--user-dir", + "examples/linformer/src", + ], ) def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only): @@ -755,7 +1126,7 @@ def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_on train_legacy_masked_language_model( data_dir, arch="masked_lm", - extra_args=('--encoder-learned-pos',) if learned_pos_emb else () + extra_args=("--encoder-learned-pos",) if learned_pos_emb else (), ) with tempfile.TemporaryDirectory( "test_mlm_translation" @@ -793,10 +1164,13 @@ def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_on "500", "--max-target-positions", "500", - ] + ( + ] + + ( ["--encoder-learned-pos", "--decoder-learned-pos"] - if learned_pos_emb else [] - ) + (['--init-encoder-only'] if encoder_only else []), + if learned_pos_emb + else [] + ) + + (["--init-encoder-only"] if encoder_only else []), task="translation_from_pretrained_xlm", ) @@ -814,8 +1188,8 @@ def test_r4f_roberta(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_r4f_roberta_head") as data_dir: create_dummy_roberta_head_data(data_dir, num_classes=num_classes) - preprocess_lm_data(os.path.join(data_dir, 'input0')) - preprocess_lm_data(os.path.join(data_dir, 'label')) + preprocess_lm_data(os.path.join(data_dir, "input0")) + preprocess_lm_data(os.path.join(data_dir, "label")) train_roberta_head( data_dir, "roberta_base", @@ -824,8 +1198,8 @@ def test_r4f_roberta(self): "--user-dir", "examples/rxf/src", "--criterion", - 'sentence_prediction_r3f', - '--spectral-norm-classification-head', + "sentence_prediction_r3f", + "--spectral-norm-classification-head", ], ) @@ -890,13 +1264,13 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): "raw", "--num-workers", "0", - ] + list(extra_args), + ] + + list(extra_args), ) train.main(train_args) class TestOptimizers(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) @@ -905,27 +1279,39 @@ def tearDown(self): def test_optimizers(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_optimizers') as data_dir: + with tempfile.TemporaryDirectory("test_optimizers") as data_dir: # Use just a bit of data and tiny model to keep this test runtime reasonable create_dummy_data(data_dir, num_examples=10, maxlen=5) preprocess_translation_data(data_dir) - optimizers = ['adafactor', 'adam', 'nag', 'adagrad', 'sgd', 'adadelta'] - last_checkpoint = os.path.join(data_dir, 'checkpoint_last.pt') + optimizers = ["adafactor", "adam", "nag", "adagrad", "sgd", "adadelta"] + last_checkpoint = os.path.join(data_dir, "checkpoint_last.pt") for optimizer in optimizers: if os.path.exists(last_checkpoint): os.remove(last_checkpoint) - train_translation_model(data_dir, 'lstm', [ - '--required-batch-size-multiple', '1', - '--encoder-layers', '1', - '--encoder-hidden-size', '32', - '--decoder-layers', '1', - '--optimizer', optimizer, - ]) + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + optimizer, + ], + ) generate_main(data_dir) -def create_dummy_roberta_head_data(data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False): - input_dir = 'input0' +def create_dummy_roberta_head_data( + data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False +): + input_dir = "input0" + def _create_dummy_data(filename): random_data = torch.rand(num_examples * maxlen) input_data = 97 + torch.floor(26 * random_data).int() @@ -933,29 +1319,29 @@ def _create_dummy_data(filename): output_data = torch.rand((num_examples, num_classes)) else: output_data = 1 + torch.floor(num_classes * torch.rand(num_examples)).int() - with open(os.path.join(data_dir, input_dir, filename+'.out'), 'w') as f_in: - label_filename = filename+'.label' if regression else filename+'.out' - with open(os.path.join(data_dir, 'label', label_filename), 'w') as f_out: + with open(os.path.join(data_dir, input_dir, filename + ".out"), "w") as f_in: + label_filename = filename + ".label" if regression else filename + ".out" + with open(os.path.join(data_dir, "label", label_filename), "w") as f_out: offset = 0 for i in range(num_examples): # write example input ex_len = random.randint(1, maxlen) - ex_str = ' '.join(map(chr, input_data[offset:offset+ex_len])) + ex_str = " ".join(map(chr, input_data[offset : offset + ex_len])) print(ex_str, file=f_in) # write example label if regression: - class_str = ' '.join(map(str, output_data[i].numpy())) + class_str = " ".join(map(str, output_data[i].numpy())) print(class_str, file=f_out) else: - class_str = 'class{}'.format(output_data[i]) + class_str = "class{}".format(output_data[i]) print(class_str, file=f_out) offset += ex_len os.mkdir(os.path.join(data_dir, input_dir)) - os.mkdir(os.path.join(data_dir, 'label')) - _create_dummy_data('train') - _create_dummy_data('valid') - _create_dummy_data('test') + os.mkdir(os.path.join(data_dir, "label")) + _create_dummy_data("train") + _create_dummy_data("valid") + _create_dummy_data("test") def train_masked_lm(data_dir, arch, extra_flags=None): @@ -963,20 +1349,32 @@ def train_masked_lm(data_dir, arch, extra_flags=None): train_args = options.parse_args_and_arch( train_parser, [ - '--task', 'masked_lm', + "--task", + "masked_lm", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "masked_lm", + "--batch-size", + "500", + "--save-dir", data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'masked_lm', - '--batch-size', '500', - '--save-dir', data_dir, - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', '0', - ] + (extra_flags or []), + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), ) train.main(train_args) @@ -986,24 +1384,40 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): train_args = options.parse_args_and_arch( train_parser, [ - '--task', 'sentence_prediction', + "--task", + "sentence_prediction", data_dir, - '--arch', arch, - '--encoder-layers', '2', - '--num-classes', str(num_classes), - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'sentence_prediction', - '--max-tokens', '500', - '--max-positions', '500', - '--batch-size', '500', - '--save-dir', data_dir, - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', '0', - ] + (extra_flags or []), + "--arch", + arch, + "--encoder-layers", + "2", + "--num-classes", + str(num_classes), + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "sentence_prediction", + "--max-tokens", + "500", + "--max-positions", + "500", + "--batch-size", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), ) train.main(train_args) @@ -1013,22 +1427,36 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) train_args = options.parse_args_and_arch( train_parser, [ - '--task', 'language_modeling', + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '500', - '--tokens-per-sample', '500', - '--save-dir', data_dir, - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', '0', - ] + (extra_flags or []), + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), ) train.main(train_args) @@ -1038,14 +1466,19 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) validate_args = options.parse_args_and_arch( validate_parser, [ - '--task', 'language_modeling', + "--task", + "language_modeling", data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--valid-subset', 'valid', - '--max-tokens', '500', - '--no-progress-bar', - '--num-workers', '0', - ] + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--valid-subset", + "valid", + "--max-tokens", + "500", + "--no-progress-bar", + "--num-workers", + "0", + ], ) validate.main(validate_args) @@ -1056,9 +1489,11 @@ def eval_lm_main(data_dir): eval_lm_parser, [ data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--no-progress-bar', - '--num-workers', '0', + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--no-progress-bar", + "--num-workers", + "0", ], ) eval_lm.main(eval_lm_args) @@ -1124,10 +1559,11 @@ def train_masked_language_model(data_dir, arch, extra_args=()): "raw", "--num-workers", "0", - ] + list(extra_args), + ] + + list(extra_args), ) train.main(train_args) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py index 30563bdb50..0165b2955b 100644 --- a/tests/test_bmuf.py +++ b/tests/test_bmuf.py @@ -4,13 +4,12 @@ # LICENSE file in the root directory of this source tree. import argparse -from multiprocessing import Manager import random import unittest +from multiprocessing import Manager import torch import torch.nn as nn - from fairseq import distributed_utils, optim @@ -169,5 +168,5 @@ def assertAlmostEqual(self, t1, t2): self.assertLess((t1 - t2).abs().max(), 1e-4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_character_token_embedder.py b/tests/test_character_token_embedder.py index 81042c2a3f..24940ebd21 100644 --- a/tests/test_character_token_embedder.py +++ b/tests/test_character_token_embedder.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import unittest +import torch from fairseq.data import Dictionary from fairseq.modules import CharacterTokenEmbedder @@ -13,12 +13,14 @@ class TestCharacterTokenEmbedder(unittest.TestCase): def test_character_token_embedder(self): vocab = Dictionary() - vocab.add_symbol('hello') - vocab.add_symbol('there') + vocab.add_symbol("hello") + vocab.add_symbol("there") - embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) + embedder = CharacterTokenEmbedder( + vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2 + ) - test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] + test_sents = [["hello", "unk", "there"], ["there"], ["hello", "there"]] max_len = max(len(s) for s in test_sents) input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) for i in range(len(test_sents)): @@ -42,5 +44,5 @@ def assertAlmostEqual(self, t1, t2): self.assertLess((t1 - t2).abs().max(), 1e-6) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_concat_dataset.py b/tests/test_concat_dataset.py index dbdb2ac518..d94aeffd48 100644 --- a/tests/test_concat_dataset.py +++ b/tests/test_concat_dataset.py @@ -40,25 +40,19 @@ def setUp(self): ) def test_concat_dataset_basics(self): - d = ConcatDataset( - [self.dataset_1, self.dataset_2] - ) - assert(len(d) == 2) - assert(d[0]['source'][0] == 1) - assert(d[1]['source'][0] == 2) + d = ConcatDataset([self.dataset_1, self.dataset_2]) + assert len(d) == 2 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 2 - d = ConcatDataset( - [self.dataset_1, self.dataset_2], sample_ratios=[1, 2] - ) - assert(len(d) == 3) - assert(d[0]['source'][0] == 1) - assert(d[1]['source'][0] == 2) - assert(d[2]['source'][0] == 2) + d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) + assert len(d) == 3 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 2 + assert d[2]["source"][0] == 2 - d = ConcatDataset( - [self.dataset_1, self.dataset_2], sample_ratios=[2, 1] - ) - assert(len(d) == 3) - assert(d[0]['source'][0] == 1) - assert(d[1]['source'][0] == 1) - assert(d[2]['source'][0] == 2) + d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) + assert len(d) == 3 + assert d[0]["source"][0] == 1 + assert d[1]["source"][0] == 1 + assert d[2]["source"][0] == 2 diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 3f63c8ace5..1c37f7e1fb 100755 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -4,9 +4,9 @@ # LICENSE file in the root directory of this source tree. import sys -import torch import unittest +import torch from fairseq.token_generation_constraints import * @@ -17,26 +17,27 @@ def tensorize(constraints: List[List[int]]) -> torch.Tensor: class TestHelperRoutines(unittest.TestCase): def setUp(self): self.examples = [ + ([[]], torch.tensor([[0]])), + ([[], []], torch.tensor([[0], [0]])), + ([[torch.tensor([1, 2])], []], torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])), ( - [[]], - torch.tensor([[0]]) - ), - ( - [[], []], - torch.tensor([[0], [0]]) - ), - ( - [[torch.tensor([1, 2])], []], - torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]]) + [ + [ + torch.tensor([3, 1, 2]), + torch.tensor([3]), + torch.tensor([4, 5, 6, 7]), + ], + [], + [torch.tensor([1, 8, 9, 10, 1, 4, 11, 12])], + ], + torch.tensor( + [ + [3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0], + ] + ), ), - ( - [[torch.tensor([3, 1, 2]), torch.tensor([3]), torch.tensor([4, 5, 6, 7])], - [], - [ torch.tensor([1, 8, 9, 10, 1, 4, 11, 12]) ]], - torch.tensor([[3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0]]) - ) ] def test_packing(self): @@ -53,20 +54,24 @@ def setUp(self): ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), "([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))", - { 1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1 } + {1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1}, + ), + ([], "[None].False#0", {}), + (tensorize([[0]]), "([None].False#1 [0].True#1)", {0: 1}), + ( + tensorize([[100000, 1, 2, 3, 4, 5]]), + "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", + {100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, ), - ( [], "[None].False#0", {} ), - ( tensorize([[0]]), "([None].False#1 [0].True#1)", { 0: 1 } ), - ( tensorize([[100000, 1, 2, 3, 4, 5]]), "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", { 100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1 } ), ( tensorize([[1, 2], [1, 2]]), "([None].False#2 ([1].False#2 [2].True#2))", - { 1: 2, 2: 2 }, + {1: 2, 2: 2}, ), ( tensorize([[1, 2], [3, 4]]), "([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))", - { 1: 1, 2: 1, 3: 1, 4: 1}, + {1: 1, 2: 1, 3: 1, 4: 1}, ), ] @@ -74,65 +79,65 @@ def setUp(self): ( self.examples[0][0], [], - { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, ), ( self.examples[0][0], [1, 2], - { "bank": 2, "num_completed": 0, "finished": False, "is_root": False }, + {"bank": 2, "num_completed": 0, "finished": False, "is_root": False}, ), ( self.examples[0][0], [1, 2, 94], - { "bank": 1, "num_completed": 1, "finished": False, "is_root": True }, + {"bank": 1, "num_completed": 1, "finished": False, "is_root": True}, ), ( self.examples[0][0], [1, 3, 999, 1, 4], - { "bank": 4, "num_completed": 2, "finished": False, "is_root": False }, + {"bank": 4, "num_completed": 2, "finished": False, "is_root": False}, ), ( self.examples[0][0], [1, 3, 999, 1, 4, 999], - { "bank": 4, "num_completed": 2, "finished": False, "is_root": True }, + {"bank": 4, "num_completed": 2, "finished": False, "is_root": True}, ), ( self.examples[0][0], [4, 5, 6, 8], - { "bank": 2, "num_completed": 1, "finished": False, "is_root": True }, + {"bank": 2, "num_completed": 1, "finished": False, "is_root": True}, ), ( self.examples[0][0], # Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5] # [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]], [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], - { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, ), ( self.examples[0][0], [1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], - { "bank": 14, "num_completed": 6, "finished": True, "is_root": True }, + {"bank": 14, "num_completed": 6, "finished": True, "is_root": True}, ), ( tensorize([[1], [2, 3]]), # Should not be able to get credit for entering 1 a second time [1, 1], - { "bank": 1, "num_completed": 1, "finished": False, "is_root": True }, + {"bank": 1, "num_completed": 1, "finished": False, "is_root": True}, ), ( self.examples[4][0], [1, 2, 1, 2], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, ), ( self.examples[4][0], [1, 2, 1, 2, 1], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": True }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": True}, ), ( self.examples[5][0], [1, 2, 3, 4, 5], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": True }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": True}, ), ] @@ -143,8 +148,12 @@ def test_graphs(self): for example in self.examples: constraints, expected, gold_counts = example c = ConstraintNode.create(constraints) - assert ConstraintNode.print_graph(c) == expected, f"got {ConstraintNode.print_graph(c)}, expected {expected}" - assert c.token_counts() == gold_counts, f"{c} got {c.token_counts()} wanted {gold_counts}" + assert ( + ConstraintNode.print_graph(c) == expected + ), f"got {ConstraintNode.print_graph(c)}, expected {expected}" + assert ( + c.token_counts() == gold_counts + ), f"{c} got {c.token_counts()} wanted {gold_counts}" def test_next_tokens(self): """ @@ -159,7 +168,9 @@ def test_next_tokens(self): state = UnorderedConstraintState(root) for token in sequence: all_tokens = root_tokens.union(state.node.children.keys()) - assert all_tokens == state.next_tokens(), f"ALL {all_tokens} NEXT {state.next_tokens()}" + assert ( + all_tokens == state.next_tokens() + ), f"ALL {all_tokens} NEXT {state.next_tokens()}" state = state.advance(token) def test_sequences(self): @@ -171,7 +182,9 @@ def test_sequences(self): for attr in expected.keys(): result[attr] = getattr(state, attr) - assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}" + assert ( + result == expected + ), f"TEST({tokens}) GOT: {result} WANTED: {expected}" class TestOrderedConstraintState(unittest.TestCase): @@ -180,62 +193,62 @@ def setUp(self): ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [], - { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2], - { "bank": 2, "num_completed": 0, "finished": False, "is_root": False }, + {"bank": 2, "num_completed": 0, "finished": False, "is_root": False}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2, 94], - { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 3, 999, 1, 4], - { "bank": 0, "num_completed": 0, "finished": False, "is_root": True }, + {"bank": 0, "num_completed": 0, "finished": False, "is_root": True}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2, 3, 999, 999], - { "bank": 3, "num_completed": 1, "finished": False, "is_root": False }, + {"bank": 3, "num_completed": 1, "finished": False, "is_root": False}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2, 3, 77, 1, 3, 1], - { "bank": 6, "num_completed": 2, "finished": False, "is_root": False }, + {"bank": 6, "num_completed": 2, "finished": False, "is_root": False}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5], - { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, ), ( tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]), [1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117], - { "bank": 14, "num_completed": 6, "finished": True, "is_root": False }, + {"bank": 14, "num_completed": 6, "finished": True, "is_root": False}, ), ( tensorize([[1], [2, 3]]), [1, 1], - { "bank": 1, "num_completed": 1, "finished": False, "is_root": False }, + {"bank": 1, "num_completed": 1, "finished": False, "is_root": False}, ), ( tensorize([[1, 2], [1, 2]]), [1, 2, 1, 2], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, ), ( tensorize([[1, 2], [1, 2]]), [1, 2, 1, 2, 1], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, ), ( tensorize([[1, 2], [3, 4]]), [1, 2, 3, 4, 5], - { "bank": 4, "num_completed": 2, "finished": True, "is_root": False }, + {"bank": 4, "num_completed": 2, "finished": True, "is_root": False}, ), ] @@ -247,8 +260,10 @@ def test_sequences(self): result = {} for attr in expected.keys(): result[attr] = getattr(state, attr) - assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}" + assert ( + result == expected + ), f"TEST({tokens}) GOT: {result} WANTED: {expected}" + if __name__ == "__main__": unittest.main() - diff --git a/tests/test_convtbc.py b/tests/test_convtbc.py index fc2ac0b5dc..3a3c9b91e7 100644 --- a/tests/test_convtbc.py +++ b/tests/test_convtbc.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import unittest -from fairseq.modules import ConvTBC + +import torch import torch.nn as nn +from fairseq.modules import ConvTBC class TestConvTBC(unittest.TestCase): - def test_convtbc(self): # ksz, in_channels, out_channels conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) @@ -27,7 +27,9 @@ def test_convtbc(self): output_tbc = conv_tbc(input_tbc) output1d = conv1d(input1d) - self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) + self.assertAlmostEqual( + output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data + ) grad_tbc = torch.randn(output_tbc.size()) grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() @@ -35,14 +37,18 @@ def test_convtbc(self): output_tbc.backward(grad_tbc) output1d.backward(grad1d) - self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) + self.assertAlmostEqual( + conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data + ) self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) - self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) + self.assertAlmostEqual( + input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data + ) def assertAlmostEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertLess((t1 - t2).abs().max(), 1e-4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index d9a1ec72c8..81ce102f4f 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -8,31 +8,39 @@ import unittest import torch - from fairseq.data import Dictionary class TestDictionary(unittest.TestCase): - def test_finalize(self): txt = [ - 'A B C D', - 'B C D', - 'C D', - 'D', + "A B C D", + "B C D", + "C D", + "D", ] - ref_ids1 = list(map(torch.IntTensor, [ - [4, 5, 6, 7, 2], - [5, 6, 7, 2], - [6, 7, 2], - [7, 2], - ])) - ref_ids2 = list(map(torch.IntTensor, [ - [7, 6, 5, 4, 2], - [6, 5, 4, 2], - [5, 4, 2], - [4, 2], - ])) + ref_ids1 = list( + map( + torch.IntTensor, + [ + [4, 5, 6, 7, 2], + [5, 6, 7, 2], + [6, 7, 2], + [7, 2], + ], + ) + ) + ref_ids2 = list( + map( + torch.IntTensor, + [ + [7, 6, 5, 4, 2], + [6, 5, 4, 2], + [5, 4, 2], + [4, 2], + ], + ) + ) # build dictionary d = Dictionary() @@ -59,7 +67,7 @@ def assertMatch(ids, ref_ids): assertMatch(finalized_ids, ref_ids2) # write to disk and reload - with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: + with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: d.save(tmp_dict.name) d = Dictionary.load(tmp_dict.name) reload_ids = get_ids(d) @@ -77,40 +85,32 @@ def test_overwrite(self): ) d = Dictionary() d.add_from_file(dict_file) - self.assertEqual(d.index(''), 1) - self.assertEqual(d.index('foo'), 3) - self.assertEqual(d.index(''), 4) - self.assertEqual(d.index(''), 5) - self.assertEqual(d.index(''), 6) - self.assertEqual(d.index(','), 7) - self.assertEqual(d.index('▁de'), 8) + self.assertEqual(d.index(""), 1) + self.assertEqual(d.index("foo"), 3) + self.assertEqual(d.index(""), 4) + self.assertEqual(d.index(""), 5) + self.assertEqual(d.index(""), 6) + self.assertEqual(d.index(","), 7) + self.assertEqual(d.index("▁de"), 8) def test_no_overwrite(self): # for example, Camembert overwrites , and dict_file = io.StringIO( - " 999\n" - " 999\n" - " 999\n" - ", 999\n" - "▁de 999\n" + " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" ) d = Dictionary() - with self.assertRaisesRegex(RuntimeError, 'Duplicate'): + with self.assertRaisesRegex(RuntimeError, "Duplicate"): d.add_from_file(dict_file) def test_space(self): # for example, character models treat space as a symbol - dict_file = io.StringIO( - " 999\n" - "a 999\n" - "b 999\n" - ) + dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") d = Dictionary() d.add_from_file(dict_file) - self.assertEqual(d.index(' '), 4) - self.assertEqual(d.index('a'), 5) - self.assertEqual(d.index('b'), 6) + self.assertEqual(d.index(" "), 4) + self.assertEqual(d.index("a"), 5) + self.assertEqual(d.index("b"), 6) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_io.py b/tests/test_file_io.py index ffcc6a3eef..aef5b80d18 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -1,14 +1,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import sys -import tempfile import os import shutil - -from typing import Optional - +import sys +import tempfile import unittest +from typing import Optional from unittest.mock import MagicMock @@ -34,14 +32,16 @@ def tearDownClass(cls) -> None: def test_file_io(self): from fairseq.file_io import PathManager + with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents) def test_file_io_oss(self): # Mock fvcore to simulate oss environment. - sys.modules['fvcore'] = MagicMock() + sys.modules["fvcore"] = MagicMock() from fairseq.file_io import PathManager + with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents) diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index bca341af1a..c4195273e3 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -8,13 +8,11 @@ import unittest import torch - from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer -@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") class TestGradientScaling(unittest.TestCase): - def setUp(self): self.x = torch.tensor([2.0]).cuda().half() weight = 3.0 @@ -30,16 +28,16 @@ def setUp(self): self.params = list(self.model.parameters()) self.namespace_dls = argparse.Namespace( - optimizer='adam', + optimizer="adam", lr=[0.1], - adam_betas='(0.9, 0.999)', + adam_betas="(0.9, 0.999)", adam_eps=1e-8, weight_decay=0.0, fp16_init_scale=1, fp16_scale_window=1, fp16_scale_tolerance=1, threshold_loss_scale=1, - min_loss_scale=1e-4 + min_loss_scale=1e-4, ) def run_iter(self, model, params, optimizer): @@ -47,15 +45,25 @@ def run_iter(self, model, params, optimizer): y = model(self.x) loss = self.loss_fn(y, self.target) optimizer.backward(loss) - self.assertEqual(loss, torch.tensor(1., device='cuda:0', dtype=torch.float16)) + self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) grad_norm = optimizer.clip_grad_norm(0) self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) optimizer.step() - self.assertEqual(model.weight, torch.tensor([[3.0996]], device='cuda:0', dtype=torch.float16, requires_grad=True)) - self.assertEqual(model.bias, torch.tensor([5.1016], device='cuda:0', dtype=torch.float16, requires_grad=True)) - self.assertEqual(optimizer.scaler.loss_scale, 2.) + self.assertEqual( + model.weight, + torch.tensor( + [[3.0996]], device="cuda:0", dtype=torch.float16, requires_grad=True + ), + ) + self.assertEqual( + model.bias, + torch.tensor( + [5.1016], device="cuda:0", dtype=torch.float16, requires_grad=True + ), + ) + self.assertEqual(optimizer.scaler.loss_scale, 2.0) def test_mixed_precision(self): model = copy.deepcopy(self.model) @@ -63,18 +71,28 @@ def test_mixed_precision(self): optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) self.run_iter(model, params, optimizer) - self.assertTrue(all( - torch.all(fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True))) - for fp32_params in optimizer.fp32_params.values() - )) + self.assertTrue( + all( + torch.all( + fp32_params.eq( + torch.tensor( + [3.1000, 5.1000], device="cuda:0", requires_grad=True + ) + ) + ) + for fp32_params in optimizer.fp32_params.values() + ) + ) def test_memory_efficient(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.namespace_dls, params) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer( + self.namespace_dls, params + ) self.run_iter(model, params, optimizer) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index 4857bc7a87..fd5edd43d6 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -6,12 +6,11 @@ import logging import unittest -from tests.test_sequence_generator import get_dummy_task_and_parser from fairseq.models.transformer import TransformerModel +from tests.test_sequence_generator import get_dummy_task_and_parser class TestInferenceDropout(unittest.TestCase): - def setUp(self): self.task, self.parser = get_dummy_task_and_parser() TransformerModel.add_args(self.parser) @@ -55,7 +54,10 @@ def test_applies_training_mode(self): def test_retain_modules(self): self.args.retain_dropout = True - self.args.retain_dropout_modules = ['TransformerEncoder', 'TransformerEncoderLayer'] + self.args.retain_dropout_modules = [ + "TransformerEncoder", + "TransformerEncoderLayer", + ] self.transformer_model = TransformerModel.build_model(self.args, self.task) self.transformer_model.prepare_for_inference_(self.args) assert self.transformer_model.encoder.dropout_module.apply_during_inference diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 9e444d154b..3d2c4d6251 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -9,7 +9,6 @@ class TestIterators(unittest.TestCase): - def test_counting_iterator(self, ref=None, itr=None): if ref is None: assert itr is None @@ -109,7 +108,7 @@ def test_counting_iterator_buffered_iterator_take(self): self.assertFalse(itr.has_next()) self.assertRaises(StopIteration, next, buffered_itr) - ref = list(range(4,10)) + ref = list(range(4, 10)) buffered_itr = iterators.BufferedIterator(2, ref) itr = iterators.CountingIterator(buffered_itr, start=4) itr.take(5) @@ -120,5 +119,5 @@ def test_counting_iterator_buffered_iterator_take(self): self.assertRaises(StopIteration, next, buffered_itr) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_label_smoothing.py b/tests/test_label_smoothing.py index 94e5ccf1f3..04c0f974ac 100644 --- a/tests/test_label_smoothing.py +++ b/tests/test_label_smoothing.py @@ -7,16 +7,15 @@ import copy import unittest +import tests.utils as test_utils import torch - from fairseq.criterions.cross_entropy import CrossEntropyCriterion -from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion - -import tests.utils as test_utils +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, +) class TestLabelSmoothing(unittest.TestCase): - def setUp(self): # build dictionary self.d = test_utils.dummy_dictionary(3) @@ -30,8 +29,14 @@ def setUp(self): # build dataset self.data = [ # the first batch item has padding - {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])}, - {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])}, + { + "source": torch.LongTensor([w1, eos]), + "target": torch.LongTensor([w1, eos]), + }, + { + "source": torch.LongTensor([w1, eos]), + "target": torch.LongTensor([w1, w1, eos]), + }, ] self.sample = next(test_utils.dummy_dataloader(self.data)) @@ -39,23 +44,35 @@ def setUp(self): self.args = argparse.Namespace() self.args.sentence_avg = False self.args.report_accuracy = False - self.args.probs = torch.FloatTensor([ - # pad eos unk w1 w2 w3 - [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05], - [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], - [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], - ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension + self.args.probs = ( + torch.FloatTensor( + [ + # pad eos unk w1 w2 w3 + [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05], + [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], + [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], + ] + ) + .unsqueeze(0) + .expand(2, 3, 7) + ) # add batch dimension self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d) self.model = self.task.build_model(self.args) def test_nll_loss(self): self.args.label_smoothing = 0.1 nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task) - smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) - nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) - smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) - self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) - self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6) + smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion( + self.args, self.task + ) + nll_loss, nll_sample_size, nll_logging_output = nll_crit( + self.model, self.sample + ) + smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit( + self.model, self.sample + ) + self.assertLess(abs(nll_loss - nll_logging_output["loss"]), 1e-6) + self.assertLess(abs(nll_loss - smooth_logging_output["nll_loss"]), 1e-6) def test_padding(self): self.args.label_smoothing = 0.1 @@ -86,9 +103,15 @@ def test_reduction(self): def test_zero_eps(self): self.args.label_smoothing = 0.0 nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task) - smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task) - nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) - smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) + smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion( + self.args, self.task + ) + nll_loss, nll_sample_size, nll_logging_output = nll_crit( + self.model, self.sample + ) + smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit( + self.model, self.sample + ) self.assertAlmostEqual(nll_loss, smooth_loss) def assertAlmostEqual(self, t1, t2): @@ -96,5 +119,5 @@ def assertAlmostEqual(self, t1, t2): self.assertLess((t1 - t2).abs().max(), 1e-6) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_lstm_jitable.py b/tests/test_lstm_jitable.py index d97652fb77..38f79d1793 100644 --- a/tests/test_lstm_jitable.py +++ b/tests/test_lstm_jitable.py @@ -92,19 +92,21 @@ def test_assert_jit_vs_nonjit_(self): idx = len(task.source_dictionary) iter = 100 # Inject random input and check output - seq_len_tensor = torch.randint(1, 10, (iter, )) - num_samples_tensor = torch.randint(1, 10, (iter, )) + seq_len_tensor = torch.randint(1, 10, (iter,)) + num_samples_tensor = torch.randint(1, 10, (iter,)) for i in range(iter): seq_len = seq_len_tensor[i] num_samples = num_samples_tensor[i] - src_token = torch.randint(0, idx, (num_samples, seq_len)), - src_lengths = torch.randint(1, seq_len+1, (num_samples,)) + src_token = (torch.randint(0, idx, (num_samples, seq_len)),) + src_lengths = torch.randint(1, seq_len + 1, (num_samples,)) src_lengths, _ = torch.sort(src_lengths, descending=True) # Force the first sample to have seq_len src_lengths[0] = seq_len - prev_output_token = torch.randint(0, idx, (num_samples, 1)), + prev_output_token = (torch.randint(0, idx, (num_samples, 1)),) result = model(src_token[0], src_lengths, prev_output_token[0], None) - scripted_result = scripted_model(src_token[0], src_lengths, prev_output_token[0], None) + scripted_result = scripted_model( + src_token[0], src_lengths, prev_output_token[0], None + ) self.assertTensorEqual(result[0], scripted_result[0]) self.assertTensorEqual(result[1], scripted_result[1]) diff --git a/tests/test_memory_efficient_fp16.py b/tests/test_memory_efficient_fp16.py index bd2b8faeb4..e10636d96a 100644 --- a/tests/test_memory_efficient_fp16.py +++ b/tests/test_memory_efficient_fp16.py @@ -8,14 +8,12 @@ import unittest import torch - from fairseq.optim.adam import FairseqAdam from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer -@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") class TestMemoryEfficientFP16(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) @@ -31,7 +29,7 @@ def test_load_state_dict(self): optimizer = FairseqAdam( argparse.Namespace( lr=[0.00001], - adam_betas='(0.9, 0.999)', + adam_betas="(0.9, 0.999)", adam_eps=1e-8, weight_decay=0.0, ), @@ -64,5 +62,5 @@ def test_load_state_dict(self): self.assertTrue(v_i.dtype == torch.float32) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 060291808e..2de6969cf4 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -10,69 +10,68 @@ class TestMetrics(unittest.TestCase): - def test_nesting(self): with metrics.aggregate() as a: - metrics.log_scalar('loss', 1) + metrics.log_scalar("loss", 1) with metrics.aggregate() as b: - metrics.log_scalar('loss', 2) + metrics.log_scalar("loss", 2) - self.assertEqual(a.get_smoothed_values()['loss'], 1.5) - self.assertEqual(b.get_smoothed_values()['loss'], 2) + self.assertEqual(a.get_smoothed_values()["loss"], 1.5) + self.assertEqual(b.get_smoothed_values()["loss"], 2) def test_new_root(self): with metrics.aggregate() as a: - metrics.log_scalar('loss', 1) + metrics.log_scalar("loss", 1) with metrics.aggregate(new_root=True) as b: - metrics.log_scalar('loss', 2) + metrics.log_scalar("loss", 2) - self.assertEqual(a.get_smoothed_values()['loss'], 1) - self.assertEqual(b.get_smoothed_values()['loss'], 2) + self.assertEqual(a.get_smoothed_values()["loss"], 1) + self.assertEqual(b.get_smoothed_values()["loss"], 2) def test_nested_new_root(self): with metrics.aggregate() as layer1: - metrics.log_scalar('loss', 1) + metrics.log_scalar("loss", 1) with metrics.aggregate(new_root=True) as layer2: - metrics.log_scalar('loss', 2) + metrics.log_scalar("loss", 2) with metrics.aggregate() as layer3: - metrics.log_scalar('loss', 3) + metrics.log_scalar("loss", 3) with metrics.aggregate(new_root=True) as layer4: - metrics.log_scalar('loss', 4) - metrics.log_scalar('loss', 1.5) + metrics.log_scalar("loss", 4) + metrics.log_scalar("loss", 1.5) - self.assertEqual(layer4.get_smoothed_values()['loss'], 4) - self.assertEqual(layer3.get_smoothed_values()['loss'], 3) - self.assertEqual(layer2.get_smoothed_values()['loss'], 2.5) - self.assertEqual(layer1.get_smoothed_values()['loss'], 1.25) + self.assertEqual(layer4.get_smoothed_values()["loss"], 4) + self.assertEqual(layer3.get_smoothed_values()["loss"], 3) + self.assertEqual(layer2.get_smoothed_values()["loss"], 2.5) + self.assertEqual(layer1.get_smoothed_values()["loss"], 1.25) def test_named(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): - metrics.log_scalar('loss', 1) + metrics.log_scalar("loss", 1) - metrics.log_scalar('loss', 3) + metrics.log_scalar("loss", 3) with metrics.aggregate(name): - metrics.log_scalar('loss', 2) + metrics.log_scalar("loss", 2) - self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5) + self.assertEqual(metrics.get_smoothed_values(name)["loss"], 1.5) def test_nested_duplicate_names(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): - metrics.log_scalar('loss', 1) + metrics.log_scalar("loss", 1) with metrics.aggregate() as other: with metrics.aggregate(name): - metrics.log_scalar('loss', 2) - metrics.log_scalar('loss', 6) + metrics.log_scalar("loss", 2) + metrics.log_scalar("loss", 6) - self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3) - self.assertEqual(other.get_smoothed_values()['loss'], 2) + self.assertEqual(metrics.get_smoothed_values(name)["loss"], 3) + self.assertEqual(other.get_smoothed_values()["loss"], 2) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index 324d8e3eb5..9aa9cb2f87 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -3,8 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import unittest + +import torch from fairseq.modules.multihead_attention import MultiheadAttention @@ -47,8 +48,8 @@ def test_append_prev_key_padding_mask(self): if key_padding_mask is not None: self.assertTrue( torch.all(torch.eq(key_padding_mask, c[2])), - f'Unexpected resultant key padding mask: {key_padding_mask}' - f' given current: {c[0]} and previous: {c[1]}', + f"Unexpected resultant key padding mask: {key_padding_mask}" + f" given current: {c[0]} and previous: {c[1]}", ) self.assertEqual(key_padding_mask.size(0), bsz) self.assertEqual(key_padding_mask.size(1), src_len) @@ -56,5 +57,5 @@ def test_append_prev_key_padding_mask(self): self.assertIsNone(c[2]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_noising.py b/tests/test_noising.py index da792a1826..b3d0d123c4 100644 --- a/tests/test_noising.py +++ b/tests/test_noising.py @@ -408,7 +408,10 @@ def test_word_blank_without_eos(self): self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) def _get_noising_dataset_batch( - self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False, + self, + src_tokens_no_pad, + src_dict, + append_eos_to_tgt=False, ): """ Constructs a NoisingDataset and the corresponding @@ -433,7 +436,8 @@ def _get_noising_dataset_batch( src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict ) language_pair_dataset = TransformEosDataset( - language_pair_dataset, src_dict.eos(), + language_pair_dataset, + src_dict.eos(), append_eos_to_tgt=append_eos_to_tgt, ) diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 80f2948250..517e23c39e 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import contextlib -from io import StringIO import json import os import tempfile import unittest +from io import StringIO import torch @@ -16,13 +16,12 @@ class TestReproducibility(unittest.TestCase): - def _test_reproducibility( self, name, extra_flags=None, delta=0.0001, - resume_checkpoint='checkpoint1.pt', + resume_checkpoint="checkpoint1.pt", max_epoch=3, ): def get_last_log_stats_containing_string(log_records, search_string): @@ -41,63 +40,99 @@ def get_last_log_stats_containing_string(log_records, search_string): # train epochs 1 and 2 together with self.assertLogs() as logs: test_binaries.train_translation_model( - data_dir, 'fconv_iwslt_de_en', [ - '--dropout', '0.0', - '--log-format', 'json', - '--log-interval', '1', - '--max-epoch', str(max_epoch), - ] + extra_flags, + data_dir, + "fconv_iwslt_de_en", + [ + "--dropout", + "0.0", + "--log-format", + "json", + "--log-interval", + "1", + "--max-epoch", + str(max_epoch), + ] + + extra_flags, ) - train_log = get_last_log_stats_containing_string(logs.records, 'train_loss') - valid_log = get_last_log_stats_containing_string(logs.records, 'valid_loss') + train_log = get_last_log_stats_containing_string(logs.records, "train_loss") + valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss") # train epoch 2, resuming from previous checkpoint 1 os.rename( os.path.join(data_dir, resume_checkpoint), - os.path.join(data_dir, 'checkpoint_last.pt'), + os.path.join(data_dir, "checkpoint_last.pt"), ) with self.assertLogs() as logs: test_binaries.train_translation_model( - data_dir, 'fconv_iwslt_de_en', [ - '--dropout', '0.0', - '--log-format', 'json', - '--log-interval', '1', - '--max-epoch', str(max_epoch), - ] + extra_flags, + data_dir, + "fconv_iwslt_de_en", + [ + "--dropout", + "0.0", + "--log-format", + "json", + "--log-interval", + "1", + "--max-epoch", + str(max_epoch), + ] + + extra_flags, ) - train_res_log = get_last_log_stats_containing_string(logs.records, 'train_loss') - valid_res_log = get_last_log_stats_containing_string(logs.records, 'valid_loss') + train_res_log = get_last_log_stats_containing_string( + logs.records, "train_loss" + ) + valid_res_log = get_last_log_stats_containing_string( + logs.records, "valid_loss" + ) - for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']: - self.assertAlmostEqual(float(train_log[k]), float(train_res_log[k]), delta=delta) - for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']: - self.assertAlmostEqual(float(valid_log[k]), float(valid_res_log[k]), delta=delta) + for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]: + self.assertAlmostEqual( + float(train_log[k]), float(train_res_log[k]), delta=delta + ) + for k in [ + "valid_loss", + "valid_ppl", + "valid_num_updates", + "valid_best_loss", + ]: + self.assertAlmostEqual( + float(valid_log[k]), float(valid_res_log[k]), delta=delta + ) def test_reproducibility(self): - self._test_reproducibility('test_reproducibility') + self._test_reproducibility("test_reproducibility") - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_reproducibility_fp16(self): - self._test_reproducibility('test_reproducibility_fp16', [ - '--fp16', - '--fp16-init-scale', '4096', - ], delta=0.011) + self._test_reproducibility( + "test_reproducibility_fp16", + [ + "--fp16", + "--fp16-init-scale", + "4096", + ], + delta=0.011, + ) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_reproducibility_memory_efficient_fp16(self): - self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [ - '--memory-efficient-fp16', - '--fp16-init-scale', '4096', - ]) + self._test_reproducibility( + "test_reproducibility_memory_efficient_fp16", + [ + "--memory-efficient-fp16", + "--fp16-init-scale", + "4096", + ], + ) def test_mid_epoch_reproducibility(self): self._test_reproducibility( - 'test_mid_epoch_reproducibility', - ['--save-interval-updates', '3'], - resume_checkpoint='checkpoint_1_3.pt', + "test_mid_epoch_reproducibility", + ["--save-interval-updates", "3"], + resume_checkpoint="checkpoint_1_3.pt", max_epoch=1, ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_resampling_dataset.py b/tests/test_resampling_dataset.py index 0d142f5a8d..ccb53a253c 100644 --- a/tests/test_resampling_dataset.py +++ b/tests/test_resampling_dataset.py @@ -7,7 +7,6 @@ import unittest import numpy as np - from fairseq.data import ListDataset, ResamplingDataset diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 517aa77d59..c890b655ff 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -11,9 +11,8 @@ import torch from fairseq import search from fairseq.data.dictionary import Dictionary - from fairseq.models.transformer import TransformerModel -from fairseq.sequence_generator import SequenceGenerator, EnsembleModel +from fairseq.sequence_generator import EnsembleModel, SequenceGenerator from fairseq.tasks.fairseq_task import LegacyFairseqTask @@ -109,7 +108,6 @@ def _test_save_and_load(self, scripted_module): class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase): - @unittest.skipIf( torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" ) @@ -130,7 +128,6 @@ def test_ensemble_sequence_generator(self): class TestJitEnsemble(TestJitSequenceGeneratorBase): - @unittest.skipIf( torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" ) @@ -190,9 +187,14 @@ def assertTensorEqual(self, t1, t2): class TestSequeneceGenerator(TestSequenceGeneratorBase): def setUp(self): - self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = ( - test_utils.sequence_generator_setup() - ) + ( + self.tgt_dict, + self.w1, + self.w2, + src_tokens, + src_lengths, + self.model, + ) = test_utils.sequence_generator_setup() self.sample = { "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths} } @@ -276,7 +278,9 @@ def test_with_lenpen_favoring_long_hypos(self): self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen) def test_maxlen(self): - generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2, max_len_b=2) + generator = SequenceGenerator( + [self.model], self.tgt_dict, beam_size=2, max_len_b=2 + ) hypos = generator.forward(self.sample) eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 # sentence 1, beam 1 @@ -294,21 +298,27 @@ def test_maxlen(self): def test_encoder_with_different_output_len(self): args = self.model.encoder.args - task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict) + task = test_utils.TestTranslationTask.setup_task( + args, self.tgt_dict, self.tgt_dict + ) reshaping_model = test_utils.TestReshapingModel.build_model(args, task) - generator = SequenceGenerator([reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2) + generator = SequenceGenerator( + [reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2 + ) hypos = generator.forward(self.sample) for sent in [0, 1]: for beam in [0, 1]: - assert hypos[sent][beam]['attention'] is not None + assert hypos[sent][beam]["attention"] is not None def test_generation_with_additional_input(self): args = self.model.encoder.args - task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict) + task = test_utils.TestTranslationTask.setup_task( + args, self.tgt_dict, self.tgt_dict + ) add_input_model = test_utils.TestAdditionalInputModel.build_model(args, task) generator = SequenceGenerator([add_input_model], self.tgt_dict, beam_size=2) sample = self.sample.copy() - sample['net_input']['fancy_other_input'] = sample['net_input']['src_tokens'] + sample["net_input"]["fancy_other_input"] = sample["net_input"]["src_tokens"] hypos = generator.forward(self.sample) eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 # sentence 1, beam 1 @@ -317,7 +327,6 @@ def test_generation_with_additional_input(self): class TestDiverseBeamSearch(TestSequenceGeneratorBase): - def setUp(self): # construct dummy dictionary d = test_utils.dummy_dictionary(vocab_size=2) @@ -329,45 +338,53 @@ def setUp(self): self.w2 = 5 # construct source data - self.src_tokens = torch.LongTensor([ - [self.w1, self.w2, self.eos], - [self.w1, self.w2, self.eos], - ]) + self.src_tokens = torch.LongTensor( + [ + [self.w1, self.w2, self.eos], + [self.w1, self.w2, self.eos], + ] + ) self.src_lengths = torch.LongTensor([2, 2]) args = argparse.Namespace() - unk = 0. + unk = 0.0 args.beam_probs = [ # step 0: - torch.FloatTensor([ - # eos w1 w2 - # sentence 1: - [0.0, unk, 0.9, 0.1], # beam 1 - [0.0, unk, 0.9, 0.1], # beam 2 - # sentence 2: - [0.0, unk, 0.7, 0.3], - [0.0, unk, 0.7, 0.3], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + # sentence 1: + [0.0, unk, 0.9, 0.1], # beam 1 + [0.0, unk, 0.9, 0.1], # beam 2 + # sentence 2: + [0.0, unk, 0.7, 0.3], + [0.0, unk, 0.7, 0.3], + ] + ), # step 1: - torch.FloatTensor([ - # eos w1 w2 - # sentence 1: - [0.0, unk, 0.6, 0.4], - [0.0, unk, 0.6, 0.4], - # sentence 2: - [0.25, unk, 0.35, 0.4], - [0.25, unk, 0.35, 0.4], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + # sentence 1: + [0.0, unk, 0.6, 0.4], + [0.0, unk, 0.6, 0.4], + # sentence 2: + [0.25, unk, 0.35, 0.4], + [0.25, unk, 0.35, 0.4], + ] + ), # step 2: - torch.FloatTensor([ - # eos w1 w2 - # sentence 1: - [1.0, unk, 0.0, 0.0], - [1.0, unk, 0.0, 0.0], - # sentence 2: - [0.9, unk, 0.1, 0.0], - [0.9, unk, 0.1, 0.0], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + # sentence 1: + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + # sentence 2: + [0.9, unk, 0.1, 0.0], + [0.9, unk, 0.1, 0.0], + ] + ), ] task = test_utils.TestTranslationTask.setup_task(args, d, d) @@ -375,11 +392,21 @@ def setUp(self): self.tgt_dict = task.target_dictionary def test_diverse_beam_search(self): - search_strategy = search.DiverseBeamSearch(self.tgt_dict, num_groups=2, diversity_strength=0.) + search_strategy = search.DiverseBeamSearch( + self.tgt_dict, num_groups=2, diversity_strength=0.0 + ) generator = SequenceGenerator( - [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy, + [self.model], + self.tgt_dict, + beam_size=2, + search_strategy=search_strategy, ) - sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}} + sample = { + "net_input": { + "src_tokens": self.src_tokens, + "src_lengths": self.src_lengths, + } + } hypos = generator.forward(sample) eos, w1, w2 = self.eos, self.w1, self.w2 # sentence 1, beam 1 @@ -439,7 +466,6 @@ def test_diverse_beam_search(self): class TestTopPSamplingSearch(TestSequenceGeneratorBase): - def setUp(self): # construct dummy dictionary d = test_utils.dummy_dictionary(vocab_size=2) @@ -451,14 +477,16 @@ def setUp(self): self.w2 = 5 # construct source data - self.src_tokens = torch.LongTensor([ - [self.w1, self.w2, self.eos], - [self.w1, self.w2, self.eos], - ]) + self.src_tokens = torch.LongTensor( + [ + [self.w1, self.w2, self.eos], + [self.w1, self.w2, self.eos], + ] + ) self.src_lengths = torch.LongTensor([2, 2]) args = argparse.Namespace() - unk = 0. + unk = 0.0 # The minimal probability of top 2 tokens. self.min_top2_prob = 0.75 # The minimal probability of the top 1 token. @@ -470,29 +498,35 @@ def setUp(self): args.beam_probs = [ # step 0: - torch.FloatTensor([ - # eos w1 w2 - [0.0, unk, 1.0, 0.0], - [0.0, unk, 1.0, 0.0], - [0.0, unk, 1.0, 0.0], - [0.0, unk, 1.0, 0.0], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + ] + ), # step 1: - torch.FloatTensor([ - # eos w1 w2 - [eos_prob, unk, w1_prob, w2_prob], - [eos_prob, unk, w1_prob, w2_prob], - [eos_prob, unk, w1_prob, w2_prob], - [eos_prob, unk, w1_prob, w2_prob], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + ] + ), # step 2: - torch.FloatTensor([ - # eos w1 w2 - [1.0, unk, 0.0, 0.0], - [1.0, unk, 0.0, 0.0], - [1.0, unk, 0.0, 0.0], - [1.0, unk, 0.0, 0.0], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + ] + ), ] task = test_utils.TestTranslationTask.setup_task(args, d, d) @@ -502,14 +536,17 @@ def setUp(self): def test_topp_sampling_search_low_prob(self): # Given a prob low enough to top-P sampling, we expect only the top # 1 token to be sampled, which always results in the same output. - low_sampling_topp = self.min_top1_prob/2.0 - search_strategy = search.Sampling(self.tgt_dict, sampling_topp=low_sampling_topp) + low_sampling_topp = self.min_top1_prob / 2.0 + search_strategy = search.Sampling( + self.tgt_dict, sampling_topp=low_sampling_topp + ) generator = SequenceGenerator( - [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy) + [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy + ) sample = { - 'net_input': { - 'src_tokens': self.src_tokens, - 'src_lengths': self.src_lengths + "net_input": { + "src_tokens": self.src_tokens, + "src_lengths": self.src_lengths, } } hypos = generator.forward(sample) @@ -530,55 +567,74 @@ def test_topp_sampling_search_low_prob(self): def test_topp_sampling_search_high_prob(self): # Given a prob high enough to top-P sampling, any of the top 2 # tokens could be sampled. This can cause different outputs. - high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0 - search_strategy = search.Sampling(self.tgt_dict, sampling_topp=high_sampling_topp) + high_sampling_topp = (self.min_top1_prob + self.min_top2_prob) / 2.0 + search_strategy = search.Sampling( + self.tgt_dict, sampling_topp=high_sampling_topp + ) generator = SequenceGenerator( - [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy) + [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy + ) sample = { - 'net_input': { - 'src_tokens': self.src_tokens, - 'src_lengths': self.src_lengths + "net_input": { + "src_tokens": self.src_tokens, + "src_lengths": self.src_lengths, } } hypos = generator.forward(sample) eos, w1, w2 = self.eos, self.w1, self.w2 # sentence 1, beam 1 - self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or - self.hypoTokens(hypos[0][0], [w1, w2, eos])) - self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or - self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])) + self.assertTrue( + self.hypoTokens(hypos[0][0], [w1, w1, eos]) + or self.hypoTokens(hypos[0][0], [w1, w2, eos]) + ) + self.assertTrue( + self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) + or self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0]) + ) # sentence 1, beam 2 - self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or - self.hypoTokens(hypos[0][1], [w1, w2, eos])) - self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or - self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])) + self.assertTrue( + self.hypoTokens(hypos[0][1], [w1, w1, eos]) + or self.hypoTokens(hypos[0][1], [w1, w2, eos]) + ) + self.assertTrue( + self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) + or self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0]) + ) # sentence 2, beam 1 - self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or - self.hypoTokens(hypos[1][0], [w1, w2, eos])) - self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or - self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])) + self.assertTrue( + self.hypoTokens(hypos[1][0], [w1, w1, eos]) + or self.hypoTokens(hypos[1][0], [w1, w2, eos]) + ) + self.assertTrue( + self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) + or self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0]) + ) # sentence 2, beam 2 - self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or - self.hypoTokens(hypos[1][1], [w1, w2, eos])) - self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or - self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])) + self.assertTrue( + self.hypoTokens(hypos[1][1], [w1, w1, eos]) + or self.hypoTokens(hypos[1][1], [w1, w2, eos]) + ) + self.assertTrue( + self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) + or self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0]) + ) def hypoTokens(self, hypo, tokens): - return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens)) + return self.tensorEqual(hypo["tokens"], torch.LongTensor(tokens)) - def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): + def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0): pos_scores = torch.FloatTensor(pos_probs).log() - if not self.almostEqual(hypo['positional_scores'], pos_scores): + if not self.almostEqual(hypo["positional_scores"], pos_scores): return False - if pos_scores.numel() != hypo['tokens'].numel(): + if pos_scores.numel() != hypo["tokens"].numel(): return False score = pos_scores.sum() if normalized: score /= pos_scores.numel() ** lenpen - return abs(score - hypo['score']) < 1e-6 + return abs(score - hypo["score"]) < 1e-6 def almostEqual(self, t1, t2): return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4 diff --git a/tests/test_sequence_scorer.py b/tests/test_sequence_scorer.py index a7c2a53a90..42f9447b59 100644 --- a/tests/test_sequence_scorer.py +++ b/tests/test_sequence_scorer.py @@ -6,15 +6,12 @@ import argparse import unittest +import tests.utils as test_utils import torch - from fairseq.sequence_scorer import SequenceScorer -import tests.utils as test_utils - class TestSequenceScorer(unittest.TestCase): - def test_sequence_scorer(self): # construct dummy dictionary d = test_utils.dummy_dictionary(vocab_size=2) @@ -28,52 +25,60 @@ def test_sequence_scorer(self): # construct dataloader data = [ { - 'source': torch.LongTensor([w1, w2, eos]), - 'target': torch.LongTensor([w1, w2, w1, eos]), + "source": torch.LongTensor([w1, w2, eos]), + "target": torch.LongTensor([w1, w2, w1, eos]), }, { - 'source': torch.LongTensor([w2, eos]), - 'target': torch.LongTensor([w2, w1, eos]), + "source": torch.LongTensor([w2, eos]), + "target": torch.LongTensor([w2, w1, eos]), }, { - 'source': torch.LongTensor([w2, eos]), - 'target': torch.LongTensor([w2, eos]), + "source": torch.LongTensor([w2, eos]), + "target": torch.LongTensor([w2, eos]), }, ] data_itr = test_utils.dummy_dataloader(data) # specify expected output probabilities args = argparse.Namespace() - unk = 0. + unk = 0.0 args.beam_probs = [ # step 0: - torch.FloatTensor([ - # eos w1 w2 - [0.0, unk, 0.6, 0.4], # sentence 1 - [0.0, unk, 0.4, 0.6], # sentence 2 - [0.0, unk, 0.7, 0.3], # sentence 3 - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [0.0, unk, 0.6, 0.4], # sentence 1 + [0.0, unk, 0.4, 0.6], # sentence 2 + [0.0, unk, 0.7, 0.3], # sentence 3 + ] + ), # step 1: - torch.FloatTensor([ - # eos w1 w2 - [0.0, unk, 0.2, 0.7], # sentence 1 - [0.0, unk, 0.8, 0.2], # sentence 2 - [0.7, unk, 0.1, 0.2], # sentence 3 - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [0.0, unk, 0.2, 0.7], # sentence 1 + [0.0, unk, 0.8, 0.2], # sentence 2 + [0.7, unk, 0.1, 0.2], # sentence 3 + ] + ), # step 2: - torch.FloatTensor([ - # eos w1 w2 - [0.10, unk, 0.50, 0.4], # sentence 1 - [0.15, unk, 0.15, 0.7], # sentence 2 - [0.00, unk, 0.00, 0.0], # sentence 3 - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [0.10, unk, 0.50, 0.4], # sentence 1 + [0.15, unk, 0.15, 0.7], # sentence 2 + [0.00, unk, 0.00, 0.0], # sentence 3 + ] + ), # step 3: - torch.FloatTensor([ - # eos w1 w2 - [0.9, unk, 0.05, 0.05], # sentence 1 - [0.0, unk, 0.00, 0.0], # sentence 2 - [0.0, unk, 0.00, 0.0], # sentence 3 - ]), + torch.FloatTensor( + [ + # eos w1 w2 + [0.9, unk, 0.05, 0.05], # sentence 1 + [0.0, unk, 0.00, 0.0], # sentence 2 + [0.0, unk, 0.00, 0.0], # sentence 3 + ] + ), ] expected_scores = [ [0.6, 0.7, 0.5, 0.9], # sentence 1 @@ -86,21 +91,21 @@ def test_sequence_scorer(self): scorer = SequenceScorer(task.target_dictionary) for sample in data_itr: hypos = task.inference_step(scorer, [model], sample) - for id, hypos_id in zip(sample['id'].tolist(), hypos): - self.assertHypoTokens(hypos_id[0], data[id]['target']) + for id, hypos_id in zip(sample["id"].tolist(), hypos): + self.assertHypoTokens(hypos_id[0], data[id]["target"]) self.assertHypoScore(hypos_id[0], expected_scores[id]) def assertHypoTokens(self, hypo, tokens): - self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens)) + self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens)) - def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): + def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0): pos_scores = torch.FloatTensor(pos_probs).log() - self.assertAlmostEqual(hypo['positional_scores'], pos_scores) - self.assertEqual(pos_scores.numel(), hypo['tokens'].numel()) + self.assertAlmostEqual(hypo["positional_scores"], pos_scores) + self.assertEqual(pos_scores.numel(), hypo["tokens"].numel()) score = pos_scores.sum() if normalized: - score /= pos_scores.numel()**lenpen - self.assertLess(abs(score - hypo['score']), 1e-6) + score /= pos_scores.numel() ** lenpen + self.assertLess(abs(score - hypo["score"]), 1e-6) def assertAlmostEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") @@ -111,5 +116,5 @@ def assertTensorEqual(self, t1, t2): self.assertEqual(t1.ne(t2).long().sum(), 0) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_sparse_multihead_attention.py b/tests/test_sparse_multihead_attention.py index eaf9742cdf..3e32b25a7f 100644 --- a/tests/test_sparse_multihead_attention.py +++ b/tests/test_sparse_multihead_attention.py @@ -3,46 +3,112 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import unittest + +import torch from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention class TestSparseMultiheadAttention(unittest.TestCase): def test_sparse_multihead_attention(self): attn_weights = torch.randn(1, 8, 8) - bidirectional_sparse_mask = torch.tensor([ - [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], - [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], - [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], - [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0] - ]) - - bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True) - bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8) - torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)) - - sparse_mask = torch.tensor([ - [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), - float('-inf'), float('-inf')], - [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], - [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], - [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')], - [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')], - [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], - ]) - - attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False) + bidirectional_sparse_mask = torch.tensor( + [ + [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], + [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], + [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], + [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], + [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], + [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], + [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], + [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], + ] + ) + + bidirectional_attention = SparseMultiheadAttention( + 16, 1, stride=4, expressivity=1, is_bidirectional=True + ) + bidirectional_attention_sparse_mask = ( + bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8) + ) + torch.all( + torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask) + ) + + sparse_mask = torch.tensor( + [ + [ + 0, + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + ], + [ + 0, + 0, + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + ], + [ + 0, + 0, + 0, + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + ], + [ + 0, + 0, + 0, + 0, + float("-inf"), + float("-inf"), + float("-inf"), + float("-inf"), + ], + [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")], + [ + float("-inf"), + float("-inf"), + float("-inf"), + 0, + 0, + 0, + float("-inf"), + float("-inf"), + ], + [ + float("-inf"), + float("-inf"), + float("-inf"), + 0, + 0, + 0, + 0, + float("-inf"), + ], + [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], + ] + ) + + attention = SparseMultiheadAttention( + 16, 1, stride=4, expressivity=1, is_bidirectional=False + ) attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8) torch.all(torch.eq(attention_sparse_mask, sparse_mask)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py index 41abb194da..ea315b4e67 100644 --- a/tests/test_token_block_dataset.py +++ b/tests/test_token_block_dataset.py @@ -5,15 +5,12 @@ import unittest +import tests.utils as test_utils import torch - from fairseq.data import TokenBlockDataset -import tests.utils as test_utils - class TestTokenBlockDataset(unittest.TestCase): - def _build_dataset(self, data, **kwargs): sizes = [len(x) for x in data] underlying_ds = test_utils.TestDataset(data) @@ -25,7 +22,7 @@ def test_eos_break_mode(self): torch.tensor([1], dtype=torch.long), torch.tensor([8, 7, 6, 1], dtype=torch.long), ] - ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') + ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos") self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[1].tolist(), [1]) self.assertEqual(ds[2].tolist(), [8, 7, 6, 1]) @@ -35,7 +32,7 @@ def test_eos_break_mode(self): torch.tensor([8, 7, 6, 1], dtype=torch.long), torch.tensor([1], dtype=torch.long), ] - ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') + ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos") self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[1].tolist(), [8, 7, 6, 1]) self.assertEqual(ds[2].tolist(), [1]) @@ -46,7 +43,7 @@ def test_block_break_mode(self): torch.tensor([8, 7, 6, 1], dtype=torch.long), torch.tensor([9, 1], dtype=torch.long), ] - ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none') + ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode="none") self.assertEqual(ds[0].tolist(), [5, 4, 3]) self.assertEqual(ds[1].tolist(), [2, 1, 8]) self.assertEqual(ds[2].tolist(), [7, 6, 1]) @@ -58,7 +55,9 @@ def test_complete_break_mode(self): torch.tensor([8, 7, 6, 1], dtype=torch.long), torch.tensor([9, 1], dtype=torch.long), ] - ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete') + ds = self._build_dataset( + data, block_size=6, pad=0, eos=1, break_mode="complete" + ) self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1]) @@ -68,7 +67,9 @@ def test_complete_break_mode(self): torch.tensor([1], dtype=torch.long), torch.tensor([6, 1], dtype=torch.long), ] - ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete') + ds = self._build_dataset( + data, block_size=3, pad=0, eos=1, break_mode="complete" + ) self.assertEqual(ds[0].tolist(), [4, 3, 2, 1]) self.assertEqual(ds[1].tolist(), [5, 1, 1]) self.assertEqual(ds[2].tolist(), [6, 1]) diff --git a/tests/test_train.py b/tests/test_train.py index 048acaca54..1b7e027c0c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -10,17 +10,16 @@ from unittest.mock import MagicMock, patch import torch - -from fairseq import data, checkpoint_utils +from fairseq import checkpoint_utils, data def mock_trainer(epoch, num_updates, iterations_in_epoch): trainer = MagicMock() trainer.load_checkpoint.return_value = { - 'train_iterator': { - 'epoch': epoch, - 'iterations_in_epoch': iterations_in_epoch, - 'shuffle': False, + "train_iterator": { + "epoch": epoch, + "iterations_in_epoch": iterations_in_epoch, + "shuffle": False, }, } trainer.get_num_updates.return_value = num_updates @@ -38,10 +37,17 @@ def mock_dict(): def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) tokens_ds = data.TokenBlockDataset( - tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, + tokens, + sizes=[tokens.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, ) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) - dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) + dataset = data.LanguagePairDataset( + tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False + ) epoch_itr = data.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, @@ -52,7 +58,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc def get_mock_args(finetune_from_model=None): args_mock = MagicMock() - args_mock.optimizer_overrides = '{}' + args_mock.optimizer_overrides = "{}" args_mock.reset_dataloader = False args_mock.reset_meters = False args_mock.reset_optimizer = False @@ -63,15 +69,14 @@ def get_mock_args(finetune_from_model=None): class TestLoadCheckpoint(unittest.TestCase): - def setUp(self): self.args_mock = get_mock_args() self.patches = { - 'os.makedirs': MagicMock(), - 'os.path.join': MagicMock(), - 'os.path.isfile': MagicMock(return_value=True), - 'os.path.isabs': MagicMock(return_value=False), - 'fairseq.file_io.PathManager.exists': MagicMock(return_value=False), + "os.makedirs": MagicMock(), + "os.path.join": MagicMock(), + "os.path.isfile": MagicMock(return_value=True), + "os.path.isabs": MagicMock(return_value=False), + "fairseq.file_io.PathManager.exists": MagicMock(return_value=False), } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] @@ -95,7 +100,7 @@ def test_load_partial_checkpoint(self): self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) - self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) + self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50) self.assertEqual(epoch_itr.iterations_in_epoch, 51) for _ in range(150 - 52): @@ -120,27 +125,32 @@ def test_load_full_checkpoint(self): self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.iterations_in_epoch, 0) - self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) + self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0) def test_load_no_checkpoint(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - self.patches['os.path.isfile'].return_value = False + self.patches["os.path.isfile"].return_value = False _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.iterations_in_epoch, 0) - self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) + self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0) def test_finetune_from_model_args_conflict(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - for arg in ['reset_optimizer', 'reset_lr_scheduler', 'reset_meters', 'reset_dataloader']: + for arg in [ + "reset_optimizer", + "reset_lr_scheduler", + "reset_meters", + "reset_dataloader", + ]: with self.subTest(arg=arg): args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") setattr(args_mock, arg, True) @@ -149,7 +159,8 @@ def test_finetune_from_model_args_conflict(self): self.assertTrue( "--finetune-from-model can not be set together with either --reset-optimizer" - " or reset_lr_scheduler or reset_meters or reset_dataloader" in str(context.exception) + " or reset_lr_scheduler or reset_meters or reset_dataloader" + in str(context.exception) ) def test_finetune_from_model(self): @@ -165,11 +176,18 @@ def mock_finetune_exist(path): return True else: return False - self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist + + self.patches[ + "fairseq.file_io.PathManager.exists" + ].side_effect = mock_finetune_exist _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) - checkpoint_path, reset_optimizer, reset_lr_scheduler, \ - optimizer_overrides = trainer.load_checkpoint.call_args[0] - reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters'] + ( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + ) = trainer.load_checkpoint.call_args[0] + reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"] self.assertTrue(reset_optimizer) self.assertTrue(reset_lr_scheduler) self.assertTrue(reset_meters) @@ -185,19 +203,26 @@ def test_finetune_from_model_resume(self): # launch second time # both restore_file=checkpoint_last.pt and finetune_from_model are set def mock_finetune_exist(path): - if path == from_model_path or path.endsWith('checkpoint_last.pt'): + if path == from_model_path or path.endsWith("checkpoint_last.pt"): return True else: return False - self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist + + self.patches[ + "fairseq.file_io.PathManager.exists" + ].side_effect = mock_finetune_exist _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) - checkpoint_path, reset_optimizer, reset_lr_scheduler, \ - optimizer_overrides = trainer.load_checkpoint.call_args[0] - reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters'] + ( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + ) = trainer.load_checkpoint.call_args[0] + reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"] self.assertFalse(reset_optimizer) self.assertFalse(reset_lr_scheduler) self.assertFalse(reset_meters) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 35fb115dda..79195903e0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,24 +6,26 @@ import unittest import torch - from fairseq import utils class TestUtils(unittest.TestCase): - def test_convert_padding_direction(self): pad = 1 - left_pad = torch.LongTensor([ - [2, 3, 4, 5, 6], - [1, 7, 8, 9, 10], - [1, 1, 1, 11, 12], - ]) - right_pad = torch.LongTensor([ - [2, 3, 4, 5, 6], - [7, 8, 9, 10, 1], - [11, 12, 1, 1, 1], - ]) + left_pad = torch.LongTensor( + [ + [2, 3, 4, 5, 6], + [1, 7, 8, 9, 10], + [1, 1, 1, 11, 12], + ] + ) + right_pad = torch.LongTensor( + [ + [2, 3, 4, 5, 6], + [7, 8, 9, 10, 1], + [11, 12, 1, 1, 1], + ] + ) self.assertAlmostEqual( right_pad, @@ -44,26 +46,34 @@ def test_convert_padding_direction(self): def test_make_positions(self): pad = 1 - left_pad_input = torch.LongTensor([ - [9, 9, 9, 9, 9], - [1, 9, 9, 9, 9], - [1, 1, 1, 9, 9], - ]) - left_pad_output = torch.LongTensor([ - [2, 3, 4, 5, 6], - [1, 2, 3, 4, 5], - [1, 1, 1, 2, 3], - ]) - right_pad_input = torch.LongTensor([ - [9, 9, 9, 9, 9], - [9, 9, 9, 9, 1], - [9, 9, 1, 1, 1], - ]) - right_pad_output = torch.LongTensor([ - [2, 3, 4, 5, 6], - [2, 3, 4, 5, 1], - [2, 3, 1, 1, 1], - ]) + left_pad_input = torch.LongTensor( + [ + [9, 9, 9, 9, 9], + [1, 9, 9, 9, 9], + [1, 1, 1, 9, 9], + ] + ) + left_pad_output = torch.LongTensor( + [ + [2, 3, 4, 5, 6], + [1, 2, 3, 4, 5], + [1, 1, 1, 2, 3], + ] + ) + right_pad_input = torch.LongTensor( + [ + [9, 9, 9, 9, 9], + [9, 9, 9, 9, 1], + [9, 9, 1, 1, 1], + ] + ) + right_pad_output = torch.LongTensor( + [ + [2, 3, 4, 5, 6], + [2, 3, 4, 5, 1], + [2, 3, 1, 1, 1], + ] + ) self.assertAlmostEqual( left_pad_output, @@ -82,9 +92,9 @@ def test_clip_grad_norm_(self): params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)] for p in params: - p.grad = torch.full((5,), fill_value=2.) + p.grad = torch.full((5,), fill_value=2.0) grad_norm = utils.clip_grad_norm_(params, 1.0) - exp_grad_norm = torch.full((15,), fill_value=2.).norm() + exp_grad_norm = torch.full((15,), fill_value=2.0).norm() self.assertTrue(torch.is_tensor(grad_norm)) self.assertEqual(grad_norm, exp_grad_norm) @@ -100,5 +110,5 @@ def assertAlmostEqual(self, t1, t2): self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 44a35fdccf..91feca6b2a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,10 +7,10 @@ import os import random import sys +from io import StringIO + import torch import torch.nn.functional as F - -from io import StringIO from fairseq import options, utils from fairseq.data import Dictionary from fairseq.data.language_pair_dataset import collate @@ -20,18 +20,11 @@ FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.tasks import LegacyFairseqTask -from fairseq.tasks import FairseqTask -from fairseq_cli import ( - generate, - interactive, - preprocess, - train, - validate, -) +from fairseq.tasks import FairseqTask, LegacyFairseqTask +from fairseq_cli import generate, interactive, preprocess, train, validate -def dummy_dictionary(vocab_size, prefix='token_'): +def dummy_dictionary(vocab_size, prefix="token_"): d = Dictionary() for i in range(vocab_size): token = prefix + str(i) @@ -51,8 +44,8 @@ def dummy_dataloader( # add any missing data to samples for i, sample in enumerate(samples): - if 'id' not in sample: - sample['id'] = i + if "id" not in sample: + sample["id"] = i # create dataloader dataset = TestDataset(samples) @@ -77,48 +70,86 @@ def sequence_generator_setup(): src_lengths = torch.LongTensor([2, 2]) args = argparse.Namespace() - unk = 0. + unk = 0.0 args.beam_probs = [ # step 0: - torch.FloatTensor([ - # eos w1 w2 - # sentence 1: - [0.0, unk, 0.9, 0.1], # beam 1 - [0.0, unk, 0.9, 0.1], # beam 2 - # sentence 2: - [0.0, unk, 0.7, 0.3], - [0.0, unk, 0.7, 0.3], - ]), + torch.FloatTensor( + [ + # eos w1 w2 + # sentence 1: + [0.0, unk, 0.9, 0.1], # beam 1 + [0.0, unk, 0.9, 0.1], # beam 2 + # sentence 2: + [0.0, unk, 0.7, 0.3], + [0.0, unk, 0.7, 0.3], + ] + ), # step 1: - torch.FloatTensor([ - # eos w1 w2 prefix - # sentence 1: - [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 : 0.9*1.0) - [0.0, unk, 0.9, 0.1], # w2: 0.1 - # sentence 2: - [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 : 0.7*0.25) - [0.00, unk, 0.10, 0.9], # w2: 0.3 - ]), + torch.FloatTensor( + [ + # eos w1 w2 prefix + # sentence 1: + [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 : 0.9*1.0) + [0.0, unk, 0.9, 0.1], # w2: 0.1 + # sentence 2: + [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 : 0.7*0.25) + [0.00, unk, 0.10, 0.9], # w2: 0.3 + ] + ), # step 2: - torch.FloatTensor([ - # eos w1 w2 prefix - # sentence 1: - [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9 - [0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 : 0.1*0.1*0.6) - # sentence 2: - [0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 : 0.7*0.4*0.6) - [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9 - ]), + torch.FloatTensor( + [ + # eos w1 w2 prefix + # sentence 1: + [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9 + [ + 0.6, + unk, + 0.2, + 0.2, + ], # w2 w2: 0.1*0.1 (emit: w2 w2 : 0.1*0.1*0.6) + # sentence 2: + [ + 0.60, + unk, + 0.4, + 0.00, + ], # w1 w2: 0.7*0.4 (emit: w1 w2 : 0.7*0.4*0.6) + [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9 + ] + ), # step 3: - torch.FloatTensor([ - # eos w1 w2 prefix - # sentence 1: - [1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 : 0.1*0.9*0.9*1.0) - [1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 : 0.1*0.9*0.1*1.0) - # sentence 2: - [0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 : 0.3*0.9*0.99*0.1) - [1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 : 0.7*0.4*0.4*1.0) - ]), + torch.FloatTensor( + [ + # eos w1 w2 prefix + # sentence 1: + [ + 1.0, + unk, + 0.0, + 0.0, + ], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 : 0.1*0.9*0.9*1.0) + [ + 1.0, + unk, + 0.0, + 0.0, + ], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 : 0.1*0.9*0.1*1.0) + # sentence 2: + [ + 0.1, + unk, + 0.5, + 0.4, + ], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 : 0.3*0.9*0.99*0.1) + [ + 1.0, + unk, + 0.0, + 0.0, + ], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 : 0.7*0.4*0.4*1.0) + ] + ), ] task = TestTranslationTask.setup_task(args, d, d) @@ -132,18 +163,18 @@ def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): def _create_dummy_data(filename): data = torch.rand(num_examples * maxlen) data = 97 + torch.floor(26 * data).int() - with open(os.path.join(data_dir, filename), 'w') as h: + with open(os.path.join(data_dir, filename), "w") as h: offset = 0 for _ in range(num_examples): ex_len = random.randint(1, maxlen) - ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) + ex_str = " ".join(map(chr, data[offset : offset + ex_len])) print(ex_str, file=h) offset += ex_len def _create_dummy_alignment_data(filename_src, filename_tgt, filename): - with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ - open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ - open(os.path.join(data_dir, filename), 'w') as h: + with open(os.path.join(data_dir, filename_src), "r") as src_f, open( + os.path.join(data_dir, filename_tgt), "r" + ) as tgt_f, open(os.path.join(data_dir, filename), "w") as h: for src, tgt in zip(src_f, tgt_f): src_len = len(src.split()) tgt_len = len(tgt.split()) @@ -151,31 +182,42 @@ def _create_dummy_alignment_data(filename_src, filename_tgt, filename): num_alignments = random.randint(avg_len // 2, 2 * avg_len) src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() - ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) + ex_str = " ".join( + [ + "{}-{}".format(src, tgt) + for src, tgt in zip(src_indices, tgt_indices) + ] + ) print(ex_str, file=h) - _create_dummy_data('train.in') - _create_dummy_data('train.out') - _create_dummy_data('valid.in') - _create_dummy_data('valid.out') - _create_dummy_data('test.in') - _create_dummy_data('test.out') + _create_dummy_data("train.in") + _create_dummy_data("train.out") + _create_dummy_data("valid.in") + _create_dummy_data("valid.out") + _create_dummy_data("test.in") + _create_dummy_data("test.out") if alignment: - _create_dummy_alignment_data('train.in', 'train.out', 'train.align') - _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') - _create_dummy_alignment_data('test.in', 'test.out', 'test.align') + _create_dummy_alignment_data("train.in", "train.out", "train.align") + _create_dummy_alignment_data("valid.in", "valid.out", "valid.align") + _create_dummy_alignment_data("test.in", "test.out", "test.align") def preprocess_lm_data(data_dir): preprocess_parser = options.get_preprocessing_parser() - preprocess_args = preprocess_parser.parse_args([ - '--only-source', - '--trainpref', os.path.join(data_dir, 'train.out'), - '--validpref', os.path.join(data_dir, 'valid.out'), - '--testpref', os.path.join(data_dir, 'test.out'), - '--destdir', data_dir, - ]) + preprocess_args = preprocess_parser.parse_args( + [ + "--only-source", + "--trainpref", + os.path.join(data_dir, "train.out"), + "--validpref", + os.path.join(data_dir, "valid.out"), + "--testpref", + os.path.join(data_dir, "test.out"), + "--destdir", + data_dir, + ] + ) preprocess.main(preprocess_args) @@ -183,15 +225,24 @@ def preprocess_translation_data(data_dir, extra_flags=None): preprocess_parser = options.get_preprocessing_parser() preprocess_args = preprocess_parser.parse_args( [ - '--source-lang', 'in', - '--target-lang', 'out', - '--trainpref', os.path.join(data_dir, 'train'), - '--validpref', os.path.join(data_dir, 'valid'), - '--testpref', os.path.join(data_dir, 'test'), - '--thresholdtgt', '0', - '--thresholdsrc', '0', - '--destdir', data_dir, - ] + (extra_flags or []), + "--source-lang", + "in", + "--target-lang", + "out", + "--trainpref", + os.path.join(data_dir, "train"), + "--validpref", + os.path.join(data_dir, "valid"), + "--testpref", + os.path.join(data_dir, "test"), + "--thresholdtgt", + "0", + "--thresholdsrc", + "0", + "--destdir", + data_dir, + ] + + (extra_flags or []), ) preprocess.main(preprocess_args) @@ -200,43 +251,72 @@ def preprocess_summarization_data(data_dir, extra_flags=None): preprocess_parser = options.get_preprocessing_parser() preprocess_args = preprocess_parser.parse_args( [ - '--source-lang', 'in', - '--target-lang', 'out', - '--trainpref', os.path.join(data_dir, 'train'), - '--validpref', os.path.join(data_dir, 'valid'), - '--testpref', os.path.join(data_dir, 'test'), - '--thresholdtgt', '0', - '--thresholdsrc', '0', - '--joined-dictionary', - '--destdir', data_dir, - ] + (extra_flags or []), + "--source-lang", + "in", + "--target-lang", + "out", + "--trainpref", + os.path.join(data_dir, "train"), + "--validpref", + os.path.join(data_dir, "valid"), + "--testpref", + os.path.join(data_dir, "test"), + "--thresholdtgt", + "0", + "--thresholdsrc", + "0", + "--joined-dictionary", + "--destdir", + data_dir, + ] + + (extra_flags or []), ) preprocess.main(preprocess_args) -def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, - lang_flags=None, extra_valid_flags=None): +def train_translation_model( + data_dir, + arch, + extra_flags=None, + task="translation", + run_validation=False, + lang_flags=None, + extra_valid_flags=None, +): if lang_flags is None: lang_flags = [ - '--source-lang', 'in', - '--target-lang', 'out', + "--source-lang", + "in", + "--target-lang", + "out", ] train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ - '--task', task, + "--task", + task, data_dir, - '--save-dir', data_dir, - '--arch', arch, - '--optimizer', 'nag', - '--lr', '0.05', - '--max-tokens', '500', - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--num-workers', '0', - ] + lang_flags + (extra_flags or []), + "--save-dir", + data_dir, + "--arch", + arch, + "--optimizer", + "nag", + "--lr", + "0.05", + "--max-tokens", + "500", + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--num-workers", + "0", + ] + + lang_flags + + (extra_flags or []), ) train.main(train_args) @@ -246,14 +326,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' validate_args = options.parse_args_and_arch( validate_parser, [ - '--task', task, + "--task", + task, data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--valid-subset', 'valid', - '--max-tokens', '500', - '--no-progress-bar', - '--num-workers', '0', - ] + lang_flags + (extra_valid_flags or []) + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--valid-subset", + "valid", + "--max-tokens", + "500", + "--no-progress-bar", + "--num-workers", + "0", + ] + + lang_flags + + (extra_valid_flags or []), ) validate.main(validate_args) @@ -261,21 +348,28 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' def generate_main(data_dir, extra_flags=None): if extra_flags is None: extra_flags = [ - '--print-alignment', + "--print-alignment", ] generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, [ data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--beam', '3', - '--batch-size', '64', - '--max-len-b', '5', - '--gen-subset', 'valid', - '--no-progress-bar', - '--num-workers', '0', - ] + (extra_flags or []), + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--beam", + "3", + "--batch-size", + "64", + "--max-len-b", + "5", + "--gen-subset", + "valid", + "--no-progress-bar", + "--num-workers", + "0", + ] + + (extra_flags or []), ) # evaluate model in batch mode @@ -283,16 +377,15 @@ def generate_main(data_dir, extra_flags=None): # evaluate model interactively generate_args.buffer_size = 0 - generate_args.input = '-' + generate_args.input = "-" generate_args.batch_size = None orig_stdin = sys.stdin - sys.stdin = StringIO('h e l l o\n') + sys.stdin = StringIO("h e l l o\n") interactive.main(generate_args) sys.stdin = orig_stdin class TestDataset(torch.utils.data.Dataset): - def __init__(self, data): super().__init__() self.data = data @@ -306,7 +399,6 @@ def __len__(self): class TestTranslationTask(LegacyFairseqTask): - def __init__(self, args, src_dict, tgt_dict, model): super().__init__(args) self.src_dict = src_dict @@ -369,8 +461,8 @@ def reorder_encoder_out(self, encoder_out, new_order): class TestIncrementalDecoder(FairseqIncrementalDecoder): def __init__(self, args, dictionary): super().__init__(dictionary) - assert hasattr(args, 'beam_probs') or hasattr(args, 'probs') - args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100) + assert hasattr(args, "beam_probs") or hasattr(args, "probs") + args.max_decoder_positions = getattr(args, "max_decoder_positions", 100) self.args = args def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): @@ -384,18 +476,19 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # determine number of steps if incremental_state is not None: # cache step number - step = utils.get_incremental_state(self, incremental_state, 'step') + step = utils.get_incremental_state(self, incremental_state, "step") if step is None: step = 0 - utils.set_incremental_state(self, incremental_state, 'step', step + 1) + utils.set_incremental_state(self, incremental_state, "step", step + 1) steps = [step] else: steps = list(range(tgt_len)) # define output in terms of raw probs - if hasattr(self.args, 'probs'): - assert self.args.probs.dim() == 3, \ - 'expected probs to have size bsz*steps*vocab' + if hasattr(self.args, "probs"): + assert ( + self.args.probs.dim() == 3 + ), "expected probs to have size bsz*steps*vocab" probs = self.args.probs.index_select(1, torch.LongTensor(steps)) else: probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() @@ -403,7 +496,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # args.beam_probs gives the probability for every vocab element, # starting with eos, then unknown, and then the rest of the vocab if step < len(self.args.beam_probs): - probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step] + probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step] else: probs[:, i, self.dictionary.eos()] = 1.0 @@ -475,8 +568,8 @@ def __init__(self, args, dictionary): self.args = args def forward(self, src_tokens, src_lengths=None, **kwargs): - assert 'fancy_other_input' in kwargs - assert kwargs['fancy_other_input'] is not None + assert "fancy_other_input" in kwargs + assert kwargs["fancy_other_input"] is not None return EncoderOut( encoder_out=src_tokens, encoder_padding_mask=None, @@ -508,8 +601,8 @@ def build_model(cls, args, task): return cls(encoder, decoder) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): - encoder_out = self.encoder( - src_tokens, src_lengths=src_lengths, **kwargs) + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) decoder_out = self.decoder( - prev_output_tokens, encoder_out=encoder_out, **kwargs) + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) return decoder_out diff --git a/train.py b/train.py index 3967ef48f3..321de3d9b5 100644 --- a/train.py +++ b/train.py @@ -10,5 +10,5 @@ from fairseq_cli.train import cli_main -if __name__ == '__main__': +if __name__ == "__main__": cli_main() From 65e11a37d5f5660bc5a02d4779c16afb0101ec54 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Mon, 19 Oct 2020 06:10:36 -0700 Subject: [PATCH 223/707] Readme with instructions to generate and evaluate with a 12B model (#1351) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1351 Reviewed By: edunov Differential Revision: D24386349 Pulled By: huihuifan fbshipit-source-id: ade362d7cb64e24e6b2689ba87c53636073d2246 --- examples/m2m_100/README.md | 213 +++++++++++++++++- .../m2m_100/process_data/clean_histogram.py | 52 +++++ examples/m2m_100/process_data/dedup_data.py | 91 ++++++++ .../process_data/remove_too_much_punc.py | 38 ++++ examples/m2m_100/tokenizers/README.md | 18 ++ 5 files changed, 401 insertions(+), 11 deletions(-) create mode 100644 examples/m2m_100/process_data/clean_histogram.py create mode 100644 examples/m2m_100/process_data/dedup_data.py create mode 100644 examples/m2m_100/process_data/remove_too_much_punc.py create mode 100644 examples/m2m_100/tokenizers/README.md diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index d2892cb2d6..a87c0f5748 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -1,18 +1,209 @@ -# MMMT Tokenizer +# Beyond English-Centric Multilingual Machine Translation -We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results. +## Introduction +In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT. -To reproduce the results, follow these steps: +If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below. +0. **Generation Data** + +To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers. +```bash +# WMT - use sacrebleu, example here: +sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr +sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en + +# WAT +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip +unzip wat2019.my-en.zip + +# FLORES +# download from: https://github.com/facebookresearch/flores + +# TED - need to detokenize with Moses! +# from: https://github.com/neulab/word-embeddings-for-nmt +wget http://phontron.com/data/ted_talks.tar.gz + +# Autshumato +# request to download: https://repo.sadilar.org/handle/20.500.12185/397 + +# Tatoeba Challenge +# available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge ``` -tgt_lang=... -reference_translation=... -cat generation_output | grep -P "^H" |sort -V |cut -f 3- |sh tok.sh $tgt_lang > hyp -cat $reference_translation |sh tok.sh $tgt_lang > ref -sacrebleu -tok 'none' ref < hyp + +1. **Training Data** + +To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data. + +2. **Preprocess Data** + +After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data. + +```bash +# preprocess data + +# remove sentences with more than 50% punctuation +python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py + +# deduplicate training data +paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup +echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)" +cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src +cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt + +# remove all instances of evaluation data from the training data +python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py + +# frequency cleaning +wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz +tar -xvzf histograms.tar.gz +python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms + +# apply SPM +wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model +python /path/to/fairseq/scripts/spm_encode.py \ + --model spm.128k.model \ + --output_format=piece \ + --inputs=/path/to/input/file/here \ + --outputs=/path/to/output/file/here + +# length ratio cleaning +perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250 + +# binarize data +wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt +fairseq-preprocess \ + --source-lang $src --target-lang $tgt \ + --testpref spm.$src.$tgt \ + --thresholdsrc 0 --thresholdtgt 0 \ + --destdir data_bin \ + --srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt ``` -# Installation +3. **Training Scripts** + +To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/master/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale). + +4. **Generation** + +To generate from our models, follow the the commands in the generation section below. + + +If you use any of the resources listed here, please cite: +```bibtex +@article{fan2020beyond, + title={Beyond English-Centric Multilingual Machine Translation}, + author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand}, + journal={arXiv preprint}, + year={2020} +} + +@article{schwenk2019ccmatrix, + title={Ccmatrix: Mining billions of high-quality parallel sentences on the web}, + author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand}, + journal={arXiv preprint arXiv:1911.04944}, + year={2019} +} + +@article{el2019massive, + title={A Massive Collection of Cross-Lingual Web-Document Pairs}, + author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp}, + journal={arXiv preprint arXiv:1911.06154}, + year={2019} +} +``` -Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh -If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install + +## Trained Models + +Looking for other trained models? Check back soon. + +Model | Description | Download +---|---|--- +`12b_last_checkpoint` | 12B parameter model trained on many-to-many training data for 100 languages | [12b_last_checkpoint](https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt) + + +## SentencePiece Model + +```bash +wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model +``` + +## Generation with M2M-100 + +### Encode using our SentencePiece Model + +Note: Install SentencePiece from [here](https://github.com/google/sentencepiece) + +```bash +fairseq=/path/to/fairseq +cd $fairseq +sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de +sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr +wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model +for lang in de fr ; do + python scripts/spm_encode.py \ + --model spm.128k.model \ + --output_format=piece \ + --inputs=raw_input.de-fr.${lang} \ + --outputs=spm.de-fr.${lang} +done +``` + +### Binarization + +```bash +wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt +fairseq-preprocess \ + --source-lang de --target-lang fr \ + --testpref spm.de-fr \ + --thresholdsrc 0 --thresholdtgt 0 \ + --destdir data_bin \ + --srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt +``` + +### Generation on a V100 GPU + +```bash +wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt +wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt +wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt +fairseq-generate \ + data_bin \ + --batch-size 1 \ + --path 12b_last_checkpoint.pt \ + --fixed-dictionary model_dict.128k.txt \ + -s de -t fr \ + --remove-bpe 'sentencepiece' \ + --beam 5 \ + --task translation_multi_simple_epoch \ + --lang-pairs language_pairs.txt \ + --decoder-langtok --encoder-langtok src \ + --gen-subset test \ + --fp16 \ + --dataset-impl mmap \ + --distributed-world-size 1 --distributed-no-spawn \ + --pipeline-model-parallel \ + --pipeline-chunks 1 \ + --pipeline-encoder-balance '[26]' \ + --pipeline-encoder-devices '[0]' \ + --pipeline-decoder-balance '[1,24,1]' \ + --pipeline-decoder-devices '[0,1,0]' > gen_out +``` +## Evaluation with M2M-100 + +### Tokenization + +Note: Refer to tokenizers/README.md for more details on tokenization. + +```bash +cd ${fairseq}/examples/m2m_100 +cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp +cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref +``` + +### BLEU + +```bash +sacrebleu -tok 'none' ref < hyp +``` diff --git a/examples/m2m_100/process_data/clean_histogram.py b/examples/m2m_100/process_data/clean_histogram.py new file mode 100644 index 0000000000..e24e073dc0 --- /dev/null +++ b/examples/m2m_100/process_data/clean_histogram.py @@ -0,0 +1,52 @@ +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--src', type=str, help='Source language') +parser.add_argument('--tgt', type=str, help='Target language') +parser.add_argument('--src-file', type=str, help='Input source file') +parser.add_argument('--tgt-file', type=str, help='Input target file') +parser.add_argument('--src-output-file', type=str, help='Output source file') +parser.add_argument('--tgt-output-file', type=str, help='Output target file') +parser.add_argument('--threshold', type=float, default=0.5, help='Threshold') +parser.add_argument('--threshold-character', type=str, default=']', help='Threshold character') +parser.add_argument('--histograms', type=str, help='Path to histograms') + +args = parser.parse_args() + + +def read_hist(f): + ch = [] + for line in f: + c = line[0] + if c == args.threshold_character: + break + ch.append(c) + return ch + + +with(open("{}/{}".format(args.histograms, args.src), 'r', encoding='utf8')) as f: + ch1 = read_hist(f) + +with(open("{}/{}".format(args.histograms, args.tgt), 'r', encoding='utf8')) as f: + ch2 = read_hist(f) + +print("Accepted characters for {}: {}".format(args.src, ch1)) +print("Accepted characters for {}: {}".format(args.tgt, ch2)) + +with open(args.src_file, 'r', encoding='utf8') as fs1, open(args.tgt_file, 'r', encoding='utf8') as fs2, open(args.src_output_file, 'w', encoding='utf8') as fos1, open(args.tgt_output_file, 'w', encoding='utf8') as fos2: + ls1 = fs1.readline() + ls2 = fs2.readline() + + while ls1 or ls2: + cnt1 = len([c for c in ls1.strip() if c in ch1]) + cnt2 = len([c for c in ls2.strip() if c in ch2]) + + if cnt1 / len(ls1) > args.threshold and cnt2 / len(ls2) > args.threshold: + fos1.write(ls1) + fos2.write(ls2) + else: + print("{} {} {} \n{} {} {}".format(args.src, cnt1 / len(ls1), ls1.strip(), args.tgt, cnt2 / len(ls2), ls2.strip())) + + ls1 = fs1.readline() + ls2 = fs2.readline() + \ No newline at end of file diff --git a/examples/m2m_100/process_data/dedup_data.py b/examples/m2m_100/process_data/dedup_data.py new file mode 100644 index 0000000000..58d9ed1cd1 --- /dev/null +++ b/examples/m2m_100/process_data/dedup_data.py @@ -0,0 +1,91 @@ +import argparse +from collections import namedtuple +import os + +DATADIR = "/path/to/train_data" +DEDUP_FROM_DIR = "/path/to/eval/data" +OUTPUT_DIR = "/path/to/output/data" + + +def main(args): + languages = set() + for language_directory in os.listdir(DATADIR): + if "_" in language_directory: + src, tgt = language_directory.split("_") + languages.add(LanguagePair(src=src, tgt=tgt)) + + data = existing_data() + train_languages = sorted(languages) + for language_pair in train_languages[args.start_index:args.start_index + args.size]: + print(language_pair) + dedup(language_pair, data) + + +LanguagePair = namedtuple("LanguagePair", ["src", "tgt"]) + + +def existing_data(): + data = set() + for file in os.listdir(DEDUP_FROM_DIR): + with open(os.path.join(DEDUP_FROM_DIR, file)) as f: + data |= set(f.readlines()) + return data + +def dedup(language_pair, data, verbose=True, output=True): + train_filenames = LanguagePair( + src=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.src}", + tgt=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.tgt}", + ) + + output_filenames = LanguagePair( + src=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.src}", + tgt=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.tgt}" + ) + + # If output exists, skip this pair. It has already been done. + if (os.path.exists(output_filenames.src) and + os.path.exists(output_filenames.tgt)): + if verbose: + print(f"{language_pair.src}-{language_pair.tgt} already done.") + return + + if verbose: + print(f"{language_pair.src}-{language_pair.tgt} ready, will check dups.") + + # If there is no output, no need to actually do the loop. + if not output: + return + + if os.path.exists(train_filenames.src) and os.path.exists(train_filenames.tgt): + with open(train_filenames.src) as f: + train_source = f.readlines() + + with open(train_filenames.tgt) as f: + train_target = f.readlines() + + # do dedup + new_train_source = [] + new_train_target = [] + for i, train_line in enumerate(train_source): + if train_line not in data and train_target[i] not in data: + new_train_source.append(train_line) + new_train_target.append(train_target[i]) + + assert len(train_source) == len(train_target) + assert len(new_train_source) == len(new_train_target) + assert len(new_train_source) <= len(train_source) + + with open(output_filenames.src, "w") as o: + for line in new_train_source: + o.write(line) + + with open(output_filenames.tgt, "w") as o: + for line in new_train_target: + o.write(line) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("-s", "--start-index", required=True, type=int) + parser.add_argument("-n", "--size", required=True, type=int) + main(parser.parse_args()) diff --git a/examples/m2m_100/process_data/remove_too_much_punc.py b/examples/m2m_100/process_data/remove_too_much_punc.py new file mode 100644 index 0000000000..cd5842d919 --- /dev/null +++ b/examples/m2m_100/process_data/remove_too_much_punc.py @@ -0,0 +1,38 @@ +import gzip +import argparse +from string import punctuation + +def len_no_punc(s, punc): + return len([ch for ch in s if ch in punc]) + +def filter_overpunc(len_npunc, len_sen): + return len_npunc < 0.5*len_sen + +def main(args): + punc = punctuation + "—|–" + print('Processing file {}'.format(args.input)) + with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv: + with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc: + with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt: + line = tsv.readline() + if not line: + continue + fields = line.split('\t') + + src, tgt = fields[1], fields[2] + + nchar_npunc_src = len_no_punc(src, punc) + nchar_npunc_tgt = len_no_punc(tgt, punc) + + if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)): + fsrc.write(src.strip() + '\n') + ftgt.write(tgt.strip() + '\n') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, type=str) + parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output') + parser.add_argument('--bitext', type=str, required=True, help='language direction') + parser.add_argument('--src-lang', type=str, required=True, help='Source language') + parser.add_argument('--tgt-lang', type=str, required=True, help='Target language') + main(parser.parse_args()) diff --git a/examples/m2m_100/tokenizers/README.md b/examples/m2m_100/tokenizers/README.md new file mode 100644 index 0000000000..e116932bc8 --- /dev/null +++ b/examples/m2m_100/tokenizers/README.md @@ -0,0 +1,18 @@ +# M2M-100 Tokenization + +We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results. + +To reproduce the results, follow these steps: + +``` +tgt_lang=... +reference_translation=... +cat generation_output | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh $tgt_lang > hyp +cat $reference_translation |sh tok.sh $tgt_lang > ref +sacrebleu -tok 'none' ref < hyp +``` + +## Installation + +Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh +If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install From e3168f74a84523415e46d848e4f4ec9a2713ad6f Mon Sep 17 00:00:00 2001 From: Angela Fan Date: Mon, 19 Oct 2020 09:09:24 -0700 Subject: [PATCH 224/707] minor fix for linter (#1360) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1360 Reviewed By: myleott Differential Revision: D24393217 Pulled By: huihuifan fbshipit-source-id: a110ef6958b1e15cd8c4e23b610db5cfc994f06d --- examples/m2m_100/process_data/remove_too_much_punc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/m2m_100/process_data/remove_too_much_punc.py b/examples/m2m_100/process_data/remove_too_much_punc.py index cd5842d919..6c280de240 100644 --- a/examples/m2m_100/process_data/remove_too_much_punc.py +++ b/examples/m2m_100/process_data/remove_too_much_punc.py @@ -15,8 +15,6 @@ def main(args): with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc: with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt: line = tsv.readline() - if not line: - continue fields = line.split('\t') src, tgt = fields[1], fields[2] From 9b8b46407094ea1671e9ed89b6db3f57e8665536 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 19 Oct 2020 09:22:28 -0700 Subject: [PATCH 225/707] Package config and examples with fairseq (#1356) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1356 Reviewed By: alexeib Differential Revision: D24385688 Pulled By: myleott fbshipit-source-id: 72c4a702d93d2854a6409d42913d7413207cb61e --- examples/linformer/README.md | 2 +- fairseq/utils.py | 12 +++- setup.py | 129 ++++++++++++++++++++++------------- tests/test_binaries.py | 2 + 4 files changed, 94 insertions(+), 51 deletions(-) diff --git a/examples/linformer/README.md b/examples/linformer/README.md index e5c11e052d..cedd667835 100644 --- a/examples/linformer/README.md +++ b/examples/linformer/README.md @@ -6,7 +6,7 @@ This example contains code to train Linformer models as described in our paper ## Training a new Linformer RoBERTa model You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md), -but replace the architecture with `--arch linformer_roberta_base` in your training command. +updating your training command with `--user-dir examples/linformer/src --arch linformer_roberta_base`. ## Citation diff --git a/fairseq/utils.py b/fairseq/utils.py index fdbf66cf3f..0044d76f3d 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -433,11 +433,17 @@ def import_user_module(args): if module_path is not None: module_path = os.path.abspath(args.user_dir) if not os.path.exists(module_path): - fairseq_rel_path = os.path.join( - os.path.dirname(__file__), "..", args.user_dir - ) + fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path + else: + fairseq_rel_path = os.path.join( + os.path.dirname(__file__), "..", args.user_dir + ) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + raise FileNotFoundError(module_path) # We want to import the module under a unique name so that it doesn't # collide with existing modules. At the same time we don't want to diff --git a/setup.py b/setup.py index ad2ea2088b..54c752d257 100644 --- a/setup.py +++ b/setup.py @@ -127,51 +127,86 @@ def include_dirs(self, dirs): ) -setup( - name="fairseq", - version="0.9.0", - description="Facebook AI Research Sequence-to-Sequence Toolkit", - url="https://github.com/pytorch/fairseq", - classifiers=[ - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.6", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], - long_description=readme, - long_description_content_type="text/markdown", - setup_requires=[ - "cython", - "numpy", - "setuptools>=18.0", - ], - install_requires=[ - "cffi", - "cython", - "dataclasses", - "editdistance", - "hydra-core", - "numpy", - "regex", - "sacrebleu>=1.4.12", - "torch", - "tqdm", - ], - dependency_links=dependency_links, - packages=find_packages(exclude=["scripts", "tests"]), - ext_modules=extensions, - test_suite="tests", - entry_points={ - "console_scripts": [ - "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", - "fairseq-generate = fairseq_cli.generate:cli_main", - "fairseq-interactive = fairseq_cli.interactive:cli_main", - "fairseq-preprocess = fairseq_cli.preprocess:cli_main", - "fairseq-score = fairseq_cli.score:cli_main", - "fairseq-train = fairseq_cli.train:cli_main", - "fairseq-validate = fairseq_cli.validate:cli_main", +def do_setup(package_data): + setup( + name="fairseq", + version="0.9.0", + description="Facebook AI Research Sequence-to-Sequence Toolkit", + url="https://github.com/pytorch/fairseq", + classifiers=[ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - }, - cmdclass=cmdclass, - zip_safe=False, -) + long_description=readme, + long_description_content_type="text/markdown", + setup_requires=[ + "cython", + "numpy", + "setuptools>=18.0", + ], + install_requires=[ + "cffi", + "cython", + "dataclasses", + "editdistance", + "hydra-core", + "numpy", + "regex", + "sacrebleu>=1.4.12", + "torch", + "tqdm", + ], + dependency_links=dependency_links, + packages=find_packages( + exclude=[ + "examples", + "examples.*", + "scripts", + "scripts.*", + "tests", + "tests.*", + ] + ), + package_data=package_data, + ext_modules=extensions, + test_suite="tests", + entry_points={ + "console_scripts": [ + "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", + "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-interactive = fairseq_cli.interactive:cli_main", + "fairseq-preprocess = fairseq_cli.preprocess:cli_main", + "fairseq-score = fairseq_cli.score:cli_main", + "fairseq-train = fairseq_cli.train:cli_main", + "fairseq-validate = fairseq_cli.validate:cli_main", + ], + }, + cmdclass=cmdclass, + zip_safe=False, + ) + + +def get_files(path, relative_to="fairseq"): + all_files = [] + for root, _dirs, files in os.walk(path, followlinks=True): + root = os.path.relpath(root, relative_to) + for file in files: + if file.endswith(".pyc"): + continue + all_files.append(os.path.join(root, file)) + return all_files + + +try: + # symlink config and examples into fairseq package so package_data accepts them + os.symlink(os.path.join("..", "config"), "fairseq/config") + os.symlink(os.path.join("..", "examples"), "fairseq/examples") + package_data = { + "fairseq": get_files("fairseq/config") + get_files("fairseq/examples"), + } + do_setup(package_data) +finally: + os.unlink("fairseq/config") + os.unlink("fairseq/examples") diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 4b87afea55..c6722402a1 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -7,6 +7,7 @@ import logging import os import random +import sys import tempfile import unittest from io import StringIO @@ -294,6 +295,7 @@ def test_multilingual_transformer(self): + dec_ltok_flag, ) + @unittest.skipIf(sys.platform.lower() == "darwin", "skip latent depth test on MacOS") def test_multilingual_translation_latent_depth(self): # test with latent depth in encoder, decoder, or both encoder_latent_layer = [[], ["--encoder-latent-layer"]] From de5c2cb35aa57b7d95d75b574d937141707db0ea Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 19 Oct 2020 14:13:23 -0700 Subject: [PATCH 226/707] Fix model parallel LM (#1358) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1358 Reviewed By: alexeib Differential Revision: D24393064 Pulled By: myleott fbshipit-source-id: ee88fd1e7b203d7df6b7a65d3b1b1469e8fe9b6e --- fairseq/model_parallel/models/transformer_lm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index ed378c4320..5db6efb7b1 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -22,6 +22,11 @@ @register_model("model_parallel_transformer_lm") class ModelParallelTransformerLanguageModel(TransformerLanguageModel): + + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + @classmethod def build_model(cls, args, task): """Build a new model instance.""" From c76cb6dfb93b531369a0a7593227b31c3b99c0a3 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Mon, 19 Oct 2020 20:15:47 -0700 Subject: [PATCH 227/707] composite criterion should still use legacy criterion as it will break with subsequent diff Summary: see title Reviewed By: myleott Differential Revision: D24393903 fbshipit-source-id: 4b972b8150c7228fb32977675c6c60b13d5194d0 --- fairseq/criterions/composite_loss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py index 65341c2d3b..98e835fa6e 100644 --- a/fairseq/criterions/composite_loss.py +++ b/fairseq/criterions/composite_loss.py @@ -4,18 +4,18 @@ # LICENSE file in the root directory of this source tree. from fairseq import utils -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import LegacyFairseqCriterion, register_criterion from torch import nn @register_criterion("composite_loss") -class CompositeLoss(FairseqCriterion): +class CompositeLoss(LegacyFairseqCriterion): """This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair""" - def __init__(self, task, underlying_criterion): - super().__init__(task) - self.underlying_criterion = underlying_criterion + def __init__(self, args, task): + super().__init__(args, task) + self.underlying_criterion = args.underlying_criterion @staticmethod def add_args(parser): @@ -60,9 +60,9 @@ def get_targets(self, *unused): def decoder(self): return self.model.decoder - class _CompositeLoss(FairseqCriterion): - def __init__(self, task, underlying_criterion): - super().__init__(task) + class _CompositeLoss(LegacyFairseqCriterion): + def __init__(self, args, task, underlying_criterion): + super().__init__(args, task) self.underlying_criterion = underlying_criterion def forward(self, model, sample, reduce=True): @@ -97,4 +97,4 @@ def aggregate_logging_outputs(logging_outputs): def reduce_metrics(logging_outputs) -> None: underlying_criterion.__class__.reduce_metrics(logging_outputs) - return _CompositeLoss(task, underlying_criterion) + return _CompositeLoss(args, task, underlying_criterion) From 3b27ed7996b0315f471c795cf9b7dfcc18467cbe Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 20 Oct 2020 00:31:00 -0700 Subject: [PATCH 228/707] Enable Hydra configs in fairseq (#1343) (#1510) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1510 this is the main pr that switches on hydra functionality in fairseq we migrate "args" object into omegaconf "DictConfig" at all legacy entry points in addition this migrates various components from secondary registries (like bpe encoders and tokenizers) to make the migration smoother i am going through code that references migrated fairseq components and changing it to inherit from "Legacy*" components instead. hopefully tests will catch most of this Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1343 Reviewed By: myleott Differential Revision: D23973928 Pulled By: alexeib fbshipit-source-id: dd9554981fff51ea75c1ff343874d1d6e61793c9 --- config/config.yaml | 116 ++++- config/config_eval_lm.yaml | 7 - config/criterion/adaptive_loss.yaml | 4 +- config/criterion/cross_entropy.yaml | 3 +- config/params/eval_lm_params.yaml | 105 ---- config/params/training_params.yaml | 95 ---- docs/hydra_integration.md | 26 +- docs/tutorial_classifying_names.rst | 2 +- examples/noisychannel/rerank.py | 8 +- examples/roberta/wsc/wsc_criterion.py | 4 +- .../unsupervised_quality_estimation/README.md | 2 +- fairseq/checkpoint_utils.py | 249 +++++---- fairseq/criterions/__init__.py | 6 +- fairseq/criterions/adaptive_loss.py | 12 +- fairseq/criterions/cross_entropy.py | 2 +- fairseq/criterions/ctc.py | 18 +- fairseq/criterions/fairseq_criterion.py | 11 +- fairseq/data/encoders/byte_bpe.py | 23 +- fairseq/data/encoders/bytes.py | 2 +- fairseq/data/encoders/characters.py | 2 +- fairseq/data/encoders/fastbpe.py | 23 +- fairseq/data/encoders/gpt2_bpe.py | 36 +- fairseq/data/encoders/hf_bert_bpe.py | 32 +- fairseq/data/encoders/hf_byte_bpe.py | 31 +- fairseq/data/encoders/moses_tokenizer.py | 48 +- fairseq/data/encoders/nltk_tokenizer.py | 2 +- fairseq/data/encoders/sentencepiece_bpe.py | 23 +- fairseq/data/encoders/space_tokenizer.py | 2 +- fairseq/data/encoders/subword_nmt_bpe.py | 28 +- fairseq/dataclass/constants.py | 2 + fairseq/dataclass/data_class.py | 487 +++++++++++------- fairseq/dataclass/utils.py | 174 +++++-- fairseq/distributed_utils.py | 228 ++++---- fairseq/hub_utils.py | 28 +- fairseq/model_parallel/megatron_trainer.py | 5 +- .../pipeline_parallel_transformer/model.py | 8 +- .../model_parallel/models/transformer_lm.py | 4 + fairseq/models/__init__.py | 24 +- fairseq/models/bart/hub_interface.py | 16 +- fairseq/models/bart/model.py | 4 +- fairseq/models/fairseq_model.py | 53 +- fairseq/models/multilingual_transformer.py | 4 +- fairseq/models/roberta/hub_interface.py | 6 +- fairseq/models/roberta/model.py | 2 +- fairseq/models/transformer.py | 11 +- fairseq/models/transformer_lm.py | 2 +- fairseq/modules/transformer_layer.py | 14 +- fairseq/optim/__init__.py | 8 +- fairseq/optim/adam.py | 25 +- fairseq/optim/bmuf.py | 23 +- fairseq/optim/fairseq_optimizer.py | 4 +- fairseq/optim/fp16_optimizer.py | 84 +-- fairseq/optim/lr_scheduler/__init__.py | 6 +- .../optim/lr_scheduler/cosine_lr_scheduler.py | 59 ++- .../lr_scheduler/fairseq_lr_scheduler.py | 4 +- .../inverse_square_root_schedule.py | 35 +- fairseq/optim/nag.py | 17 +- fairseq/optim/shard.py | 2 +- fairseq/options.py | 123 +---- fairseq/quantization_utils.py | 5 +- fairseq/registry.py | 60 +-- fairseq/scoring/__init__.py | 23 +- fairseq/scoring/bleu.py | 56 +- fairseq/scoring/tokenizer.py | 6 +- fairseq/scoring/wer.py | 45 +- fairseq/tasks/__init__.py | 13 +- fairseq/tasks/audio_pretraining.py | 2 +- fairseq/tasks/fairseq_task.py | 29 +- fairseq/tasks/language_modeling.py | 12 +- fairseq/tasks/multilingual_translation.py | 10 +- fairseq/tasks/speech_to_text.py | 4 +- fairseq/trainer.py | 170 +++--- fairseq_cli/eval_lm.py | 129 +++-- fairseq_cli/generate.py | 158 +++--- fairseq_cli/interactive.py | 105 ++-- fairseq_cli/score.py | 8 +- fairseq_cli/train.py | 211 ++++---- fairseq_cli/validate.py | 64 ++- tests/speech_recognition/asr_test_base.py | 5 +- tests/test_bmuf.py | 72 ++- tests/test_fp16_optimizer.py | 35 +- tests/test_inference_dropout.py | 10 +- tests/test_memory_efficient_fp16.py | 40 +- tests/test_train.py | 63 ++- tests/utils.py | 2 +- 85 files changed, 2037 insertions(+), 1684 deletions(-) delete mode 100644 config/config_eval_lm.yaml delete mode 100644 config/params/eval_lm_params.yaml delete mode 100644 config/params/training_params.yaml diff --git a/config/config.yaml b/config/config.yaml index 66723e706c..b9ee6c74ac 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,7 +1,111 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + tpu: false + bf16: false + fp16: false + memory_efficient_fp16: false + memory_efficient_bf16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + quantization_config_path: null + profile: false +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + batch_size: null + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${dataset.max_tokens} + batch_size_valid: ${dataset.batch_size} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [ 1 ] + lr: [ 0.25 ] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 + checkpoint_suffix: "" +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false defaults: - - params: training_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt + - task: language_modeling + - model: null + - criterion: null + - optimizer: null + - lr_scheduler: null + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/config/config_eval_lm.yaml b/config/config_eval_lm.yaml deleted file mode 100644 index 5a93cb5d92..0000000000 --- a/config/config_eval_lm.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - params: eval_lm_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml index a85a7eed1c..7997b0766e 100644 --- a/config/criterion/adaptive_loss.yaml +++ b/config/criterion/adaptive_loss.yaml @@ -1,3 +1,3 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} +ddp_backend: ${distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml index a85a7eed1c..ad3d4148c2 100644 --- a/config/criterion/cross_entropy.yaml +++ b/config/criterion/cross_entropy.yaml @@ -1,3 +1,2 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} diff --git a/config/params/eval_lm_params.yaml b/config/params/eval_lm_params.yaml deleted file mode 100644 index 6f27055d64..0000000000 --- a/config/params/eval_lm_params.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -common_eval: - path: null - remove_bpe: null - quiet: false - model_overrides: "{}" - results_path: null -eval_lm: - output_word_probs: false - output_word_stats: false - context_window: 0 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/config/params/training_params.yaml b/config/params/training_params.yaml deleted file mode 100644 index 2ce94f9290..0000000000 --- a/config/params/training_params.yaml +++ /dev/null @@ -1,95 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 9b77dd8351..0973cd279e 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p ``` defaults: - - params: training_params - task: language_modeling - model: transformer_lm - criterion: cross_entropy @@ -21,7 +20,7 @@ defaults: - lr_scheduler: inverse_sqrt ``` -- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` +- Provide generic parameters common across different jobs: `config.yaml` - Provide task parameters: `config/task/language_modeling.yaml` - Provide model parameters: `config/model/transformer_lm.yaml` - Provide criterion parameters: `config/criterion/cross_entropy.yaml` @@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c ``` python fairseq_cli/train_hydra.py -params=training_params \ task=language_modeling \ task.data=/private/home/abaevski/data/wiki103 \ task.tokens_per_sample=512 \ @@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \ lr_scheduler.warmup_updates=4000 \ lr_scheduler.warmup_init_lr=1e-07 \ criterion=cross_entropy \ -params.common.fp16=true \ -params.common.log_format=json \ -params.common.log_interval=1 \ -params.dataset.max_tokens=1024 \ -params.dataset.num_workers=4 \ -params.optimization.update_freq=[16] \ -params.optimization.max_update=50000 \ -params.optimization.clip_norm=0.0 \ -params.optimization.lr=[0.0005] \ -params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ -params.checkpoint.save_interval_updates=10 +common.fp16=true \ +common.log_format=json \ +common.log_interval=1 \ +dataset.max_tokens=1024 \ +dataset.num_workers=4 \ +optimization.update_freq=[16] \ +optimization.max_update=50000 \ +optimization.clip_norm=0.0 \ +optimization.lr=[0.0005] \ +checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +checkpoint.save_interval_updates=10 ``` ## Migrate existing/Creating new modules to hydra interface diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index 40a3cb6f25..b02fec0489 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -212,7 +212,7 @@ following contents:: @register_task('simple_classification') - class SimpleClassificationTask(FairseqTask): + class SimpleClassificationTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index 4df424e6b5..13036926e0 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -27,7 +27,13 @@ def score_target_hypo( print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) dict = dictionary.Dictionary() - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) ordered_hypos = {} ordered_targets = {} diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py index 1a5901234b..ed0251fdec 100644 --- a/examples/roberta/wsc/wsc_criterion.py +++ b/examples/roberta/wsc/wsc_criterion.py @@ -20,8 +20,8 @@ def __init__(self, args, task): self.prediction_h = open(self.args.save_predictions, "w") else: self.prediction_h = None - self.bpe = encoders.build_bpe(args) - self.tokenizer = encoders.build_tokenizer(args) + self.bpe = encoders.build_bpe(args.bpe) + self.tokenizer = encoders.build_tokenizer(args.tokenizer) def __del__(self): if self.prediction_h is not None: diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index 809a58e41b..aeb96a14b1 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt ``` CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout - --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer + --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 75e2c68ca3..c036e12966 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -3,36 +3,42 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import ast import collections import logging import os import re import traceback from collections import OrderedDict -from typing import Union +from typing import Optional, Union import torch +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, open_dict from torch.serialization import default_restore_location logger = logging.getLogger(__name__) -def save_checkpoint(args, trainer, epoch_itr, val_loss): - from fairseq import distributed_utils, meters +def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss): + from fairseq import meters # only one worker should attempt to create the required dir - if args.distributed_rank == 0: - os.makedirs(args.save_dir, exist_ok=True) + if cfg.distributed_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: - best_function = max if args.maximize_best_checkpoint_metric else min + best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) - if args.no_save: + if cfg.no_save: return trainer.consolidate_optimizer() @@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): return def is_better(a, b): - return a >= b if args.maximize_best_checkpoint_metric else a <= b + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() @@ -50,38 +56,36 @@ def is_better(a, b): end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( - end_of_epoch - and not args.no_epoch_checkpoints - and epoch % args.save_interval == 0 + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch - and args.save_interval_updates > 0 - and updates % args.save_interval_updates == 0 + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) - if val_loss is not None and args.keep_best_checkpoints > 0: + if val_loss is not None and cfg.keep_best_checkpoints > 0: checkpoint_conds[ - "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) + "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) ] = not hasattr(save_checkpoint, "best") or is_better( val_loss, save_checkpoint.best ) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) - ] = not args.no_last_checkpoints + ] = not cfg.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ - os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) @@ -95,51 +99,52 @@ def is_better(a, b): ) ) - if not end_of_epoch and args.keep_interval_updates > 0: + if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" ) - for old_chk in checkpoints[args.keep_interval_updates :]: + for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_last_epochs > 0: + if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") - for old_chk in checkpoints[args.keep_last_epochs :]: + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_best_checkpoints > 0: + if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( - args.save_dir, + cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - args.best_checkpoint_metric + cfg.best_checkpoint_metric ), ) - if not args.maximize_best_checkpoint_metric: + if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] - for old_chk in checkpoints[args.keep_best_checkpoints :]: + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) -def load_checkpoint(args, trainer, **passthrough_args): +def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ - reset_optimizer = args.reset_optimizer - reset_lr_scheduler = args.reset_lr_scheduler - optimizer_overrides = eval(args.optimizer_overrides) - reset_meters = args.reset_meters - reset_dataloader = args.reset_dataloader - if getattr(args, "finetune_from_model", None) is not None and ( + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader ): raise ValueError( @@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix if ( - args.restore_file == "checkpoint_last.pt" + cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - args.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) - if getattr(args, "finetune_from_model", None) is not None and first_launch: + if cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - if PathManager.exists(args.finetune_from_model): - checkpoint_path = args.finetune_from_model + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True @@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args): ) else: raise ValueError( - f"--funetune-from-model {args.finetune_from_model} does not exist" + f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif getattr(args, "model_parallel_size", 1) > 1: - checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") + elif cfg.model_parallel_size > 1: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: - checkpoint_path = args.restore_file + checkpoint_path = cfg.restore_file - if args.restore_file != "checkpoint_last.pt" and getattr( - args, "finetune_from_model", None - ): + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(args) + "can not be specified together: " + str(cfg) ) extra_state = trainer.load_checkpoint( @@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): f, map_location=lambda s, l: default_restore_location(s, "cpu") ) - args = state["args"] - if arg_overrides is not None: + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + state = _upgrade_state_dict(state) return state @@ -274,19 +281,28 @@ def load_model_ensemble_and_task( filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) - if shard_idx == 0: - args = state["args"] - if task is None: - task = tasks.setup_task(args) - - # build model for ensemble - model = task.build_model(args) - model.load_state_dict(state["model"], strict=strict, args=args) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + if task is None: + task = tasks.setup_task(cfg.task) + + # build model for ensemble + model = task.build_model(cfg.model) + + model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) ensemble.append(model) - return ensemble, args, task + return ensemble, cfg, task def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): @@ -323,7 +339,7 @@ def torch_persistent_save(obj, f): def save_state( filename, - args, + cfg: DictConfig, model_state_dict, criterion, optimizer, @@ -331,6 +347,7 @@ def save_state( num_updates, optim_history=None, extra_state=None, + **kwargs, ): from fairseq import utils @@ -339,7 +356,8 @@ def save_state( if extra_state is None: extra_state = {} state_dict = { - "args": args, + "cfg": cfg, + "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history + [ @@ -354,11 +372,17 @@ def save_state( } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() - if not args.no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - # convert all state to CPU - state_dict = utils.move_to_cpu(state_dict) + if cfg is None: + cfg = state_dict["args"] + assert cfg is not None, "must provide cfg or args" + + if isinstance(cfg, DictConfig): + no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state + else: + no_save_optimizer_state = cfg.no_save_optimizer_state + if not no_save_optimizer_state: + state_dict["last_optimizer_state"] = optimizer.state_dict() with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f) @@ -403,46 +427,49 @@ def _upgrade_state_dict(state): # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 - # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( - state["args"], "max_source_positions" - ): - state["args"].max_source_positions = state["args"].max_positions - state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } - # default to translation task - if not hasattr(state["args"], "task"): - state["args"].task = "translation" - # --raw-text and --lazy-load are deprecated - if getattr(state["args"], "raw_text", False): - state["args"].dataset_impl = "raw" - elif getattr(state["args"], "lazy_load", False): - state["args"].dataset_impl = "lazy" - # epochs start at 1 - if state["extra_state"]["train_iterator"] is not None: - state["extra_state"]["train_iterator"]["epoch"] = max( - state["extra_state"]["train_iterator"].get("epoch", 1), - 1, - ) - # set any missing default values in the task, model or other registries - registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) - registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch]) - for registry_name, REGISTRY in registry.REGISTRIES.items(): - choice = getattr(state["args"], registry_name, None) - if choice is not None: - cls = REGISTRY["registry"][choice] - registry.set_defaults(state["args"], cls) + # old model checkpoints may not have separate source/target positions + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + with open_dict(state["cfg"]): + if state["cfg"].task is not None: + if hasattr(state["cfg"].task, "max_positions") and not hasattr( + state["cfg"].task, "max_source_positions" + ): + state["cfg"].task.max_source_positions = state[ + "cfg" + ].task.max_positions + state["cfg"].task.max_target_positions = state[ + "cfg" + ].task.max_positions return state -def prune_state_dict(state_dict, args): +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): """Prune the given state_dict if desired for LayerDrop (https://arxiv.org/abs/1909.11556). @@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args): It's called by functions that load models from checkpoints and does not need to be called directly. """ - if not args or args.arch == "ptt_transformer": + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": # args should not be none, but don't crash if it is. return state_dict - encoder_layers_to_keep = ( - args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None - ) - decoder_layers_to_keep = ( - args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None - ) + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict @@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args): def create_pruning_pass(layers_to_keep, layer_name): keep_layers = sorted( - [int(layer_string) for layer_string in layers_to_keep.split(",")] + int(layer_string) for layer_string in layers_to_keep.split(",") ) mapping_dict = {} for i in range(len(keep_layers)): @@ -518,10 +549,12 @@ def create_pruning_pass(layers_to_keep, layer_name): # Since layers are now pruned, *_layers_to_keep are no longer needed. # This is more of "It would make it work fix" rather than a proper fix. - if "encoder_layers_to_keep" in vars(args): - args.encoder_layers_to_keep = None - if "decoder_layers_to_keep" in vars(args): - args.decoder_layers_to_keep = None + + with open_dict(model_cfg): + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None return new_state_dict diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index a7eb5f6f3c..8cc6c0f043 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.criterions.fairseq_criterion import ( # noqa @@ -27,8 +25,8 @@ ) -def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): - return build_criterion_(criterion_cfg, task) +def build_criterion(cfg: DictConfig, task): + return build_criterion_(cfg, task) # automatically import any Python files in the criterions/ directory diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 74ba37c321..04832295ec 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -11,13 +11,13 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.constants import DDP_BACKEND_CHOICES -from omegaconf import II +from omegaconf import II, DictConfig @dataclass class AdaptiveLossConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") + sentence_avg: bool = II("optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend") @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) @@ -31,14 +31,14 @@ def __init__(self, task, sentence_avg): self.sentence_avg = sentence_avg @classmethod - def build_criterion(cls, args, task): - if getattr(args, "ddp_backend", None) == "c10d": + def build_criterion(cls, cfg: DictConfig, task): + if cfg.ddp_backend == "c10d": raise Exception( "AdaptiveLoss is not compatible with the c10d " "version of DistributedDataParallel. Please use " "`--ddp-backend=no_c10d` instead." ) - return cls(task, args.sentence_avg) + return cls(task, cfg.sentence_avg) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 91b58545ed..758e727660 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -15,7 +15,7 @@ @dataclass class CrossEntropyCriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") @register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 4f93b3cbfd..9310024f29 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -10,24 +10,24 @@ import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import LegacyFairseqCriterion, register_criterion from fairseq.data.data_utils import post_process from fairseq.logging.meters import safe_round @register_criterion("ctc") -class CtcCriterion(FairseqCriterion): - def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): - super().__init__(task) +class CtcCriterion(LegacyFairseqCriterion): + def __init__(self, args, task): + super().__init__(args, task) self.blank_idx = task.target_dictionary.bos() self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = remove_bpe if remove_bpe else "letter" + self.post_process = args.remove_bpe if args.remove_bpe else "letter" - if wer_args is not None: + if args.wer_args is not None: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder - wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args) + wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) dec_args = Namespace() dec_args.nbest = 1 @@ -46,8 +46,8 @@ def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): else: self.w2l_decoder = None - self.zero_infinity = zero_infinity - self.sentence_avg = sentence_avg + self.zero_infinity = args.zero_infinity + self.sentence_avg = args.sentence_avg @staticmethod def add_args(parser): diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index ef94a86327..b2eda1a7e4 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -8,6 +8,7 @@ from fairseq import metrics, utils from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig from torch.nn.modules.loss import _Loss @@ -27,10 +28,8 @@ def add_args(cls, parser): gen_parser_from_dataclass(parser, dc()) @classmethod - def build_criterion(cls, args, task): + def build_criterion(cls, cfg: DictConfig, task): """Construct a criterion from command-line args.""" - # Criterions can override this, but for convenience we also try - # to automatically map argparse.Namespace keys to corresponding # arguments in the __init__. init_args = {} for p in inspect.signature(cls).parameters.values(): @@ -47,8 +46,8 @@ def build_criterion(cls, args, task): if p.name == "task": init_args["task"] = task - elif hasattr(args, p.name): - init_args[p.name] = getattr(args, p.name) + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: pass # we'll use the default value else: @@ -70,7 +69,7 @@ def forward(self, model, sample, reduce=True): @staticmethod def aggregate_logging_outputs( - logging_outputs: List[Dict[str, Any]], + logging_outputs: List[Dict[str, Any]] ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py index 0d2da3ea1a..31e3a06278 100644 --- a/fairseq/data/encoders/byte_bpe.py +++ b/fairseq/data/encoders/byte_bpe.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe from fairseq.data.encoders.byte_utils import ( @@ -12,19 +14,20 @@ byte_encode, smart_byte_decode, ) +from fairseq.dataclass import FairseqDataclass + +@dataclass +class ByteBpeConfig(FairseqDataclass): + sentencepiece_model_path: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) -@register_bpe("byte_bpe") + +@register_bpe("byte_bpe", dataclass=ByteBpeConfig) class ByteBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model-path', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - vocab = file_utils.cached_path(args.sentencepiece_model_path) + def __init__(self, cfg): + vocab = file_utils.cached_path(cfg.sentencepiece_model_path) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py index bb9554ed53..f88f8f6929 100644 --- a/fairseq/data/encoders/bytes.py +++ b/fairseq/data/encoders/bytes.py @@ -15,7 +15,7 @@ @register_bpe("bytes") class Bytes(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py index cffc57511c..494ea21939 100644 --- a/fairseq/data/encoders/characters.py +++ b/fairseq/data/encoders/characters.py @@ -13,7 +13,7 @@ @register_bpe("characters") class Characters(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py index 74d4ad8504..f7c2103954 100644 --- a/fairseq/data/encoders/fastbpe.py +++ b/fairseq/data/encoders/fastbpe.py @@ -3,23 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class fastBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) -@register_bpe("fastbpe") +@register_bpe("fastbpe", dataclass=fastBPEConfig) class fastBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to fastBPE BPE') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=fastbpe") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: import fastBPE diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py index 8ac099a688..e661426a73 100644 --- a/fairseq/data/encoders/gpt2_bpe.py +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -3,8 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass from .gpt2_bpe_utils import get_encoder @@ -13,26 +16,21 @@ DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" -@register_bpe("gpt2") +@dataclass +class GPT2BPEConfig(FairseqDataclass): + gpt2_encoder_json: str = field( + default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} + ) + gpt2_vocab_bpe: str = field( + default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} + ) + + +@register_bpe("gpt2", dataclass=GPT2BPEConfig) class GPT2BPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--gpt2-encoder-json', type=str, - default=DEFAULT_ENCODER_JSON, - help='path to encoder.json') - parser.add_argument('--gpt2-vocab-bpe', type=str, - default=DEFAULT_VOCAB_BPE, - help='path to vocab.bpe') - # fmt: on - - def __init__(self, args): - encoder_json = file_utils.cached_path( - getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) - ) - vocab_bpe = file_utils.cached_path( - getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) - ) + def __init__(self, cfg): + encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) + vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) self.bpe = get_encoder(encoder_json, vocab_bpe) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py index a968fe8857..a41c059343 100644 --- a/fairseq/data/encoders/hf_bert_bpe.py +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -3,22 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import Optional + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class BertBPEConfig(FairseqDataclass): + bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) + bpe_vocab_file: Optional[str] = field( + default=None, metadata={"help": "bpe vocab file"} + ) -@register_bpe("bert") +@register_bpe("bert", dataclass=BertBPEConfig) class BertBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-cased', action='store_true', - help='set for cased BPE', - default=False) - parser.add_argument('--bpe-vocab-file', type=str, - help='bpe vocab file.') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from transformers import BertTokenizer except ImportError: @@ -26,13 +28,13 @@ def __init__(self, args): "Please install transformers with: pip install transformers" ) - if "bpe_vocab_file" in args: + if cfg.bpe_vocab_file: self.bert_tokenizer = BertTokenizer( - args.bpe_vocab_file, do_lower_case=not args.bpe_cased + cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased ) else: vocab_file_name = ( - "bert-base-cased" if args.bpe_cased else "bert-base-uncased" + "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" ) self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 544d408273..92d2c3922c 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class HuggingFaceByteLevelBPEConfig(FairseqDataclass): + bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) + bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) + bpe_add_prefix_space: bool = field( + default=False, metadata={"help": "add prefix space before encoding"} + ) -@register_bpe("hf_byte_bpe") +@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) class HuggingFaceByteLevelBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-merges', help='path to merges.txt') - parser.add_argument('--bpe-vocab', help='path to vocab.json') - parser.add_argument('--bpe-add-prefix-space', action='store_true', - help='add prefix space before encoding') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from tokenizers import ByteLevelBPETokenizer except ImportError: @@ -26,9 +29,9 @@ def __init__(self, args): ) self.bpe = ByteLevelBPETokenizer( - args.bpe_vocab, - args.bpe_merges, - add_prefix_space=getattr(args, "bpe_add_prefix_space", False), + cfg.bpe_vocab, + cfg.bpe_merges, + add_prefix_space=cfg.bpe_add_prefix_space, ) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index 8c24844263..fa004dd4af 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -3,37 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + target_lang: str = field(default="en", metadata={"help": "target language"}) + moses_no_dash_splits: bool = field( + default=False, metadata={"help": "don't apply dash split rules"} + ) + moses_no_escape: bool = field( + default=False, + metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, + ) -@register_tokenizer("moses") +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) class MosesTokenizer(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--moses-source-lang', metavar='SRC', - help='source language') - parser.add_argument('--moses-target-lang', metavar='TARGET', - help='target language') - parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, - help='don\'t apply dash split rules') - parser.add_argument('--moses-no-escape', action='store_true', default=False, - help='don\'t perform HTML escaping on apostrophy, quotes, etc.') - # fmt: on - - def __init__(self, args): - self.args = args - - if getattr(args, "moses_source_lang", None) is None: - args.moses_source_lang = getattr(args, "source_lang", "en") - if getattr(args, "moses_target_lang", None) is None: - args.moses_target_lang = getattr(args, "target_lang", "en") + def __init__(self, cfg): + self.cfg = cfg try: from sacremoses import MosesTokenizer, MosesDetokenizer - self.tok = MosesTokenizer(args.moses_source_lang) - self.detok = MosesDetokenizer(args.moses_target_lang) + self.tok = MosesTokenizer(cfg.source_lang) + self.detok = MosesDetokenizer(cfg.target_lang) except ImportError: raise ImportError( "Please install Moses tokenizer with: pip install sacremoses" @@ -42,9 +40,9 @@ def __init__(self, args): def encode(self, x: str) -> str: return self.tok.tokenize( x, - aggressive_dash_splits=(not self.args.moses_no_dash_splits), + aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), return_str=True, - escape=(not self.args.moses_no_escape), + escape=(not self.cfg.moses_no_escape), ) def decode(self, x: str) -> str: diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index 3b617e7314..ee164710a0 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -8,7 +8,7 @@ @register_tokenizer("nltk") class NLTKTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): try: from nltk.tokenize import word_tokenize diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index b25c6caebe..a76d46a201 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SentencepieceConfig(FairseqDataclass): + sentencepiece_model: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) -@register_bpe("sentencepiece") +@register_bpe("sentencepiece", dataclass=SentencepieceConfig) class SentencepieceBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) + def __init__(self, cfg): + sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 3bc7ce4958..7c7f644d5c 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -10,7 +10,7 @@ @register_tokenizer("space") class SpaceTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): self.space_tok = re.compile(r"\s+") def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py index e85f99af39..5d724d2730 100644 --- a/fairseq/data/encoders/subword_nmt_bpe.py +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -3,25 +3,25 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SubwordNMTBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) + bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) -@register_bpe("subword_nmt") +@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) class SubwordNMTBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to subword NMT BPE') - parser.add_argument('--bpe-separator', default='@@', - help='BPE separator') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=subword_nmt") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: from subword_nmt import apply_bpe @@ -31,7 +31,7 @@ def __init__(self, args): "--codes", codes, "--separator", - args.bpe_separator, + cfg.bpe_separator, ] ) self.bpe = apply_bpe.BPE( diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 21b36450f9..2fd87f5fc4 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -9,5 +9,7 @@ LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"]) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index ed1d12d865..b0c17ba0ad 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -3,32 +3,37 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import sys from argparse import Namespace -from dataclasses import dataclass, field +from dataclasses import _MISSING_TYPE, dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type import torch -from fairseq.criterions import CRITERION_DATACLASS_REGISTRY from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.constants import ( DDP_BACKEND_CHOICES, DISTRIBUTED_WRAPPER_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, PIPELINE_CHECKPOINT_CHOICES, ZERO_SHARDING_CHOICES, ) from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY -from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY from fairseq.optim.bmuf import FairseqBMUFConfig -from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY +from fairseq.registry import REGISTRIES from fairseq.tasks import TASK_DATACLASS_REGISTRY from hydra.core.config_store import ConfigStore +from omegaconf import II + + +logger = logging.getLogger(__name__) @dataclass -class CommonParams(FairseqDataclass): +class CommonConfig(FairseqDataclass): # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. no_progress_bar: bool = field( @@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass): model_parallel_size: int = field( default=1, metadata={"help": "total number of GPUs to parallelize model over"} ) - checkpoint_suffix: str = field( - default="", metadata={"help": "suffix to add to the checkpoint file name"} - ) - checkpoint_shard_count: int = field( - default=1, - metadata={ - "help": "Number of shards containing the checkpoint - " - "if the checkpoint is over 300GB, it is preferable " - "to split it into shards to prevent OOM on CPU while loading " - "the checkpoint" - }, - ) quantization_config_path: Optional[str] = field( default=None, metadata={"help": "path to quantization config file"} ) @@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass): @dataclass -class DistributedTrainingParams(FairseqDataclass): +class DistributedTrainingConfig(FairseqDataclass): distributed_world_size: int = field( default=max(1, torch.cuda.device_count()), metadata={ @@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass): default=False, metadata={"help": "if set, use pipeline model parallelism across GPUs"}, ) - pipeline_balance: str = field( + pipeline_balance: Optional[str] = field( default=None, metadata={ "help": "partition the model into N_K pieces, where each piece " @@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of layers in the model" }, ) - pipeline_devices: str = field( + pipeline_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-balance argument" }, ) - pipeline_chunks: int = field( + pipeline_chunks: Optional[int] = field( default=0, metadata={"help": "microbatch count for pipeline model parallelism"} ) - pipeline_encoder_balance: str = field( + pipeline_encoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " @@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of encoder layers in the model" }, ) - pipeline_encoder_devices: str = field( + pipeline_encoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-encoder-balance argument" }, ) - pipeline_decoder_balance: str = field( + pipeline_decoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " @@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of decoder layers in the model" }, ) - pipeline_decoder_devices: str = field( + pipeline_decoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + tpu: bool = II("common.tpu") @dataclass -class DatasetParams(FairseqDataclass): +class DatasetConfig(FairseqDataclass): num_workers: int = field( default=1, metadata={"help": "how many subprocesses to use for data loading"} ) @@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass): @dataclass -class OptimizationParams(FairseqDataclass): +class OptimizationConfig(FairseqDataclass): max_epoch: int = field( default=0, metadata={"help": "force stop training at specified epoch"} ) @@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass): @dataclass -class CheckpointParams(FairseqDataclass): +class CheckpointConfig(FairseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) @@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass): ) }, ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + distributed_rank: int = II("distributed_training.distributed_rank") @dataclass -class CommonEvalParams(FairseqDataclass): +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: bool = field( + default=False, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens" + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + retain_dropout_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, metadata={"help": "path(s) to model file(s), colon separated"} + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, ) remove_bpe: Optional[str] = field( default=None, @@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass): @dataclass -class EvalLMParams(FairseqDataclass): +class EvalLMConfig(FairseqDataclass): output_word_probs: bool = field( default=False, metadata={ @@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass): @dataclass -class TrainingConfig(FairseqDataclass): - """Config for training, a composition of training params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() - - -@dataclass -class EvalLMConfig(FairseqDataclass): - """Config for eval lm, a composition of eval_lm params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() - common_eval: CommonEvalParams = CommonEvalParams() - eval_lm: EvalLMParams = EvalLMParams() +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) -def register_params_dataclass( - cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass] -) -> None: - """register params dataclass in config store""" - node_ = data_class(_name=data_class.name()) - cs.store(name=name, group=group, node=node_) +CONFIGS = { + "common": CommonConfig, + "common_eval": CommonEvalConfig, + "distributed_training": DistributedTrainingConfig, + "dataset": DatasetConfig, + "optimization": OptimizationConfig, + "checkpoint": CheckpointConfig, + "bmuf": FairseqBMUFConfig, + "generation": GenerationConfig, + "eval_lm": EvalLMConfig, + "interactive": InteractiveConfig, +} def register_module_dataclass( @@ -608,100 +801,67 @@ def register_module_dataclass( """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" # note that if `group == model`, we register all model archs, not the model name. for k, v in registry.items(): - if v is not None: - node_ = v(_name=k) - cs.store(name=k, group=group, node=node_) + node_ = v() + node_._name = k + cs.store(name=k, group=group, node=node_, provider="fairseq") -def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: +def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: """cs: config store instance, register common training configs""" - register_params_dataclass( - cs, name="training_params", group="params", data_class=TrainingConfig - ) + for k, v in CONFIGS.items(): + try: + cs.store(name=k, node=v()) + except BaseException: + logger.error(f"{k} - {v()}") + raise register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") - -def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: - """cs: config store instance, register common training configs""" - - register_params_dataclass( - cs, name="eval_lm_params", group="params", data_class=EvalLMConfig - ) - - register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") + for k, v in REGISTRIES.items(): + register_module_dataclass(cs, v["dataclass_registry"], k) def _override_attr( sub_node: str, data_class: Type[FairseqDataclass], args: Namespace ) -> List[str]: overrides = [] - for k in data_class.__dataclass_fields__.keys(): - if k == "_name": + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): # private member, skip continue - if not hasattr(args, k): - # print(f"cannot override {sub_node}.{k} since args does not have attribute {k}") - continue - if getattr(args, k) is None: + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + if val is None: overrides.append("{}.{}=null".format(sub_node, k)) - elif getattr(args, k) == "": + elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) - elif isinstance(getattr(args, k), str): - if ( - getattr(args, k).startswith("[") - or getattr(args, k).startswith("(") - or getattr(args, k).startswith("{") - or ("," in getattr(args, k)) - ): - overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k))) - else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + elif isinstance(val, str): + overrides.append("{}.{}='{}'".format(sub_node, k, val)) else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + overrides.append("{}.{}={}".format(sub_node, k, val)) return overrides -def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.optimization", OptimizationParams, args)) - overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes - - -def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args)) - overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: @@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: overrides = [] deletes = [] + for k, v in CONFIGS.items(): + overrides.extend(_override_attr(k, v, args)) + if args is not None: - assert ( - hasattr(args, "task") - and hasattr(args, "criterion") - and hasattr(args, "optimizer") - and hasattr(args, "lr_scheduler") - ) - if args.task in TASK_DATACLASS_REGISTRY: - overrides.append("task={}".format(args.task)) - overrides.append("task._name={}".format(args.task)) - overrides.extend( - _override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args) + if hasattr(args, "task"): + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes ) else: deletes.append("task") - if args.criterion in CRITERION_DATACLASS_REGISTRY: - overrides.append("criterion={}".format(args.criterion)) - overrides.append("criterion._name={}".format(args.criterion)) - overrides.extend( - _override_attr( - "criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args - ) - ) - else: - deletes.append("criterion") - if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY: - overrides.append("optimizer={}".format(args.optimizer)) - overrides.append("optimizer._name={}".format(args.optimizer)) - overrides.extend( - _override_attr( - "optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args - ) - ) - else: - deletes.append("optimizer") - if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY: - overrides.append("lr_scheduler={}".format(args.lr_scheduler)) - overrides.append("lr_scheduler._name={}".format(args.lr_scheduler)) - overrides.extend( - _override_attr( - "lr_scheduler", - LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler], + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, ) - ) - else: - deletes.append("lr_scheduler") + else: + deletes.append(k) no_dc = True if hasattr(args, "arch"): diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 599cc2b4c2..bcfe23294a 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -3,17 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from argparse import ArgumentParser -from dataclasses import MISSING, dataclass +import ast +from argparse import ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, dataclass from enum import Enum from typing import Any, Dict, List, Optional +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict + def eval_str_list(x, x_type=float): if x is None: return None if isinstance(x, str): - x = eval(x) + if len(x) == 0: + return [] + x = ast.literal_eval(x) try: return list(map(x_type, x)) except TypeError: @@ -70,22 +77,11 @@ def _get_default(self, attribute_name: str) -> Any: != self.__dataclass_fields__[attribute_name].default ): return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default - def _get_default_factory(self, attribute_name: str) -> Any: - if hasattr(self, attribute_name): - if str(getattr(self, attribute_name)).startswith("${"): - return str(getattr(self, attribute_name)) - elif str(self.__dataclass_fields__[attribute_name].default).startswith( - "${" - ): - return str(self.__dataclass_fields__[attribute_name].default) - elif ( - getattr(self, attribute_name) - != self.__dataclass_fields__[attribute_name].default_factory() - ): - return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default_factory() + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default def _get_type(self, attribute_name: str) -> Any: return self.__dataclass_fields__[attribute_name].type @@ -119,7 +115,7 @@ def argparse_name(name: str): def interpret_dc_type(field_type): if isinstance(field_type, str): - raise RuntimeError() + raise RuntimeError("field should be a type") typestring = str(field_type) if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): return field_type.__args__[0] @@ -129,12 +125,13 @@ def get_kwargs_from_dc( dataclass_instance: FairseqDataclass, k: str ) -> Dict[str, Any]: """k: dataclass attributes""" + + kwargs = {} + field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) - if isinstance(inter_type, type) and issubclass(inter_type, List): - field_default = dataclass_instance._get_default_factory(k) - else: - field_default = dataclass_instance._get_default(k) + + field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] @@ -143,7 +140,7 @@ def get_kwargs_from_dc( field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) - kwargs = {} + if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: @@ -163,7 +160,11 @@ def get_kwargs_from_dc( else: raise NotImplementedError() if field_default is not MISSING: - kwargs["default"] = ",".join(map(str, field_default)) + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) elif ( isinstance(inter_type, type) and issubclass(inter_type, Enum) ) or "Enum" in str(inter_type): @@ -187,6 +188,7 @@ def get_kwargs_from_dc( if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" + return kwargs for k in dataclass_instance._get_all_attributes(): @@ -194,8 +196,122 @@ def get_kwargs_from_dc( if field_name is None: continue kwargs = get_kwargs_from_dc(dataclass_instance, k) - if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): - continue - if delete_default: - del kwargs["default"] + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + continue + if delete_default: + del kwargs["default"] parser.add_argument(field_name, **kwargs) + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + from fairseq.dataclass.data_class import override_module_args + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + cfg_name = "config" + cfg_path = f"../../{cfg_name}" + + if not GlobalHydra().is_initialized(): + initialize(config_path=cfg_path) + + composed_cfg = compose(cfg_name, overrides=overrides, strict=False) + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(cfg, True) + return cfg + + +def populate_dataclass( + args: Namespace, dataclass: FairseqDataclass +) -> FairseqDataclass: + for k in dataclass.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(dataclass, k, getattr(args, k)) + + return dataclass + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + with open_dict(cfg): + for k in cfg.keys(): + if isinstance(cfg[k], DictConfig): + overwrite_args_by_name(cfg[k], overrides) + elif k in overrides: + cfg[k] = overrides[k] diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index bcb0595e6e..23cdfc6938 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -11,35 +11,38 @@ import struct import subprocess import warnings +from argparse import Namespace from collections import OrderedDict from typing import Any, Dict, Mapping import torch import torch.distributed as dist from fairseq import utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from omegaconf import DictConfig, open_dict logger = logging.getLogger(__name__) -def is_master(args): - return args.distributed_rank == 0 +def is_master(cfg: DictConfig): + return cfg.distributed_rank == 0 -def infer_init_method(args, force_distributed=False): - if args.distributed_init_method is not None or getattr(args, "tpu", False): +def infer_init_method(cfg: DictConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: return - if args.pipeline_model_parallel: + if cfg.pipeline_model_parallel: balance_exists = ( - args.pipeline_balance is not None - or args.pipeline_encoder_balance is not None - or args.pipeline_decoder_balance is not None + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None ) devices_exist = ( - args.pipeline_devices is not None - or args.pipeline_encoder_devices is not None - or args.pipeline_decoder_devices is not None + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None ) if not balance_exists: raise ValueError( @@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False): "--pipeline-devices is currently required for pipeline model parallelism" ) - args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) - if args.pipeline_devices is not None: - args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) - num_pipeline_devices = len(set(args.pipeline_devices)) + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) else: - args.pipeline_encoder_devices = utils.eval_str_list( - args.pipeline_encoder_devices, type=int + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int ) - args.pipeline_decoder_devices = utils.eval_str_list( - args.pipeline_decoder_devices, type=int + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int ) num_pipeline_devices = len( - set(args.pipeline_encoder_devices + args.pipeline_decoder_devices) + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) ) gpus_per_node = torch.cuda.device_count() assert ( @@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False): key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): - args.distributed_init_method = "env://" - args.distributed_world_size = int(os.environ["WORLD_SIZE"]) - args.distributed_rank = int(os.environ["RANK"]) + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) # processes are created by torch.distributed.launch - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # we can determine the init method automatically for Slurm - elif args.distributed_port > 0: + elif cfg.distributed_port > 0: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: node_list = os.environ.get("SLURM_JOB_NODELIST") @@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False): hostnames = subprocess.check_output( ["scontrol", "show", "hostnames", node_list] ) - args.distributed_init_method = "tcp://{host}:{port}".format( + cfg.distributed_init_method = "tcp://{host}:{port}".format( host=hostnames.split()[0].decode("utf-8"), - port=args.distributed_port, + port=cfg.distributed_port, ) nnodes = int(os.environ.get("SLURM_NNODES")) ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") @@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False): if ntasks_per_node == 1: gpus_per_node = torch.cuda.device_count() node_id = int(os.environ.get("SLURM_NODEID")) - args.distributed_rank = node_id * gpus_per_node - args.distributed_world_size = nnodes * gpus_per_node - elif args.pipeline_model_parallel: + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: assert ntasks_per_node == num_pipelines_per_node, ( "SLURM --ntasks-per-node must match number of pipelines per " "node (={})".format(num_pipelines_per_node) ) - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on # the first node, [1, 2] on the second node, etc. This # matches torch.distributed.launch. node_id = int(os.environ.get("SLURM_NODEID")) local_id = int(os.environ.get("SLURM_LOCALID")) - args.distributed_rank = node_id * num_pipelines_per_node + local_id + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id # In the above example, device_id will always be in [0, 1], # which also matches torch.distributed.launch. - args.device_id = local_id + cfg.device_id = local_id # We also want to set distributed_world_size to be the total # number of pipelines across all nodes. - args.distributed_world_size = nnodes * num_pipelines_per_node + cfg.distributed_world_size = nnodes * num_pipelines_per_node else: - assert ntasks_per_node == args.distributed_world_size // nnodes - args.distributed_no_spawn = True - args.distributed_rank = int(os.environ.get("SLURM_PROCID")) - args.device_id = int(os.environ.get("SLURM_LOCALID")) + assert ntasks_per_node == cfg.distributed_world_size // nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed pass - elif args.distributed_world_size > 1 or force_distributed: + elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() + assert cfg.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = "tcp://localhost:{port}".format(port=port) + cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) - if args.pipeline_model_parallel: - if not args.distributed_no_spawn: + if cfg.pipeline_model_parallel: + if not cfg.distributed_no_spawn: # When distributed_no_spawn is False, we expect distributed_rank and # distributed_world_size to be based on the total number of GPUs, so # we need to correct them to be based on the number of pipelines. - assert args.distributed_world_size % num_pipeline_devices == 0 - args.distributed_world_size = ( - args.distributed_world_size // num_pipeline_devices + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = ( + cfg.distributed_world_size // num_pipeline_devices ) # In the case of 4-way MP on nodes with 8 GPUs, we want # distributed_rank to be the starting GPU index for each pipeline # i.e., 0, 2, ... - assert args.distributed_rank % gpus_per_node == 0 - assert args.distributed_rank % num_pipeline_devices == 0 - args.distributed_rank = args.distributed_rank // num_pipeline_devices - # launch one process per pipeline - args.distributed_num_procs = num_pipelines_per_node + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 # and 4, indicating the starting device IDs for each pipeline - args.device_id *= num_pipeline_devices + cfg.device_id *= num_pipeline_devices - if args.device_id > 0: + if cfg.device_id > 0: # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 # GPU node), we need to adjust pipeline_devices accordingly logger.debug( "setting CUDA device={} on rank {}".format( - args.device_id, args.distributed_rank + cfg.device_id, cfg.distributed_rank ) ) - torch.cuda.set_device(args.device_id) - args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] + torch.cuda.set_device(cfg.device_id) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] logger.info( "setting pipeline_devices={} on rank {}".format( - args.pipeline_devices, args.distributed_rank - ), + cfg.pipeline_devices, cfg.distributed_rank + ) ) - elif not args.distributed_no_spawn: - args.distributed_num_procs = min( - torch.cuda.device_count(), - args.distributed_world_size, - ) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size + ) + +def distributed_init(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) -def distributed_init(args): - if not getattr(args, "tpu", False): + if not cfg.common.tpu: if torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!" @@ -200,20 +209,20 @@ def distributed_init(args): else: logger.info( "distributed init (rank {}): {}".format( - args.distributed_rank, - args.distributed_init_method, + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, ) ) dist.init_process_group( - backend=args.distributed_backend, - init_method=args.distributed_init_method, - world_size=args.distributed_world_size, - rank=args.distributed_rank, + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, ) logger.info( "initialized host {} as rank {}".format( socket.gethostname(), - args.distributed_rank, + cfg.distributed_training.distributed_rank, ) ) @@ -221,20 +230,22 @@ def distributed_init(args): if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) - args.distributed_rank = torch.distributed.get_rank() + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm - assert xm.xrt_world_size() == args.distributed_world_size - args.device_id = xm.get_local_ordinal() - args.distributed_rank = xm.get_ordinal() + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + cfg.distributed_training.device_id = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() - if not is_master(args): + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: logging.getLogger().setLevel(logging.WARNING) - if args.model_parallel_size > 1: + if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, @@ -247,58 +258,61 @@ def distributed_init(args): "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - initialize_model_parallel(args.model_parallel_size) - model_parallel_cuda_manual_seed(args.seed) + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() - args.checkpoint_suffix += "-model_part-{0}".format(model_part_number) - return args.distributed_rank + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + return cfg.distributed_training.distributed_rank -def distributed_main(i, main, args, kwargs): - args.device_id = i - if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): - torch.cuda.set_device(args.device_id) - if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = kwargs.pop("start_rank", 0) + i +def distributed_main(i, main, cfg: DictConfig, kwargs): + cfg.distributed_training.device_id = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i - args.distributed_rank = distributed_init(args) + cfg.distributed_training.distributed_rank = distributed_init(cfg) after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) if after_distributed_init_fn: - args = after_distributed_init_fn(args) + cfg = after_distributed_init_fn(cfg) - main(args, **kwargs) + main(cfg, **kwargs) -def call_main(args, main, **kwargs): - if args.distributed_init_method is None: - infer_init_method(args) +def call_main(cfg: DictConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) - if args.distributed_init_method is not None: + if cfg.distributed_training.distributed_init_method is not None: # distributed training - if not args.distributed_no_spawn: - start_rank = args.distributed_rank - args.distributed_rank = None # assign automatically + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically kwargs["start_rank"] = start_rank torch.multiprocessing.spawn( fn=distributed_main, - args=(main, args, kwargs), - nprocs=args.distributed_num_procs, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), ) else: - distributed_main(args.device_id, main, args, kwargs) - elif getattr(args, "tpu", False) and args.distributed_world_size > 1: + distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: import torch_xla.distributed.xla_multiprocessing as xmp torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( fn=distributed_main, - args=(main, args, kwargs), + args=(main, cfg, kwargs), nprocs=8, # use all 8 TPU cores ) else: # single GPU main - main(args, **kwargs) + main(cfg, **kwargs) def get_rank(): @@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384): ) -def all_reduce_dict( - data: Mapping[str, Any], - device, - group=None, -) -> Dict[str, Any]: +def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]: """ AllReduce a dictionary of values across workers. We separately reduce items that are already on the device and items on CPU for diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index b293e54e2a..3be7078b7a 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -8,11 +8,12 @@ import copy import logging import os -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List import torch from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict from torch import nn @@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module): translation or language model. """ - def __init__(self, args, task, models): + def __init__(self, cfg, task, models): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.models = nn.ModuleList(models) self.src_dict = task.source_dictionary @@ -95,14 +96,14 @@ def __init__(self, args, task, models): # optimize model for generation for model in self.models: - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None)) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) - self.tokenizer = encoders.build_tokenizer(args) - self.bpe = encoders.build_bpe(args) + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in models] @@ -156,10 +157,11 @@ def generate( )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator(self.models, gen_args) inference_step_args = inference_step_args or {} @@ -253,8 +255,8 @@ def _build_batches( lengths = torch.LongTensor([t.numel() for t in tokens]) batch_iterator = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference(tokens, lengths), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, disable_iterator_cache=True, diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 761ffc8e61..258551c933 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -9,6 +9,7 @@ from fairseq import distributed_utils from fairseq.trainer import Trainer +from omegaconf import DictConfig try: @@ -28,14 +29,14 @@ class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training.""" - def __init__(self, args, task, model, criterion): + def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - super().__init__(args, task, model, criterion) + super().__init__(cfg, task, model, criterion, **kwargs) @property def data_parallel_world_size(self): diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index cbfc6ae4a0..76cfe3b0b4 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -96,7 +96,7 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): encoder_output_tuple = self.encoder(input) return self.decoder(encoder_output_tuple) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg): if self.encoder is not None and self.decoder is not None: logger.info("Encoder and Decoder already initialized") return @@ -111,9 +111,9 @@ def prepare_for_inference_(self, args): decoder_module_list.append(module) module_count += 1 self.model = None - self.encoder = TransformerEncoder(args, None, None, encoder_module_list) + self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list) self.decoder = TransformerDecoder( - args, None, None, decoder_module_list=decoder_module_list + cfg.model, None, None, decoder_module_list=decoder_module_list ) @staticmethod @@ -320,7 +320,7 @@ def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder_max_positions - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, cfg=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 5db6efb7b1..dc52f6e8dd 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -72,6 +72,10 @@ def build_model(cls, args, task): ) return cls(decoder) + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 7ff9442711..3b4fd51d6c 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union import fairseq from fairseq.dataclass import FairseqDataclass @@ -52,10 +50,10 @@ ] -def build_model(model_cfg: Union[DictConfig, Namespace], task): - if isinstance(model_cfg, DictConfig): - return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task) - return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task) +def build_model(cfg: DictConfig, task): + if isinstance(cfg, DictConfig): + return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task) + return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task) def register_model(name, dataclass=None): @@ -92,7 +90,8 @@ def register_model_cls(cls): ) cls.__dataclass = dataclass - MODEL_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass return cls return register_model_cls @@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name): For example:: @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') - def lstm_luong_wmt_en_de(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) (...) - The decorated function should take a single argument *args*, which is a - :class:`argparse.Namespace` of arguments parsed from the command-line. The - decorated function should modify these arguments in-place to match the - desired architecture. + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. Args: model_name (str): the name of the Model (Model must already be diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index cdabe36010..6a520cb980 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict logger = logging.getLogger(__name__) @@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = min( utils.resolve_max_positions( @@ -120,10 +121,11 @@ def generate( sample = self._build_sample(tokens) # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator([self.model], gen_args) translations = self.task.inference_step( generator, diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 0f22352b68..7263a78dc2 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -144,7 +144,9 @@ def register_classification_head( num_classes=num_classes, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, - do_spectral_norm=self.args.spectral_norm_classification_head, + do_spectral_norm=getattr( + self.args, "spectral_norm_classification_head", False + ), ) def upgrade_state_dict_named(self, state_dict, name): diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index bfd41777b2..3ebb30e3ad 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -7,6 +7,7 @@ """ import logging +from argparse import Namespace from typing import Dict, List, Optional, Tuple import torch @@ -15,8 +16,12 @@ from fairseq import utils from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary -from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig from torch import Tensor @@ -86,15 +91,26 @@ def max_positions(self): """Maximum length supported by the model.""" return None - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): @@ -133,18 +149,18 @@ def _apply(m): self.apply(_apply) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} kwargs["beamable_mm_beam_size"] = ( - None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5) + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) ) - kwargs["need_attn"] = getattr(args, "print_alignment", False) - if hasattr(args, "retain_dropout"): - kwargs["retain_dropout"] = args.retain_dropout - kwargs["retain_dropout_modules"] = getattr( - args, "retain_dropout_modules", None - ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules self.make_generation_fast_(**kwargs) def make_generation_fast_(self, **kwargs): @@ -437,15 +453,26 @@ def decoder(self): def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index e3fbbd5710..2e1f86f36e 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -194,14 +194,14 @@ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): module_class = TransformerEncoder if is_encoder else TransformerDecoder return module_class(args, lang_dict, embed_tokens) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, model_cfg=None): state_dict_subset = state_dict.copy() for k, _ in state_dict.items(): assert k.startswith("models.") lang_pair = k.split(".")[1] if lang_pair not in self.models: del state_dict_subset[k] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) @register_model_architecture("multilingual_transformer", "multilingual_transformer") diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 526823bd1f..0c723f06dd 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) # this is useful for determining the device self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 6ce216a6bf..d1a6319630 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -494,7 +494,7 @@ def base_architecture(args): args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.spectral_norm_classification_head = getattr( - args, "spectral_nrom_classification_head", False + args, "spectral_norm_classification_head", False ) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index fbb7ce2338..f87fa50d29 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -578,10 +578,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): if embed_dim != input_embed_dim else None ) - self.embed_positions = ( PositionalEmbedding( - args.max_target_positions, + self.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, @@ -963,6 +962,14 @@ def base_architecture(args): args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + @register_model_architecture("transformer", "transformer_iwslt_de_en") def transformer_iwslt_de_en(args): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 22b17f06ee..df809bdb19 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass): add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - tpu: bool = II("params.common.tpu") + tpu: bool = II("common.tpu") @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 48cd4c7314..8775aa7766 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim - self.quant_noise = getattr(args, "quant_noise_pq", 0) - self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + self.quant_noise = getattr(args, 'quant_noise_pq', 0) + self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( - activation=getattr(args, "activation_fn", "relu") + activation=getattr(args, 'activation_fn', 'relu') or "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) @@ -197,10 +197,10 @@ def __init__( if getattr(args, "activation_fn", None) is not None else "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 94eb2c7ee9..d8e581729e 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.bmuf import FairseqBMUF # noqa @@ -19,7 +17,6 @@ from fairseq.optim.shard import shard_ from omegaconf import DictConfig - __all__ = [ "FairseqOptimizer", "FP16Optimizer", @@ -27,7 +24,6 @@ "shard_", ] - ( _build_optimizer, register_optimizer, @@ -37,12 +33,12 @@ def build_optimizer( - optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs + cfg: DictConfig, params, *extra_args, **extra_kwargs ): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] params = list(filter(lambda p: p.requires_grad, params)) - return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) + return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) # automatically import any Python files in the optim/ directory diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index f678a9f56c..9b8ddffd7e 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,6 +5,7 @@ import logging import math +from collections import Collection from dataclasses import dataclass, field from typing import List @@ -14,7 +15,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class -from omegaconf import II +from omegaconf import II, DictConfig logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass): default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} ) # TODO common vars below in parent - tpu: bool = II("params.common.tpu") - lr: List[float] = II("params.optimization.lr") + tpu: bool = II("common.tpu") + lr: List[float] = II("optimization.lr") @register_optimizer("adam", dataclass=FairseqAdamConfig) @@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer): analogous to torch.optim.AdamW from PyTorch. """ - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) fused_adam_cls = get_fused_adam_class() use_fused_adam = ( - not getattr(args, "use_old_adam", False) + not getattr(cfg, "use_old_adam", False) and fused_adam_cls is not None and torch.cuda.is_available() ) - if getattr(args, "tpu", False): + if getattr(cfg, "tpu", False): # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) @@ -73,10 +74,12 @@ def optimizer_config(self): different learning rate. """ return { - "lr": self.args.lr[0], - "betas": eval(self.args.adam_betas), - "eps": self.args.adam_eps, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, } def average_params(self): diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 3312f81103..55f225ba6a 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -10,7 +10,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer -from omegaconf import II +from omegaconf import II, DictConfig @dataclass @@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass): }, ) distributed_world_size: int = II( - "params.distributed_training.distributed_world_size" + "distributed_training.distributed_world_size" ) @@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer): model-update filtering """ - def __init__(self, args, optimizer): - - super().__init__(args) + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg) self._optimizer = optimizer self._num_updates = 0 - self.sync_iter = self.args.global_sync_iter - self.block_momentum = self.args.block_momentum - self.block_lr = self.args.block_lr + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr self._reset_local_data() - self.warmup_iteration = self.args.warmup_iterations - self.use_nbm = self.args.use_nbm + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm self.initial_state = self._optimizer.state_dict() - self.average_sync = self.args.average_sync - self.world_size = self.args.distributed_world_size + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size @staticmethod def add_args(parser): diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 8a10399a8b..9c0938331d 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -9,9 +9,9 @@ class FairseqOptimizer(object): - def __init__(self, args): + def __init__(self, cfg): super().__init__() - self.args = args + self.cfg = cfg @classmethod def add_args(cls, parser): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b622fbde44..b08a7237a9 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -7,7 +7,8 @@ from itertools import chain import torch -from fairseq import optim, utils +from fairseq import optim +from omegaconf import DictConfig from .dynamic_loss_scaler import DynamicLossScaler @@ -211,7 +212,7 @@ def zero_grad(self): for fp32_params in self.fp32_params.values(): fp32_params.grad.zero_() else: - raise ("self.fp32_params must be a tensor or dict") + raise RuntimeError("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: p32.grad.zero_() @@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): Wrap an *optimizer* to support FP16 (mixed precision) training. """ - def __init__(self, args, params, fp32_optimizer, fp32_params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) self.fp16_params = params self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0]) else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: - args (argparse.Namespace): fairseq args + cfg (omegaconf.DictConfig): fairseq args params (iterable): iterable of parameters to optimize """ - flatten = not getattr(args, "fp16_no_flatten_grads", False) - if getattr(args, "bf16", False): + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): flatten = False # mixed precision is faster on TPUs without flat grads - fp32_params = cls.build_fp32_params(args, params, flatten=flatten) + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) if flatten: - fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) else: - fp32_optimizer = optim.build_optimizer(args, fp32_params) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) if flatten and not fp32_optimizer.supports_flat_params: raise RuntimeError( - "chosen optimizer does not support flat params, " - "please set --fp16-no-flatten-grads" + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" ) - return cls(args, params, fp32_optimizer, fp32_params) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) @property def optimizer(self): @@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, args, params, optimizer): + def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): if not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) - super().__init__(args) + super().__init__(cfg.optimizer) self.wrapped_optimizer = optimizer - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = ( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0] else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: args (argparse.Namespace): fairseq args params (iterable): iterable of parameters to optimize """ - fp16_optimizer = optim.build_optimizer(args, params) - return cls(args, params, fp16_optimizer) + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) @property def optimizer(self): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index 7b72c25784..f07d43c7c3 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa @@ -27,8 +25,8 @@ ) -def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): - return build_lr_scheduler_(lr_scheduler_cfg, optimizer) +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) # automatically import any Python files in the optim/lr_scheduler/ directory diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 98d557504f..c3c6663ece 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import math +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass): default=0.1, metadata={"help": "shrink factor for annealing"} ) # TODO common var for parent class - lr: List[float] = II("params.optimization.lr") - max_update: int = II("params.optimization.max_update") + lr: List[float] = II("optimization.lr") + max_update: int = II("optimization.max_update") @register_lr_scheduler("cosine", dataclass=CosineConfig) @@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler): after every iteration. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__( + self, cfg: DictConfig, fairseq_optimizer + ): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with cosine." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.max_lr - if args.warmup_init_lr < 0: - args.warmup_init_lr = args.lr[0] - - self.min_lr = args.lr[0] - self.max_lr = args.max_lr - + warmup_end_lr = cfg.max_lr + lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = lr + + self.min_lr = lr + self.max_lr = cfg.max_lr assert self.max_lr > self.min_lr, "max_lr must be more than lr" - self.t_mult = args.t_mult - self.period = args.lr_period_updates + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates if self.period <= 0: assert ( - args.max_update >= 0 + cfg.max_update >= 0 ), "Either --max_update or --lr-period-updates must be set" - self.period = args.max_update - args.warmup_updates + self.period = cfg.max_update - cfg.warmup_updates - if args.warmup_updates > 0: + if cfg.warmup_updates > 0: # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates else: self.lr_step = 1 - self.warmup_updates = args.warmup_updates - self.lr_shrink = args.lr_shrink + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -113,10 +122,10 @@ def step(self, epoch, val_loss=None): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: - curr_updates = num_updates - self.args.warmup_updates + curr_updates = num_updates - self.cfg.warmup_updates if self.t_mult != 1: i = math.floor( math.log( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 8fde0713aa..569e448262 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -11,11 +11,11 @@ class FairseqLRScheduler(object): - def __init__(self, args, optimizer): + def __init__(self, cfg, optimizer): super().__init__() if not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") - self.args = args + self.cfg = cfg self.optimizer = optimizer self.best = None diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d27261ad48..c42e090677 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): }, ) # TODO common vars at parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) @@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler): lr = decay_factor / sqrt(update_num) """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with inverse_sqrt." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.lr[0] - if args.warmup_init_lr < 0: - args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + warmup_end_lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = ( + 0 if cfg.warmup_updates > 0 else warmup_end_lr + ) # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates # then, decay prop. to the inverse square root of the update number - self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 + self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -77,8 +86,8 @@ def step(self, epoch, val_loss=None): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: self.lr = self.decay_factor * num_updates ** -0.5 self.optimizer.set_lr(self.lr) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 58d2f3560f..3982a8271d 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,12 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List import torch from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from torch.optim.optimizer import Optimizer, required from . import FairseqOptimizer, register_optimizer @@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass): momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) # TODO common vars in parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_optimizer("nag", dataclass=FairseqNAGConfig) class FairseqNAG(FairseqOptimizer): - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) self._optimizer = NAG(params, **self.optimizer_config) @property @@ -37,9 +38,11 @@ def optimizer_config(self): different learning rate. """ return { - "lr": self.args.lr[0], - "momentum": self.args.momentum, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, } diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index a035a1c1f9..ecef05b442 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -12,7 +12,7 @@ _has_fairscale = False -def shard_(args, optimizer, group): +def shard_(optimizer, group): if not _has_fairscale: raise ImportError( "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" diff --git a/fairseq/options.py b/fairseq/options.py index 1a24fccaec..6bc526ce0e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -10,13 +10,15 @@ from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.data_class import ( - CheckpointParams, - CommonEvalParams, - CommonParams, - DatasetParams, - DistributedTrainingParams, - EvalLMParams, - OptimizationParams, + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"): add_dataset_args(parser, gen=True) add_distributed_training_args(parser, default_world_size=1) add_generation_args(parser) + add_checkpoint_args(parser) if interactive: add_interactive_args(parser) return parser @@ -67,7 +70,7 @@ def get_validation_parser(default_task=None): add_dataset_args(parser, train=True) add_distributed_training_args(parser, default_world_size=1) group = parser.add_argument_group("Evaluation") - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) return parser @@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"): utils.import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) - gen_parser_from_dataclass(parser, CommonParams()) + gen_parser_from_dataclass(parser, CommonConfig()) from fairseq.registry import REGISTRIES @@ -283,7 +286,7 @@ def add_preprocess_args(parser): def add_dataset_args(parser, train=False, gen=False): group = parser.add_argument_group("dataset_data_loading") - gen_parser_from_dataclass(group, DatasetParams()) + gen_parser_from_dataclass(group, DatasetConfig()) # fmt: on return group @@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None): if default_world_size is None: default_world_size = max(1, torch.cuda.device_count()) gen_parser_from_dataclass( - group, DistributedTrainingParams(distributed_world_size=default_world_size) + group, DistributedTrainingConfig(distributed_world_size=default_world_size) ) return group @@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None): def add_optimization_args(parser): group = parser.add_argument_group("optimization") # fmt: off - gen_parser_from_dataclass(group, OptimizationParams()) + gen_parser_from_dataclass(group, OptimizationConfig()) # fmt: on return group @@ -309,117 +312,31 @@ def add_optimization_args(parser): def add_checkpoint_args(parser): group = parser.add_argument_group("checkpoint") # fmt: off - gen_parser_from_dataclass(group, CheckpointParams()) + gen_parser_from_dataclass(group, CheckpointConfig()) # fmt: on return group def add_common_eval_args(group): - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) def add_eval_lm_args(parser): group = parser.add_argument_group("LM Evaluation") add_common_eval_args(group) - gen_parser_from_dataclass(group, EvalLMParams()) + gen_parser_from_dataclass(group, EvalLMConfig()) def add_generation_args(parser): group = parser.add_argument_group("Generation") add_common_eval_args(group) - # fmt: off - group.add_argument('--beam', default=5, type=int, metavar='N', - help='beam size') - group.add_argument('--nbest', default=1, type=int, metavar='N', - help='number of hypotheses to output') - group.add_argument('--max-len-a', default=0, type=float, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--max-len-b', default=200, type=int, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--min-len', default=1, type=float, metavar='N', - help=('minimum generation length')) - group.add_argument('--match-source-len', default=False, action='store_true', - help=('generations should match the source length')) - group.add_argument('--no-early-stop', action='store_true', - help='deprecated') - group.add_argument('--unnormalized', action='store_true', - help='compare unnormalized hypothesis scores') - group.add_argument('--no-beamable-mm', action='store_true', - help='don\'t use BeamableMM in attention layers') - group.add_argument('--lenpen', default=1, type=float, - help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') - group.add_argument('--unkpen', default=0, type=float, - help='unknown word penalty: <0 produces more unks, >0 produces fewer') - group.add_argument('--replace-unk', nargs='?', const=True, default=None, - help='perform unknown replacement (optionally with alignment dictionary)') - group.add_argument('--sacrebleu', action='store_true', - help='score with sacrebleu') - group.add_argument('--score-reference', action='store_true', - help='just score the reference translation') - group.add_argument('--prefix-size', default=0, type=int, metavar='PS', - help='initialize generation by target prefix of given length') - group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N', - help='ngram blocking such that this size ngram cannot be repeated in the generation') - group.add_argument('--sampling', action='store_true', - help='sample hypotheses instead of using beam search') - group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', - help='sample from top K likely next words instead of all words') - group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS', - help='sample from the smallest set whose cumulative probability mass exceeds p for next words') - group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"], - help='enables lexically constrained decoding') - group.add_argument('--temperature', default=1., type=float, metavar='N', - help='temperature for generation') - group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', - help='number of groups for Diverse Beam Search') - group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N', - help='strength of diversity penalty for Diverse Beam Search') - group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N', - help='strength of diversity penalty for Diverse Siblings Search') - group.add_argument('--print-alignment', action='store_true', - help='if set, uses attention feedback to compute and print alignment to source tokens') - group.add_argument('--print-step', action='store_true') - - group.add_argument('--lm-path', default=None, type=str, metavar='PATH', - help='path to lm checkpoint for lm fusion') - group.add_argument('--lm-weight', default=0.0, type=float, metavar='N', - help='weight for lm probs for lm fusion') - - # arguments for iterative refinement generator - group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N', - help='if > 0.0, it penalized early-stopping in decoding.') - group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N', - help='maximum iterations for iterative refinement.') - group.add_argument('--iter-decode-force-max-iter', action='store_true', - help='if set, run exact the maximum number of iterations without early stop') - group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N', - help='if > 1, model will generate translations varying by the lengths.') - group.add_argument('--iter-decode-with-external-reranker', action='store_true', - help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'), - group.add_argument('--retain-iter-history', action='store_true', - help='if set, decoding returns the whole history of iterative refinement') - group.add_argument('--retain-dropout', action='store_true', - help='Use dropout at inference time') - group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str, - help='if set, only retain dropout for the specified modules; ' - 'if not set, then dropout will be retained for all modules') - - # special decoding format for advanced decoding. - group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) - # fmt: on + gen_parser_from_dataclass(group, GenerationConfig()) return group def add_interactive_args(parser): group = parser.add_argument_group("Interactive") - # fmt: off - group.add_argument('--buffer-size', default=0, type=int, metavar='N', - help='read this many sentences into a buffer before processing them') - group.add_argument('--input', default='-', type=str, metavar='FILE', - help='file to read from; use - for stdin') - # fmt: on + gen_parser_from_dataclass(group, InteractiveConfig()) def add_model_args(parser): diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py index 69dd61d785..11fc414c85 100644 --- a/fairseq/quantization_utils.py +++ b/fairseq/quantization_utils.py @@ -6,13 +6,14 @@ import logging from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig logger = logging.getLogger(__name__) -def quantize_model_scalar(model, args): - quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 if quant_noise_scalar > 0: # quantize_model edits the model in place scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) diff --git a/fairseq/registry.py b/fairseq/registry.py index 382dec22a8..4446084d4a 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import argparse from argparse import Namespace -from typing import Union +from typing import Union from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import populate_dataclass from omegaconf import DictConfig - REGISTRIES = {} @@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F # maintain a registry of all registries if registry_name in REGISTRIES: return # registry already exists - REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default} - - def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs): - if isinstance(args, DictConfig): - if getattr(args, "_name", None) is not None: - choice = args._name - elif hasattr(args, registry_name): - choice = args.registry_name - else: - raise RuntimeError( - f"Neither _name nor {registry_name} in args, args = {args}" - ) + REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY} + + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + elif isinstance(cfg, str): + choice = cfg else: - choice = getattr(args, registry_name, None) + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) if choice is None: if required: - raise ValueError("--{} is required!".format(registry_name)) + raise ValueError('{} is required!'.format(registry_name)) return None + cls = REGISTRY[choice] if hasattr(cls, "build_" + registry_name): builder = getattr(cls, "build_" + registry_name) else: builder = cls - if isinstance(args, Namespace): - set_defaults(args, cls) - return builder(args, *extra_args, **extra_kwargs) + + return builder(cfg, *extra_args, **extra_kwargs) def register_x(name, dataclass=None): def register_x_cls(cls): @@ -77,30 +73,10 @@ def register_x_cls(cls): cls.__dataclass = dataclass REGISTRY[name] = cls - DATACLASS_REGISTRY[name] = cls.__dataclass - REGISTRY_CLASS_NAMES.add(cls.__name__) + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass return cls return register_x_cls return build_x, register_x, REGISTRY, DATACLASS_REGISTRY - - -def set_defaults(args: Namespace, cls): - """Helper to set default arguments based on *add_args*.""" - if not hasattr(cls, "add_args"): - return - parser = argparse.ArgumentParser( - argument_default=argparse.SUPPRESS, allow_abbrev=False - ) - cls.add_args(parser) - # copied from argparse.py: - defaults = argparse.Namespace() - for action in parser._actions: - if action.dest is not argparse.SUPPRESS: - if not hasattr(defaults, action.dest): - if action.default is not argparse.SUPPRESS: - setattr(defaults, action.dest, action.default) - for key, default_value in vars(defaults).items(): - if not hasattr(args, key): - setattr(args, key, default_value) diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 4be0cb5188..8c706cb585 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -9,11 +9,12 @@ from abc import ABC, abstractmethod from fairseq import registry +from omegaconf import DictConfig class BaseScorer(ABC): - def __init__(self, args): - self.args = args + def __init__(self, cfg): + self.cfg = cfg self.ref = [] self.pred = [] @@ -39,19 +40,17 @@ def result_string(self) -> str: ) -def build_scorer(args, tgt_dict): - from fairseq import utils +def build_scorer(choice, tgt_dict): + if isinstance(choice, DictConfig): + choice = choice._name - if args.sacrebleu: - utils.deprecation_warning( - "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." - ) - args.scoring = "sacrebleu" - if args.scoring == "bleu": + if choice == "bleu": from fairseq.scoring import bleu - return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) - return _build_scorer(args) + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) # automatically import any Python files in the current directory diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 7f8bd73bf5..97de5f966e 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -6,8 +6,10 @@ import ctypes import math import sys +from dataclasses import dataclass, field import torch +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer @@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure): ] -@register_scorer("sacrebleu") +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) class SacrebleuScorer(BaseScorer): - def __init__(self, args): - super(SacrebleuScorer, self).__init__(args) + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) import sacrebleu self.sacrebleu = sacrebleu self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.sacrebleu_tokenizer, - lowercase=self.args.sacrebleu_lowercase, - character_tokenization=self.args.sacrebleu_char_level, + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='tokenizer') - parser.add_argument('--sacrebleu-lowercase', type=str, default=False, - help='apply lowercasing') - parser.add_argument('--sacrebleu-char-level', action='store_true', - help='evaluate at character level') - # fmt: on - def add_string(self, ref, pred): self.ref.append(self.tokenizer.tokenize(ref)) self.pred.append(self.tokenizer.tokenize(pred)) @@ -68,13 +71,20 @@ def result_string(self, order=4): ).format() -@register_scorer("bleu") +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) class Scorer(object): - def __init__(self, pad, eos, unk): + def __init__(self, cfg): self.stat = BleuStat() - self.pad = pad - self.eos = eos - self.unk = unk + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk try: from fairseq import libbleu diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py index dbcc6e4d10..0d0702bf15 100644 --- a/fairseq/scoring/tokenizer.py +++ b/fairseq/scoring/tokenizer.py @@ -5,6 +5,8 @@ import unicodedata +from fairseq.dataclass.utils import ChoiceEnum + class EvaluationTokenizer(object): """A generic evaluation-time tokenizer, which leverages built-in tokenizers @@ -22,7 +24,7 @@ class EvaluationTokenizer(object): SPACE = chr(32) SPACE_ESCAPE = chr(9601) - ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] + ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) def __init__( self, @@ -33,7 +35,7 @@ def __init__( ): from sacrebleu.tokenizers import TOKENIZERS - assert tokenizer_type in self.ALL_TOKENIZER_TYPES + assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" self.lowercase = lowercase self.punctuation_removal = punctuation_removal self.character_tokenization = character_tokenization diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 21efefd9b8..633dc47c24 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,14 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer -@register_scorer("wer") +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) class WerScorer(BaseScorer): - def __init__(self, args): - super().__init__(args) + def __init__(self, cfg): + super().__init__(cfg) self.reset() try: import editdistance as ed @@ -18,26 +35,12 @@ def __init__(self, args): raise ImportError("Please install editdistance to use WER scorer") self.ed = ed self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.wer_tokenizer, - lowercase=self.args.wer_lowercase, - punctuation_removal=self.args.wer_remove_punct, - character_tokenization=self.args.wer_char_level, + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--wer-tokenizer', type=str, default='none', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='sacreBLEU tokenizer to use for evaluation') - parser.add_argument('--wer-remove-punct', action='store_true', - help='remove punctuation') - parser.add_argument('--wer-char-level', action='store_true', - help='evaluate at character level') - parser.add_argument('--wer-lowercase', action='store_true', - help='lowercasing') - # fmt: on - def reset(self): self.distance = 0 self.ref_length = 0 diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index e0abce253c..41f461f802 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union from fairseq.dataclass import FairseqDataclass from omegaconf import DictConfig @@ -22,10 +20,10 @@ TASK_CLASS_NAMES = set() -def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs): - if isinstance(task_cfg, DictConfig): - return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs) - return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs) +def setup_task(cfg: DictConfig, **kwargs): + if isinstance(cfg, DictConfig): + return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs) + return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs) def register_task(name, dataclass=None): @@ -70,7 +68,8 @@ def register_task_cls(cls): ) cls.__dataclass = dataclass - TASK_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass return cls diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index ff2342afa9..a831ad6ee8 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -79,7 +79,7 @@ def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + args (omegaconf.DictConfig): parsed command-line arguments """ return cls(args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 0a96aeb1ea..3cdb64cfae 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -12,6 +12,7 @@ from fairseq import metrics, search, tokenizer, utils from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -39,8 +40,8 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() - def __init__(self, args): - self.args = args + def __init__(self, cfg: DictConfig, **kwargs): + self.cfg = cfg self.datasets = {} self.dataset_to_epoch_iter = {} @@ -78,16 +79,16 @@ def build_dictionary( return d @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ - return cls(args, **kwargs) + return cls(cfg, **kwargs) def has_sharded_data(self, split): - return os.pathsep in getattr(self.args, "data", "") + return os.pathsep in getattr(self.cfg, "data", "") def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. @@ -254,39 +255,39 @@ def get_batch_iterator( return epoch_iter - def build_model(self, args): + def build_model(self, cfg: DictConfig): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configuration object Returns: a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils - model = models.build_model(args, self) - if getattr(args, "tpu", False): + model = models.build_model(cfg, self) + if getattr(cfg, "tpu", False): model.prepare_for_tpu_() - model = quantization_utils.quantize_model_scalar(model, args) + model = quantization_utils.quantize_model_scalar(model, cfg) return model - def build_criterion(self, args): + def build_criterion(self, cfg: DictConfig): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configration object Returns: a :class:`~fairseq.criterions.FairseqCriterion` instance """ from fairseq import criterions - return criterions.build_criterion(args, self) + return criterions.build_criterion(cfg, self) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 8792c6481c..6e85417ff5 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -28,7 +28,7 @@ from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task from omegaconf import II @@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass): }, ) # TODO common vars below add to parent - seed: int = II("params.common.seed") + seed: int = II("common.seed") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( - "params.dataset.dataset_impl" + "dataset.dataset_impl" ) - data_buffer_size: int = II("params.dataset.data_buffer_size") - tpu: bool = II("params.common.tpu") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(FairseqTask): +class LanguageModelingTask(LegacyFairseqTask): """ Train a language model. diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index f6cb17f12a..26e0b529d5 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -117,7 +117,7 @@ def setup_task(cls, args, **kwargs): return cls(args, dicts, training) @classmethod - def prepare(cls, args, **kargs): + def update_args(cls, args): args.left_pad_source = utils.eval_bool(args.left_pad_source) args.left_pad_target = utils.eval_bool(args.left_pad_target) @@ -127,6 +127,10 @@ def prepare(cls, args, **kargs): ) if isinstance(args.lang_pairs, str): args.lang_pairs = args.lang_pairs.split(",") + + @classmethod + def prepare(cls, args, **kargs): + cls.update_args(args) sorted_langs = sorted( list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) ) @@ -298,6 +302,10 @@ def check_args(): if len(messages) > 0: raise ValueError(" ".join(messages)) + # Update args -> the fact that the constructor here + # changes the args object doesn't mean you get the same one here + self.update_args(args) + # Check if task args are consistant with model args check_args() diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 6d222f0de3..c200bb1407 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -13,7 +13,7 @@ SpeechToTextDataset, SpeechToTextDatasetCreator, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task logging.basicConfig( @@ -25,7 +25,7 @@ @register_task("speech_to_text") -class SpeechToTextTask(FairseqTask): +class SpeechToTextTask(LegacyFairseqTask): @staticmethod def add_args(parser): parser.add_argument("data", help="manifest root path") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0069b79425..8b00e8b431 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -11,15 +11,18 @@ import logging import sys import time +from argparse import Namespace from itertools import chain from typing import Any, Dict, List import torch from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -35,19 +38,25 @@ class Trainer(object): communication of the gradients across workers. """ - def __init__(self, args, task, model, criterion, quantizer=None): - self.args = args + def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) - - self.tpu = getattr(args, "tpu", False) - self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: - self.device = utils.get_tpu_device(args) + self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") @@ -58,19 +67,21 @@ def __init__(self, args, task, model, criterion, quantizer=None): import torch_xla.core.xla_model as xm self._model = xm.send_cpu_data_to_device(self._model, self.device) - if args.fp16: + if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() - elif args.bf16: + elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - if not args.pipeline_model_parallel: + if not cfg.distributed_training.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) - self.pipeline_model_parallel = args.pipeline_model_parallel + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: - self.last_device = torch.device(args.pipeline_devices[-1]) + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) # check that shared parameters are preserved after device transfer for shared_param in shared_params: @@ -129,7 +140,7 @@ def reinitialize(self): @property def data_parallel_world_size(self): - return self.args.distributed_world_size + return self.cfg.distributed_training.distributed_world_size @property def data_parallel_process_group(self): @@ -140,11 +151,11 @@ def data_parallel_process_group(self): @property def data_parallel_rank(self): - return self.args.distributed_rank + return self.cfg.distributed_training.distributed_rank @property def is_data_parallel_master(self): - return distributed_utils.is_master(self.args) + return distributed_utils.is_master(self.cfg.distributed_training) @property def criterion(self): @@ -152,11 +163,11 @@ def criterion(self): if ( utils.has_parameters(self._criterion) and self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_criterion = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._criterion, process_group=self.data_parallel_process_group, ) @@ -169,11 +180,11 @@ def model(self): if self._wrapped_model is None: if ( self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_model = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._model, process_group=self.data_parallel_process_group, ) @@ -201,44 +212,51 @@ def _build_optimizer(self): ) ) - if self.args.fp16 or self.args.bf16: + if self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " "please switch to FP32 which is likely to be faster" ) - if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16: + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( - self.args, params + self.cfg, params ) else: - self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) else: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: logger.info("NOTE: your device may support faster training with --fp16") - self._optimizer = optim.build_optimizer(self.args, params) + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) - if self.args.use_bmuf: - self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) - if self.args.zero_sharding == "os": + if self.cfg.distributed_training.zero_sharding == "os": if ( - self.args.fp16 - and not self.args.memory_efficient_fp16 - and not self.args.memory_efficient_bf16 - ) and not self.args.fp16_no_flatten_grads: + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: raise ValueError( "ZeRO is incomptabile with fp16 and flattened grads. " "Please use --fp16-no-flatten-grads" ) else: - optim.shard_( - self.args, self._optimizer, self.data_parallel_process_group - ) + optim.shard_(self._optimizer, self.data_parallel_process_group) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. - self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) self._lr_scheduler.step_update(0) def consolidate_optimizer(self): @@ -253,7 +271,7 @@ def save_checkpoint(self, filename, extra_state): extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( filename, - self.args, + self.cfg, self.get_model().state_dict(), self.get_criterion(), self.optimizer, @@ -277,11 +295,10 @@ def load_checkpoint( bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) - # load model parameters try: self.get_model().load_state_dict( - state["model"], strict=True, args=self.args + state["model"], strict=True, model_cfg=self.cfg.model ) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( @@ -355,28 +372,28 @@ def get_train_iterator( if load_dataset: logger.info("loading train data for epoch {}".format(epoch)) self.task.load_dataset( - self.args.train_subset, + self.cfg.dataset.train_subset, epoch=epoch, combine=combine, data_selector=data_selector, ) batch_iterator = self.task.get_batch_iterator( - dataset=self.task.dataset(self.args.train_subset), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), - self.args.max_tokens, + self.cfg.dataset.max_tokens, ), ignore_invalid_inputs=True, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, - num_workers=self.args.num_workers, + num_workers=self.cfg.dataset.num_workers, epoch=epoch, - data_buffer_size=self.args.data_buffer_size, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -390,19 +407,19 @@ def get_valid_iterator( """Return an EpochBatchIterator over given validation subset for a given epoch.""" batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(subset), - max_tokens=self.args.max_tokens_valid, - max_sentences=self.args.batch_size_valid, + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), ), - ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers, - data_buffer_size=self.args.data_buffer_size, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -504,7 +521,7 @@ def maybe_no_sync(): self.zero_grad() if self.cuda: torch.cuda.empty_cache() - if self.args.distributed_world_size == 1: + if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e @@ -565,7 +582,7 @@ def maybe_no_sync(): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). - if not self.args.use_bmuf: + if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size ) @@ -575,12 +592,12 @@ def maybe_no_sync(): with torch.autograd.profiler.record_function("clip-grads"): # clip grads - grad_norm = self.clip_grad_norm(self.args.clip_norm) + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers if ( - not self.args.use_bmuf - and self.args.distributed_wrapper != "SlowMo" + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" and not self.tpu ): self._check_grad_norms(grad_norm) @@ -624,7 +641,10 @@ def maybe_no_sync(): self.optimizer.optimizer ) - if not overflow or self.args.distributed_wrapper == "SlowMo": + if ( + not overflow + or self.cfg.distributed_training.distributed_wrapper == "SlowMo" + ): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: @@ -636,7 +656,7 @@ def maybe_no_sync(): # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} - if self.get_num_updates() % self.args.log_interval == 0: + if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 @@ -677,16 +697,16 @@ def maybe_no_sync(): # clear CUDA cache to reduce memory fragmentation if ( self.cuda - and self.args.empty_cache_freq > 0 + and self.cfg.common.empty_cache_freq > 0 and ( - (self.get_num_updates() + self.args.empty_cache_freq - 1) - % self.args.empty_cache_freq + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() - if self.args.fp16: + if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, @@ -883,10 +903,10 @@ def apply_bfloat16(t): return t.to(dtype=torch.bfloat16) return t - if self.args.fp16: + if self.cfg.common.fp16: sample = utils.apply_to_sample(apply_half, sample) - if self.args.bf16: + if self.cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) return sample @@ -894,7 +914,7 @@ def apply_bfloat16(t): def _set_seed(self): # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints - seed = self.args.seed + self.get_num_updates() + seed = self.cfg.common.seed + self.get_num_updates() utils.set_torch_seed(seed) def _sync_stats(self): @@ -902,10 +922,12 @@ def _sync_stats(self): # BMUF and it's a bmuf sync with warmup iterations completed before. if self.data_parallel_world_size == 1: return False - elif self.args.use_bmuf: - return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and ( + elif self.cfg.optimization.use_bmuf: + return ( + self.get_num_updates() + 1 + ) % self.cfg.bmuf.global_sync_iter == 0 and ( self.get_num_updates() + 1 - ) > self.args.warmup_iterations + ) > self.cfg.bmuf.warmup_iterations else: return True @@ -950,7 +972,7 @@ def _all_gather_list_sync( zip( *distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), - max_size=getattr(self.args, "all_gather_list_size", 16384), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), group=self.data_parallel_process_group, ) ) @@ -1038,11 +1060,11 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) - if self.args.clip_norm > 0: + if self.cfg.optimization.clip_norm > 0: metrics.log_scalar( "clip", torch.where( - grad_norm > self.args.clip_norm, + grad_norm > self.cfg.optimization.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), @@ -1087,7 +1109,7 @@ def _check_xla_compilation(self): logger.warning( "XLA compilation detected on device #{}; too many of these can lead " "to slow training, but we expect a few in the beginning".format( - self.args.distributed_rank + self.cfg.distributed_training.distributed_rank ) ) self._num_xla_compiles = num_xla_compiles diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 9a4ff8ee39..4621a66acd 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -11,13 +11,19 @@ import logging import math import os +from argparse import Namespace import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -60,65 +66,60 @@ def __str__(self): ) -def main(parsed_args, **unused_kwargs): - assert parsed_args.path is not None, "--path required for evaluation!" +def main(cfg: DictConfig, override_args=None, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - if torch.cuda.is_available() and not parsed_args.cpu: - torch.cuda.set_device(parsed_args.device_id) + utils.import_user_module(cfg.common) - utils.import_user_module(parsed_args) + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu - logger.info(parsed_args) + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) - use_cuda = torch.cuda.is_available() and not parsed_args.cpu + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None - task = tasks.setup_task(parsed_args) + logger.info(cfg) # Load ensemble - logger.info("loading model(s) from {}".format(parsed_args.path)) - models, args = checkpoint_utils.load_model_ensemble( - parsed_args.path.split(os.pathsep), - arg_overrides=eval(parsed_args.model_overrides), - task=task, - suffix=getattr(parsed_args, "checkpoint_suffix", ""), - strict=(parsed_args.checkpoint_shard_count == 1), - num_shards=parsed_args.checkpoint_shard_count, - ) - - for arg in vars(parsed_args).keys(): - if arg not in { - "self_target", - "future_target", - "past_target", - "tokens_per_sample", - "output_size_dictionary", - "add_bos_token", - }: - setattr(args, arg, getattr(parsed_args, arg)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size - args.tokens_per_sample -= args.context_window - task = tasks.setup_task(args) + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) # Load dataset splits - task.load_dataset(args.gen_subset) - dataset = task.dataset(args.gen_subset) - if args.context_window > 0: + gen_subset = cfg.dataset.gen_subset + task.load_dataset(gen_subset) + dataset = task.dataset(gen_subset) + if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, - tokens_per_sample=args.tokens_per_sample, - context_window=args.context_window, + tokens_per_sample=cfg.task.tokens_per_sample, + context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) - logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) + logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: - if args.fp16: + if use_fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) assert len(models) > 0 @@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs): itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens or 36000, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens or 36000, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() - scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) + scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 - if args.remove_bpe is not None: - if args.remove_bpe == "sentencepiece": + if cfg.common_eval.remove_bpe is not None: + if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: - bpe_cont = args.remove_bpe.rstrip() + bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) @@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if getattr(args, "add_bos_token", False): + if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs): score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks - if args.output_word_probs or args.output_word_stats: + if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False @@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs): ) is_bpe = False w = "" - if args.output_word_probs: + if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " @@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs): ) ) - if args.output_word_stats: + if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws) @@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs): def cli_main(): parser = options.get_eval_lm_parser() args = options.parse_args_and_arch(parser) - distributed_utils.call_main(args, main) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + distributed_utils.call_main(args, main, override_args=override_args) if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 8ddf981cc3..6a6f7465cb 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -12,33 +12,45 @@ import math import os import sys +from argparse import Namespace from itertools import chain import numpy as np import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig -def main(args): - assert args.path is not None, "--path required for generation!" +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for generation!" assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - args.replace_unk is None or args.dataset_impl == "raw" + cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" - if args.results_path is not None: - os.makedirs(args.results_path, exist_ok=True) + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join( - args.results_path, "generate-{}.txt".format(args.gen_subset) + cfg.common_eval.results_path, + "generate-{}.txt".format(cfg.dataset.gen_subset), ) with open(output_path, "w", buffering=1, encoding="utf-8") as h: - return _main(args, h) + return _main(cfg, h) else: - return _main(args, sys.stdout) + return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): @@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator): return {generator.eos} -def _main(args, output_file): +def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -57,22 +69,22 @@ def _main(args, output_file): ) logger = logging.getLogger("fairseq_cli.generate") - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) # Set dictionaries try: @@ -81,32 +93,30 @@ def _main(args, output_file): src_dict = None tgt_dict = task.target_dictionary - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) - if args.lm_path is not None: - overrides["data"] = args.data + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( - [args.lm_path], - arg_overrides=overrides, - task=None, + [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})" + f"as target dict and is located in the data dir ({cfg.task.data})" ) raise @@ -118,49 +128,50 @@ def _main(args, output_file): for model in chain(models, lms): if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), *[model.max_positions() for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() - extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight} + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} generator = task.build_generator( - models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) # Handle tokenization and BPE - tokenizer = task.build_tokenizer(args) - bpe = task.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: @@ -169,7 +180,7 @@ def decode_fn(x): x = tokenizer.decode(x) return x - scorer = scoring.build_scorer(args, tgt_dict) + scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True @@ -180,8 +191,8 @@ def decode_fn(x): continue prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, : args.prefix_size] + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: @@ -217,19 +228,21 @@ def decode_fn(x): # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: - src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) - target_str = task.dataset(args.gen_subset).tgt.get_original_text( + src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( + sample_id + ) + target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, - args.remove_bpe, + cfg.common_eval.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator @@ -240,25 +253,25 @@ def decode_fn(x): if has_target: target_str = decode_fn(target_str) - if not args.quiet: + if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][: args.nbest]): + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) - if not args.quiet: + if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( @@ -286,7 +299,7 @@ def decode_fn(x): file=output_file, ) - if args.print_alignment: + if cfg.generation.print_alignment: print( "A-{}\t{}".format( sample_id, @@ -300,13 +313,13 @@ def decode_fn(x): file=output_file, ) - if args.print_step: + if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) - if getattr(args, "retain_iter_history", False): + if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), @@ -323,7 +336,7 @@ def decode_fn(x): # Score only the top hypothesis if has_target and j == 0: - if align_dict is not None or args.remove_bpe is not None: + if align_dict is not None or cfg.common_eval.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True @@ -353,8 +366,8 @@ def decode_fn(x): ) ) if has_target: - if args.bpe and not args.sacrebleu: - if args.remove_bpe: + if cfg.bpe and not cfg.generation.sacrebleu: + if cfg.common_eval.remove_bpe: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) @@ -365,7 +378,7 @@ def decode_fn(x): # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( - args.gen_subset, args.beam, scorer.result_string() + cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) @@ -380,4 +393,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index de3893a385..ddd2617c3d 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -7,20 +7,27 @@ Translate raw text with a trained model. Batches data on-the-fly. """ +import ast import fileinput import logging import math import os import sys import time +from argparse import Namespace from collections import namedtuple import numpy as np import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -49,11 +56,11 @@ def buffered_read(input, buffer_size): yield buffer -def make_batches(lines, args, task, max_positions, encode_fn): +def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) - if args.constraints: + if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] @@ -79,7 +86,7 @@ def encode_fn_target(x): for src_str in lines ] - if args.constraints: + if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None @@ -89,10 +96,10 @@ def encode_fn_target(x): dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=max_positions, - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] @@ -108,45 +115,50 @@ def encode_fn_target(x): ) -def main(args): +def main(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + start_time = time.time() total_translate_time = 0 - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.buffer_size < 1: - args.buffer_size = 1 - if args.max_tokens is None and args.batch_size is None: - args.batch_size = 1 + if cfg.interactive.buffer_size < 1: + cfg.interactive.buffer_size = 1 + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.batch_size = 1 assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - not args.batch_size or args.batch_size <= args.buffer_size + not cfg.dataset.batch_size + or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" - logger.info(args) + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation - task = tasks.setup_task(args) + task = tasks.setup_task(cfg.task) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(os.pathsep), - arg_overrides=eval(args.model_overrides), + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries @@ -155,18 +167,20 @@ def main(args): # Optimize ensemble for generation for model in models: - if args.fp16: + if model is None: + continue + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Initialize generator - generator = task.build_generator(models, args) + generator = task.build_generator(models, cfg.task) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(args) - bpe = encoders.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: @@ -184,25 +198,25 @@ def decode_fn(x): # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) - if args.constraints: + if cfg.generation.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) - if args.buffer_size > 1: - logger.info("Sentence buffer size: %s", args.buffer_size) + if cfg.interactive.buffer_size > 1: + logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Type the input sentence and press return:") start_id = 0 - for inputs in buffered_read(args.input, args.buffer_size): + for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): results = [] - for batch in make_batches(inputs, args, task, max_positions, encode_fn): + for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths @@ -226,7 +240,7 @@ def decode_fn(x): translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] - if args.constraints: + if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) @@ -246,25 +260,25 @@ def decode_fn(x): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print( "C-{}\t{}".format( - id_, tgt_dict.string(constraint, args.remove_bpe) + id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe) ) ) # Process top predictions - for hypo in hypos[: min(len(hypos), args.nbest)]: + for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) @@ -285,7 +299,7 @@ def decode_fn(x): ), ) ) - if args.print_alignment: + if cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment] ) @@ -308,4 +322,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index b8354eb95a..e06d67259d 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -78,7 +78,13 @@ def score(fdsys): def score(fdsys): with open(args.ref) as fdref: - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): sys_tok = dict.encode_line(sys_tok) ref_tok = dict.encode_line(ref_tok) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index cd3a93b13e..4c00761060 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -11,11 +11,13 @@ import logging import math import os -import random import sys +from typing import Dict, Optional, Any, List, Tuple, Callable import numpy as np import torch +from hydra.core.config_store import ConfigStore + from fairseq import ( checkpoint_utils, distributed_utils, @@ -25,8 +27,12 @@ utils, ) from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig +from hydra.experimental import initialize +from fairseq.dataclass.data_class import register_hydra_cfg from fairseq.trainer import Trainer @@ -39,90 +45,86 @@ logger = logging.getLogger("fairseq_cli.train") -def main(args): - utils.import_user_module(args) +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - assert ( - args.max_tokens is not None or args.batch_size is not None - ), "Must specify batch size either with --max-tokens or --batch-size" + utils.import_user_module(cfg.common) + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + 'Must specify batch size either with --max-tokens or --batch-size' metrics.reset() - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(args): - checkpoint_utils.verify_checkpoint_directory(args.save_dir) + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args - logger.info(args) + logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(args) - + task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(","): + for valid_sub_split in cfg.dataset.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion - model = task.build_model(args) - criterion = task.build_criterion(args) + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) - logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) logger.info( - "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) - ) - logger.info( - "num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # (optionally) Configure quantization - if args.quantization_config_path is not None: + if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( - config_path=args.quantization_config_path, - max_epoch=args.max_epoch, - max_update=args.max_update, + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer - if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion, quantizer) + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) else: - trainer = MegatronTrainer(args, task, model, criterion) + trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( - "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) - ) - logger.info( - "max tokens per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.batch_size - ) - ) + logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size)) + logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - args, + cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) - # Train until the learning rate gets too small - max_epoch = args.max_epoch or math.inf + max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - - while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): # train for one epoch - valid_losses, should_stop = train(args, trainer, task, epoch_itr) + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break @@ -140,15 +142,15 @@ def main(args): logger.info("done training in {:.1f} seconds".format(train_meter.sum)) -def should_stop_early(args, valid_loss): +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False - if args.patience <= 0: + if cfg.checkpoint.patience <= 0: return False def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): @@ -157,48 +159,43 @@ def is_better(a, b): return False else: should_stop_early.num_runs += 1 - if should_stop_early.num_runs >= args.patience: - logger.info( - "early stop since valid performance hasn't improved for last {} runs".format( - args.patience - ) - ) + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience)) return True else: return False @metrics.aggregate("train") -def train(args, trainer, task, epoch_itr): +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( - fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( - args.update_freq[epoch_itr.epoch - 1] - if epoch_itr.epoch <= len(args.update_freq) - else args.update_freq[-1] + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, "tpu", False): + if getattr(cfg.common, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) trainer.begin_epoch(epoch_itr.epoch) - valid_losses = [None] - valid_subsets = args.valid_subset.split(",") + valid_subsets = cfg.dataset.valid_subset.split(',') should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr): if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: + if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) @@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr): end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( - args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: @@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr): return valid_losses, should_stop -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() - max_update = args.max_update or math.inf + max_update = cfg.optimization.max_update or math.inf do_save = ( - (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( - args.save_interval_updates > 0 + cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( - args.validate_interval_updates > 0 + cfg.dataset.validate_interval_updates > 0 and num_updates > 0 - and num_updates % args.validate_interval_updates == 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not args.disable_validation + ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( - should_stop_early(args, valid_losses[0]) + should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( - args.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop -def get_training_stats(stats): +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats -def validate(args, trainer, task, epoch_itr, subsets): +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" - if args.fixed_validation_seed is not None: + if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation - utils.set_torch_seed(args.fixed_validation_seed) + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] @@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets): # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics @@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric]) + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer, stats): +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): - key = "best_{0}".format(args.best_checkpoint_metric) - best_function = max if args.maximize_best_checkpoint_metric else min + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] ) return stats -def cli_main(modify_parser=None): +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) else: - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) -if __name__ == "__main__": +if __name__ == '__main__': + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index df857550d1..368c9cb581 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -u -#!/usr/bin/env python3 -u +# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -8,11 +8,17 @@ import logging import os import sys +from argparse import Namespace from itertools import chain import torch from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -24,18 +30,21 @@ logger = logging.getLogger("fairseq_cli.validate") -def main(args, override_args=None): - utils.import_user_module(args) +def main(cfg: DictConfig, override_args=None): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) assert ( - args.max_tokens is not None or args.batch_size is not None + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" - use_fp16 = args.fp16 - use_cuda = torch.cuda.is_available() and not args.cpu + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: - torch.cuda.set_device(args.device_id) + torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) @@ -44,11 +53,11 @@ def main(args, override_args=None): overrides = None # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [args.path], + [cfg.common_eval.path], arg_overrides=overrides, - suffix=getattr(args, "checkpoint_suffix", ""), + suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] @@ -63,10 +72,10 @@ def main(args, override_args=None): logger.info(model_args) # Build criterion - criterion = task.build_criterion(model_args) + criterion = task.build_criterion(model_args.criterion) criterion.eval() - for subset in args.valid_subset.split(","): + for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) @@ -76,26 +85,26 @@ def main(args, override_args=None): # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] @@ -105,10 +114,10 @@ def main(args, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if args.distributed_world_size > 1: + if cfg.distributed_training.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, - max_size=getattr(args, "all_gather_list_size", 16384), + max_size=cfg.common.all_gather_list_size, ) log_outputs = list(chain.from_iterable(log_outputs)) @@ -131,4 +140,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 0341031394..8c5d414e7b 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -272,6 +272,7 @@ def setUpModel(self, model_cls, extra_args_setters=None): model_cls.add_args(parser) args = parser.parse_args([]) + if extra_args_setters is not None: for args_setter in extra_args_setters: args_setter(args) @@ -515,9 +516,7 @@ def setUpArgs(self): def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) - self.criterion = self.criterion_cls.build_criterion( - args=args, task=DummyTask(args) - ) + self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args)) def get_src_tokens(self, correct_prediction, aggregate): """ diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py index 0165b2955b..e7aa6da1ca 100644 --- a/tests/test_bmuf.py +++ b/tests/test_bmuf.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from fairseq import distributed_utils, optim - +from omegaconf import OmegaConf class Model(nn.Module): def __init__(self, input_size, output_size): @@ -23,13 +23,14 @@ def forward(self, input): return output -def setup_model_loss_criterion(args, rank, is_cuda): +def setup_model_loss_criterion(cfg, args, rank, is_cuda): """ setup model, criterion and optimizer based on input args """ args.distributed_rank = rank - if args.distributed_world_size > 1: - distributed_utils.distributed_init(args) + cfg.distributed_training.distributed_rank = args.distributed_rank + if cfg.distributed_training.distributed_world_size > 1: + distributed_utils.distributed_init(cfg) torch.manual_seed(1) model = Model(args.input_size, args.nb_classes) loss_fn = nn.CrossEntropyLoss() @@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda): loss_fn = loss_fn.cuda() optimizer = optim.sgd.SGD(args, model.parameters()) - optimizer = optim.FairseqBMUF(args, optimizer) + optimizer = optim.FairseqBMUF( + cfg=cfg.bmuf, + optimizer=optimizer + ) return model, loss_fn, optimizer @@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused): optimizer.step() -def single_gpu_training(args, rank, iterations, shared_results): +def single_gpu_training(cfg, args, rank, iterations, shared_results): is_cuda = torch.cuda.is_available() if is_cuda: torch.cuda.set_device(rank) - model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda) + model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) for _ in range(iterations): input = torch.randn(1, args.input_size) @@ -103,18 +107,44 @@ def setup_args(): args.distributed_init_host = "localhost" args.distributed_port = port + 1 args.local_world_size = args.distributed_world_size - return args + + cfg = OmegaConf.create() + cfg.optimization = OmegaConf.create() + cfg.common = OmegaConf.create() + cfg.distributed_training = OmegaConf.create() + cfg.dataset = OmegaConf.create() + cfg.bmuf = OmegaConf.create() + cfg.optimizer = OmegaConf.create() + + cfg.bmuf.global_sync_iter = args.global_sync_iter + cfg.bmuf.block_momentum = args.block_momentum + cfg.bmuf.block_lr = args.block_lr + cfg.dataset.batch_size = args.batch_size + cfg.optimization.lr = args.lr + cfg.optimizer.momentum = args.momentum + cfg.optimizer.weight_decay = args.weight_decay + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.use_nbm = args.use_nbm + cfg.bmuf.average_sync = args.average_sync + cfg.common.model_parallel_size = args.model_parallel_size + cfg.distributed_training.distributed_backend = args.distributed_backend + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.distributed_training.distributed_init_method = args.distributed_init_method + cfg.distributed_training.distributed_port = args.distributed_port + + return cfg, args @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") class TestBMUF(unittest.TestCase): - def bmuf_process(self, args, iterations): + def bmuf_process(self, cfg, args, iterations): processes = [] results = Manager().dict() ctx = torch.multiprocessing.get_context("spawn") for rank in range(args.distributed_world_size): p = ctx.Process( - target=single_gpu_training, args=(args, rank, iterations, results) + target=single_gpu_training, args=(cfg, args, rank, iterations, results) ) p.start() processes.append(p) @@ -125,19 +155,20 @@ def bmuf_process(self, args, iterations): def test_bmuf_sync(self): # Train model for 1 iteration and do bmuf sync without doing warmup - args = setup_args() + cfg, args = setup_args() iterations = 1 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_warmup_sync(self): # Train model for 20 iteration and do warmup sync without doing bmuf sync - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) @@ -145,22 +176,27 @@ def test_warmup_sync(self): def test_warmup_sync_bmuf_sync(self): # Train model for 25 iteration and do warmup sync after 20 iteration # and bmuf sync after 25 iteration - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 args.global_sync_iter = 5 + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.global_sync_iter = args.global_sync_iter iterations = 25 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_single_gpu_bmuf(self): # Train model for 5 iterations and use GPU 1 - args = setup_args() + cfg, args = setup_args() args.distributed_world_size = 1 args.warmup_iterations = 5 + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) assert len(results) == 1 def assertAlmostEqual(self, t1, t2): diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index c4195273e3..aa6a863d32 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -9,6 +9,7 @@ import torch from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -27,17 +28,23 @@ def setUp(self): self.model.cuda().half() self.params = list(self.model.parameters()) - self.namespace_dls = argparse.Namespace( - optimizer="adam", - lr=[0.1], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + self.cfg_dls = OmegaConf.create( + { + "optimizer": { + "_name": "adam", + "lr": [0.1], + "adam_betas": "(0.9, 0.999)", + "adam_eps": 1e-8, + "weight_decay": 0.0, + }, + "common": { + "fp16_init_scale": 1, + "fp16_scale_window": 1, + "fp16_scale_tolerance": 1, + "threshold_loss_scale": 1, + "min_loss_scale": 1e-4, + }, + } ) def run_iter(self, model, params, optimizer): @@ -68,7 +75,7 @@ def run_iter(self, model, params, optimizer): def test_mixed_precision(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) + optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) self.assertTrue( @@ -87,9 +94,7 @@ def test_mixed_precision(self): def test_memory_efficient(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = MemoryEfficientFP16Optimizer.build_optimizer( - self.namespace_dls, params - ) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index fd5edd43d6..353ac67478 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -6,6 +6,7 @@ import logging import unittest +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models.transformer import TransformerModel from tests.test_sequence_generator import get_dummy_task_and_parser @@ -25,7 +26,8 @@ def tearDown(self): def test_sets_inference_dropout_to_true(self): self.args.retain_dropout = True self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -33,7 +35,8 @@ def test_sets_inference_dropout_to_true(self): def test_inference_dropout_false_by_default(self): self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert not self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -59,7 +62,8 @@ def test_retain_modules(self): "TransformerEncoderLayer", ] self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.decoder.layers: diff --git a/tests/test_memory_efficient_fp16.py b/tests/test_memory_efficient_fp16.py index e10636d96a..2bf2f29888 100644 --- a/tests/test_memory_efficient_fp16.py +++ b/tests/test_memory_efficient_fp16.py @@ -10,6 +10,7 @@ import torch from fairseq.optim.adam import FairseqAdam from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -26,25 +27,36 @@ def test_load_state_dict(self): params = list(model.parameters()) # initialize memory efficient FP16 optimizer + # with pseudo DictConfigs optimizer = FairseqAdam( - argparse.Namespace( - lr=[0.00001], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, + cfg=OmegaConf.create( + vars( + argparse.Namespace( + adam_betas="(0.9, 0.999)", + adam_eps=1e-8, + weight_decay=0.0, + lr=[0.00001], + ) + ) ), - params, + params=params, ) me_optimizer = MemoryEfficientFP16Optimizer( - argparse.Namespace( - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + cfg=OmegaConf.create( + { + "common": vars( + argparse.Namespace( + fp16_init_scale=1, + fp16_scale_window=1, + fp16_scale_tolerance=1, + threshold_loss_scale=1, + min_loss_scale=1e-4, + ) + ) + } ), - params, - optimizer, + params=params, + optimizer=optimizer, ) # optimizer state is created in the first step diff --git a/tests/test_train.py b/tests/test_train.py index 1b7e027c0c..57daa194b2 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -11,6 +11,7 @@ import torch from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf def mock_trainer(epoch, num_updates, iterations_in_epoch): @@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc return trainer, epoch_itr -def get_mock_args(finetune_from_model=None): - args_mock = MagicMock() - args_mock.optimizer_overrides = "{}" - args_mock.reset_dataloader = False - args_mock.reset_meters = False - args_mock.reset_optimizer = False - args_mock.reset_lr_scheduler = False - args_mock.finetune_from_model = finetune_from_model - args_mock.model_parallel_size = 1 - return args_mock +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock class TestLoadCheckpoint(unittest.TestCase): def setUp(self): - self.args_mock = get_mock_args() + self.cfg_mock = get_mock_cfg(None) self.patches = { "os.makedirs": MagicMock(), "os.path.join": MagicMock(), @@ -91,7 +100,9 @@ def test_load_partial_checkpoint(self): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) @@ -120,7 +131,9 @@ def test_load_full_checkpoint(self): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) @@ -133,7 +146,9 @@ def test_load_no_checkpoint(self): trainer.get_train_iterator = MagicMock(return_value=epoch_itr) self.patches["os.path.isfile"].return_value = False - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) @@ -152,10 +167,12 @@ def test_finetune_from_model_args_conflict(self): "reset_dataloader", ]: with self.subTest(arg=arg): - args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") - setattr(args_mock, arg, True) + cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt") + cfg_mock["checkpoint"][arg] = True with self.assertRaises(Exception) as context: - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + _, _ = checkpoint_utils.load_checkpoint( + cfg_mock.checkpoint, trainer + ) self.assertTrue( "--finetune-from-model can not be set together with either --reset-optimizer" @@ -168,8 +185,6 @@ def test_finetune_from_model(self): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" def mock_finetune_exist(path): if path == from_model_path: @@ -180,7 +195,9 @@ def mock_finetune_exist(path): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, @@ -197,8 +214,6 @@ def test_finetune_from_model_resume(self): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" # launch second time # both restore_file=checkpoint_last.pt and finetune_from_model are set @@ -211,7 +226,9 @@ def mock_finetune_exist(path): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, diff --git a/tests/utils.py b/tests/utils.py index 91feca6b2a..a145aa587d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.tasks import FairseqTask, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask from fairseq_cli import generate, interactive, preprocess, train, validate From f6677b675524d22a4df9f2304f63ee382594c9e3 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Tue, 20 Oct 2020 13:43:15 -0700 Subject: [PATCH 229/707] fix #2761, #2760 Summary: Fixes issue #2761 and #2760 args from registries were not added to argparse Reviewed By: myleott Differential Revision: D24422792 fbshipit-source-id: c8a8e835965da5c4f527bd589bd621371441e7fe --- fairseq/dataclass/utils.py | 7 +++++-- fairseq/options.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index bcfe23294a..f4431db82a 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import ast -from argparse import ArgumentParser, Namespace +from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING, dataclass from enum import Enum from typing import Any, Dict, List, Optional @@ -204,7 +204,10 @@ def get_kwargs_from_dc( continue if delete_default: del kwargs["default"] - parser.add_argument(field_name, **kwargs) + try: + parser.add_argument(field_name, **kwargs) + except ArgumentError: + pass def _set_legacy_defaults(args, cls): diff --git a/fairseq/options.py b/fairseq/options.py index 6bc526ce0e..0315d8cd3e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -150,6 +150,9 @@ def parse_args_and_arch( cls = REGISTRY["registry"][choice] if hasattr(cls, "add_args"): cls.add_args(parser) + elif hasattr(cls, "__dataclass"): + gen_parser_from_dataclass(parser, cls.__dataclass()) + if hasattr(args, "task"): from fairseq.tasks import TASK_REGISTRY From 8248a12a6433f45b1757fac206f453f24b88403a Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 20 Oct 2020 15:41:33 -0700 Subject: [PATCH 230/707] Upgrade args: max_sentences to batch_size (#2754) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Upgrade args: `max_sentences` to `batch_size` ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2754 Reviewed By: alexeib Differential Revision: D24418980 Pulled By: myleott fbshipit-source-id: 5269c2fc8c434513cc5114f7e9d2eccd0c553fbd --- fairseq/models/roberta/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index d1a6319630..1d934b4fff 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -250,6 +250,9 @@ def from_pretrained( load_checkpoint_heads=True, **kwargs, ) + cls.upgrade_args(x["args"]) + + logger.info(x["args"]) return RobertaHubInterface(x["args"], x["task"], x["models"][0]) def upgrade_state_dict_named(self, state_dict, name): From d6f2c907be8f7351195184981e4f3a9e003a4258 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 20 Oct 2020 15:43:26 -0700 Subject: [PATCH 231/707] remove unnecessary logging configs (#2733) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? It's sufficient to set logging.basicConfig in the most outside calling code like train.py or generate.py. Actually the setting of logging.basicConfig () (like [here](https://github.com/pytorch/fairseq/blob/master/fairseq_cli/generate.py#L54)) will been overwritten if logging.basicConfig is set in the inner part of the whole code. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2733 Reviewed By: alexeib Differential Revision: D24418987 Pulled By: myleott fbshipit-source-id: 862d200023357de8947799f380e513f4c411b143 --- fairseq/data/audio/speech_to_text_dataset.py | 5 ----- fairseq/tasks/speech_to_text.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index aefe95658d..6e5fd70e3c 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -23,11 +23,6 @@ from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform -logging.basicConfig( - format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, -) logger = logging.getLogger(__name__) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index c200bb1407..8fb341b0c5 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -16,11 +16,6 @@ from fairseq.tasks import LegacyFairseqTask, register_task -logging.basicConfig( - format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, -) logger = logging.getLogger(__name__) From 9b0611e6786a048b6c4a70e36051027671d951a7 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 20 Oct 2020 15:43:36 -0700 Subject: [PATCH 232/707] Fix torch.hub (fixes #2756) (#2762) Summary: Typically `torch.hub.load(...)` doesn't call `pip install`, so our Cython components never get built. We have a hack in our hubconf that builds these components by running the equivalent of `python setup.py build_ext --inplace` using the setuptools sandbox: https://github.com/pytorch/fairseq/blob/f6677b675524d22a4df9f2304f63ee382594c9e3/hubconf.py#L52-L55. Unfortunately, this sandbox gets mad if you modify the filesystem, which is what this recent change does: https://github.com/pytorch/fairseq/blob/f6677b675524d22a4df9f2304f63ee382594c9e3/setup.py#L203-L205. Combined this breaks torch.hub. The solution is that when we're doing `build_ext`, don't setup the symlinks. This is fine, since `build_ext` doesn't actually build a package, so we don't care about including config or examples. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2762 Reviewed By: alexeib Differential Revision: D24430228 Pulled By: myleott fbshipit-source-id: e05d075a003ddfde196cb8a86b32882d73808015 --- setup.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 54c752d257..edc0196714 100644 --- a/setup.py +++ b/setup.py @@ -201,12 +201,14 @@ def get_files(path, relative_to="fairseq"): try: # symlink config and examples into fairseq package so package_data accepts them - os.symlink(os.path.join("..", "config"), "fairseq/config") - os.symlink(os.path.join("..", "examples"), "fairseq/examples") + if "build_ext" not in sys.argv[1:]: + os.symlink(os.path.join("..", "config"), "fairseq/config") + os.symlink(os.path.join("..", "examples"), "fairseq/examples") package_data = { "fairseq": get_files("fairseq/config") + get_files("fairseq/examples"), } do_setup(package_data) finally: - os.unlink("fairseq/config") - os.unlink("fairseq/examples") + if "build_ext" not in sys.argv[1:]: + os.unlink("fairseq/config") + os.unlink("fairseq/examples") From eece1d7082caf84a957c6b9685a43ee5d2beefe2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 21 Oct 2020 07:45:57 -0700 Subject: [PATCH 233/707] More detailed error message for data iterator size mismatch (#2768) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2768 Reviewed By: vimalmanohar Differential Revision: D24446804 Pulled By: myleott fbshipit-source-id: 19220f2fd3e3db49f7528f6fb17188834b09646f --- fairseq/data/iterators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 1c3ffe94b1..15796234db 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -60,7 +60,11 @@ def __iter__(self): if self.n >= self.total: raise RuntimeError( "Mismatch between actual and expected iterable length. " - "Please report this to the fairseq developers." + "This may be caused by resuming training from a checkpoint using " + "a different number of GPUs, in which case you can try the " + "--reset-dataloader option. Alternatively you may have a train or " + "validation set that is smaller than the number of GPUs. If none " + "of these apply, please report this to the fairseq developers." ) self.n += 1 yield x From ee450dde198e404ae897acd8854665ed8719801e Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Wed, 21 Oct 2020 08:09:17 -0700 Subject: [PATCH 234/707] S2T multilingual example + bug fix Summary: * S2T multilingual example on MuST-C * A bug fix for `speech_to_text_dataset` (for multilingual setting) Reviewed By: jmp84 Differential Revision: D24339394 fbshipit-source-id: ef0c0be08137884897b532e45ebc56551d20be48 --- examples/speech_to_text/README.md | 143 +++++++++++++----- examples/speech_to_text/data_utils.py | 67 ++++++-- examples/speech_to_text/prep_mustc_data.py | 113 ++++++++------ fairseq/data/audio/speech_to_text_dataset.py | 6 +- .../models/speech_to_text/s2t_transformer.py | 7 +- 5 files changed, 240 insertions(+), 96 deletions(-) diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 62fd2700b8..4030af0144 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -1,11 +1,15 @@ # Speech-to-Text (S2T) Modeling +[https://arxiv.org/abs/2010.05171](https://arxiv.org/abs/2010.05171) + +Examples for speech recognition (ASR) and speech-to-text translation (ST) with fairseq. + ## Data Preparation S2T modeling data consists of source speech features, target text and other optional information (source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files to store these information. Each data field is represented by a column in the TSV file. -Unlike text token embeddings, speech features (e.g. log mel-filter banks) are usually fixed +Unlike text token embeddings, speech features (e.g. log mel-scale filter banks) are usually fixed during model training and can be pre-computed. The manifest file contains the path to either the feature file in NumPy format or the WAV/FLAC audio file. For the latter, features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed @@ -23,10 +27,9 @@ It requires arguments `--task speech_to_text` and `--arch `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. -You may want to update it accordingly when using more than 1 GPU. +Example for multilingual models: +```bash +fairseq-train ${MUSTC_ROOT} \ + --train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \ + --valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \ + --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --task speech_to_text \ + --arch s2t_transformer_s --criterion label_smoothed_cross_entropy --report-accuracy --ignore-prefix-size 1 \ + --max-update 100000 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 \ + --seed 1 --update-freq 8 --load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR +for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. ###### Inference & Evaluation Average the last 10 checkpoints and evaluate on the `tst-COMMON` split: ```bash CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py \ - --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +python scripts/average_checkpoints.py --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu + +# For multilingual models +python scripts/average_checkpoints.py --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +for LANG in de nl es fr it pt ro ru; do + fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_${LANG}_st --task speech_to_text --prefix-size 1 \ + --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu +done ``` +For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. ###### Result -| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | -|---|---|---|---|---|---|---|---|---|---| -| s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 | +| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | +|---|---|---|---|---|---|---|---|---|---|---| +| Bilingual | s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 | +| Multilingual | s2t_transformer_m | 76M | 24.5 | 28.6 | 28.2 | 34.9 | 24.6 | 31.1 | 23.8 | 16.0 | ## Example 3: ST on CoVoST +We replicate the experiments in +[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310). + #### Data Preparation -Download and preprocess CoVoST data with +Download and preprocess [CoVoST (version 2)](https://arxiv.org/abs/2007.10310) data with ```bash # En ASR -python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \ - --vocab-type char --src-lang en +python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} --vocab-type char --src-lang en # ST -python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \ - --vocab-type char --src-lang fr --tgt-lang en +python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} --vocab-type char \ + --src-lang fr --tgt-lang en ``` where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files. @@ -158,8 +216,8 @@ You may want to update it accordingly when using more than 1 GPU. ###### Inference & Evaluation ```bash CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py \ - --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +python scripts/average_checkpoints.py --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \ --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct @@ -186,8 +244,8 @@ You may want to update it accordingly when using more than 1 GPU. Average the last 10 checkpoints and evaluate on test split: ```bash CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py \ - --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +python scripts/average_checkpoints.py --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu ``` @@ -214,3 +272,12 @@ Please cite as: year = {2019}, } ``` + +## More Paper Code +The following papers also base their experiments on fairseq S2T. We are adding more examples for replication. + +- [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474) +- [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124) +- [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490) +- [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320) +- [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 1efeff4df1..083d7316cd 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List import numpy as np +import pandas as pd import sentencepiece as sp from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN @@ -27,10 +28,7 @@ def gen_vocab( - input_path: str, - output_path_prefix: str, - model_type="bpe", - vocab_size=1000, + input_path: str, output_path_prefix: str, model_type="bpe", vocab_size=1000, ): # Train SentencePiece Model arguments = [ @@ -128,29 +126,52 @@ def get_zip_manifest(zip_root, zip_filename): def gen_config_yaml( - data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb" + data_root, + spm_filename, + yaml_filename="config.yaml", + specaugment_policy="lb", + prepend_tgt_lang_tag=False, + sampling_alpha=1.0, ): - assert specaugment_policy in {"lb", "ld"} data_root = op.abspath(data_root) writer = S2TDataConfigWriter(op.join(data_root, yaml_filename)) writer.set_audio_root(op.abspath(data_root)) writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) writer.set_input_channels(1) writer.set_input_feat_per_channel(80) - if specaugment_policy == "lb": - writer.set_specaugment_lb_policy() - else: - writer.set_specaugment_ld_policy() + specaugment_setters = { + "lb": writer.set_specaugment_lb_policy, + "ld": writer.set_specaugment_ld_policy, + "sm": writer.set_specaugment_sm_policy, + "ss": writer.set_specaugment_ss_policy, + } + assert specaugment_policy in specaugment_setters + specaugment_setters[specaugment_policy]() writer.set_bpe_tokenizer( { "bpe": "sentencepiece", "sentencepiece_model": op.join(data_root, spm_filename), } ) + if prepend_tgt_lang_tag: + writer.set_prepend_tgt_lang_tag(True) + writer.set_sampling_alpha(sampling_alpha) writer.set_feature_transforms("_train", ["specaugment"]) writer.flush() +def load_df_from_tsv(path: str): + return pd.read_csv( + path, + sep="\t", + header=0, + encoding="utf-8", + escapechar="\\", + quoting=csv.QUOTE_NONE, + na_filter=False, + ) + + def save_df_to_tsv(dataframe, path): dataframe.to_csv( path, @@ -247,6 +268,26 @@ def set_specaugment_ld_policy(self): time_mask_p=1.0, ) + def set_specaugment_sm_policy(self): + self.set_specaugment( + time_wrap_w=0, + freq_mask_n=2, + freq_mask_f=15, + time_mask_n=2, + time_mask_t=70, + time_mask_p=0.2, + ) + + def set_specaugment_ss_policy(self): + self.set_specaugment( + time_wrap_w=0, + freq_mask_n=2, + freq_mask_f=27, + time_mask_n=2, + time_mask_t=70, + time_mask_p=0.2, + ) + def set_input_channels(self, input_channels=1): self.config["input_channels"] = input_channels @@ -260,3 +301,9 @@ def set_feature_transforms(self, split, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} self.config["transforms"][split] = transforms + + def set_prepend_tgt_lang_tag(self, flag=True): + self.config["prepend_tgt_lang_tag"] = flag + + def set_sampling_alpha(self, sampling_alpha=1.0): + self.config["sampling_alpha"] = sampling_alpha diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 5593d2e7e2..59a42803f9 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -22,6 +22,7 @@ gen_config_yaml, gen_vocab, get_zip_manifest, + load_df_from_tsv, save_df_to_tsv, ) from torch import Tensor @@ -33,7 +34,6 @@ MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] -TASKS = ["asr", "st"] class MUSTC(Dataset): @@ -123,77 +123,102 @@ def process(args): zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}") # Generate TSV manifest print("Generating manifest...") - train_text = {task: [] for task in TASKS} + train_text = [] for split in MUSTC.SPLITS: is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} - text = {task: [] for task in TASKS} dataset = MUSTC(args.data_root, lang, split) for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): manifest["id"].append(utt_id) manifest["audio"].append(zip_manifest[utt_id]) duration_ms = int(wav.size(1) / sr * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - text["asr"].append(src_utt) - text["st"].append(tgt_utt) + manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) manifest["speaker"].append(speaker_id) if is_train_split: - for task in TASKS: - train_text[task].extend(text[task]) - for task in TASKS: - manifest["tgt_text"] = text[task] - df = pd.DataFrame.from_dict(manifest) - df = filter_manifest_df(df, is_train_split=is_train_split) - save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv")) + train_text.extend(manifest["tgt_text"]) + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, op.join(cur_root, f"{split}_{args.task}.tsv")) # Generate vocab - for task in TASKS: - vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size - if task == "st": - vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size - vocab_size_str = "" if vocab_type == "char" else str(vocab_size) - spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}" - with NamedTemporaryFile(mode="w") as f: - for t in train_text[task]: - f.write(t + "\n") - gen_vocab( - f.name, - op.join(cur_root, spm_filename_prefix), - vocab_type, - vocab_size, - ) - # Generate config YAML - gen_config_yaml( - cur_root, - spm_filename_prefix + ".model", - yaml_filename=f"config_{task}.yaml", - specaugment_policy="lb", + v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for t in train_text: + f.write(t + "\n") + gen_vocab( + f.name, + op.join(cur_root, spm_filename_prefix), + args.vocab_type, + args.vocab_size, ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + ) # Clean up shutil.rmtree(feature_root) +def process_joint(args): + assert all( + op.isdir(op.join(args.data_root, f"en-{lang}")) for lang in MUSTC.LANGUAGES + ), "do not have downloaded data available for all 8 languages" + cur_root = args.data_root + # Generate vocab + vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for lang in MUSTC.LANGUAGES: + tsv_path = op.join(cur_root, f"en-{lang}", f"train_{args.task}.tsv") + df = load_df_from_tsv(tsv_path) + for t in df["tgt_text"]: + f.write(t + "\n") + gen_vocab( + f.name, + op.join(cur_root, spm_filename_prefix), + args.vocab_type, + args.vocab_size, + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + prepend_tgt_lang_tag=(args.task == "st"), + ) + # Make symbolic links to manifests + for lang in MUSTC.LANGUAGES: + for split in MUSTC.SPLITS: + src_path = op.join(cur_root, f"en-{lang}", f"{split}_{args.task}.tsv") + desc_path = op.join(cur_root, f"{split}_{lang}_{args.task}.tsv") + if not op.islink(desc_path): + os.symlink(src_path, desc_path) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--data-root", "-d", required=True, type=str) parser.add_argument( - "--asr-vocab-type", - default="unigram", - required=True, - type=str, - choices=["bpe", "unigram", "char"], - ), - parser.add_argument( - "--st-vocab-type", + "--vocab-type", default="unigram", required=True, type=str, choices=["bpe", "unigram", "char"], ), - parser.add_argument("--asr-vocab-size", default=5000, type=int) - parser.add_argument("--st-vocab-size", default=8000, type=int) + parser.add_argument("--vocab-size", default=8000, type=int) + parser.add_argument("--task", type=str, choices=["asr", "st"]) + parser.add_argument("--joint", action="store_true", help="") args = parser.parse_args() - process(args) + if args.joint: + process_joint(args) + else: + process(args) if __name__ == "__main__": diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 6e5fd70e3c..39d22c7a5e 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -68,7 +68,7 @@ def bpe_tokenizer(self) -> Dict: a dictionary with `bpe` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" - return self.config.get("bpe_tokenizer", None) + return self.config.get("bpe_tokenizer", {"bpe": None}) @property def prepend_tgt_lang_tag(self) -> bool: @@ -246,10 +246,10 @@ def __init__( assert (tgt_dict is None and tgt_texts is None) or ( tgt_dict is not None and tgt_texts is not None ) - self.tgt_dict = tgt_dict - self.check_tgt_lang_tag() self.src_texts, self.tgt_texts = src_texts, tgt_texts self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.tgt_dict = tgt_dict + self.check_tgt_lang_tag() self.ids = ids self.shuffle = data_cfg.shuffle if is_train_split else False diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 8e48964f79..fc2f14fea6 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -180,6 +180,11 @@ def add_args(parser): action="store_true", help="apply layernorm before each decoder block", ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) parser.add_argument( "--layernorm-embedding", action="store_true", @@ -410,7 +415,7 @@ def base_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", False + args, "share_decoder_input_output_embed", True ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False From 0f44e89c383f45ea455bddc3c44ec950b3df91f1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 21 Oct 2020 15:30:18 -0700 Subject: [PATCH 235/707] Fix Latent Depth args (#1365) Summary: Args should be registered in the Model rather than modules Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1365 Reviewed By: pipibjc Differential Revision: D24453007 Pulled By: myleott fbshipit-source-id: d22b0d86a3c940456b394b005acab4bb6a3f5bed --- .../models/latent_multilingual_transformer.py | 16 ++++++++++++++++ .../src/models/latent_transformer.py | 14 ++++++++++++-- .../latent_depth/src/modules/latent_layers.py | 17 +++-------------- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/examples/latent_depth/src/models/latent_multilingual_transformer.py b/examples/latent_depth/src/models/latent_multilingual_transformer.py index 9e075fcc47..12b7e67d03 100644 --- a/examples/latent_depth/src/models/latent_multilingual_transformer.py +++ b/examples/latent_depth/src/models/latent_multilingual_transformer.py @@ -21,6 +21,22 @@ class LatentMultilingualTransformerModel(MultilingualTransformerModel): (https://arxiv.org/abs/2009.13102). """ + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + MultilingualTransformerModel.add_args(parser) + parser.add_argument( + '--soft-select', + action='store_true', + help='use soft samples in training an inference', + ) + parser.add_argument( + '--sampling-tau', + type=float, + default=5., + help='sampling temperature', + ) + @classmethod def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): if is_encoder: diff --git a/examples/latent_depth/src/models/latent_transformer.py b/examples/latent_depth/src/models/latent_transformer.py index db30239bff..6a825301a4 100644 --- a/examples/latent_depth/src/models/latent_transformer.py +++ b/examples/latent_depth/src/models/latent_transformer.py @@ -23,7 +23,12 @@ def __init__(self, args, dictionary, embed_tokens, num_logits=1): self.num_logits = num_logits self.num_layers = args.encoder_layers super().__init__(args, dictionary, embed_tokens) - self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) + self.layer_select = LayerSelect( + num_layers=self.num_layers, + num_logits=self.num_logits, + soft_select=getattr(args, "soft_select", False), + sampling_tau=getattr(args, "sampling_tau", 5.), + ) self.lang_idx = None self.layers = nn.ModuleList( [self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)] @@ -74,7 +79,12 @@ def __init__( super().__init__( args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn ) - self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) + self.layer_select = LayerSelect( + num_layers=self.num_layers, + num_logits=self.num_logits, + soft_select=getattr(args, "soft_select", False), + sampling_tau=getattr(args, "sampling_tau", 5.), + ) self.lang_idx = None self.layers = nn.ModuleList( [ diff --git a/examples/latent_depth/src/modules/latent_layers.py b/examples/latent_depth/src/modules/latent_layers.py index a2b8ab4476..2be05d5535 100644 --- a/examples/latent_depth/src/modules/latent_layers.py +++ b/examples/latent_depth/src/modules/latent_layers.py @@ -12,28 +12,17 @@ class LayerSelect(nn.Module): either (soft) weighting or (hard) selection of residual connection. https://arxiv.org/abs/2009.13102 """ - - def __init__(self, num_layers, num_logits, args): + def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.): super(LayerSelect, self).__init__() - self.args = args self.layer_logits = torch.nn.Parameter( torch.Tensor(num_logits, num_layers), requires_grad=True, ) - self.hard_select = not (hasattr(args, "soft_select") and args.soft_select) - self.tau = getattr(args, "sampling_tau", 5) + self.hard_select = not soft_select + self.tau = sampling_tau self.detach_grad = False self.layer_samples = [None] * num_logits - @staticmethod - def add_args(parser): - parser.add_argument( - "--soft-select", - action="store_true", - help="use soft samples in training an inference", - ) - parser.add_argument("--sampling-tau", type=float, help="sampling temperature") - def sample(self, logit_idx): """To leverage the efficiency of distributed training, samples for all layers are computed at once for each logit_idx. Logits are parameters From 43c69a7666b59d1d0c8a30c7acefec9822fedcaf Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 22 Oct 2020 06:24:54 -0700 Subject: [PATCH 236/707] Fix deprecated usage of nonzero() (#1364) Summary: PyTorch requires the `as_tuple` argument now, otherwise it prints warnings. Let's just fix this everywhere Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1364 Reviewed By: edunov Differential Revision: D24452587 Pulled By: myleott fbshipit-source-id: 7e6d424792ffec74a6197b2a266600cb13f24770 --- fairseq/data/mask_tokens_dataset.py | 2 +- fairseq/models/bart/hub_interface.py | 2 +- fairseq/models/roberta/hub_interface.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 8ea86245f7..9e2c7119d8 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -112,7 +112,7 @@ def __getitem__(self, index: int): if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) - word_begins_idx = word_begins_mask.nonzero().view(-1) + word_begins_idx = word_begins_mask.nonzero(as_tuple=False).view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 6a520cb980..8ed91e651d 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -86,7 +86,7 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens[1:] # remove eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] - sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = np.split(tokens, doc_mask.nonzero(as_tuple=False)[0] + 1) sentences = [ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences ] diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 0c723f06dd..d6322c30e8 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -71,7 +71,7 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens[1:] # remove eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] - sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = np.split(tokens, doc_mask.nonzero(as_tuple=False)[0] + 1) sentences = [ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences ] @@ -173,7 +173,7 @@ def fill_mask(self, masked_input: str, topk: int = 5): add_if_not_exist=False, ) - masked_index = (tokens == self.task.mask_idx).nonzero() + masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False) if tokens.dim() == 1: tokens = tokens.unsqueeze(0) From 751bcbfcb939b777e61af251d1fce5d4a4dc1f12 Mon Sep 17 00:00:00 2001 From: Pavel Soriano Date: Thu, 22 Oct 2020 06:25:06 -0700 Subject: [PATCH 237/707] Changed EnvironmentError to RuntimeError in get_from_cache (#2767) Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? No need I believe - [x] Did you write any new necessary tests? No ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2724 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Yes! It is not a big PR at all but it allowed me to familiarize with the caching/downloading logic used in fairseq (which is very similar to that used in pytorch/transformers) Pull Request resolved: https://github.com/pytorch/fairseq/pull/2767 Reviewed By: edunov Differential Revision: D24456055 Pulled By: myleott fbshipit-source-id: bc634a9b97f957ecc5a8da57b112ff892e492107 --- fairseq/file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py index 0a94ac7112..ec6de37f77 100644 --- a/fairseq/file_utils.py +++ b/fairseq/file_utils.py @@ -286,7 +286,7 @@ def get_from_cache(url, cache_dir=None): etag = None else: etag = response.headers.get("ETag") - except EnvironmentError: + except RuntimeError: etag = None filename = url_to_filename(url, etag) From 18cadab1d0fc6a98988a17e92683f8b83b03a177 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 22 Oct 2020 07:07:27 -0700 Subject: [PATCH 238/707] =?UTF-8?q?support=20new=20cfg=20based=20models;?= =?UTF-8?q?=20make=20sure=20--normalize=20is=20consistent=20in=20=E2=80=A6?= =?UTF-8?q?=20(#1370)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: support new cfg based models; make sure --normalize is consistent in infer with the model Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1370 Reviewed By: myleott Differential Revision: D24467698 Pulled By: alexeib fbshipit-source-id: 056b3608e3c1fe8acdb3e45e0306de5d874cb4d1 --- examples/speech_recognition/infer.py | 45 +++++++++++++++------------- fairseq/tasks/audio_pretraining.py | 15 ++++++---- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index a197ab5a63..cc04225035 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -18,6 +18,7 @@ import torch from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data.data_utils import post_process +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging.meters import StopwatchMeter, TimeMeter @@ -203,18 +204,25 @@ def load_models_and_criterions( else: state = model_state - args = state["args"] + if "cfg" in state: + cfg = state["cfg"] + else: + cfg = convert_namespace_to_omegaconf(state["args"]) + if task is None: - task = tasks.setup_task(args) - model = task.build_model(args) + if hasattr(cfg.task, 'data'): + cfg.task.data = data_path + task = tasks.setup_task(cfg.task) + + model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=True) models.append(model) - criterion = task.build_criterion(args) + criterion = task.build_criterion(cfg.criterion) if "criterion" in state: criterion.load_state_dict(state["criterion"], strict=True) criterions.append(criterion) - return models, criterions, args + return models, criterions, task def optimize_models(args, use_cuda, models): @@ -255,29 +263,15 @@ def main(args, task=None, model_state=None): use_cuda = torch.cuda.is_available() and not args.cpu - if task is None: - # Load dataset splits - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) - - logger.info( - "| {} {} {} examples".format( - args.data, args.gen_subset, len(task.dataset(args.gen_subset)) - ) - ) - - # Set dictionary - tgt_dict = task.target_dictionary logger.info("| decoding with criterion {}".format(args.criterion)) # Load ensemble - if args.load_emissions: models, criterions = [], [] else: logger.info("| loading model(s) from {}".format(args.path)) - models, criterions, _ = load_models_and_criterions( + models, criterions, task = load_models_and_criterions( args.path, data_path=args.data, arg_overrides=eval(args.model_overrides), # noqa @@ -286,6 +280,17 @@ def main(args, task=None, model_state=None): ) optimize_models(args, use_cuda, models) + # Load dataset splits + task.load_dataset(args.gen_subset) + # Set dictionary + tgt_dict = task.target_dictionary + + logger.info( + "| {} {} {} examples".format( + args.data, args.gen_subset, len(task.dataset(args.gen_subset)) + ) + ) + # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": trans = criterions[0].asg.trans.data diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index a831ad6ee8..298bdbe938 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -68,9 +68,9 @@ def add_args(parser): help="extension of the label file to load, if any", ) - def __init__(self, args, source_dictionary=None): + def __init__(self, args, source_dictionary=None, target_dictionary=None): super().__init__(args) - self._target_dictionary = None + self._target_dictionary = target_dictionary self._source_dictionary = source_dictionary self.is_ctc = args.criterion == "ctc" @@ -81,7 +81,14 @@ def setup_task(cls, args, **kwargs): Args: args (omegaconf.DictConfig): parsed command-line arguments """ - return cls(args) + + if args.labels: + dict_path = os.path.join(args.data, f"dict.{args.labels}.txt") + target_dictionary = Dictionary.load(dict_path) + else: + target_dictionary = None + + return cls(args, target_dictionary=target_dictionary) def load_dataset(self, split, **kwargs): """Load a given dataset split. @@ -101,8 +108,6 @@ def load_dataset(self, split, **kwargs): ) if self.args.labels: - dict_path = os.path.join(self.args.data, f"dict.{self.args.labels}.txt") - self._target_dictionary = Dictionary.load(dict_path) label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}") labels = [] with open(label_path, "r") as f: From 31c23baafcac94733f31613c7431997d787a204d Mon Sep 17 00:00:00 2001 From: Chau Tran Date: Thu, 22 Oct 2020 11:29:08 -0700 Subject: [PATCH 239/707] Fix fairseq/criss README Summary: Add requirements, fix wrong command Reviewed By: tangyuq Differential Revision: D24452748 fbshipit-source-id: 4837610ea7e5b5df8caecc685226080cafddb3e0 --- examples/criss/README.md | 16 +++++++++++++--- .../criss/download_and_preprocess_tatoeba.sh | 11 ++++++++++- examples/criss/mining/mine_example.sh | 2 +- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/criss/README.md b/examples/criss/README.md index a534056254..4689ed7c10 100644 --- a/examples/criss/README.md +++ b/examples/criss/README.md @@ -6,14 +6,22 @@ https://arxiv.org/pdf/2006.09526.pdf CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time. +## Requirements: + +* faiss: https://github.com/facebookresearch/faiss +* mosesdecoder: https://github.com/moses-smt/mosesdecoder +* flores: https://github.com/facebookresearch/flores +* LASER: https://github.com/facebookresearch/LASER + ## Unsupervised Machine Translation ##### 1. Download and decompress CRISS checkpoints ``` cd examples/criss -wget https://dl.fbaipublicfiles.com/fairseq/models/criss/criss_checkpoints.tar.gz +wget https://dl.fbaipublicfiles.com/criss/criss_3rd_checkpoints.tar.gz tar -xf criss_checkpoints.tar.gz ``` ##### 2. Download and preprocess Flores test dataset +Make sure to run all scripts from examples/criss directory ``` bash download_and_preprocess_flores_test.sh ``` @@ -35,9 +43,11 @@ bash sentence_retrieval/sentence_retrieval_tatoeba.sh ``` ## Mining -##### 1. Mine pseudo-parallel +##### 1. Install faiss +Follow instructions on https://github.com/facebookresearch/faiss/blob/master/INSTALL.md +##### 2. Mine pseudo-parallel data between Kazakh and English ``` -bash sentence_retrieval/sentence_retrieval_tatoeba.sh +bash mining/mine_example.sh ``` ## Citation diff --git a/examples/criss/download_and_preprocess_tatoeba.sh b/examples/criss/download_and_preprocess_tatoeba.sh index 4579d65aba..7ed64f017d 100644 --- a/examples/criss/download_and_preprocess_tatoeba.sh +++ b/examples/criss/download_and_preprocess_tatoeba.sh @@ -10,7 +10,16 @@ DATA=data_tmp SPM_MODEL=criss_checkpoints/sentence.bpe.model DICT=criss_checkpoints/dict.txt -git clone https://github.com/facebookresearch/LASER +if [[ -f flores ]]; then + echo "flores already cloned" +else + git clone https://github.com/facebookresearch/flores +fi +if [[ -f LASER ]]; then + echo "LASER already cloned" +else + git clone https://github.com/facebookresearch/LASER +fi mkdir -p data_tmp declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn") for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do diff --git a/examples/criss/mining/mine_example.sh b/examples/criss/mining/mine_example.sh index 92b5291338..ace995ac44 100644 --- a/examples/criss/mining/mine_example.sh +++ b/examples/criss/mining/mine_example.sh @@ -7,7 +7,7 @@ # source_lang=kk_KZ target_lang=en_XX -MODEL=criss_checkpoints/criss.2nd.pt +MODEL=criss_checkpoints/criss.3rd.pt SPM=criss_checkpoints/sentence.bpe.model SPLIT=test LANG_DICT=criss_checkpoints/lang_dict.txt From 11aaffdd18ab610e91ea0f4d394271602853ef04 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 22 Oct 2020 12:03:15 -0700 Subject: [PATCH 240/707] rm FairseqModel::upgrade_args, it's not needed anymore (#1363) Summary: Tests seems to pass without it, so let's remove it Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1363 Reviewed By: alexeib Differential Revision: D24452369 Pulled By: myleott fbshipit-source-id: 186933ff3ee16be61c77a9581658db8e853c1baa --- fairseq/models/fairseq_model.py | 8 -------- fairseq/models/roberta/model.py | 1 - 2 files changed, 9 deletions(-) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 3ebb30e3ad..15c2c4ab2e 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -237,11 +237,6 @@ def apply_prepare_for_tpu_(module): self.apply(apply_prepare_for_tpu_) - @classmethod - def upgrade_args(cls, args): - if hasattr(args, "max_sentences") and not hasattr(args, "batch_size"): - args.batch_size = args.max_sentences - @classmethod def from_pretrained( cls, @@ -280,9 +275,6 @@ def from_pretrained( archive_map=cls.hub_models(), **kwargs, ) - - cls.upgrade_args(x["args"]) - logger.info(x["args"]) return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 1d934b4fff..5c9f92a149 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -250,7 +250,6 @@ def from_pretrained( load_checkpoint_heads=True, **kwargs, ) - cls.upgrade_args(x["args"]) logger.info(x["args"]) return RobertaHubInterface(x["args"], x["task"], x["models"][0]) From f0fcb55d5b2617371cd1cf5c2d3712ea4bd79122 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 22 Oct 2020 12:18:07 -0700 Subject: [PATCH 241/707] fix #2764 (#1368) Summary: fix interactive.py + add args from tasks before registries (where we catch errors) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1368 Reviewed By: myleott Differential Revision: D24462871 Pulled By: alexeib fbshipit-source-id: 307b829c935aa5061bdd79d8cc339eaf87fd8845 --- fairseq/options.py | 20 ++++++++++---------- fairseq_cli/interactive.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index 0315d8cd3e..58e5e46190 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -141,6 +141,16 @@ def parse_args_and_arch( else: raise RuntimeError() + if hasattr(args, "task"): + from fairseq.tasks import TASK_REGISTRY + + TASK_REGISTRY[args.task].add_args(parser) + if getattr(args, "use_bmuf", False): + # hack to support extra args for block distributed data parallelism + from fairseq.optim.bmuf import FairseqBMUF + + FairseqBMUF.add_args(parser) + # Add *-specific args to parser. from fairseq.registry import REGISTRIES @@ -153,16 +163,6 @@ def parse_args_and_arch( elif hasattr(cls, "__dataclass"): gen_parser_from_dataclass(parser, cls.__dataclass()) - if hasattr(args, "task"): - from fairseq.tasks import TASK_REGISTRY - - TASK_REGISTRY[args.task].add_args(parser) - if getattr(args, "use_bmuf", False): - # hack to support extra args for block distributed data parallelism - from fairseq.optim.bmuf import FairseqBMUF - - FairseqBMUF.add_args(parser) - # Modify the parser a second time, since defaults may have been reset if modify_parser is not None: modify_parser(parser) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index ddd2617c3d..2feff950ab 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -318,7 +318,7 @@ def decode_fn(x): def cli_main(): parser = options.get_interactive_generation_parser() args = options.parse_args_and_arch(parser) - distributed_utils.call_main(args, main) + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) if __name__ == "__main__": From b8a938e96e08e5b39deb585d6cc6690de062dd4d Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 22 Oct 2020 12:44:02 -0700 Subject: [PATCH 242/707] BART hub fixes + improvements (#1342) Summary: - Make BART hub interface extend from GeneratorHubInterface (fixes #1748) - Add mask filling interface for BART Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1342 Reviewed By: ngoyal2707 Differential Revision: D24264195 Pulled By: myleott fbshipit-source-id: 0885f90a54fabe1672b1bfe137dfbccbc5d25d0e --- examples/bart/README.md | 17 ++++ fairseq/models/bart/hub_interface.py | 124 +++++++++++++-------------- fairseq/sequence_generator.py | 5 +- fairseq/tasks/denoising.py | 40 +++++++++ 4 files changed, 119 insertions(+), 67 deletions(-) diff --git a/examples/bart/README.md b/examples/bart/README.md index 394503f29f..76857a99a2 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -131,6 +131,23 @@ bart.cuda() bart.predict('new_task', tokens) ``` +#### Filling masks: + +BART can be used to fill multiple `` tokens in the input. +```python +bart = torch.hub.load('pytorch/fairseq', 'bart.base') +bart.eval() +bart.fill_mask('The cat on the .', topk=3, beam=10) +# [('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))] +``` + +Note that by default we enforce the output length to match the input length. +This can be disabled by setting ``match_source_len=False``: +``` +bart.fill_mask('The cat on the .', topk=3, beam=10, match_source_len=False) +# [('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))] +``` + #### Evaluating the `bart.large.mnli` model: Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set. diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 8ed91e651d..819ea8eeda 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -5,7 +5,7 @@ import copy import logging -from typing import List +from typing import Dict, List import numpy as np import torch @@ -13,39 +13,22 @@ import torch.nn.functional as F from fairseq import utils from fairseq.data import encoders +from fairseq.hub_utils import GeneratorHubInterface from omegaconf import open_dict logger = logging.getLogger(__name__) -class BARTHubInterface(nn.Module): +class BARTHubInterface(GeneratorHubInterface): """A simple PyTorch Hub interface to BART. Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart """ def __init__(self, cfg, task, model): - super().__init__() - self.cfg = cfg - self.task = task - self.model = model - - self.bpe = encoders.build_bpe(cfg.bpe) - - self.max_positions = min( - utils.resolve_max_positions( - self.task.max_positions(), - self.model.max_positions(), - ) - ) - - # this is useful for determining the device - self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) - - @property - def device(self): - return self._float_tensor.device + super().__init__(cfg, task, [model]) + self.model = self.models[0] def encode( self, sentence: str, *addl_sentences, no_separator=True @@ -70,8 +53,8 @@ def encode( [0, 8331, 2] """ tokens = self.bpe.encode(sentence) - if len(tokens.split(" ")) > self.max_positions - 2: - tokens = " ".join(tokens.split(" ")[: self.max_positions - 2]) + if len(tokens.split(" ")) > min(self.max_positions) - 2: + tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2]) bpe_sentence = " " + tokens + " " for s in addl_sentences: bpe_sentence += " " if not no_separator else "" @@ -104,50 +87,28 @@ def _build_sample(self, src_tokens: List[torch.LongTensor]): sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample) return sample - def sample( - self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs - ) -> str: - input = [self.encode(sentence) for sentence in sentences] - hypos = self.generate(input, beam, verbose, **kwargs) - return [self.decode(x["tokens"]) for x in hypos] - def generate( self, - tokens: List[torch.LongTensor], - beam: int = 5, - verbose: bool = False, + tokenized_sentences: List[torch.LongTensor], + *args, + inference_step_args=None, **kwargs - ) -> torch.LongTensor: - sample = self._build_sample(tokens) - - # build generator using current args as well as any kwargs - gen_args = copy.copy(self.cfg) - with open_dict(gen_args): - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) - generator = self.task.build_generator([self.model], gen_args) - translations = self.task.inference_step( - generator, - [self.model], - sample, - prefix_tokens=sample["net_input"]["src_tokens"] - .new_zeros((len(tokens), 1)) - .fill_(self.task.source_dictionary.bos()), + ) -> List[List[Dict[str, torch.Tensor]]]: + inference_step_args = inference_step_args or {} + if "prefix_tokens" in inference_step_args: + raise NotImplementedError("prefix generation not implemented for BART") + else: + bsz = len(tokenized_sentences) + inference_step_args["prefix_tokens"] = tokenized_sentences[0].new_full( + (bsz, 1), fill_value=self.task.source_dictionary.bos() + ).to(device=self.device) + return super().generate( + tokenized_sentences, + *args, + inference_step_args=inference_step_args, + **kwargs ) - if verbose: - src_str_with_unk = self.string(tokens) - logger.info("S\t{}".format(src_str_with_unk)) - - def getarg(name, default): - return getattr(gen_args, name, getattr(self.args, name, default)) - - # Process top predictions - hypos = [x[0] for x in translations] - hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))] - return hypos - def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False ) -> torch.Tensor: @@ -201,3 +162,40 @@ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = Fal if return_logits: return logits return F.log_softmax(logits, dim=-1) + + def fill_mask( + self, + masked_input: str, + topk: int = 5, + match_source_len: bool = True, + **generate_kwargs + ): + masked_token = '' + assert masked_token in masked_input, \ + "please add one {} token for the input".format(masked_token) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = (' {0} '.format(masked_token)).join( + [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] + ).strip() + tokens = self.task.source_dictionary.encode_line( + ' ' + text_spans_bpe + ' ', + append_eos=False, + add_if_not_exist=False, + ).long() + + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + + # ensure beam size is at least as big as topk + generate_kwargs['beam'] = max( + topk, + generate_kwargs.get('beam', -1), + ) + generate_kwargs['match_source_len'] = match_source_len + hypos = self.generate(tokens, **generate_kwargs)[0] + + return [ + (self.decode(hypo['tokens']), hypo['score']) + for hypo in hypos[:topk] + ] diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ddfb67853f..9c5423e2b1 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -293,7 +293,6 @@ def _generate( for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams - # print(f'step: {step}') if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences @@ -635,12 +634,11 @@ def finalize_hypos( else: cum_unfin.append(prev) - # set() is not supported in script export - # The keys here are of the form "{sent}_{unfin_idx}", where # "unfin_idx" is the index in the current (possibly reduced) # list of sentences, and "sent" is the index in the original, # unreduced batch + # set() is not supported in script export sents_seen: Dict[str, Optional[Tensor]] = {} # For every finished beam item @@ -651,7 +649,6 @@ def finalize_hypos( unfin_idx = idx // beam_size # sentence index in the original (unreduced) batch sent = unfin_idx + cum_unfin[unfin_idx] - # print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}") # Cannot create dict for key type '(int, int)' in torchscript. # The workaround is to cast int to string seen = str(sent.item()) + "_" + str(unfin_idx.item()) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index 3e88bf0ed0..41bddc1a05 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -11,6 +11,10 @@ AppendTokenDataset, DenoisingDataset, Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, PrependTokenDataset, StripTokenDataset, TokenBlockDataset, @@ -18,6 +22,7 @@ ) from fairseq.data.encoders.utils import get_whole_word_mask from fairseq.tasks import LegacyFairseqTask, register_task +import numpy as np logger = logging.getLogger(__name__) @@ -195,6 +200,41 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): ) ) + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.args.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.args.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) From e0737c3c2985d2a71f0a30bb29f6d8741b4f87f3 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 22 Oct 2020 12:45:19 -0700 Subject: [PATCH 243/707] Dynamically generate versions based on commit hash (#2774) Summary: This will produce version strings like `1.0.0a0+3065963`, similar to PyTorch version strings. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2774 Reviewed By: alexeib Differential Revision: D24453517 Pulled By: myleott fbshipit-source-id: 03a0c324ed6124bbc513ba7edc954abd71d63a0f --- .gitignore | 1 + docs/conf.py | 7 ++++--- examples/__init__.py | 4 +--- fairseq/__init__.py | 13 ++++++++++--- fairseq/version.txt | 1 + setup.py | 24 +++++++++++++++++++++++- 6 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 fairseq/version.txt diff --git a/.gitignore b/.gitignore index a7c4577149..9546cffd90 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ ENV/ /fairseq/temporal_convolution_tbc /fairseq/modules/*_layer/*_forward.cu /fairseq/modules/*_layer/*_backward.cu +/fairseq/version.py # data data-bin/ diff --git a/docs/conf.py b/docs/conf.py index 52971a27e7..440784bfae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,6 +19,7 @@ import os import sys +from fairseq import __version__ # source code directory, relative to this file, for sphinx-autobuild @@ -51,7 +52,7 @@ # General information about the project. project = "fairseq" -copyright = "2019, Facebook AI Research (FAIR)" +copyright = "Facebook AI Research (FAIR)" author = "Facebook AI Research (FAIR)" github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/" @@ -61,9 +62,9 @@ # built documents. # # The short X.Y version. -version = "0.9.0" +version = __version__ # The full version, including alpha/beta/rc tags. -release = "0.9.0" +release = __version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/examples/__init__.py b/examples/__init__.py index 9a6b08a75b..80d95f5fe7 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -3,6 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -__version__ = "0.9.0" - -import examples.noisychannel # noqa +from fairseq.version import __version__ # noqa diff --git a/fairseq/__init__.py b/fairseq/__init__.py index cac3d0e43b..4ccfc90257 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -4,11 +4,18 @@ # LICENSE file in the root directory of this source tree. """isort:skip_file""" -__all__ = ["pdb"] -__version__ = "1.0.0a0" - +import os import sys +try: + from .version import __version__ # noqa +except ImportError: + version_txt = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_txt) as f: + __version__ = f.read().strip() + +__all__ = ["pdb"] + # backwards compatibility to support `from fairseq.meters import AverageMeter` from fairseq.logging import meters, metrics, progress_bar # noqa diff --git a/fairseq/version.txt b/fairseq/version.txt new file mode 100644 index 0000000000..41432f00d9 --- /dev/null +++ b/fairseq/version.txt @@ -0,0 +1 @@ +1.0.0a0 diff --git a/setup.py b/setup.py index edc0196714..7b13f13e4c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import os +import subprocess import sys +from setuptools import setup, find_packages, Extension from setuptools import Extension, find_packages, setup @@ -14,6 +16,26 @@ sys.exit("Sorry, Python >= 3.6 is required for fairseq.") +def write_version_py(): + with open(os.path.join("fairseq", "version.txt")) as f: + version = f.read().strip() + + # append latest commit hash to version string + try: + sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + version += "+" + sha[:7] + except Exception: + pass + + # write version info to fairseq/version.py + with open(os.path.join("fairseq", "version.py"), "w") as f: + f.write("__version__ = \"{}\"\n".format(version)) + return version + + +version = write_version_py() + + with open("README.md") as f: readme = f.read() @@ -130,7 +152,7 @@ def include_dirs(self, dirs): def do_setup(package_data): setup( name="fairseq", - version="0.9.0", + version=version, description="Facebook AI Research Sequence-to-Sequence Toolkit", url="https://github.com/pytorch/fairseq", classifiers=[ From cd2bba4419629ffc17eb83c669e88b0bd3af6eb9 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 22 Oct 2020 16:30:07 -0700 Subject: [PATCH 244/707] rename remove_bpe to post_process; add aliasing (#1369) Summary: some binaries (e.g. speech based ones) used --post-process, some used --remove-bpe. --post-process seems more appropriate as it does more than just remove bpe at the moment. this renames remove_bpe to post_process, adds alias so existing command lines would work and adds checkpoint upgrades so they continue to work also. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1369 Reviewed By: myleott Differential Revision: D24465040 Pulled By: alexeib fbshipit-source-id: 1b3e388291ccc403e76e069ef6606b80ead863a7 --- examples/criss/save_encoder.py | 4 ++-- examples/noisychannel/rerank.py | 8 ++++---- examples/noisychannel/rerank_generate.py | 12 ++++++------ examples/noisychannel/rerank_options.py | 2 +- examples/noisychannel/rerank_score_lm.py | 2 +- examples/speech_recognition/infer.py | 4 ++-- fairseq/checkpoint_utils.py | 3 +++ fairseq/criterions/ctc.py | 4 ++-- fairseq/dataclass/data_class.py | 15 +++++++++++---- fairseq/dataclass/utils.py | 10 +++++++++- fairseq_cli/eval_lm.py | 6 +++--- fairseq_cli/generate.py | 10 +++++----- fairseq_cli/interactive.py | 6 +++--- 13 files changed, 52 insertions(+), 34 deletions(-) diff --git a/examples/criss/save_encoder.py b/examples/criss/save_encoder.py index 4d0f17f0f2..d911d066e3 100644 --- a/examples/criss/save_encoder.py +++ b/examples/criss/save_encoder.py @@ -133,7 +133,7 @@ def main(args): sample, prefix_tokens, src_dict, - args.remove_bpe, + args.post_process, has_langtok=encoder_has_langtok, ) if all_avg_pool is not None: @@ -158,7 +158,7 @@ def main(args): ) else: if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, args.post_process) else: src_str = "" diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index 13036926e0..b5ffd1ca34 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -295,7 +295,7 @@ def load_score_files(args): predictions_bpe_file = args.nbest_list gen_output = rerank_utils.BitextOutputFromGen( predictions_bpe_file, - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac, @@ -308,7 +308,7 @@ def load_score_files(args): score1_file, args.backwards1, args.right_to_left1, - args.remove_bpe, + args.post_process, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, @@ -322,7 +322,7 @@ def load_score_files(args): score2_file, args.backwards2, args.right_to_left2, - args.remove_bpe, + args.post_process, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, @@ -346,7 +346,7 @@ def load_score_files(args): lm_score_file, args.lm_dict, args.prefix_len, - args.remove_bpe, + args.post_process, args.target_prefix_frac, ) else: diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py index 4356b3387e..d512088de8 100644 --- a/examples/noisychannel/rerank_generate.py +++ b/examples/noisychannel/rerank_generate.py @@ -163,7 +163,7 @@ def gen_and_reprocess_nbest(args): gen_output = rerank_utils.BitextOutputFromGen( predictions_bpe_file, - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac, @@ -248,7 +248,7 @@ def gen_and_reprocess_nbest(args): pre_gen + rescore_file + "." + args.source_lang, pre_gen + rescore_file + "." + args.target_lang, pre_gen + "/reference_file", - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, ) if args.prefix_len is not None: bw_rescore_file = prefix_len_rescore_file @@ -260,7 +260,7 @@ def gen_and_reprocess_nbest(args): pre_gen + prefix_len_rescore_file + "." + args.target_lang, pre_gen + "/reference_file", prefix_len=args.prefix_len, - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, ) elif args.target_prefix_frac is not None: bw_rescore_file = target_prefix_frac_rescore_file @@ -277,7 +277,7 @@ def gen_and_reprocess_nbest(args): + "." + args.target_lang, pre_gen + "/reference_file", - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, target_prefix_frac=args.target_prefix_frac, ) else: @@ -298,7 +298,7 @@ def gen_and_reprocess_nbest(args): + "." + args.target_lang, pre_gen + "/reference_file", - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, source_prefix_frac=args.source_prefix_frac, ) else: @@ -313,7 +313,7 @@ def gen_and_reprocess_nbest(args): pre_gen + "/right_to_left_rescore_data." + args.target_lang, pre_gen + "/right_to_left_reference_file", right_to_left=True, - bpe_symbol=args.remove_bpe, + bpe_symbol=args.post_process, ) print("STEP 3: binarize the translations") diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py index ca7a2e0a61..de91939e66 100644 --- a/examples/noisychannel/rerank_options.py +++ b/examples/noisychannel/rerank_options.py @@ -64,7 +64,7 @@ def add_reranking_args(parser): help='whether the first model group is a right to left model') group.add_argument('--right-to-left2', action='store_true', help='whether the second model group is a right to left model') - group.add_argument('--remove-bpe', '--post-process', default='@@ ', + group.add_argument('--post-process', '--remove-bpe', default='@@ ', help='the bpe symbol, used for the bitext and LM') group.add_argument('--prefix-len', default=None, type=int, help='the length of the target prefix to use in rescoring (in terms of words wo bpe)') diff --git a/examples/noisychannel/rerank_score_lm.py b/examples/noisychannel/rerank_score_lm.py index fa3aa64462..89ebf61cce 100644 --- a/examples/noisychannel/rerank_score_lm.py +++ b/examples/noisychannel/rerank_score_lm.py @@ -37,7 +37,7 @@ def score_lm(args): predictions_bpe_file = args.nbest_list gen_output = rerank_utils.BitextOutputFromGen( - predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest + predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest ) if args.language_model is not None: diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index cc04225035..1570177cc6 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -121,7 +121,7 @@ def process_predictions( if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: - hyp_words = post_process(hyp_pieces, args.remove_bpe) + hyp_words = post_process(hyp_pieces, args.post_process) if res_files is not None: print( @@ -134,7 +134,7 @@ def process_predictions( ) tgt_pieces = tgt_dict.string(target_tokens) - tgt_words = post_process(tgt_pieces, args.remove_bpe) + tgt_words = post_process(tgt_pieces, args.post_process) if res_files is not None: print( diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index c036e12966..f8a5855622 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -451,6 +451,9 @@ def _upgrade_state_dict(state): state["extra_state"]["train_iterator"].get("epoch", 1), 1 ) + if hasattr(state["args"], "remove_bpe"): + state["args"].post_process = state["args"].remove_bpe + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 9310024f29..6b77ce47eb 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -22,7 +22,7 @@ def __init__(self, args, task): self.blank_idx = task.target_dictionary.bos() self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = args.remove_bpe if args.remove_bpe else "letter" + self.post_process = args.post_process if args.post_process else "letter" if args.wer_args is not None: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder @@ -57,8 +57,8 @@ def add_args(parser): ) try: parser.add_argument( - "--remove-bpe", "--post-process", + "--remove-bpe", default="letter", help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)", ) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index b0c17ba0ad..bcb802e651 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -296,7 +296,11 @@ class DatasetConfig(FairseqDataclass): default=None, metadata={"help": "maximum number of tokens in a batch"} ) batch_size: Optional[int] = field( - default=None, metadata={"help": "number of examples in a batch"} + default=None, + metadata={ + "help": "number of examples in a batch", + "argparse_alias": "--max-sentences", + }, ) required_batch_size_multiple: int = field( default=8, metadata={"help": "batch size will be a multiplier of this value"} @@ -349,7 +353,8 @@ class DatasetConfig(FairseqDataclass): batch_size_valid: Optional[int] = field( default=None, metadata={ - "help": "batch size of the validation batch" " (defaults to --batch-size)" + "help": "batch size of the validation batch" " (defaults to --batch-size)", + "argparse_alias": "--max-sentences-valid", }, ) curriculum: int = field( @@ -720,11 +725,13 @@ class CommonEvalConfig(FairseqDataclass): default=None, metadata={"help": "path(s) to model file(s), colon separated"}, ) - remove_bpe: Optional[str] = field( + post_process: Optional[str] = field( default=None, metadata={ - "help": "remove BPE tokens before scoring (can be set to sentencepiece)", + "help": "post-process text by removing pre-processing such as BPE, letter segmentation, etc " + "(valid options are: sentencepiece, wordpiece, letter, _EOW, none, otherwise treated as BPE symbol)", "argparse_const": "@@ ", + "argparse_alias": "--remove-bpe", }, ) quiet: bool = field(default=False, metadata={"help": "only print final scores"}) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f4431db82a..9c501c5b00 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -92,6 +92,9 @@ def _get_help(self, attribute_name: str) -> Any: def _get_argparse_const(self, attribute_name: str) -> Any: return self._get_meta(attribute_name, "argparse_const") + def _get_argparse_alias(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_alias") + def _get_choices(self, attribute_name: str) -> Any: return self._get_meta(attribute_name, "choices") @@ -197,6 +200,11 @@ def get_kwargs_from_dc( continue kwargs = get_kwargs_from_dc(dataclass_instance, k) + field_args = [field_name] + alias = dataclass_instance._get_argparse_alias(k) + if alias is not None: + field_args.append(alias) + if "default" in kwargs: if isinstance(kwargs["default"], str) and kwargs["default"].startswith( "${" @@ -205,7 +213,7 @@ def get_kwargs_from_dc( if delete_default: del kwargs["default"] try: - parser.add_argument(field_name, **kwargs) + parser.add_argument(*field_args, **kwargs) except ArgumentError: pass diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 4621a66acd..b70c0d3a77 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -159,11 +159,11 @@ def main(cfg: DictConfig, override_args=None, **unused_kwargs): score_sum = 0.0 count = 0 - if cfg.common_eval.remove_bpe is not None: - if cfg.common_eval.remove_bpe == "sentencepiece": + if cfg.common_eval.post_process is not None: + if cfg.common_eval.post_process == "sentencepiece": raise NotImplementedError else: - bpe_cont = cfg.common_eval.remove_bpe.rstrip() + bpe_cont = cfg.common_eval.post_process.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 6a6f7465cb..f7260e125e 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -236,13 +236,13 @@ def decode_fn(x): ) else: if src_dict is not None: - src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, - cfg.common_eval.remove_bpe, + cfg.common_eval.post_process, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator @@ -267,7 +267,7 @@ def decode_fn(x): alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=cfg.common_eval.remove_bpe, + remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) @@ -336,7 +336,7 @@ def decode_fn(x): # Score only the top hypothesis if has_target and j == 0: - if align_dict is not None or cfg.common_eval.remove_bpe is not None: + if align_dict is not None or cfg.common_eval.post_process is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True @@ -367,7 +367,7 @@ def decode_fn(x): ) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: - if cfg.common_eval.remove_bpe: + if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 2feff950ab..85607d8f44 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -260,13 +260,13 @@ def decode_fn(x): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: - src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print( "C-{}\t{}".format( - id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe) + id_, tgt_dict.string(constraint, cfg.common_eval.post_process) ) ) @@ -278,7 +278,7 @@ def decode_fn(x): alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=cfg.common_eval.remove_bpe, + remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) From 2409d5a36e074fe237a734fa2053867fe62b5e01 Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 23 Oct 2020 00:05:52 -0700 Subject: [PATCH 245/707] =?UTF-8?q?refactor=20dataclass=20related=20files,?= =?UTF-8?q?=20add=20proper=20types=20for=20static=20checkin=E2=80=A6=20(#1?= =?UTF-8?q?371)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: - refactor dataclass/ hierarchy to make it a bit more sane (while avoiding circular references) - add top level FairseqConfig - change typehints to reflect the correct config type if it is known Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1371 Reviewed By: myleott Differential Revision: D24469026 Pulled By: alexeib fbshipit-source-id: 01f68918f761d51ec5216286b8959ad35f41a7b2 --- fairseq/checkpoint_utils.py | 7 +- fairseq/criterions/adaptive_loss.py | 4 +- fairseq/data/indexed_dataset.py | 4 +- fairseq/dataclass/__init__.py | 3 +- .../dataclass/{data_class.py => configs.py} | 242 +++++++----------- fairseq/dataclass/constants.py | 24 +- fairseq/dataclass/initialize.py | 48 ++++ fairseq/dataclass/utils.py | 179 +++++++------ fairseq/distributed_utils.py | 16 +- fairseq/model_parallel/megatron_trainer.py | 4 +- fairseq/optim/bmuf.py | 34 +-- fairseq/options.py | 2 +- fairseq/scoring/tokenizer.py | 2 +- fairseq/trainer.py | 4 +- fairseq_cli/eval_lm.py | 2 +- fairseq_cli/generate.py | 2 +- fairseq_cli/interactive.py | 6 +- fairseq_cli/train.py | 2 +- fairseq_cli/validate.py | 2 +- 19 files changed, 306 insertions(+), 281 deletions(-) rename fairseq/dataclass/{data_class.py => configs.py} (84%) create mode 100644 fairseq/dataclass/initialize.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index f8a5855622..fdee84c181 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -13,6 +13,7 @@ from typing import Optional, Union import torch +from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, overwrite_args_by_name, @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) -def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss): +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir @@ -130,7 +131,7 @@ def is_better(a, b): os.remove(old_chk) -def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args): +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. @@ -339,7 +340,7 @@ def torch_persistent_save(obj, f): def save_state( filename, - cfg: DictConfig, + cfg: FairseqConfig, model_state_dict, criterion, optimizer, diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 04832295ec..15ad9a15bf 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -11,7 +11,7 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.constants import DDP_BACKEND_CHOICES -from omegaconf import II, DictConfig +from omegaconf import II @dataclass @@ -31,7 +31,7 @@ def __init__(self, task, sentence_avg): self.sentence_avg = sentence_avg @classmethod - def build_criterion(cls, cfg: DictConfig, task): + def build_criterion(cls, cfg: AdaptiveLossConfig, task): if cfg.ddp_backend == "c10d": raise Exception( "AdaptiveLoss is not compatible with the c10d " diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 3efecab3a6..827754d848 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -3,13 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os import shutil import struct from functools import lru_cache import numpy as np import torch +from fairseq.dataclass.constants import DATASET_IMPL_CHOICES from fairseq.data.fasta_dataset import FastaDataset from fairseq.file_io import PathManager @@ -24,7 +24,7 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ["raw", "lazy", "cached", "mmap", "fasta"] + return list(map(str, DATASET_IMPL_CHOICES)) def infer_dataset_impl(path): diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py index 32870814d5..5c9004d3ba 100644 --- a/fairseq/dataclass/__init__.py +++ b/fairseq/dataclass/__init__.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .utils import ChoiceEnum, FairseqDataclass +from .configs import FairseqDataclass +from .constants import ChoiceEnum __all__ = ["FairseqDataclass", "ChoiceEnum"] diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/configs.py similarity index 84% rename from fairseq/dataclass/data_class.py rename to fairseq/dataclass/configs.py index bcb802e651..abcb9c4c48 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/configs.py @@ -3,15 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import sys -from argparse import Namespace from dataclasses import _MISSING_TYPE, dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, List, Optional import torch -from fairseq.data.indexed_dataset import get_available_dataset_impl + from fairseq.dataclass.constants import ( + DATASET_IMPL_CHOICES, DDP_BACKEND_CHOICES, DISTRIBUTED_WRAPPER_CHOICES, GENERATION_CONSTRAINTS_CHOICES, @@ -20,16 +19,64 @@ PIPELINE_CHECKPOINT_CHOICES, ZERO_SHARDING_CHOICES, ) -from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass -from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY -from fairseq.optim.bmuf import FairseqBMUFConfig -from fairseq.registry import REGISTRIES -from fairseq.tasks import TASK_DATACLASS_REGISTRY -from hydra.core.config_store import ConfigStore + from omegaconf import II -logger = logging.getLogger(__name__) +@dataclass +class FairseqDataclass: + """fairseq base dataclass that supported fetching attributes and metas""" + + _name: Optional[str] = None + + @staticmethod + def name(): + return None + + def _get_all_attributes(self) -> List[str]: + return [k for k in self.__dataclass_fields__.keys()] + + def _get_meta( + self, attribute_name: str, meta: str, default: Optional[Any] = None + ) -> Any: + return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) + + def _get_name(self, attribute_name: str) -> str: + return self.__dataclass_fields__[attribute_name].name + + def _get_default(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default + ): + return getattr(self, attribute_name) + + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + def _get_type(self, attribute_name: str) -> Any: + return self.__dataclass_fields__[attribute_name].type + + def _get_help(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "help") + + def _get_argparse_const(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_const") + + def _get_argparse_alias(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_alias") + + def _get_choices(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "choices") @dataclass @@ -311,7 +358,7 @@ class DatasetConfig(FairseqDataclass): "help": "maximum sequence length in batch will be a multiplier of this value" }, ) - dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field( + dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( default=None, metadata={"help": "output dataset implementation"} ) data_buffer_size: int = field( @@ -529,6 +576,33 @@ class CheckpointConfig(FairseqDataclass): distributed_rank: int = II("distributed_training.distributed_rank") +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II("distributed_training.distributed_world_size") + + @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( @@ -788,135 +862,15 @@ class InteractiveConfig(FairseqDataclass): ) -CONFIGS = { - "common": CommonConfig, - "common_eval": CommonEvalConfig, - "distributed_training": DistributedTrainingConfig, - "dataset": DatasetConfig, - "optimization": OptimizationConfig, - "checkpoint": CheckpointConfig, - "bmuf": FairseqBMUFConfig, - "generation": GenerationConfig, - "eval_lm": EvalLMConfig, - "interactive": InteractiveConfig, -} - - -def register_module_dataclass( - cs: ConfigStore, registry: Dict[str, Any], group: str -) -> None: - """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" - # note that if `group == model`, we register all model archs, not the model name. - for k, v in registry.items(): - node_ = v() - node_._name = k - cs.store(name=k, group=group, node=node_, provider="fairseq") - - -def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: - """cs: config store instance, register common training configs""" - - for k, v in CONFIGS.items(): - try: - cs.store(name=k, node=v()) - except BaseException: - logger.error(f"{k} - {v()}") - raise - - register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") - - for k, v in REGISTRIES.items(): - register_module_dataclass(cs, v["dataclass_registry"], k) - - -def _override_attr( - sub_node: str, data_class: Type[FairseqDataclass], args: Namespace -) -> List[str]: - overrides = [] - - def get_default(f): - if not isinstance(f.default_factory, _MISSING_TYPE): - return f.default_factory() - return f.default - - for k, v in data_class.__dataclass_fields__.items(): - if k.startswith("_"): - # private member, skip - continue - - val = get_default(v) if not hasattr(args, k) else getattr(args, k) - - if val is None: - overrides.append("{}.{}=null".format(sub_node, k)) - elif val == "": - overrides.append("{}.{}=''".format(sub_node, k)) - elif isinstance(val, str): - overrides.append("{}.{}='{}'".format(sub_node, k, val)) - else: - overrides.append("{}.{}={}".format(sub_node, k, val)) - return overrides - - -def migrate_registry( - name, value, registry, args, overrides, deletes, use_name_as_val=False -): - if value in registry: - overrides.append("{}={}".format(name, value)) - overrides.append("{}._name={}".format(name, value)) - overrides.extend(_override_attr(name, registry[value], args)) - elif use_name_as_val and value is not None: - overrides.append("{}={}".format(name, value)) - else: - deletes.append(name) - - -def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: - """use the field in args to overrides those in cfg""" - overrides = [] - deletes = [] - - for k, v in CONFIGS.items(): - overrides.extend(_override_attr(k, v, args)) - - if args is not None: - if hasattr(args, "task"): - migrate_registry( - "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes - ) - else: - deletes.append("task") - - # these options will be set to "None" if they have not yet been migrated - # so we can populate them with the entire flat args - CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} - - for k, v in REGISTRIES.items(): - if hasattr(args, k): - migrate_registry( - k, - getattr(args, k), - v["dataclass_registry"], - args, - overrides, - deletes, - use_name_as_val=k not in CORE_REGISTRIES, - ) - else: - deletes.append(k) - - no_dc = True - if hasattr(args, "arch"): - if args.arch in ARCH_MODEL_REGISTRY: - m_cls = ARCH_MODEL_REGISTRY[args.arch] - dc = getattr(m_cls, "__dataclass", None) - if dc is not None: - overrides.append("model={}".format(args.arch)) - overrides.append("model._name={}".format(args.arch)) - # override model params with those exist in args - overrides.extend(_override_attr("model", dc, args)) - no_dc = False - if no_dc: - deletes.append("model") - - return overrides, deletes +@dataclass +class FairseqConfig(object): + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + optimization: OptimizationConfig = OptimizationConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + generation: GenerationConfig = GenerationConfig() + eval_lm: EvalLMConfig = EvalLMConfig() + interactive: InteractiveConfig = InteractiveConfig() diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 2fd87f5fc4..3eb63ec609 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -3,13 +3,33 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.dataclass.utils import ChoiceEnum +from enum import Enum +from typing import List + + +class StrEnum(Enum): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) +DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) -GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( + ["unigram", "ensemble", "vote", "dp", "bs"] +) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py new file mode 100644 index 0000000000..1f755d9807 --- /dev/null +++ b/fairseq/dataclass/initialize.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import Dict, Any + +from hydra.core.config_store import ConfigStore + +from fairseq.dataclass.configs import FairseqConfig + +from fairseq.models import MODEL_DATACLASS_REGISTRY +from fairseq.tasks import TASK_DATACLASS_REGISTRY +from fairseq.registry import REGISTRIES + + +logger = logging.getLogger(__name__) + + +def register_module_dataclass( + cs: ConfigStore, registry: Dict[str, Any], group: str +) -> None: + """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" + # note that if `group == model`, we register all model archs, not the model name. + for k, v in registry.items(): + node_ = v() + node_._name = k + cs.store(name=k, group=group, node=node_, provider="fairseq") + + +def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: + """cs: config store instance, register common training configs""" + + for k in FairseqConfig.__dataclass_fields__: + v = FairseqConfig.__dataclass_fields__[k].default + try: + cs.store(name=k, node=v) + except BaseException: + logger.error(f"{k} - {v}") + raise + + register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") + register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") + + for k, v in REGISTRIES.items(): + register_module_dataclass(cs, v["dataclass_registry"], k) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 9c501c5b00..8dc51c01f5 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -5,10 +5,12 @@ import ast from argparse import ArgumentError, ArgumentParser, Namespace -from dataclasses import _MISSING_TYPE, MISSING, dataclass +from dataclasses import _MISSING_TYPE, MISSING from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Tuple, Type +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig from hydra.core.global_hydra import GlobalHydra from hydra.experimental import compose, initialize from omegaconf import DictConfig, OmegaConf, open_dict @@ -27,78 +29,6 @@ def eval_str_list(x, x_type=float): return [x_type(x)] -class StrEnum(Enum): - def __str__(self): - return self.value - - def __eq__(self, other: str): - return self.value == other - - def __repr__(self): - return self.value - - -def ChoiceEnum(choices: List[str]): - """return the Enum class used to enforce list of choices""" - return StrEnum("Choices", {k: k for k in choices}) - - -@dataclass -class FairseqDataclass: - """fairseq base dataclass that supported fetching attributes and metas""" - - _name: Optional[str] = None - - @staticmethod - def name(): - return None - - def _get_all_attributes(self) -> List[str]: - return [k for k in self.__dataclass_fields__.keys()] - - def _get_meta( - self, attribute_name: str, meta: str, default: Optional[Any] = None - ) -> Any: - return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) - - def _get_name(self, attribute_name: str) -> str: - return self.__dataclass_fields__[attribute_name].name - - def _get_default(self, attribute_name: str) -> Any: - if hasattr(self, attribute_name): - if str(getattr(self, attribute_name)).startswith("${"): - return str(getattr(self, attribute_name)) - elif str(self.__dataclass_fields__[attribute_name].default).startswith( - "${" - ): - return str(self.__dataclass_fields__[attribute_name].default) - elif ( - getattr(self, attribute_name) - != self.__dataclass_fields__[attribute_name].default - ): - return getattr(self, attribute_name) - - f = self.__dataclass_fields__[attribute_name] - if not isinstance(f.default_factory, _MISSING_TYPE): - return f.default_factory() - return f.default - - def _get_type(self, attribute_name: str) -> Any: - return self.__dataclass_fields__[attribute_name].type - - def _get_help(self, attribute_name: str) -> Any: - return self._get_meta(attribute_name, "help") - - def _get_argparse_const(self, attribute_name: str) -> Any: - return self._get_meta(attribute_name, "argparse_const") - - def _get_argparse_alias(self, attribute_name: str) -> Any: - return self._get_meta(attribute_name, "argparse_alias") - - def _get_choices(self, attribute_name: str) -> Any: - return self._get_meta(attribute_name, "choices") - - def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, @@ -241,8 +171,107 @@ def _set_legacy_defaults(args, cls): setattr(args, key, default_value) +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): + # private member, skip + continue + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + if val is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif val == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(val, str): + overrides.append("{}.{}='{}'".format(sub_node, k, val)) + else: + overrides.append("{}.{}={}".format(sub_node, k, val)) + return overrides + + +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + for k in FairseqConfig.__dataclass_fields__.keys(): + overrides.extend( + _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) + ) + + if args is not None: + if hasattr(args, "task"): + from fairseq.tasks import TASK_DATACLASS_REGISTRY + + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes + ) + else: + deletes.append("task") + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + from fairseq.registry import REGISTRIES + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], + args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, + ) + else: + deletes.append(k) + + no_dc = True + if hasattr(args, "arch"): + from fairseq.models import ARCH_MODEL_REGISTRY + + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + overrides.append("model={}".format(args.arch)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") + + return overrides, deletes + + def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: - from fairseq.dataclass.data_class import override_module_args # Here we are using field values provided in args to override counterparts inside config object overrides, deletes = override_module_args(args) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 23cdfc6938..83b6d4d9d6 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -18,18 +18,18 @@ import torch import torch.distributed as dist from fairseq import utils -from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from omegaconf import DictConfig, open_dict +from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig +from omegaconf import open_dict logger = logging.getLogger(__name__) -def is_master(cfg: DictConfig): +def is_master(cfg: DistributedTrainingConfig): return cfg.distributed_rank == 0 -def infer_init_method(cfg: DictConfig, force_distributed=False): +def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): if cfg.distributed_init_method is not None or cfg.tpu: return @@ -197,8 +197,10 @@ def infer_init_method(cfg: DictConfig, force_distributed=False): ) -def distributed_init(cfg: DictConfig): +def distributed_init(cfg: FairseqConfig): if isinstance(cfg, Namespace): + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + cfg = convert_namespace_to_omegaconf(cfg) if not cfg.common.tpu: @@ -265,7 +267,7 @@ def distributed_init(cfg: DictConfig): return cfg.distributed_training.distributed_rank -def distributed_main(i, main, cfg: DictConfig, kwargs): +def distributed_main(i, main, cfg: FairseqConfig, kwargs): cfg.distributed_training.device_id = i if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: torch.cuda.set_device(cfg.distributed_training.device_id) @@ -281,7 +283,7 @@ def distributed_main(i, main, cfg: DictConfig, kwargs): main(cfg, **kwargs) -def call_main(cfg: DictConfig, main, **kwargs): +def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: infer_init_method(cfg.distributed_training) diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 258551c933..cf83685862 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -9,7 +9,7 @@ from fairseq import distributed_utils from fairseq.trainer import Trainer -from omegaconf import DictConfig +from fairseq.dataclass.configs import FairseqConfig try: @@ -29,7 +29,7 @@ class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training.""" - def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs): + def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 55f225ba6a..d6d0e04e86 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -7,39 +7,9 @@ import torch import torch.distributed as dist -from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqBMUFConfig from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer -from omegaconf import II, DictConfig - - -@dataclass -class FairseqBMUFConfig(FairseqDataclass): - block_lr: float = field( - default=1, metadata={"help": "block learning rate for bmuf"} - ) - block_momentum: float = field( - default=0.875, metadata={"help": "block momentum for bmuf"} - ) - global_sync_iter: int = field( - default=50, metadata={"help": "Iteration for syncing global model"} - ) - warmup_iterations: int = field( - default=500, metadata={"help": "warmup iterations for model to broadcast"} - ) - use_nbm: bool = field( - default=False, - metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, - ) - average_sync: bool = field( - default=False, - metadata={ - "help": "Specify whether you want to average the local momentum after each sync" - }, - ) - distributed_world_size: int = II( - "distributed_training.distributed_world_size" - ) class FairseqBMUF(FairseqOptimizer): @@ -52,7 +22,7 @@ class FairseqBMUF(FairseqOptimizer): model-update filtering """ - def __init__(self, cfg: DictConfig, optimizer): + def __init__(self, cfg: FairseqBMUFConfig, optimizer): super().__init__(cfg) self._optimizer = optimizer self._num_updates = 0 diff --git a/fairseq/options.py b/fairseq/options.py index 58e5e46190..f2a3e7cfb1 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -9,7 +9,7 @@ import torch from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl -from fairseq.dataclass.data_class import ( +from fairseq.dataclass.configs import ( CheckpointConfig, CommonConfig, CommonEvalConfig, diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py index 0d0702bf15..61cf6d4a7c 100644 --- a/fairseq/scoring/tokenizer.py +++ b/fairseq/scoring/tokenizer.py @@ -5,7 +5,7 @@ import unicodedata -from fairseq.dataclass.utils import ChoiceEnum +from fairseq.dataclass import ChoiceEnum class EvaluationTokenizer(object): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 8b00e8b431..a4d273ca67 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -17,12 +17,12 @@ import torch from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler -from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class Trainer(object): communication of the gradients across workers. """ - def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None): + def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if isinstance(cfg, Namespace): logger.warning( diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index b70c0d3a77..a34e5e096e 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -16,7 +16,7 @@ import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset -from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index f7260e125e..82c23a3776 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -19,7 +19,7 @@ import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.data import encoders -from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 85607d8f44..6921f551ca 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -21,13 +21,13 @@ import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders -from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output from hydra.core.config_store import ConfigStore from hydra.experimental import initialize -from omegaconf import DictConfig logging.basicConfig( @@ -115,7 +115,7 @@ def encode_fn_target(x): ) -def main(cfg: DictConfig): +def main(cfg: FairseqConfig): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 4c00761060..7e4e4e3071 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -32,7 +32,7 @@ from fairseq.model_parallel.megatron_trainer import MegatronTrainer from omegaconf import DictConfig from hydra.experimental import initialize -from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.trainer import Trainer diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 368c9cb581..90c2b84f73 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -13,7 +13,7 @@ import torch from fairseq import checkpoint_utils, distributed_utils, options, utils -from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar from hydra.core.config_store import ConfigStore From 4b0cf6649bc65093fe2d091ecdd4150bc00ec64f Mon Sep 17 00:00:00 2001 From: Shashank Jain Date: Fri, 23 Oct 2020 14:43:59 -0700 Subject: [PATCH 246/707] Revert "Fix deprecated usage of nonzero()" Summary: Reverting the diff because it has already been fixed in https://github.com/pytorch/pytorch/pull/45413 Reviewed By: myleott Differential Revision: D24511658 fbshipit-source-id: a5561dae50d69a03443ca8a60bebe2cd064e3ee0 --- fairseq/data/mask_tokens_dataset.py | 2 +- fairseq/models/bart/hub_interface.py | 2 +- fairseq/models/roberta/hub_interface.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 9e2c7119d8..8ea86245f7 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -112,7 +112,7 @@ def __getitem__(self, index: int): if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) - word_begins_idx = word_begins_mask.nonzero(as_tuple=False).view(-1) + word_begins_idx = word_begins_mask.nonzero().view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 819ea8eeda..4c5fd0b585 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -69,7 +69,7 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens[1:] # remove eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] - sentences = np.split(tokens, doc_mask.nonzero(as_tuple=False)[0] + 1) + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) sentences = [ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences ] diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index d6322c30e8..0c723f06dd 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -71,7 +71,7 @@ def decode(self, tokens: torch.LongTensor): tokens = tokens[1:] # remove eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] - sentences = np.split(tokens, doc_mask.nonzero(as_tuple=False)[0] + 1) + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) sentences = [ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences ] @@ -173,7 +173,7 @@ def fill_mask(self, masked_input: str, topk: int = 5): add_if_not_exist=False, ) - masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False) + masked_index = (tokens == self.task.mask_idx).nonzero() if tokens.dim() == 1: tokens = tokens.unsqueeze(0) From c147060598f69385a1c2c05bc97dd43b56d73575 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 24 Oct 2020 10:20:07 -0700 Subject: [PATCH 247/707] add new w2v models (#1373) Summary: update readme to add new wav2vec models (incl w/ self training) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1373 Reviewed By: michaelauli Differential Revision: D24524182 Pulled By: alexeib fbshipit-source-id: c918971f8009b11855908e71bfcc247cf6776a8f --- examples/wav2vec/README.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 518d8f86cb..22d2225fcd 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -2,6 +2,8 @@ wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). +We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). + ## Pre-trained models Model | Finetuning split | Dataset | Model @@ -14,10 +16,15 @@ Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) -Wav2Vec 2.0 Large (LV-60) | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox.pt) -Wav2Vec 2.0 Large (LV-60) | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m.pt) -Wav2Vec 2.0 Large (LV-60) | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h.pt) -Wav2Vec 2.0 Large (LV-60) | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h.pt) +Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_pl.pt) + +\* updated (Oct. 24, 2020) ## Training a new model with the CLI tools From 6ee0364685fca0ac5cc2721b193f396166a32646 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 24 Oct 2020 21:18:44 -0700 Subject: [PATCH 248/707] fix building components when no configuration is provided (#1374) Summary: see title, in particular fixes evaluating generate.py with --scoring wer Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1374 Reviewed By: kahne Differential Revision: D24527059 Pulled By: alexeib fbshipit-source-id: b01994441fda12eafd4e465d147047c6e84a8335 --- fairseq/registry.py | 2 ++ fairseq/scoring/__init__.py | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fairseq/registry.py b/fairseq/registry.py index 4446084d4a..96994cb8d4 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -31,6 +31,8 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs) choice = cfg._name elif isinstance(cfg, str): choice = cfg + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice]() else: choice = getattr(cfg, registry_name, None) if choice in DATACLASS_REGISTRY: diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 8c706cb585..9163be87e7 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -18,10 +18,6 @@ def __init__(self, cfg): self.ref = [] self.pred = [] - @staticmethod - def add_args(parser): - pass - def add_string(self, ref, pred): self.ref.append(ref) self.pred.append(pred) From 3c414780837dd3506ea82a868ea92628d1fdd576 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sun, 25 Oct 2020 12:53:07 -0700 Subject: [PATCH 249/707] fix loading emissions (#1375) Summary: broken in last change to infer.py Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1375 Reviewed By: xuqiantong Differential Revision: D24531499 Pulled By: alexeib fbshipit-source-id: fab60abf67a05c48e1ff750fac3ab6d4c0fa2770 --- examples/speech_recognition/infer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 1570177cc6..68889463f4 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -269,6 +269,7 @@ def main(args, task=None, model_state=None): # Load ensemble if args.load_emissions: models, criterions = [], [] + task = tasks.setup_task(args) else: logger.info("| loading model(s) from {}".format(args.path)) models, criterions, task = load_models_and_criterions( @@ -282,6 +283,7 @@ def main(args, task=None, model_state=None): # Load dataset splits task.load_dataset(args.gen_subset) + # Set dictionary tgt_dict = task.target_dictionary From 81677d751de120f69eef0c3eb36e849c977f7814 Mon Sep 17 00:00:00 2001 From: Vladimir Smirnov Date: Mon, 26 Oct 2020 08:17:12 -0700 Subject: [PATCH 250/707] Update README.md (#2796) Summary: Fixed link. # Before submitting - [-] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [+] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [+] Did you make sure to update the docs? - [-] Did you write any new necessary tests? ## What does this PR do? Fixes link. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2796 Reviewed By: nlaptev Differential Revision: D24538759 Pulled By: myleott fbshipit-source-id: af947f432c34ca2aec35c9fe59dd1214e363450b --- examples/wav2vec/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 22d2225fcd..1da42f388a 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -22,7 +22,7 @@ Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebo Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) \* updated (Oct. 24, 2020) From beeac0ad68594b07594f565bfd5cb6f4f46cd816 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Tue, 27 Oct 2020 02:13:47 -0700 Subject: [PATCH 251/707] Get 12B M2M-100 model generation to work correctly on exactly 2 32gb gpus (#1366) Summary: # What does this PR do? Addresses https://github.com/pytorch/fairseq/issues/2772 where external users can't generate using the model because the README is currently not accurate. This PR fixes the issues in the README Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1366 Reviewed By: edunov Differential Revision: D24455634 Pulled By: shruti-bh fbshipit-source-id: 480a11f8b95d1278162d585700e58d467a35d35a --- examples/m2m_100/README.md | 39 ++++++++++++----- .../pipeline_parallel_transformer/model.py | 42 +++++++------------ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index a87c0f5748..0bacd4c8b1 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -116,11 +116,25 @@ If you use any of the resources listed here, please cite: ## Trained Models -Looking for other trained models? Check back soon. +More models coming up soon. -Model | Description | Download ----|---|--- -`12b_last_checkpoint` | 12B parameter model trained on many-to-many training data for 100 languages | [12b_last_checkpoint](https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt) +### 12B Model +12B parameter model trained on many-to-many training data for 100 languages. We include the last checkpoint, average of last 5 checkpoints, average of last 10 checkpoints. There isn't a universally best choice out of these three, but all three versions are pretty close in accuracy. You can either sweep over the 3 checkpoints on a dev test and use the best performing checkpoint for final testing. Or the last checkpoint can be a good default choice. + +**Model Download Links** +Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs +:--|:--|:--|:--|:-- +Last Checkpoint | [12b_last_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_2_gpus.pt) | [12b_last_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt) | [12b_last_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_6_gpus.pt) | [12b_last_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_8_gpus.pt) +Average of last 5 checkpoints | [12b_avg5_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_2_gpus.pt) | [12b_avg5_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_4_gpus.pt) | [12b_avg5_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_6_gpus.pt) | [12b_avg5_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_8_gpus.pt) +Average of last 10 checkpoints | [12b_avg10_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_2_gpus.pt) | [12b_avg10_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_4_gpus.pt) | [12b_avg10_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_6_gpus.pt) | [12b_avg10_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_8_gpus.pt) + +**Generation Arguments** +Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs +:--|:--|:--|:--|:-- +`--pipeline-encoder-balance` | `[26]` | `[1,15,10]` | `[1,9,9,7]` | `[1,6,6,6,7]` +`--pipeline-encoder-devices` | `[0]` | `[0,1,0]` | `[0,1,2,0]` | `[0,4,5,1,0]` +`--pipeline-decoder-balance` | `[3,22,1]` | `[3,11,11,1]` | `[3,7,7,8,1]` | `[1,6,6,6,6,1]` +`--pipeline-decoder-devices` | `[0,1,0]` | `[0,2,3,0]` | `[0,3,4,5,0]` | `[0,2,6,7,3,0]` ## SentencePiece Model @@ -162,16 +176,19 @@ fairseq-preprocess \ --srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt ``` -### Generation on a V100 GPU +### Generation for the 12B model + +Note that generation can currently be run using 2 32GB / 4 16GB / 6 12GB / 8 8GB GPUs, and the corresponding model checkpoints and pipeline arguments can be found in the [12B Model Section](#12b-model). +Generation on CPUs will be added in the future. ```bash wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt -wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt +wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt fairseq-generate \ data_bin \ --batch-size 1 \ - --path 12b_last_checkpoint.pt \ + --path 12b_last_chk_4_gpus.pt \ --fixed-dictionary model_dict.128k.txt \ -s de -t fr \ --remove-bpe 'sentencepiece' \ @@ -185,10 +202,10 @@ fairseq-generate \ --distributed-world-size 1 --distributed-no-spawn \ --pipeline-model-parallel \ --pipeline-chunks 1 \ - --pipeline-encoder-balance '[26]' \ - --pipeline-encoder-devices '[0]' \ - --pipeline-decoder-balance '[1,24,1]' \ - --pipeline-decoder-devices '[0,1,0]' > gen_out + --pipeline-encoder-balance '[1,15,10]' \ + --pipeline-encoder-devices '[0,1,0]' \ + --pipeline-decoder-balance '[3,11,11,1]' \ + --pipeline-decoder-devices '[0,2,3,0]' > gen_out ``` ## Evaluation with M2M-100 diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 76cfe3b0b4..7873611214 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -111,9 +111,9 @@ def prepare_for_inference_(self, cfg): decoder_module_list.append(module) module_count += 1 self.model = None - self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list) + self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list) self.decoder = TransformerDecoder( - cfg.model, None, None, decoder_module_list=decoder_module_list + cfg.distributed_training, None, None, decoder_module_list=decoder_module_list ) @staticmethod @@ -320,7 +320,7 @@ def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder_max_positions - def load_state_dict(self, state_dict, strict=True, cfg=None): + def load_state_dict(self, state_dict, strict=True, model_cfg=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. @@ -429,17 +429,16 @@ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): from fairscale.nn import Pipe except ImportError: raise ImportError("Please install fairscale with: pip install fairscale") - if encoder_module_list is None: - embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) - layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + self.use_pipeline = encoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]) if isinstance(embed_tokens, nn.ModuleList): emb_dim = sum(e.embedding_dim for e in embed_tokens) else: emb_dim = embed_tokens.embedding_dim - final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) - encoder_module_list = [embedding_layer] + layers + [final_layer_norm] - self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None - if self.use_pipeline: + self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) + else: encoder_balance = utils.eval_str_list( args.pipeline_encoder_balance, type=int ) @@ -457,10 +456,6 @@ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) - else: - self.embedding_layer = encoder_module_list[0] - self.encoder_layers = nn.Sequential(*encoder_module_list[1:-1]) - self.final_layer_norm = encoder_module_list[-1] def forward(self, src_tokens, src_lengths): """ @@ -570,18 +565,17 @@ def __init__( from fairscale.nn import Pipe except ImportError: raise ImportError("Please install fairscale with: pip install fairscale") - if decoder_module_list is None: - embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) - layers = [ + self.use_pipeline = decoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + self.decoder_layers = nn.Sequential(*[ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) - ] - decoder_output_layer = TransformerDecoderOutputLayer( + ]) + self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary ) - decoder_module_list = [embedding_layer] + layers + [decoder_output_layer] - self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None - if self.use_pipeline: + else: decoder_balance = utils.eval_str_list( args.pipeline_decoder_balance, type=int ) @@ -599,10 +593,6 @@ def __init__( chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) - else: - self.embedding_layer = decoder_module_list[0] - self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1]) - self.decoder_output_layer = decoder_module_list[-1] def forward( self, From 01be083e46d2e4614dc274b0edf29d0ddd516186 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 27 Oct 2020 07:45:13 -0700 Subject: [PATCH 252/707] Centralize hydra init (and support packaged location of configs) (#2784) Summary: Configs can either be in `/fairseq/configs` (once the package is installed) or `/configs` (if using an editable installation). This centralizes the hydra init and supports these two possible config locations. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2784 Reviewed By: alexeib Differential Revision: D24513586 Pulled By: myleott fbshipit-source-id: 8e10a88177ebcf809d5d37d448d2b384142febef --- fairseq/__init__.py | 4 ++++ fairseq/dataclass/initialize.py | 4 ++-- fairseq/dataclass/utils.py | 19 +++++++++++-------- fairseq_cli/eval_lm.py | 6 ------ fairseq_cli/generate.py | 6 ------ fairseq_cli/interactive.py | 6 ------ fairseq_cli/train.py | 6 ------ fairseq_cli/validate.py | 6 ------ 8 files changed, 17 insertions(+), 40 deletions(-) diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 4ccfc90257..ccd45add79 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -23,6 +23,10 @@ sys.modules["fairseq.metrics"] = metrics sys.modules["fairseq.progress_bar"] = progress_bar +# initialize hydra +from fairseq.dataclass.initialize import hydra_init +hydra_init() + import fairseq.criterions # noqa import fairseq.models # noqa import fairseq.modules # noqa diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 1f755d9807..b762af990f 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -30,8 +30,8 @@ def register_module_dataclass( cs.store(name=k, group=group, node=node_, provider="fairseq") -def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: - """cs: config store instance, register common training configs""" +def hydra_init() -> None: + cs = ConfigStore.instance() for k in FairseqConfig.__dataclass_fields__: v = FairseqConfig.__dataclass_fields__[k].default diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 8dc51c01f5..5ce017d765 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import ast +import os from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum @@ -272,19 +273,21 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + """Convert a flat argparse.Namespace to a structured DictConfig.""" # Here we are using field values provided in args to override counterparts inside config object overrides, deletes = override_module_args(args) - cfg_name = "config" - cfg_path = f"../../{cfg_name}" + # configs will be in fairseq/config after installation + config_path = os.path.join("..", "config") + if not os.path.exists(config_path): + # in case of "--editable" installs we need to go one dir up + config_path = os.path.join("..", "..", "config") - if not GlobalHydra().is_initialized(): - initialize(config_path=cfg_path) - - composed_cfg = compose(cfg_name, overrides=overrides, strict=False) - for k in deletes: - composed_cfg[k] = None + with initialize(config_path=config_path, strict=True): + composed_cfg = compose("config", overrides=overrides, strict=False) + for k in deletes: + composed_cfg[k] = None cfg = OmegaConf.create( OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index a34e5e096e..efc9c4b5aa 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -16,13 +16,10 @@ import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset -from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer -from hydra.core.config_store import ConfigStore -from hydra.experimental import initialize from omegaconf import DictConfig @@ -288,7 +285,4 @@ def cli_main(): if __name__ == "__main__": - cs = ConfigStore.instance() - register_hydra_cfg(cs) - initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 82c23a3776..79b9ed8bf0 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -19,12 +19,9 @@ import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.data import encoders -from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter -from hydra.core.config_store import ConfigStore -from hydra.experimental import initialize from omegaconf import DictConfig @@ -393,7 +390,4 @@ def cli_main(): if __name__ == "__main__": - cs = ConfigStore.instance() - register_hydra_cfg(cs) - initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 6921f551ca..530830d6b0 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -22,12 +22,9 @@ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders from fairseq.dataclass.configs import FairseqConfig -from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output -from hydra.core.config_store import ConfigStore -from hydra.experimental import initialize logging.basicConfig( @@ -322,7 +319,4 @@ def cli_main(): if __name__ == "__main__": - cs = ConfigStore.instance() - register_hydra_cfg(cs) - initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 7e4e4e3071..ec10028f03 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -16,7 +16,6 @@ import numpy as np import torch -from hydra.core.config_store import ConfigStore from fairseq import ( checkpoint_utils, @@ -31,8 +30,6 @@ from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from omegaconf import DictConfig -from hydra.experimental import initialize -from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.trainer import Trainer @@ -353,7 +350,4 @@ def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] if __name__ == '__main__': - cs = ConfigStore.instance() - register_hydra_cfg(cs) - initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 90c2b84f73..7315b14e2a 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -13,11 +13,8 @@ import torch from fairseq import checkpoint_utils, distributed_utils, options, utils -from fairseq.dataclass.initialize import register_hydra_cfg from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar -from hydra.core.config_store import ConfigStore -from hydra.experimental import initialize from omegaconf import DictConfig @@ -140,7 +137,4 @@ def cli_main(): if __name__ == "__main__": - cs = ConfigStore.instance() - register_hydra_cfg(cs) - initialize(config_path="../config", strict=True) cli_main() From 1bc83c703ad70d7f62c1e54b197e29b95d07b1f0 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 27 Oct 2020 11:24:58 -0700 Subject: [PATCH 253/707] Misc fixes (#2786) Summary: - Rename type -> key in fairseq/tasks/sentence_prediction.py (fixes https://github.com/pytorch/fairseq/issues/2746) - Update preprocessing docs (fixes https://github.com/pytorch/fairseq/issues/2565) - Turn off logging in test_fp16_optimizer.TestGradientScaling - Documentation updates - Remove some unused code - Fix noisychannel example (fixes https://github.com/pytorch/fairseq/issues/2213) Pull Request resolved: https://github.com/pytorch/fairseq/pull/2786 Reviewed By: shruti-bh Differential Revision: D24515146 Pulled By: myleott fbshipit-source-id: 86b0f5516c57610fdca801c60e58158ef052fc3a --- docs/getting_started.rst | 15 ++++++++++++--- examples/noisychannel/rerank.py | 2 +- examples/noisychannel/rerank_generate.py | 2 +- examples/noisychannel/rerank_score_bw.py | 2 +- examples/noisychannel/rerank_score_lm.py | 2 +- examples/noisychannel/rerank_tune.py | 2 +- examples/roberta/README.md | 1 - fairseq/dataclass/configs.py | 2 +- fairseq/models/roberta/model.py | 5 ++++- fairseq/modules/transformer_layer.py | 9 --------- fairseq/options.py | 8 +++++--- fairseq/tasks/sentence_prediction.py | 10 +++++----- tests/test_fp16_optimizer.py | 5 +++++ 13 files changed, 37 insertions(+), 28 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index fa5971dd31..d227b95544 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -170,13 +170,14 @@ The easiest way to launch jobs is with the `torch.distributed.launch For example, to train a large English-German Transformer model on 2 nodes each with 8 GPUs (in total 16 GPUs), run the following command on each node, -replacing ``node_rank=0`` with ``node_rank=1`` on the second node: +replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making +sure to update ``--master_addr`` to the IP address of the first node: .. code-block:: console > python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \ - --master_port=1234 \ + --master_port=12345 \ $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \ --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ @@ -184,7 +185,15 @@ replacing ``node_rank=0`` with ``node_rank=1`` on the second node: --lr 0.0005 --min-lr 1e-09 \ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 3584 \ - --fp16 --distributed-no-spawn + --fp16 + +On SLURM clusters, fairseq will automatically detect the number of nodes and +GPUs, but a port number must be provided: + +.. code-block:: console + + > salloc --gpus=16 --nodes 2 (...) + > srun fairseq-train --distributed-port 12345 (...). Sharding very large datasets ---------------------------- diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index b5ffd1ca34..bb80d11a67 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -11,7 +11,7 @@ from fairseq.data import dictionary from fairseq.scoring import bleu -from . import ( +from examples.noisychannel import ( rerank_generate, rerank_options, rerank_score_bw, diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py index d512088de8..daeeae059a 100644 --- a/examples/noisychannel/rerank_generate.py +++ b/examples/noisychannel/rerank_generate.py @@ -15,7 +15,7 @@ from fairseq import options from fairseq_cli import generate, preprocess -from . import rerank_options, rerank_utils +from examples.noisychannel import rerank_options, rerank_utils def gen_and_reprocess_nbest(args): diff --git a/examples/noisychannel/rerank_score_bw.py b/examples/noisychannel/rerank_score_bw.py index 895673b1cc..b0bc913651 100644 --- a/examples/noisychannel/rerank_score_bw.py +++ b/examples/noisychannel/rerank_score_bw.py @@ -9,7 +9,7 @@ from fairseq import options from fairseq_cli import generate -from . import rerank_options, rerank_utils +from examples.noisychannel import rerank_options, rerank_utils def score_bw(args): diff --git a/examples/noisychannel/rerank_score_lm.py b/examples/noisychannel/rerank_score_lm.py index 89ebf61cce..e80948d78b 100644 --- a/examples/noisychannel/rerank_score_lm.py +++ b/examples/noisychannel/rerank_score_lm.py @@ -7,7 +7,7 @@ from fairseq import options -from . import rerank_options, rerank_utils +from examples.noisychannel import rerank_options, rerank_utils def score_lm(args): diff --git a/examples/noisychannel/rerank_tune.py b/examples/noisychannel/rerank_tune.py index 1be71744a3..b2e8b7594a 100644 --- a/examples/noisychannel/rerank_tune.py +++ b/examples/noisychannel/rerank_tune.py @@ -9,7 +9,7 @@ import numpy as np from fairseq import options -from . import rerank, rerank_options +from examples.noisychannel import rerank, rerank_options def random_search(args): diff --git a/examples/roberta/README.md b/examples/roberta/README.md index fdddd5b8d2..ca86131eea 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -276,7 +276,6 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples)) - [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md) - [Finetuning on Winograd Schema Challenge (WSC)](wsc/README.md) - [Finetuning on Commonsense QA (CQA)](commonsense_qa/README.md) -- Finetuning on SQuAD: coming soon ## Pretraining using your own data diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index abcb9c4c48..484d2526d7 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -400,7 +400,7 @@ class DatasetConfig(FairseqDataclass): batch_size_valid: Optional[int] = field( default=None, metadata={ - "help": "batch size of the validation batch" " (defaults to --batch-size)", + "help": "batch size of the validation batch (defaults to --batch-size)", "argparse_alias": "--max-sentences-valid", }, ) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 5c9f92a149..0f6efe5b33 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -393,6 +393,9 @@ class RobertaEncoder(FairseqEncoder): def __init__(self, args, dictionary): super().__init__(dictionary) + + # set any missing default values + base_architecture(args) self.args = args if args.encoder_layers_to_keep: @@ -417,7 +420,6 @@ def __init__(self, args, dictionary): q_noise=args.quant_noise_pq, qn_block_size=args.quant_noise_pq_block_size, ) - args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) self.lm_head = RobertaLMHead( embed_dim=args.encoder_embed_dim, @@ -495,6 +497,7 @@ def base_architecture(args): args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 8775aa7766..6f3c79de7c 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -144,7 +144,6 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): residual = x if self.normalize_before: x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) x = self.fc2(x) @@ -413,11 +412,3 @@ def forward( def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn - - -def Linear(in_features, out_features, bias=True): - m = nn.Linear(in_features, out_features, bias) - nn.init.xavier_uniform_(m.weight) - if bias: - nn.init.constant_(m.bias, 0.0) - return m diff --git a/fairseq/options.py b/fairseq/options.py index f2a3e7cfb1..b79443a177 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -249,11 +249,13 @@ def add_preprocess_args(parser): group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", help="target language") group.add_argument("--trainpref", metavar="FP", default=None, - help="train file prefix") + help="train file prefix (also used to build dictionaries)") group.add_argument("--validpref", metavar="FP", default=None, - help="comma separated, valid file prefixes") + help="comma separated, valid file prefixes " + "(words missing from train set are replaced with )") group.add_argument("--testpref", metavar="FP", default=None, - help="comma separated, test file prefixes") + help="comma separated, test file prefixes " + "(words missing from train set are replaced with )") group.add_argument("--align-suffix", metavar="FP", default=None, help="alignment file suffix") group.add_argument("--destdir", metavar="DIR", default="data-bin", diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 69dc996e6a..0ec3824d04 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -135,11 +135,11 @@ def setup_task(cls, args, **kwargs): def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" - def get_path(type, split): - return os.path.join(self.args.data, type, split) + def get_path(key, split): + return os.path.join(self.args.data, key, split) - def make_dataset(type, dictionary): - split_path = get_path(type, split) + def make_dataset(key, dictionary): + split_path = get_path(key, split) dataset = data_utils.load_indexed_dataset( split_path, @@ -151,7 +151,7 @@ def make_dataset(type, dictionary): input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( - get_path(type, split) + get_path("input0", split) ) input1 = make_dataset("input1", self.source_dictionary) diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index aa6a863d32..8de8e28ce0 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -5,6 +5,7 @@ import argparse import copy +import logging import unittest import torch @@ -46,6 +47,10 @@ def setUp(self): }, } ) + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) def run_iter(self, model, params, optimizer): optimizer.zero_grad() From 3c726544d240f610cd35ea264d893d6a6ada074a Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Wed, 28 Oct 2020 14:52:54 -0700 Subject: [PATCH 254/707] fix issue where is_initialized is not available in single-worker paradigm (#2801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/1205 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2801 Reviewed By: alexeib Differential Revision: D24579193 Pulled By: myleott fbshipit-source-id: bcb14bb588d4538398bff4114e0a387fd29818c5 --- fairseq/distributed_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 83b6d4d9d6..0d5804c8f7 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -204,7 +204,7 @@ def distributed_init(cfg: FairseqConfig): cfg = convert_namespace_to_omegaconf(cfg) if not cfg.common.tpu: - if torch.distributed.is_initialized(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!" ) From f6d9313092cf3bc5fa289123b6062b22e463a7da Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 28 Oct 2020 14:58:37 -0700 Subject: [PATCH 255/707] fix eval lm (#1380) Summary: fixes eval lm that wasnt parsing arguments correctly Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1380 Reviewed By: myleott Differential Revision: D24600415 Pulled By: alexeib fbshipit-source-id: eb56575bef4d20a3cd5cee3dcd279046f085d938 --- fairseq_cli/eval_lm.py | 8 ++------ fairseq_cli/validate.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index efc9c4b5aa..1197d6987b 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -14,7 +14,7 @@ from argparse import Namespace import torch -from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq.data import LMContextWindowDataset from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar @@ -277,11 +277,7 @@ def cli_main(): parser = options.get_eval_lm_parser() args = options.parse_args_and_arch(parser) - # only override args that are explicitly given on the command line - override_parser = options.get_validation_parser() - override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) - - distributed_utils.call_main(args, main, override_args=override_args) + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) if __name__ == "__main__": diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 7315b14e2a..a1e577ed7a 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -133,7 +133,7 @@ def cli_main(): override_parser = options.get_validation_parser() override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) - distributed_utils.call_main(args, main, override_args=override_args) + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args) if __name__ == "__main__": From 65b02d529a45f687da8bbc6ec37611b8a9c96297 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 28 Oct 2020 17:16:56 -0700 Subject: [PATCH 256/707] fix wav2vec infer and finetuning (#1384) Summary: Fixes #2807, #2810, #2519 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1384 Reviewed By: myleott Differential Revision: D24605451 Pulled By: alexeib fbshipit-source-id: 46ec8f273ac2fab86bd444461e2706c35608b250 --- examples/speech_recognition/w2l_decoder.py | 2 +- fairseq/models/wav2vec/wav2vec2_asr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 2a1d8a779d..f760cd6df2 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -34,7 +34,6 @@ SmearingMode, Trie, LexiconDecoder, - LexiconFreeDecoder, ) except: warnings.warn( @@ -404,6 +403,7 @@ def __init__(self, args, tgt_dict): self.unit_lm, ) else: + from wav2letter.decoder import LexiconFreeDecoder self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 52ca9a8007..1cbc6374fb 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -328,7 +328,7 @@ def __init__(self, args, tgt_dict=None): state = checkpoint_utils.load_checkpoint_to_cpu( args.w2v_path, arg_overrides ) - w2v_args = state["args"] + w2v_args = state.get("args", None) or state["cfg"].model else: state = None w2v_args = args.w2v_args From e4e01780f8a087f4a215199ddb83caca2dea16e7 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 28 Oct 2020 18:17:17 -0700 Subject: [PATCH 257/707] Fix dummy LM task (#1381) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1381 Reviewed By: alexeib Differential Revision: D24603479 Pulled By: myleott fbshipit-source-id: 5aae8da9c0f20d6526c98b0b37bf9b32a8c78393 --- fairseq/benchmark/dummy_lm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 6429d04de3..f3146b3581 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -28,6 +28,12 @@ def add_args(parser): help="max number of total tokens over all segments " "per sample for BERT dataset", ) + parser.add_argument("--add-bos-token", action="store_true", help="unused") + parser.add_argument( + "--max-target-positions", + default=None, + help="max number of tokens in the target sequence", + ) def __init__(self, args, dictionary): super().__init__(args) From 9c66ff54c4acd8fa3280a9a5ab6d5fe58d1a2cf3 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 28 Oct 2020 18:19:37 -0700 Subject: [PATCH 258/707] =?UTF-8?q?build=5Fgenerator()=20in=20generator.py?= =?UTF-8?q?=20should=20accept=20cfg.generation=20instea=E2=80=A6=20(#2813)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …d of cfg.task # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2813 Reviewed By: alexeib Differential Revision: D24604698 Pulled By: myleott fbshipit-source-id: e41996147203ec47274ded803bab910460a19eb3 --- fairseq_cli/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 79b9ed8bf0..021f819ed7 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -163,7 +163,7 @@ def _main(cfg: DictConfig, output_file): extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} generator = task.build_generator( - models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) # Handle tokenization and BPE From b7d8b9dce2dd5ca6a76e1c6f540945da20922478 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 28 Oct 2020 18:27:53 -0700 Subject: [PATCH 259/707] fix architecture params (#1382) Summary: fixes architectures not getting applied to migrated models Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1382 Reviewed By: myleott Differential Revision: D24603110 Pulled By: alexeib fbshipit-source-id: 18f44d3736853282466feed5e8896db95338b097 --- fairseq/models/fairseq_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 15c2c4ab2e..0c8d106be5 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -40,7 +40,8 @@ def add_args(cls, parser): """Add model-specific arguments to the parser.""" dc = getattr(cls, "__dataclass", None) if dc is not None: - gen_parser_from_dataclass(parser, dc()) + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass(parser, dc(), delete_default=True) @classmethod def build_model(cls, args, task): From 4cdc81f6f1b16cbb6e1016e3d06a6e4962edcec0 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 28 Oct 2020 18:34:13 -0700 Subject: [PATCH 260/707] Support activation checkpointing in Transformer (#1378) Summary: Without activation checkpointing (peak GPU memory usage: 7138MiB) ``` $ python train.py --task dummy_mt --arch transformer --dropout 0.1 --max-tokens 4096 --optimizer adam --lr 0.00001 --log-format simple --log-interval 25 --fp16 (...) 2020-10-28 08:03:03 | INFO | train_inner | epoch 001: 25 / 92 loss=12.67, ppl=6517.2, wps=281380, ups=8.61, wpb=32640, bsz=1088, num_updates=25, lr=1e-05, gnorm=8.541, clip=0, loss_scale=128, train_wall=5, wall=10 2020-10-28 08:03:05 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-10-28 08:03:06 | INFO | train_inner | epoch 001: 51 / 92 loss=8.938, ppl=490.52, wps=302975, ups=9.28, wpb=32640, bsz=1088, num_updates=50, lr=1e-05, gnorm=6.395, clip=0, loss_scale=64, train_wall=3, wall=12 2020-10-28 08:03:08 | INFO | train_inner | epoch 001: 76 / 92 loss=3.855, ppl=14.47, wps=316039, ups=9.68, wpb=32640, bsz=1088, num_updates=75, lr=1e-05, gnorm=9.078, clip=0, loss_scale=64, train_wall=3, wall=15 2020-10-28 08:03:10 | INFO | fairseq_cli.train | begin validation on "valid" subset 2020-10-28 08:03:17 | INFO | valid | epoch 001 | valid on 'valid' subset | loss 0.048 | ppl 1.03 | wps 1.09646e+06 | wpb 32640 | bsz 1088 | num_updates 91 ``` With activation checkpointing (peak GPU memory usage: 6466MiB) ``` $ python train.py --checkpoint-activations --task dummy_mt --arch transformer --dropout 0.1 --max-tokens 4096 --optimizer adam --lr 0.00001 --log-format simple --log-interval 25 --fp16 (...) 2020-10-28 08:01:50 | INFO | train_inner | epoch 001: 25 / 92 loss=12.67, ppl=6517.22, wps=291110, ups=8.91, wpb=32640, bsz=1088, num_updates=25, lr=1e-05, gnorm=8.541, clip=0, loss_scale=128, train_wall=4, wall=9 2020-10-28 08:01:51 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-10-28 08:01:52 | INFO | train_inner | epoch 001: 51 / 92 loss=8.938, ppl=490.54, wps=295438, ups=9.05, wpb=32640, bsz=1088, num_updates=50, lr=1e-05, gnorm=6.394, clip=0, loss_scale=64, train_wall=3, wall=12 2020-10-28 08:01:55 | INFO | train_inner | epoch 001: 76 / 92 loss=3.855, ppl=14.47, wps=308351, ups=9.45, wpb=32640, bsz=1088, num_updates=75, lr=1e-05, gnorm=9.082, clip=0, loss_scale=64, train_wall=3, wall=14 2020-10-28 08:01:57 | INFO | fairseq_cli.train | begin validation on "valid" subset ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1378 Reviewed By: min-xu-ai Differential Revision: D24593170 Pulled By: myleott fbshipit-source-id: 701254e603a2277d22f8b3bcc3ebbade54bb7479 --- fairseq/models/transformer.py | 15 +- fairseq/models/transformer_lm.py | 4 + fairseq/modules/checkpoint_activations.py | 205 ++++++++++++++++++++++ fairseq/utils.py | 53 ++++-- 4 files changed, 258 insertions(+), 19 deletions(-) create mode 100644 fairseq/modules/checkpoint_activations.py diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f87fa50d29..7614c33f74 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -27,6 +27,7 @@ TransformerDecoderLayer, TransformerEncoderLayer, ) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor @@ -151,6 +152,9 @@ def add_args(parser): help='add layernorm to embedding') parser.add_argument('--no-scale-embedding', action='store_true', help='if True, dont scale embeddings') + parser.add_argument('--checkpoint-activations', action='store_true', + help='checkpoint activations at each layer, which saves GPU ' + 'memory usage at the cost of some additional compute') # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) parser.add_argument('--no-cross-attention', default=False, action='store_true', help='do not perform cross-attention') @@ -362,7 +366,10 @@ def __init__(self, args, dictionary, embed_tokens): self.layer_norm = None def build_encoder_layer(self, args): - return TransformerEncoderLayer(args) + layer = TransformerEncoderLayer(args) + if getattr(args, "checkpoint_activations", False): + layer = checkpoint_wrapper(layer) + return layer def forward_embedding( self, src_tokens, token_embedding: Optional[torch.Tensor] = None @@ -649,7 +656,10 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ) def build_decoder_layer(self, args, no_encoder_attn=False): - return TransformerDecoderLayer(args, no_encoder_attn) + layer = TransformerDecoderLayer(args, no_encoder_attn) + if getattr(args, "checkpoint_activations", False): + layer = checkpoint_wrapper(layer) + return layer def forward( self, @@ -961,6 +971,7 @@ def base_architecture(args): args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index df809bdb19..9467b25efd 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -141,6 +141,9 @@ class TransformerLanguageModelConfig(FairseqDataclass): no_scale_embedding: bool = field( default=False, metadata={"help": "if True, dont scale embeddings"} ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -304,6 +307,7 @@ def base_lm_architecture(args): args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) @register_model_architecture("transformer_lm", "transformer_lm_big") diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py new file mode 100644 index 0000000000..a4341fe742 --- /dev/null +++ b/fairseq/modules/checkpoint_activations.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Union + +import torch + +from fairseq import utils + + +def checkpoint_wrapper(m): + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + """ + original_forward = m.forward + + def _checkpointed_forward(*args, **kwargs): + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict = {} + output = CheckpointFunction.apply( + original_forward, parent_ctx_dict, kwarg_keys, *flat_args + ) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + m.forward = _checkpointed_forward + return m + + +def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: + """ + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == [1, 2] + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return kwarg_keys, flat_args + + +def unpack_kwargs( + kwarg_keys: List[str], flat_args: List[Any] +) -> Tuple[List[Any], Dict[str, Any]]: + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any]] +) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: + """ + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + assert isinstance(mixed, tuple) + tensors = [] + packed_non_tensors = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors( + tensors: Tuple[torch.Tensor], + packed_non_tensors: Dict[str, List[Any]], +) -> Tuple[Any]: + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict) + mixed = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling ``unpack_non_tensors``. + """ + + @staticmethod + def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + torch.utils.checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = utils.get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + assert isinstance(outputs, tuple) + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + tensor_inputs = ctx.saved_tensors + tensor_inputs = torch.utils.checkpoint.detach_variable(tensor_inputs) + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = utils.get_rng_state() + + # Set the states to what it used to be before the forward pass. + utils.set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + + # Set the states back to what it was at the start of this function. + utils.set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "None of the outputs have requires_grad=True, " + "this checkpoint() is not necessary" + ) + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + return (None, None, None) + grads diff --git a/fairseq/utils.py b/fairseq/utils.py index 0044d76f3d..a0d8f89b6a 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -32,6 +32,11 @@ except ImportError: multi_tensor_l2norm_available = False +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + logger = logging.getLogger(__name__) @@ -535,23 +540,39 @@ def has_parameters(module): return False -def set_torch_seed(seed): - # Set seed based on args.seed and the update number so that we get - # reproducible results when resuming from checkpoints - assert isinstance(seed, int) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) +def get_rng_state(): + state = {"torch_rng_state": torch.get_rng_state()} + if xm is not None: + state["xla_rng_state"] = xm.get_rng_state() + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state -@contextlib.contextmanager -def with_torch_seed(seed): - assert isinstance(seed, int) - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() - set_torch_seed(seed) - yield - torch.set_rng_state(rng_state) - torch.cuda.set_rng_state(cuda_rng_state) +def set_rng_state(state): + torch.set_rng_state(state["torch_rng_state"]) + if xm is not None: + xm.set_rng_state(state["xla_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) def parse_alignment(line): @@ -618,8 +639,6 @@ def new_arange(x, *size): def get_tpu_device(args): - import torch_xla.core.xla_model as xm - return xm.xla_device() From 6debe29150204a3a98e61057cebf55e160ccb8b7 Mon Sep 17 00:00:00 2001 From: Anuroop Sriram Date: Thu, 29 Oct 2020 11:44:46 -0700 Subject: [PATCH 261/707] Compute WER for Wav2Vec 2.0 Seq2Seq models (#1376) Summary: # Before submitting - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? ## What does this PR do? Adds support to compute WER for wav2vec2.0 seq2seq models. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1376 Reviewed By: alexeib Differential Revision: D24611516 Pulled By: anuroopsriram fbshipit-source-id: dd7daab73ebccc21367dd51f41a11e89c404977b --- fairseq/tasks/audio_pretraining.py | 102 ++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 298bdbe938..90eb7ca2d6 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,12 +5,17 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import editdistance import os import sys +import torch -from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset +from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders +from fairseq.data.data_utils import post_process from . import LegacyFairseqTask, register_task +from .. import utils +from ..logging import metrics class LabelEncoder(object): @@ -68,11 +73,26 @@ def add_args(parser): help="extension of the label file to load, if any", ) + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + parser.add_argument( + "--eval-wer", + action="store_true", + help="compute WER for Seq2Seq models", + ) + parser.add_argument( + "--eval-wer-remove-bpe", + default="letter", + help="remove BPE tokens before scoring (can be sentencepiece, letter, and more)", + ) + def __init__(self, args, source_dictionary=None, target_dictionary=None): super().__init__(args) self._target_dictionary = target_dictionary self._source_dictionary = source_dictionary self.is_ctc = args.criterion == "ctc" + if getattr(self.args, "eval_wer", False): + assert args.labels is not None, "eval_wer can only be set during fine-tuning" @classmethod def setup_task(cls, args, **kwargs): @@ -149,3 +169,83 @@ def filter_indices_by_size( ): # we do not need to filter by size in this task as dataloaders take care of this return indices + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if getattr(self.args, "eval_wer", False) and not self.is_ctc: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + return loss, sample_size, logging_output + + def build_model(self, args): + model = super().build_model(args) + + if getattr(args, 'eval_wer', False) and not self.is_ctc: + self.sequence_generator = self.build_generator([model], args, ) + self.tokenizer = encoders.build_tokenizer(args) + return model + + def _inference_with_wer(self, generator, sample, model): + def decode(toks, escape_unk=True): + s = self.target_dictionary.string( + toks.int().cpu(), + self.args.eval_wer_remove_bpe, + escape_unk=escape_unk, + extra_symbols_to_ignore={generator.eos}, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + escape_unk=True, + ) + hyp = post_process(hyp, self.args.eval_wer_remove_bpe).strip("_") + ref = post_process(ref, self.args.eval_wer_remove_bpe).strip("_") + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split("_") + ref_words = ref.split("_") + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.) + num_char_errors = sum(log.get("_num_char_errors", zero) for log in logging_outputs) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum(log.get("_num_word_errors", zero) for log in logging_outputs) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_words > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum * 100.0 / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 else float("nan") + ) + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum * 100.0 / meters["_num_words"].sum + if meters["_num_words"].sum > 0 else float("nan") + ) From a4356b1da2b19ebd2e1be5c12ff882026ea4d7d2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 29 Oct 2020 17:07:12 -0700 Subject: [PATCH 262/707] Simplify --user-dir and require user-dir module name to be globally unique (#2815) Summary: This PR reverts recent changes that attempted to make `--user-dir` work with non-unique module names. But that new approach introduced other issues (e.g., poor compatibility with multiprocessing and Windows), so let's revert to the previous simpler implementation. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2815 Reviewed By: alexeib Differential Revision: D24611571 Pulled By: myleott fbshipit-source-id: cecfe28395585ca0401f844f10bd0d49d014c4d8 --- examples/latent_depth/README.md | 2 +- .../{src => latent_depth_src}/__init__.py | 0 .../loss/__init__.py | 0 .../loss/latent_depth.py | 0 .../models/__init__.py | 0 .../models/latent_multilingual_transformer.py | 0 .../models/latent_transformer.py | 0 .../modules/__init__.py | 0 .../modules/latent_layers.py | 0 .../multilingual_translation_latent_depth.py | 0 examples/linformer/README.md | 2 +- .../{src => linformer_src}/__init__.py | 0 .../{src => linformer_src}/models/__init__.py | 0 .../models/linformer_roberta.py | 0 .../modules/__init__.py | 0 .../modules/linformer_sentence_encoder.py | 0 .../linformer_sentence_encoder_layer.py | 0 .../modules/multihead_linear_attention.py | 0 examples/pointer_generator/README.xsum.md | 4 +-- .../__init__.py | 0 .../transformer_pg.py | 0 examples/rxf/README.md | 2 +- examples/rxf/__init__.py | 2 +- examples/rxf/{src => rxf_src}/__init__.py | 0 .../label_smoothed_cross_entropy_r3f.py | 0 .../sentence_prediction_r3f.py | 0 examples/translation_moe/README.md | 6 ++--- .../{src => translation_moe_src}/__init__.py | 0 .../logsumexp_moe.py | 0 .../mean_pool_gating_network.py | 0 .../translation_moe.py | 0 fairseq/utils.py | 27 ++++++++++--------- tests/test_binaries.py | 26 +++++++++--------- 33 files changed, 37 insertions(+), 34 deletions(-) rename examples/latent_depth/{src => latent_depth_src}/__init__.py (100%) rename examples/latent_depth/{src => latent_depth_src}/loss/__init__.py (100%) rename examples/latent_depth/{src => latent_depth_src}/loss/latent_depth.py (100%) rename examples/latent_depth/{src => latent_depth_src}/models/__init__.py (100%) rename examples/latent_depth/{src => latent_depth_src}/models/latent_multilingual_transformer.py (100%) rename examples/latent_depth/{src => latent_depth_src}/models/latent_transformer.py (100%) rename examples/latent_depth/{src => latent_depth_src}/modules/__init__.py (100%) rename examples/latent_depth/{src => latent_depth_src}/modules/latent_layers.py (100%) rename examples/latent_depth/{src => latent_depth_src}/multilingual_translation_latent_depth.py (100%) rename examples/linformer/{src => linformer_src}/__init__.py (100%) rename examples/linformer/{src => linformer_src}/models/__init__.py (100%) rename examples/linformer/{src => linformer_src}/models/linformer_roberta.py (100%) rename examples/linformer/{src => linformer_src}/modules/__init__.py (100%) rename examples/linformer/{src => linformer_src}/modules/linformer_sentence_encoder.py (100%) rename examples/linformer/{src => linformer_src}/modules/linformer_sentence_encoder_layer.py (100%) rename examples/linformer/{src => linformer_src}/modules/multihead_linear_attention.py (100%) rename examples/pointer_generator/{src => pointer_generator_src}/__init__.py (100%) rename examples/pointer_generator/{src => pointer_generator_src}/transformer_pg.py (100%) rename examples/rxf/{src => rxf_src}/__init__.py (100%) rename examples/rxf/{src => rxf_src}/label_smoothed_cross_entropy_r3f.py (100%) rename examples/rxf/{src => rxf_src}/sentence_prediction_r3f.py (100%) rename examples/translation_moe/{src => translation_moe_src}/__init__.py (100%) rename examples/translation_moe/{src => translation_moe_src}/logsumexp_moe.py (100%) rename examples/translation_moe/{src => translation_moe_src}/mean_pool_gating_network.py (100%) rename examples/translation_moe/{src => translation_moe_src}/translation_moe.py (100%) diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md index bc78ca8055..a0ec55a3f6 100644 --- a/examples/latent_depth/README.md +++ b/examples/latent_depth/README.md @@ -14,7 +14,7 @@ lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" databin_dir= fairseq-train ${databin_dir} \ - --user-dir, examples/latent_depth/src \ + --user-dir examples/latent_depth/latent_depth_src \ --lang-pairs "${lang_pairs_str}" \ --arch multilingual_transformer_iwslt_de_en \ --task multilingual_translation_latent_depth \ diff --git a/examples/latent_depth/src/__init__.py b/examples/latent_depth/latent_depth_src/__init__.py similarity index 100% rename from examples/latent_depth/src/__init__.py rename to examples/latent_depth/latent_depth_src/__init__.py diff --git a/examples/latent_depth/src/loss/__init__.py b/examples/latent_depth/latent_depth_src/loss/__init__.py similarity index 100% rename from examples/latent_depth/src/loss/__init__.py rename to examples/latent_depth/latent_depth_src/loss/__init__.py diff --git a/examples/latent_depth/src/loss/latent_depth.py b/examples/latent_depth/latent_depth_src/loss/latent_depth.py similarity index 100% rename from examples/latent_depth/src/loss/latent_depth.py rename to examples/latent_depth/latent_depth_src/loss/latent_depth.py diff --git a/examples/latent_depth/src/models/__init__.py b/examples/latent_depth/latent_depth_src/models/__init__.py similarity index 100% rename from examples/latent_depth/src/models/__init__.py rename to examples/latent_depth/latent_depth_src/models/__init__.py diff --git a/examples/latent_depth/src/models/latent_multilingual_transformer.py b/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py similarity index 100% rename from examples/latent_depth/src/models/latent_multilingual_transformer.py rename to examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py diff --git a/examples/latent_depth/src/models/latent_transformer.py b/examples/latent_depth/latent_depth_src/models/latent_transformer.py similarity index 100% rename from examples/latent_depth/src/models/latent_transformer.py rename to examples/latent_depth/latent_depth_src/models/latent_transformer.py diff --git a/examples/latent_depth/src/modules/__init__.py b/examples/latent_depth/latent_depth_src/modules/__init__.py similarity index 100% rename from examples/latent_depth/src/modules/__init__.py rename to examples/latent_depth/latent_depth_src/modules/__init__.py diff --git a/examples/latent_depth/src/modules/latent_layers.py b/examples/latent_depth/latent_depth_src/modules/latent_layers.py similarity index 100% rename from examples/latent_depth/src/modules/latent_layers.py rename to examples/latent_depth/latent_depth_src/modules/latent_layers.py diff --git a/examples/latent_depth/src/multilingual_translation_latent_depth.py b/examples/latent_depth/latent_depth_src/multilingual_translation_latent_depth.py similarity index 100% rename from examples/latent_depth/src/multilingual_translation_latent_depth.py rename to examples/latent_depth/latent_depth_src/multilingual_translation_latent_depth.py diff --git a/examples/linformer/README.md b/examples/linformer/README.md index cedd667835..f8b36bc691 100644 --- a/examples/linformer/README.md +++ b/examples/linformer/README.md @@ -6,7 +6,7 @@ This example contains code to train Linformer models as described in our paper ## Training a new Linformer RoBERTa model You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md), -updating your training command with `--user-dir examples/linformer/src --arch linformer_roberta_base`. +updating your training command with `--user-dir examples/linformer/linformer_src --arch linformer_roberta_base`. ## Citation diff --git a/examples/linformer/src/__init__.py b/examples/linformer/linformer_src/__init__.py similarity index 100% rename from examples/linformer/src/__init__.py rename to examples/linformer/linformer_src/__init__.py diff --git a/examples/linformer/src/models/__init__.py b/examples/linformer/linformer_src/models/__init__.py similarity index 100% rename from examples/linformer/src/models/__init__.py rename to examples/linformer/linformer_src/models/__init__.py diff --git a/examples/linformer/src/models/linformer_roberta.py b/examples/linformer/linformer_src/models/linformer_roberta.py similarity index 100% rename from examples/linformer/src/models/linformer_roberta.py rename to examples/linformer/linformer_src/models/linformer_roberta.py diff --git a/examples/linformer/src/modules/__init__.py b/examples/linformer/linformer_src/modules/__init__.py similarity index 100% rename from examples/linformer/src/modules/__init__.py rename to examples/linformer/linformer_src/modules/__init__.py diff --git a/examples/linformer/src/modules/linformer_sentence_encoder.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py similarity index 100% rename from examples/linformer/src/modules/linformer_sentence_encoder.py rename to examples/linformer/linformer_src/modules/linformer_sentence_encoder.py diff --git a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py similarity index 100% rename from examples/linformer/src/modules/linformer_sentence_encoder_layer.py rename to examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py diff --git a/examples/linformer/src/modules/multihead_linear_attention.py b/examples/linformer/linformer_src/modules/multihead_linear_attention.py similarity index 100% rename from examples/linformer/src/modules/multihead_linear_attention.py rename to examples/linformer/linformer_src/modules/multihead_linear_attention.py diff --git a/examples/pointer_generator/README.xsum.md b/examples/pointer_generator/README.xsum.md index ab288afc0c..ac3a8c3ddc 100644 --- a/examples/pointer_generator/README.xsum.md +++ b/examples/pointer_generator/README.xsum.md @@ -77,7 +77,7 @@ update_freq=4 pointer_layer=-2 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \ - --user-dir examples/pointer_generator/src \ + --user-dir examples/pointer_generator/pointer_generator_src \ --max-tokens "$max_tokens" \ --task translation \ --source-lang src --target-lang tgt \ @@ -125,7 +125,7 @@ max_length=60 length_penalty=1.0 fairseq-interactive bin \ - --user-dir examples/pointer_generator/src \ + --user-dir examples/pointer_generator/pointer_generator_src \ --batch-size "$batch_size" \ --task translation \ --source-lang src --target-lang tgt \ diff --git a/examples/pointer_generator/src/__init__.py b/examples/pointer_generator/pointer_generator_src/__init__.py similarity index 100% rename from examples/pointer_generator/src/__init__.py rename to examples/pointer_generator/pointer_generator_src/__init__.py diff --git a/examples/pointer_generator/src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py similarity index 100% rename from examples/pointer_generator/src/transformer_pg.py rename to examples/pointer_generator/pointer_generator_src/transformer_pg.py diff --git a/examples/rxf/README.md b/examples/rxf/README.md index a09de63d33..22a1cc47df 100644 --- a/examples/rxf/README.md +++ b/examples/rxf/README.md @@ -38,7 +38,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin \ --find-unused-parameters \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ --noise-type uniform --r3f-lambda 0.7 \ - --user-dir examples/rxf; + --user-dir examples/rxf/rxf_src ``` ## Citation diff --git a/examples/rxf/__init__.py b/examples/rxf/__init__.py index 63453f9333..b24cb6b797 100644 --- a/examples/rxf/__init__.py +++ b/examples/rxf/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import src # noqa +from . import rxf_src # noqa diff --git a/examples/rxf/src/__init__.py b/examples/rxf/rxf_src/__init__.py similarity index 100% rename from examples/rxf/src/__init__.py rename to examples/rxf/rxf_src/__init__.py diff --git a/examples/rxf/src/label_smoothed_cross_entropy_r3f.py b/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py similarity index 100% rename from examples/rxf/src/label_smoothed_cross_entropy_r3f.py rename to examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py diff --git a/examples/rxf/src/sentence_prediction_r3f.py b/examples/rxf/rxf_src/sentence_prediction_r3f.py similarity index 100% rename from examples/rxf/src/sentence_prediction_r3f.py rename to examples/rxf/rxf_src/sentence_prediction_r3f.py diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md index 33f1bee5cb..ef7abdb44b 100644 --- a/examples/translation_moe/README.md +++ b/examples/translation_moe/README.md @@ -18,7 +18,7 @@ The following command will train a `hMoElp` model with `3` experts: fairseq-train --ddp-backend='no_c10d' \ data-bin/wmt17_en_de \ --max-update 100000 \ - --task translation_moe --user-dir examples/translation_moe/src \ + --task translation_moe --user-dir examples/translation_moe/translation_moe_src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --arch transformer_wmt_en_de --share-all-embeddings \ @@ -37,7 +37,7 @@ For example, to generate from expert 0: fairseq-generate data-bin/wmt17_en_de \ --path checkpoints/checkpoint_best.pt \ --beam 1 --remove-bpe \ - --task translation_moe --user-dir examples/translation_moe/src \ + --task translation_moe --user-dir examples/translation_moe/translation_moe_src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --gen-expert 0 @@ -61,7 +61,7 @@ for EXPERT in $(seq 0 2); do \ --beam 1 \ --bpe subword_nmt --bpe-codes $BPE_CODE \ --buffer-size 500 --max-tokens 6000 \ - --task translation_moe --user-dir examples/translation_moe/src \ + --task translation_moe --user-dir examples/translation_moe/translation_moe_src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --gen-expert $EXPERT ; \ diff --git a/examples/translation_moe/src/__init__.py b/examples/translation_moe/translation_moe_src/__init__.py similarity index 100% rename from examples/translation_moe/src/__init__.py rename to examples/translation_moe/translation_moe_src/__init__.py diff --git a/examples/translation_moe/src/logsumexp_moe.py b/examples/translation_moe/translation_moe_src/logsumexp_moe.py similarity index 100% rename from examples/translation_moe/src/logsumexp_moe.py rename to examples/translation_moe/translation_moe_src/logsumexp_moe.py diff --git a/examples/translation_moe/src/mean_pool_gating_network.py b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py similarity index 100% rename from examples/translation_moe/src/mean_pool_gating_network.py rename to examples/translation_moe/translation_moe_src/mean_pool_gating_network.py diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/translation_moe_src/translation_moe.py similarity index 100% rename from examples/translation_moe/src/translation_moe.py rename to examples/translation_moe/translation_moe_src/translation_moe.py diff --git a/fairseq/utils.py b/fairseq/utils.py index a0d8f89b6a..87c124b736 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -450,18 +450,21 @@ def import_user_module(args): else: raise FileNotFoundError(module_path) - # We want to import the module under a unique name so that it doesn't - # collide with existing modules. At the same time we don't want to - # import the module multiple times. The solution is to create a - # temporary directory and symlink the user_dir under a new name, which is - # a deterministic hash of the original module_path. - with tempfile.TemporaryDirectory() as tmpdirname: - unique_mod_name = "fairseq_user_dir_{}".format(hash(module_path) % 100000) - os.symlink(module_path, os.path.join(tmpdirname, unique_mod_name)) - - sys.path.insert(0, tmpdirname) - importlib.import_module(unique_mod_name) - sys.path.remove(tmpdirname) + # ensure that user modules are only imported once + import_user_module.memo = getattr(import_user_module, "memo", set()) + if module_path not in import_user_module.memo: + import_user_module.memo.add(module_path) + + module_parent, module_name = os.path.split(module_path) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + else: + raise ImportError( + "Failed to import --user-dir={} because the corresponding module name " + "({}) is not globally unique. Please rename the directory to " + "something unique and try again.".format(module_path, module_name) + ) def softmax(x, dim: int, onnx_trace: bool = False): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index c6722402a1..ca18adea04 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -320,7 +320,7 @@ def test_multilingual_translation_latent_depth(self): task="multilingual_translation_latent_depth", extra_flags=[ "--user-dir", - "examples/latent_depth/src", + "examples/latent_depth/latent_depth_src", "--encoder-layers", "2", "--decoder-layers", @@ -340,7 +340,7 @@ def test_multilingual_translation_latent_depth(self): run_validation=True, extra_valid_flags=[ "--user-dir", - "examples/latent_depth/src", + "examples/latent_depth/latent_depth_src", ] + enc_ll_flag + dec_ll_flag, @@ -349,7 +349,7 @@ def test_multilingual_translation_latent_depth(self): data_dir, extra_flags=[ "--user-dir", - "examples/latent_depth/src", + "examples/latent_depth/latent_depth_src", "--task", "multilingual_translation_latent_depth", "--lang-pairs", @@ -465,7 +465,7 @@ def test_transformer_pointer_generator(self): "transformer_pointer_generator", extra_flags=[ "--user-dir", - "examples/pointer_generator/src", + "examples/pointer_generator/pointer_generator_src", "--encoder-layers", "2", "--decoder-layers", @@ -482,11 +482,11 @@ def test_transformer_pointer_generator(self): "0", ], run_validation=True, - extra_valid_flags=["--user-dir", "examples/pointer_generator/src"], + extra_valid_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], ) generate_main( data_dir, - extra_flags=["--user-dir", "examples/pointer_generator/src"], + extra_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], ) def test_lightconv(self): @@ -700,7 +700,7 @@ def test_mixture_of_experts(self): "--task", "translation_moe", "--user-dir", - "examples/translation_moe/src", + "examples/translation_moe/translation_moe_src", "--method", "hMoElp", "--mean-pool-gating-network", @@ -722,7 +722,7 @@ def test_mixture_of_experts(self): "--task", "translation_moe", "--user-dir", - "examples/translation_moe/src", + "examples/translation_moe/translation_moe_src", "--method", "hMoElp", "--mean-pool-gating-network", @@ -1058,7 +1058,7 @@ def test_linformer_roberta_masked_lm(self): "linformer_roberta_base", extra_flags=[ "--user-dir", - "examples/linformer/src", + "examples/linformer/linformer_src", "--encoder-layers", "2", ], @@ -1075,7 +1075,7 @@ def test_linformer_roberta_sentence_prediction(self): data_dir, "linformer_roberta_base", num_classes=num_classes, - extra_flags=["--user-dir", "examples/linformer/src"], + extra_flags=["--user-dir", "examples/linformer/linformer_src"], ) def test_linformer_roberta_regression_single(self): @@ -1095,7 +1095,7 @@ def test_linformer_roberta_regression_single(self): extra_flags=[ "--regression-target", "--user-dir", - "examples/linformer/src", + "examples/linformer/linformer_src", ], ) @@ -1116,7 +1116,7 @@ def test_linformer_roberta_regression_multiple(self): extra_flags=[ "--regression-target", "--user-dir", - "examples/linformer/src", + "examples/linformer/linformer_src", ], ) @@ -1198,7 +1198,7 @@ def test_r4f_roberta(self): num_classes=num_classes, extra_flags=[ "--user-dir", - "examples/rxf/src", + "examples/rxf/rxf_src", "--criterion", "sentence_prediction_r3f", "--spectral-norm-classification-head", From de859692ff39cff1ecfd65e8e6860c621fb0e58a Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 30 Oct 2020 18:23:14 -0700 Subject: [PATCH 263/707] Enable translation_multi_simple_epoch to have different source and target dictionaries Summary: In past, we always use shared dictionary for multilingual experiments. This diff renables different dictionaries for source and target languages by changing the assertion criteria and reverts back to use specific languages to return source_dict and target_dict. Reviewed By: chtran Differential Revision: D24637682 fbshipit-source-id: a982e4f1e48395cc5bf10dc03b98fbe970062f8d --- .../tasks/translation_multi_simple_epoch.py | 32 +++++++---- tests/test_binaries.py | 54 +++++++++++++++++++ 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 95a2d162c0..d871502a2c 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -96,25 +96,35 @@ def __init__(self, args, langs, dicts, training): # models.build_model(). This allows multitask type of sub-class can # build models other than the input lang_pairs self.model_lang_pairs = self.lang_pairs + self.source_langs = [d.split("-")[0] for d in self.lang_pairs] + self.target_langs = [d.split("-")[1] for d in self.lang_pairs] + self.check_dicts(self.dicts, self.source_langs, self.target_langs) + self.sampling_method = SamplingMethod.build_sampler(args, self) self.data_manager = MultilingualDatasetManager.setup_data_manager( args, self.lang_pairs, langs, dicts, self.sampling_method ) + @classmethod + def check_dicts(cls, dicts, source_langs, target_langs): + src_dict = dicts[source_langs[0]] + tgt_dict = dicts[target_langs[0]] + for src_lang in source_langs: + assert ( + src_dict == dicts[src_lang] + ), "Diffrent dictionary are specified for different source languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" + for tgt_lang in target_langs: + assert ( + tgt_dict == dicts[tgt_lang] + ), "Diffrent dictionary are specified for different target languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" + @classmethod def setup_task(cls, args, **kwargs): langs, dicts, training = MultilingualDatasetManager.prepare( cls.load_dictionary, args, **kwargs ) - dict0 = None - for _, lang_dict in dicts.items(): - if dict0 is None: - dict0 = lang_dict - else: - assert ( - dict0 == lang_dict - ), "Diffrent dictionary are specified for different languages; " - "TranslationMultiSimpleEpochTask only supports one shared dictionary across all languages" return cls(args, langs, dicts, training) def has_sharded_data(self, split): @@ -249,11 +259,11 @@ def max_positions(self): @property def source_dictionary(self): - return next(iter(self.dicts.values())) + return self.dicts[self.source_langs[0]] @property def target_dictionary(self): - return next(iter(self.dicts.values())) + return self.dicts[self.target_langs[0]] def create_batch_sampler_func( self, diff --git a/tests/test_binaries.py b/tests/test_binaries.py index ca18adea04..dae38dda0c 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -425,6 +425,60 @@ def test_translation_multi_simple_epoch(self): + dec_ltok_flag, ) + def test_translation_multi_simple_epoch_dicts(self): + # test with all combinations of encoder/decoder lang tokens + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, extra_flags=[] + ) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + def test_transformer_cross_self_attention(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( From de977736f91d23c53e6a60c45822973a615daa15 Mon Sep 17 00:00:00 2001 From: Shashank Jain Date: Mon, 2 Nov 2020 17:16:03 -0800 Subject: [PATCH 264/707] Support running batch of sentences together on GPU with BART fill_mask (#2833) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2833 Add support for filling masks using BART on a batch of sentences. This will be helpful when running on GPU Reviewed By: myleott Differential Revision: D24687773 fbshipit-source-id: 1b8005c18a09be526f40e9e2b99207afa38e0f1a --- examples/bart/README.md | 18 +++++++++---- fairseq/models/bart/hub_interface.py | 38 ++++++++++++++-------------- fairseq/models/bart/model.py | 2 ++ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/examples/bart/README.md b/examples/bart/README.md index 76857a99a2..e891894a84 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -100,7 +100,7 @@ bart.predict('mnli', tokens).argmax() # 2: entailment ##### Register a new (randomly initialized) classification head: ```python bart.register_classification_head('new_task', num_classes=3) -logprobs = bart.predict('new_task', tokens) +logprobs = bart.predict('new_task', tokens) ``` ##### Batched prediction: @@ -137,15 +137,23 @@ BART can be used to fill multiple `` tokens in the input. ```python bart = torch.hub.load('pytorch/fairseq', 'bart.base') bart.eval() -bart.fill_mask('The cat on the .', topk=3, beam=10) -# [('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))] +bart.fill_mask(['The cat on the .'], topk=3, beam=10) +# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]] ``` Note that by default we enforce the output length to match the input length. This can be disabled by setting ``match_source_len=False``: ``` -bart.fill_mask('The cat on the .', topk=3, beam=10, match_source_len=False) -# [('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))] +bart.fill_mask(['The cat on the .'], topk=3, beam=10, match_source_len=False) +# [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]] +``` + +Example code to fill masks for a batch of sentences using GPU +``` +bart.cuda() +bart.fill_mask(['The cat on the .', 'The dog on the .'], topk=3, beam=10) +# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)), +('The dog was asleep on the couch', tensor(-0.6796))]] ``` #### Evaluating the `bart.large.mnli` model: diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 4c5fd0b585..1ff170a782 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -165,27 +165,27 @@ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = Fal def fill_mask( self, - masked_input: str, + masked_inputs: List[str], topk: int = 5, match_source_len: bool = True, **generate_kwargs ): masked_token = '' - assert masked_token in masked_input, \ - "please add one {} token for the input".format(masked_token) - - text_spans = masked_input.split(masked_token) - text_spans_bpe = (' {0} '.format(masked_token)).join( - [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] - ).strip() - tokens = self.task.source_dictionary.encode_line( - ' ' + text_spans_bpe + ' ', - append_eos=False, - add_if_not_exist=False, - ).long() - - if tokens.dim() == 1: - tokens = tokens.unsqueeze(0) + batch_tokens = [] + for masked_input in masked_inputs: + assert masked_token in masked_input, \ + "please add one {} token for the input".format(masked_token) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = (' {0} '.format(masked_token)).join( + [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] + ).strip() + tokens = self.task.source_dictionary.encode_line( + ' ' + text_spans_bpe + ' ', + append_eos=False, + add_if_not_exist=False, + ).long() + batch_tokens.append(tokens) # ensure beam size is at least as big as topk generate_kwargs['beam'] = max( @@ -193,9 +193,9 @@ def fill_mask( generate_kwargs.get('beam', -1), ) generate_kwargs['match_source_len'] = match_source_len - hypos = self.generate(tokens, **generate_kwargs)[0] + batch_hypos = self.generate(batch_tokens, **generate_kwargs) return [ - (self.decode(hypo['tokens']), hypo['score']) - for hypo in hypos[:topk] + [(self.decode(hypo['tokens']), hypo['score']) for hypo in hypos[:topk]] + for hypos in batch_hypos ] diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 7263a78dc2..e105d6fc46 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -108,6 +108,7 @@ def from_pretrained( checkpoint_file="model.pt", data_name_or_path=".", bpe="gpt2", + sample_break_mode="eos", **kwargs, ): from fairseq import hub_utils @@ -119,6 +120,7 @@ def from_pretrained( archive_map=cls.hub_models(), bpe=bpe, load_checkpoint_heads=True, + sample_break_mode=sample_break_mode, **kwargs, ) return BARTHubInterface(x["args"], x["task"], x["models"][0]) From b120fbbe8fdb6fc8412149916fe09c54757bdaf6 Mon Sep 17 00:00:00 2001 From: Joshua Meier Date: Tue, 3 Nov 2020 14:05:30 -0800 Subject: [PATCH 265/707] Fix correctness issue with megatron save/load checkpoints (#1386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2681. Proof that it's working now: ``` python fairseq_train.py --task masked_lm /checkpoint/bioseq_nonsecure/model-parallel-data/tiny_sample_valid_ur50-bin --dataset-impl fasta --save-dir checkpoints/mp-fix4 --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 128 --sample-break-mode none --max-tokens 128 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --encoder-layers 4 --arch model_parallel_roberta_large --model-parallel-size 2 --update-freq 2 --save-interval-updates 10 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 11 / 78 loss=0.939, ppl=1.92, wps=116.7, ups=0.11, wpb=1024, bsz=8, num_updates=11, lr=1.47473e-06, gnorm=2.276, train_wall=0, wall=15 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 12 / 78 loss=0.938, ppl=1.92, wps=15769.2, ups=15.38, wpb=1024, bsz=8, num_updates=12, lr=1.5997e-06, gnorm=2.612, train_wall=0, wall=15 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 13 / 78 loss=0.877, ppl=1.84, wps=18658.8, ups=18.2, wpb=1024, bsz=8, num_updates=13, lr=1.72468e-06, gnorm=2.798, train_wall=0, wall=15 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 14 / 78 loss=0.887, ppl=1.85, wps=18324.5, ups=17.88, wpb=1024, bsz=8, num_updates=14, lr=1.84965e-06, gnorm=2.326, train_wall=0, wall=15 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 15 / 78 loss=0.867, ppl=1.82, wps=17616.5, ups=17.19, wpb=1024, bsz=8, num_updates=15, lr=1.97463e-06, gnorm=2.112, train_wall=0, wall=15 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 16 / 78 loss=0.891, ppl=1.85, wps=18624.5, ups=18.17, wpb=1024, bsz=8, num_updates=16, lr=2.0996e-06, gnorm=2.123, train_wall=0, wall=16 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 17 / 78 loss=0.887, ppl=1.85, wps=17972.5, ups=17.53, wpb=1024, bsz=8, num_updates=17, lr=2.22458e-06, gnorm=2.061, train_wall=0, wall=16 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 18 / 78 loss=0.862, ppl=1.82, wps=14672.4, ups=14.32, wpb=1024, bsz=8, num_updates=18, lr=2.34955e-06, gnorm=2.282, train_wall=0, wall=16 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 19 / 78 loss=0.876, ppl=1.83, wps=14398.6, ups=14.05, wpb=1024, bsz=8, num_updates=19, lr=2.47453e-06, gnorm=2.261, train_wall=0, wall=16 2020-10-29 18:42:08 | INFO | train_inner | epoch 001: 20 / 78 loss=0.818, ppl=1.76, wps=18652.2, ups=18.2, wpb=1024, bsz=8, num_updates=20, lr=2.5995e-06, gnorm=1.969, train_wall=0, wall=16 ...relaunch... 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 11 / 78 loss=0.939, ppl=1.92, wps=98.2, ups=0.1, wpb=1024, bsz=8, num_updates=11, lr=1.47473e-06, gnorm=2.276, train_wall=1, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 12 / 78 loss=0.938, ppl=1.92, wps=17137.8, ups=16.72, wpb=1024, bsz=8, num_updates=12, lr=1.5997e-06, gnorm=2.612, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 13 / 78 loss=0.877, ppl=1.84, wps=17239.6, ups=16.82, wpb=1024, bsz=8, num_updates=13, lr=1.72468e-06, gnorm=2.798, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 14 / 78 loss=0.887, ppl=1.85, wps=18132, ups=17.69, wpb=1024, bsz=8, num_updates=14, lr=1.84965e-06, gnorm=2.326, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 15 / 78 loss=0.867, ppl=1.82, wps=17795.1, ups=17.36, wpb=1024, bsz=8, num_updates=15, lr=1.97463e-06, gnorm=2.112, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 16 / 78 loss=0.891, ppl=1.85, wps=18021.3, ups=17.58, wpb=1024, bsz=8, num_updates=16, lr=2.0996e-06, gnorm=2.123, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 17 / 78 loss=0.887, ppl=1.85, wps=16452.9, ups=16.05, wpb=1024, bsz=8, num_updates=17, lr=2.22458e-06, gnorm=2.061, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 18 / 78 loss=0.862, ppl=1.82, wps=17563.3, ups=17.14, wpb=1024, bsz=8, num_updates=18, lr=2.34955e-06, gnorm=2.282, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 19 / 78 loss=0.876, ppl=1.83, wps=16770.3, ups=16.36, wpb=1024, bsz=8, num_updates=19, lr=2.47453e-06, gnorm=2.261, train_wall=0, wall=0 2020-10-29 18:47:20 | INFO | train_inner | epoch 001: 20 / 78 loss=0.818, ppl=1.76, wps=16808.2, ups=16.4, wpb=1024, bsz=8, num_updates=20, lr=2.5995e-06, gnorm=1.969, train_wall=0, wall=0 ``` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1386 Reviewed By: myleott Differential Revision: D24640946 Pulled By: joshim5 fbshipit-source-id: cb141d92496b289a04d53f080ecd4d5ac6941672 --- fairseq/model_parallel/megatron_trainer.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index cf83685862..a8c7bc9d98 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -11,7 +11,6 @@ from fairseq.trainer import Trainer from fairseq.dataclass.configs import FairseqConfig - try: from fairseq.model_parallel.megatron.mpu import ( get_data_parallel_group, @@ -19,6 +18,7 @@ get_data_parallel_world_size, get_model_parallel_group, get_model_parallel_src_rank, + get_cuda_rng_tracker, ) has_megatron_submodule = True @@ -65,3 +65,23 @@ def _aggregate_model_parallel_grad_norm(total_norm): clip_norm, aggregate_norm_fn=_aggregate_model_parallel_grad_norm, ) + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + extra_state['rng_tracker_states'] \ + = get_cuda_rng_tracker().get_states() + super().save_checkpoint(filename, extra_state) + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) + if extra_state is not None and 'rng_tracker_states' in extra_state: + get_cuda_rng_tracker().set_states( + extra_state['rng_tracker_states']) + return extra_state \ No newline at end of file From dd52ed0f3896639b3c04aa67c44775f689faf1a5 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 3 Nov 2020 20:44:09 -0800 Subject: [PATCH 266/707] Small fixes (#1392) Summary: - Set default value of clip-norm back to 0.0 (disabled) - Add comment explaining that we divide loss by log(2) to covert the base - Fix `--zero-optimizer=os` (fixes #2811) - Update requirements to PyTorch >= 1.5 - Fix bug in fixed LR schedule Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1392 Reviewed By: alexeib Differential Revision: D24714231 Pulled By: myleott fbshipit-source-id: 63dc8cfc74683bbccbf05b44228014eb12ddbfc7 --- README.md | 2 +- config/config.yaml | 2 +- fairseq/criterions/cross_entropy.py | 1 + fairseq/dataclass/configs.py | 2 +- fairseq/optim/fp16_optimizer.py | 3 ++- fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py | 4 ++++ fairseq/optim/lr_scheduler/fixed_schedule.py | 9 ++++----- fairseq/optim/lr_scheduler/polynomial_decay_schedule.py | 5 ++--- fairseq/trainer.py | 8 ++++++++ 9 files changed, 24 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 56ec16cdab..70e98fe395 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example # Requirements and Installation -* [PyTorch](http://pytorch.org/) version >= 1.4.0 +* [PyTorch](http://pytorch.org/) version >= 1.5.0 * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * **To install fairseq** and develop locally: diff --git a/config/config.yaml b/config/config.yaml index b9ee6c74ac..dc0ca0fa60 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -63,7 +63,7 @@ dataset: optimization: max_epoch: 0 max_update: 0 - clip_norm: 25.0 + clip_norm: 0.0 sentence_avg: false update_freq: [ 1 ] lr: [ 0.25 ] diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 758e727660..fe46106471 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -64,6 +64,7 @@ def reduce_metrics(logging_outputs) -> None: ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 metrics.log_scalar( "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 ) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 484d2526d7..ce07282422 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -434,7 +434,7 @@ class OptimizationConfig(FairseqDataclass): }, ) clip_norm: float = field( - default=25.0, metadata={"help": "clip threshold of gradients"} + default=0.0, metadata={"help": "clip threshold of gradients"} ) sentence_avg: bool = field( default=False, diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b08a7237a9..aacd3e1d94 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -215,7 +215,8 @@ def zero_grad(self): raise RuntimeError("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: - p32.grad.zero_() + if p32.grad: + p32.grad.zero_() self._needs_sync = False if self.scaler is not None: diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 569e448262..d0ac115829 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -34,6 +34,10 @@ def load_state_dict(self, state_dict): """Load an LR scheduler state dict.""" self.best = state_dict["best"] + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + pass + def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" if val_loss is not None: diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index 7ca7826ed2..e91ba86f8c 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -27,7 +27,7 @@ def add_args(parser): """Add arguments to the parser for this LR scheduler.""" # fmt: off parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', - help='force annealing at specified epoch') + help='force annealing at specified epoch (epochs start at 1)') parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', help='shrink factor for annealing, lr_new = (lr * lr_shrink)') parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', @@ -45,7 +45,7 @@ def get_next_lr(self, epoch): lrs = self.args.lr if self.args.force_anneal is None or epoch < self.args.force_anneal: # use fixed LR schedule - next_lr = lrs[min(epoch, len(lrs) - 1)] + next_lr = lrs[min(epoch - 1, len(lrs) - 1)] else: # annneal based on lr_shrink next_lr = lrs[-1] * self.args.lr_shrink ** ( @@ -53,9 +53,8 @@ def get_next_lr(self, epoch): ) return next_lr - def step(self, epoch, val_loss=None): - """Update the learning rate at the end of the given epoch.""" - super().step(epoch, val_loss) + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" self.lr = self.get_next_lr(epoch) self.optimizer.set_lr(self.warmup_factor * self.lr) return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index ea8e647668..63adc740a9 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -57,9 +57,8 @@ def get_next_lr(self, epoch): next_lr = self.optimizer.get_lr() return next_lr - def step(self, epoch, val_loss=None): - """Update the learning rate at the end of the given epoch.""" - super().step(epoch, val_loss) + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" self.lr = self.get_next_lr(epoch) self.optimizer.set_lr(self.warmup_factor * self.lr) return self.optimizer.get_lr() diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a4d273ca67..5daf2e2e5b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -429,6 +429,8 @@ def begin_epoch(self, epoch): """Called at the beginning of each epoch.""" logger.info("begin training epoch {}".format(epoch)) + self.lr_step_begin_epoch(epoch) + if self.quantizer is not None: self.quantizer.begin_epoch(epoch) @@ -782,6 +784,12 @@ def valid_step(self, sample, raise_oom=False): def zero_grad(self): self.optimizer.zero_grad() + def lr_step_begin_epoch(self, epoch): + """Adjust the learning rate at the beginning of the epoch.""" + self.lr_scheduler.step_begin_epoch(epoch) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + def lr_step(self, epoch, val_loss=None): """Adjust the learning rate at the end of the epoch.""" self.lr_scheduler.step(epoch, val_loss) From 1a709b2a401ac8bd6d805c8a6a5f4d7f03b923ff Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 3 Nov 2020 20:46:45 -0800 Subject: [PATCH 267/707] Reproduce #1781. Add Weights and Biases support Summary: Fixes https://github.com/pytorch/fairseq/issues/1790. Reviewed By: alexeib Differential Revision: D24579153 fbshipit-source-id: 74a30effa164db9d6376554376e36b1f47618899 Co-authored-by: Nikolay Korolev Co-authored-by: Vlad Lyalin --- .gitignore | 3 ++ config/config.yaml | 1 + fairseq/dataclass/configs.py | 6 ++++ fairseq/logging/progress_bar.py | 51 +++++++++++++++++++++++++++++++++ fairseq_cli/train.py | 10 +++++-- 5 files changed, 69 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9546cffd90..4112804793 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,6 @@ data-bin/ # Experimental Folder experimental/* + +# Weights and Biases logs +wandb/ diff --git a/config/config.yaml b/config/config.yaml index dc0ca0fa60..e3d26089f7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,6 +4,7 @@ common: log_interval: 100 log_format: null tensorboard_logdir: null + wandb_project: null seed: 1 cpu: false tpu: false diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index ce07282422..3bdc6d16d4 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -102,6 +102,12 @@ class CommonConfig(FairseqDataclass): "of running tensorboard (default: no tensorboard logging)" }, ) + wandb_project: Optional[str] = field( + default=None, + metadata={ + "help": "Weights and Biases project name to use for logging" + }, + ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} ) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 63e5394815..3183d2f476 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -33,6 +33,7 @@ def progress_bar( prefix: Optional[str] = None, tensorboard_logdir: Optional[str] = None, default_log_format: str = "tqdm", + wandb_project: Optional[str] = None, ): if log_format is None: log_format = default_log_format @@ -60,6 +61,9 @@ def progress_bar( except ImportError: bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) + if wandb_project: + bar = WandBProgressBarWrapper(bar, wandb_project) + return bar @@ -353,3 +357,50 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): elif isinstance(stats[key], Number): writer.add_scalar(key, stats[key], step) writer.flush() + + +try: + import wandb +except ImportError: + wandb = None + + +class WandBProgressBarWrapper(BaseProgressBar): + """Log to Weights & Biases.""" + + def __init__(self, wrapped_bar, wandb_project): + self.wrapped_bar = wrapped_bar + if wandb is None: + logger.warning('wandb not found, pip install wandb') + return + + # reinit=False to ensure if wandb.init() is called multiple times + # within one process it still references the same run + wandb.init(project=wandb_project, reinit=False) + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def _log_to_wandb(self, stats, tag=None, step=None): + if wandb is None: + return + if step is None: + step = stats['num_updates'] + + prefix = '' if tag is None else tag + '/' + + for key in stats.keys() - {'num_updates'}: + if isinstance(stats[key], AverageMeter): + wandb.log({prefix + key: stats[key].val}, step=step) + elif isinstance(stats[key], Number): + wandb.log({prefix + key: stats[key]}, step=step) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ec10028f03..9eeca18e92 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -187,7 +187,10 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + ), ) trainer.begin_epoch(epoch_itr.epoch) @@ -307,7 +310,10 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + ), ) # create a new root metrics aggregator so validation metrics From ea4ccd94de131d6b39163836418696369dd1d034 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Wed, 4 Nov 2020 12:56:13 -0800 Subject: [PATCH 268/707] Load and broadcast fairseq checkpoints instead of having each rank load them individually Summary: This diff is based on feedback in D24379649 Before when loading checkpoints: Each rank loads the checkpoint from Manifold. Now: Rank 0 loads checkpoint from Manifold. This checkpoint is broadcasted to all other ranks. This saves IO. Furthermore, when doing zero-sharding, we only broadcast the relevant parts of the optimizer state to each node. This makes checkpoint loading more memory-efficient and should enable loading models beyond 2-3B parameters. Reviewed By: myleott Differential Revision: D24660791 fbshipit-source-id: e30b2ea5990083375e4549f0427a112346ba170d --- fairseq/distributed_utils.py | 42 +++++++++++++++++++++- fairseq/optim/fairseq_optimizer.py | 10 ++++++ fairseq/optim/shard.py | 58 +++++++++++++++++++++++++++++- fairseq/trainer.py | 45 ++++++++++++++++++++--- 4 files changed, 149 insertions(+), 6 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 0d5804c8f7..cbe6c6de5d 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import io import logging import os import pickle @@ -13,7 +14,7 @@ import warnings from argparse import Namespace from collections import OrderedDict -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import torch import torch.distributed as dist @@ -455,3 +456,42 @@ def get_from_stack(key): raise KeyError return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) + + +# From fairscale/optim/utils.py +def broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + dist_device: Optional[torch.device] = None, +) -> Any: + """ + Either broadcast from master to the fleet (default), + or use the src setting as the original rank. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + if dist.get_rank() == src_rank: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(dist_device) + data_send_tensor = torch.ByteTensor(data).to(dist_device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Fetch from the source + length_tensor = torch.LongTensor([0]).to(dist_device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty( + [int(length_tensor.item())], dtype=torch.uint8, device=dist_device + ) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=dist_device) + return obj diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 9c0938331d..e91e9d3204 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -144,6 +144,16 @@ def supports_flat_params(self): def average_params(self): pass + def broadcast_global_state_dict(self, state_dict): + """ + Broadcasts a global state dict to all ranks. + Useful for optimizers that shard state between ranks. + """ + if hasattr(self.optimizer, "broadcast_global_state_dict"): + return self.optimizer.broadcast_global_state_dict(state_dict) + else: + return state_dict + class LegacyFairseqOptimizer(FairseqOptimizer): def __init__(self, args): diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index ecef05b442..3d025a23ca 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -3,9 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Dict + +import torch + try: - from fairscale.optim import OSS + from fairscale.optim import OSS, utils _has_fairscale = True except ImportError: @@ -30,6 +34,58 @@ def __getattr__(self, name): "'FairseqOSS' object has no attribute {0!r}".format(name) ) + def broadcast_global_state_dict( + self, state_dict: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Broadcasts the relevant parts of a global state dict from rank 0 to + all other ranks. + """ + if self.rank == 0: + + # Create template state dict for all other keys not related to sharding + template_state_dict = { + key: state_dict[key] + for key in state_dict + if key not in ("param_groups", "state") + } + template_state_dict["local_state_dict"] = True + + for dst_rank in range(self.world_size): + # Get the dst_rank's param_groups shard + send_state = { + "param_groups": state_dict["param_groups"][ + state_dict["partition"][dst_rank][0] : state_dict[ + "partition" + ][dst_rank][1] + ], + "state": state_dict["state"][dst_rank], + } + send_state.update(template_state_dict) + + if dst_rank == 0: + recv_state = send_state + else: + utils.broadcast_object( + send_state, + src_rank=0, + group=self.group, + dist_device=self._device, + ) + else: + empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) + for dst_rank in range(1, self.world_size): + state = utils.broadcast_object( + empty_buffer, + src_rank=0, + group=self.group, + dist_device=self._device, + ) + if dst_rank == self.rank: + recv_state = state + + return recv_state + torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5daf2e2e5b..657374aab7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -289,12 +289,46 @@ def load_checkpoint( optimizer_overrides=None, reset_meters=False, ): - """Load all training state from a checkpoint file.""" + """ + Load all training state from a checkpoint file. + rank = 0 will load the checkpoint, and then broadcast it to all + other ranks. + """ extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: - state = checkpoint_utils.load_checkpoint_to_cpu(filename) + if self.data_parallel_rank == 0: + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + last_optim_state = state.get("last_optimizer_state", None) + + # If doing zero_sharding, do not broadcast global optimizer + # state. Later we will broadcast sharded states to each rank + # to avoid memory from exploding. + if ( + self.cfg.distributed_training.zero_sharding == "os" + and "last_optimizer_state" in state + and self.data_parallel_world_size > 1 + ): + state["last_optimizer_state"] = "SHARDED" + else: + last_optim_state = None + state = None + + if self.data_parallel_world_size > 1: + group = ( + self.data_parallel_process_group + if self.data_parallel_process_group is not None + else torch.distributed.group.WORLD + ) + state = distributed_utils.broadcast_object( + state, + src_rank=0, + group=group, + ) + if self.data_parallel_rank > 0: + last_optim_state = state.get("last_optimizer_state", None) + # load model parameters try: self.get_model().load_state_dict( @@ -309,10 +343,8 @@ def load_checkpoint( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format(filename) ) - extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] - last_optim_state = state.get("last_optimizer_state", None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed @@ -329,6 +361,11 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + + if self.data_parallel_world_size > 1: + last_optim_state = self.optimizer.broadcast_global_state_dict( + last_optim_state + ) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) From b58f4f017ed275aff327046943857b4259f64a47 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 4 Nov 2020 18:19:10 -0800 Subject: [PATCH 269/707] end to end hydra configs (#1393) Summary: this adds a hydra_train binary that uses hydra configs/command line overrides instead of argparse use case 1: built in configs + overrides from command line ``` python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=/private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/ model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 ``` use case 2: use an external config that is used instead of bundled configs (but dataclass defaults still work) ``` python fairseq_cli/hydra_train.py --config-path ~/fairseq-py-dev/lm --config-name wiki103 ``` the config file contains this: ``` # package _group_ model: _name: transformer_lm distributed_training: distributed_world_size: 1 dataset: batch_size: 2 task: _name: language_modeling data: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/ add_bos_token: false max_target_positions: 1024 optimization: max_update: 50000 lr: [ 0.25 ] criterion: cross_entropy optimizer: adam lr_scheduler: _name: cosine ``` use case 3: use an external config directory that provides additional configs for e.g. models python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=/private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 --config-dir ~/fairseq-py-dev/lm/hydra where ~/fairseq-py-dev/lm/hydra has the following structure: - model -- transformer_lm --- 2_layers.yaml and inside 2_layers.yaml is a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1393 Reviewed By: myleott Differential Revision: D24722252 Pulled By: alexeib fbshipit-source-id: 758ea431fa099cd7c0e4daf41eff680df1d3b841 --- config/config.yaml | 105 +----------------- config/criterion/adaptive_loss.yaml | 3 - config/criterion/cross_entropy.yaml | 2 - config/lr_scheduler/cosine.yaml | 7 -- config/lr_scheduler/inverse_sqrt.yaml | 3 - config/model/transformer_lm.yaml | 36 ------ .../transformer_lm_baevski_gbw.yaml | 0 .../transformer_lm_baevski_wiki103.yaml | 0 .../transformer_lm_big.yaml | 0 .../transformer_lm_gbw.yaml | 0 .../transformer_lm_gpt.yaml | 0 .../transformer_lm_gpt2_big.yaml | 0 .../transformer_lm_gpt2_medium.yaml | 0 .../transformer_lm_gpt2_small.yaml | 0 .../transformer_lm_wiki103.yaml | 0 config/optimizer/adam.yaml | 5 - config/optimizer/nag.yaml | 3 - config/task/language_modeling.yaml | 10 -- fairseq/dataclass/__init__.py | 5 +- fairseq/dataclass/configs.py | 14 ++- fairseq/dataclass/constants.py | 3 + fairseq/dataclass/initialize.py | 15 ++- fairseq/dataclass/utils.py | 24 +++- fairseq/distributed_utils.py | 3 +- fairseq/models/__init__.py | 34 +++++- fairseq/optim/fp16_optimizer.py | 2 +- .../optim/lr_scheduler/cosine_lr_scheduler.py | 2 +- fairseq/registry.py | 14 ++- fairseq/tasks/__init__.py | 21 +++- fairseq/tasks/language_modeling.py | 5 +- fairseq_cli/hydra_train.py | 47 ++++++++ fairseq_cli/train.py | 101 +++++++++++------ tests/test_fp16_optimizer.py | 4 + 33 files changed, 240 insertions(+), 228 deletions(-) delete mode 100644 config/criterion/adaptive_loss.yaml delete mode 100644 config/criterion/cross_entropy.yaml delete mode 100644 config/lr_scheduler/cosine.yaml delete mode 100644 config/lr_scheduler/inverse_sqrt.yaml delete mode 100644 config/model/transformer_lm.yaml rename config/model/{ => transformer_lm}/transformer_lm_baevski_gbw.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_baevski_wiki103.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_big.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_gbw.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_gpt.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_gpt2_big.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_gpt2_medium.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_gpt2_small.yaml (100%) rename config/model/{ => transformer_lm}/transformer_lm_wiki103.yaml (100%) delete mode 100644 config/optimizer/adam.yaml delete mode 100644 config/optimizer/nag.yaml delete mode 100644 config/task/language_modeling.yaml create mode 100644 fairseq_cli/hydra_train.py diff --git a/config/config.yaml b/config/config.yaml index e3d26089f7..039609aece 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,109 +1,10 @@ # @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - wandb_project: null - seed: 1 - cpu: false - tpu: false - bf16: false - fp16: false - memory_efficient_fp16: false - memory_efficient_bf16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - quantization_config_path: null - profile: false -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: null - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${dataset.max_tokens} - batch_size_valid: ${dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 0.0 - sentence_avg: false - update_freq: [ 1 ] - lr: [ 0.25 ] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 - checkpoint_suffix: "" -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false defaults: - task: language_modeling - model: null - - criterion: null - - optimizer: null - - lr_scheduler: null + - criterion: cross_entropy + - optimizer: adam + - lr_scheduler: cosine - bpe: null - tokenizer: null - scoring: null diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml deleted file mode 100644 index 7997b0766e..0000000000 --- a/config/criterion/adaptive_loss.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# @package _group_ -sentence_avg: ${optimization.sentence_avg} -ddp_backend: ${distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml deleted file mode 100644 index ad3d4148c2..0000000000 --- a/config/criterion/cross_entropy.yaml +++ /dev/null @@ -1,2 +0,0 @@ -# @package _group_ -sentence_avg: ${optimization.sentence_avg} diff --git a/config/lr_scheduler/cosine.yaml b/config/lr_scheduler/cosine.yaml deleted file mode 100644 index 0f91e0d240..0000000000 --- a/config/lr_scheduler/cosine.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _group_ -warmup_updates: 0 -warmup_init_lr: -1 -max_lr: 1.0 -t_mult: 1.0 -lr_period_updates: -1 -lr_shrink: 0.1 diff --git a/config/lr_scheduler/inverse_sqrt.yaml b/config/lr_scheduler/inverse_sqrt.yaml deleted file mode 100644 index 0eac7d88eb..0000000000 --- a/config/lr_scheduler/inverse_sqrt.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# @package _group_ -warmup_updates: 4000 -warmup_init_lr: -1 diff --git a/config/model/transformer_lm.yaml b/config/model/transformer_lm.yaml deleted file mode 100644 index 3837ea54e1..0000000000 --- a/config/model/transformer_lm.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# @package _group_ -activation_fn: "relu" -dropout: 0.1 -attention_dropout: 0.0 -activation_dropout: 0.0 -relu_dropout: 0.0 -decoder_embed_dim: 512 -decoder_output_dim: 512 -decoder_input_dim: 512 -decoder_ffn_embed_dim: 2048 -decoder_layers: 6 -decoder_attention_heads: 8 -decoder_normalize_before: true -no_decoder_final_norm: false -adaptive_softmax_cutoff: null -adaptive_softmax_dropout: 0 -adaptive_softmax_factor: 4 -no_token_positional_embeddings: false -share_decoder_input_output_embed: false -character_embeddings: false -character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" -character_embedding_dim: 4 -char_embedder_highway_layers: 2 -adaptive_input: false -adaptive_input_factor: 4 -adaptive_input_cutoff: null -tie_adaptive_weights: false -tie_adaptive_proj: false -decoder_learned_pos: false -decoder_layerdrop: 0 -decoder_layers_to_keep: null -layernorm_embedding: false -no_scale_embedding: false -quant_noise_pq: 0 -quant_noise_pq_block_size: 8 -quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_baevski_gbw.yaml b/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml similarity index 100% rename from config/model/transformer_lm_baevski_gbw.yaml rename to config/model/transformer_lm/transformer_lm_baevski_gbw.yaml diff --git a/config/model/transformer_lm_baevski_wiki103.yaml b/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml similarity index 100% rename from config/model/transformer_lm_baevski_wiki103.yaml rename to config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml diff --git a/config/model/transformer_lm_big.yaml b/config/model/transformer_lm/transformer_lm_big.yaml similarity index 100% rename from config/model/transformer_lm_big.yaml rename to config/model/transformer_lm/transformer_lm_big.yaml diff --git a/config/model/transformer_lm_gbw.yaml b/config/model/transformer_lm/transformer_lm_gbw.yaml similarity index 100% rename from config/model/transformer_lm_gbw.yaml rename to config/model/transformer_lm/transformer_lm_gbw.yaml diff --git a/config/model/transformer_lm_gpt.yaml b/config/model/transformer_lm/transformer_lm_gpt.yaml similarity index 100% rename from config/model/transformer_lm_gpt.yaml rename to config/model/transformer_lm/transformer_lm_gpt.yaml diff --git a/config/model/transformer_lm_gpt2_big.yaml b/config/model/transformer_lm/transformer_lm_gpt2_big.yaml similarity index 100% rename from config/model/transformer_lm_gpt2_big.yaml rename to config/model/transformer_lm/transformer_lm_gpt2_big.yaml diff --git a/config/model/transformer_lm_gpt2_medium.yaml b/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml similarity index 100% rename from config/model/transformer_lm_gpt2_medium.yaml rename to config/model/transformer_lm/transformer_lm_gpt2_medium.yaml diff --git a/config/model/transformer_lm_gpt2_small.yaml b/config/model/transformer_lm/transformer_lm_gpt2_small.yaml similarity index 100% rename from config/model/transformer_lm_gpt2_small.yaml rename to config/model/transformer_lm/transformer_lm_gpt2_small.yaml diff --git a/config/model/transformer_lm_wiki103.yaml b/config/model/transformer_lm/transformer_lm_wiki103.yaml similarity index 100% rename from config/model/transformer_lm_wiki103.yaml rename to config/model/transformer_lm/transformer_lm_wiki103.yaml diff --git a/config/optimizer/adam.yaml b/config/optimizer/adam.yaml deleted file mode 100644 index e5264f895e..0000000000 --- a/config/optimizer/adam.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _group_ -adam_betas: "(0.9, 0.999)" -adam_eps: 1.0e-8 -weight_decay: 0 -use_old_adam: false diff --git a/config/optimizer/nag.yaml b/config/optimizer/nag.yaml deleted file mode 100644 index 4ab2745686..0000000000 --- a/config/optimizer/nag.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# @package _group_ -momentum: 0.99 -weight_decay: 0.0 diff --git a/config/task/language_modeling.yaml b/config/task/language_modeling.yaml deleted file mode 100644 index 58a2ad1358..0000000000 --- a/config/task/language_modeling.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package _group_ -data: ??? -sample_break_mode: "none" -tokens_per_sample: 1024 -output_dictionary_size: -1 -self_target: false -future_target: false -past_target: false -add_bos_token: false -max_target_positions: null diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py index 5c9004d3ba..25408d28ec 100644 --- a/fairseq/dataclass/__init__.py +++ b/fairseq/dataclass/__init__.py @@ -7,4 +7,7 @@ from .constants import ChoiceEnum -__all__ = ["FairseqDataclass", "ChoiceEnum"] +__all__ = [ + "FairseqDataclass", + "ChoiceEnum", +] diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3bdc6d16d4..a3c0d06a39 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -20,7 +20,7 @@ ZERO_SHARDING_CHOICES, ) -from omegaconf import II +from omegaconf import II, MISSING @dataclass @@ -781,7 +781,9 @@ class GenerationConfig(FairseqDataclass): default=False, metadata={"help": "Use dropout at inference time"}, ) - retain_dropout_modules: Optional[List[str]] = field( + # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed + # retain_dropout_modules: Optional[List[str]] = field( + retain_dropout_modules: Any = field( default=None, metadata={ "help": "if set, only retain dropout for the specified modules; " @@ -880,3 +882,11 @@ class FairseqConfig(object): generation: GenerationConfig = GenerationConfig() eval_lm: EvalLMConfig = EvalLMConfig() interactive: InteractiveConfig = InteractiveConfig() + model: Any = MISSING + task: Any = None + criterion: Any = None + optimizer: Any = None + lr_scheduler: Any = None + scoring: Any = None + bpe: Any = None + tokenizer: Any = None diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 3eb63ec609..fad04f3482 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -17,6 +17,9 @@ def __eq__(self, other: str): def __repr__(self): return self.value + def __hash__(self): + return hash(str(self)) + def ChoiceEnum(choices: List[str]): """return the Enum class used to enforce list of choices""" diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index b762af990f..24fedd52bf 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -2,15 +2,20 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import logging - from typing import Dict, Any - from hydra.core.config_store import ConfigStore - from fairseq.dataclass.configs import FairseqConfig +# the imports below are necessary so that "REGISTRIES" is correctly populated with all components +from fairseq.criterions import CRITERION_REGISTRY # noqa +from fairseq.optim import OPTIMIZER_REGISTRY # noqa +from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY # noqa +from fairseq.scoring import SCORER_REGISTRY # noqa +from fairseq.data.encoders import BPE_REGISTRY, TOKENIZER_REGISTRY # noqa + from fairseq.models import MODEL_DATACLASS_REGISTRY from fairseq.tasks import TASK_DATACLASS_REGISTRY from fairseq.registry import REGISTRIES @@ -30,8 +35,10 @@ def register_module_dataclass( cs.store(name=k, group=group, node=node_, provider="fairseq") -def hydra_init() -> None: +def hydra_init(cfg_name="config") -> None: + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=FairseqConfig) for k in FairseqConfig.__dataclass_fields__: v = FairseqConfig.__dataclass_fields__[k].default diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 5ce017d765..477a198d0f 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -8,11 +8,11 @@ from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum +import inspect from typing import Any, Dict, List, Tuple, Type from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig -from hydra.core.global_hydra import GlobalHydra from hydra.experimental import compose, initialize from omegaconf import DictConfig, OmegaConf, open_dict @@ -177,6 +177,9 @@ def _override_attr( ) -> List[str]: overrides = [] + if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass): + return overrides + def get_default(f): if not isinstance(f.default_factory, _MISSING_TYPE): return f.default_factory() @@ -189,6 +192,12 @@ def get_default(f): val = get_default(v) if not hasattr(args, k) else getattr(args, k) + if getattr(v.type, "__origin__", None) is List: + # if type is int but val is float, then we will crash later - try to convert here + t_args = v.type.__args__ + if len(t_args) == 1: + val = list(map(t_args[0], val)) + if val is None: overrides.append("{}.{}=null".format(sub_node, k)) elif val == "": @@ -255,13 +264,14 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: no_dc = True if hasattr(args, "arch"): - from fairseq.models import ARCH_MODEL_REGISTRY + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY if args.arch in ARCH_MODEL_REGISTRY: m_cls = ARCH_MODEL_REGISTRY[args.arch] dc = getattr(m_cls, "__dataclass", None) if dc is not None: - overrides.append("model={}".format(args.arch)) + m_name = ARCH_MODEL_NAME_REGISTRY[args.arch] + overrides.append("model={}".format(m_name)) overrides.append("model._name={}".format(args.arch)) # override model params with those exist in args overrides.extend(_override_attr("model", dc, args)) @@ -358,3 +368,11 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): overwrite_args_by_name(cfg[k], overrides) elif k in overrides: cfg[k] = overrides[k] + + +def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig): + dc_instance = DictConfig(dc) + dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"] + cfg = OmegaConf.merge(dc_instance, cfg) + OmegaConf.set_struct(cfg, True) + return cfg diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index cbe6c6de5d..3439508c94 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -147,7 +147,8 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert cfg.distributed_world_size <= torch.cuda.device_count() + assert cfg.distributed_world_size <= torch.cuda.device_count(), \ + f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" port = random.randint(10000, 20000) cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 3b4fd51d6c..e8af024795 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -10,6 +10,7 @@ import fairseq from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent from omegaconf import DictConfig, OmegaConf from .composite_encoder import CompositeEncoder @@ -51,9 +52,36 @@ def build_model(cfg: DictConfig, task): - if isinstance(cfg, DictConfig): - return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task) - return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task) + + model = None + model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) + + if not model_type and len(cfg) == 1: + # this is hit if config object is nested in directory that is named after model type + + model_type = next(iter(cfg)) + if model_type in MODEL_DATACLASS_REGISTRY: + cfg = cfg[model_type] + else: + raise Exception( + "Could not infer model type from directory. Please add _name field to indicate model type" + ) + + if model_type in ARCH_MODEL_REGISTRY: + # case 1: legacy models + model = ARCH_MODEL_REGISTRY[model_type] + elif model_type in MODEL_DATACLASS_REGISTRY: + # case 2: config-driven models + model = MODEL_REGISTRY[model_type] + + if model_type in MODEL_DATACLASS_REGISTRY: + # set defaults from dataclass. note that arch name and model name can be the same + dc = MODEL_DATACLASS_REGISTRY[model_type] + cfg = merge_with_parent(dc(), cfg) + + assert model is not None, f"Could not infer model type from {cfg}" + + return model.build_model(cfg, task) def register_model(name, dataclass=None): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index aacd3e1d94..8ef61a6a7e 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -215,7 +215,7 @@ def zero_grad(self): raise RuntimeError("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: - if p32.grad: + if p32.grad is not None: p32.grad.zero_() self._needs_sync = False diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index c3c6663ece..646ac66be9 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -74,7 +74,7 @@ def __init__( if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with cosine." - " Consider --lr-scheduler=fixed instead." + f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" ) warmup_end_lr = cfg.max_lr diff --git a/fairseq/registry.py b/fairseq/registry.py index 96994cb8d4..29631bb326 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -7,7 +7,7 @@ from typing import Union from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import populate_dataclass +from fairseq.dataclass.utils import populate_dataclass, merge_with_parent from omegaconf import DictConfig REGISTRIES = {} @@ -24,11 +24,19 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F # maintain a registry of all registries if registry_name in REGISTRIES: return # registry already exists - REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY} + REGISTRIES[registry_name] = { + "registry": REGISTRY, + "default": default, + "dataclass_registry": DATACLASS_REGISTRY, + } def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): if isinstance(cfg, DictConfig): choice = cfg._name + + if choice and choice in DATACLASS_REGISTRY: + dc = DATACLASS_REGISTRY[choice] + cfg = merge_with_parent(dc(), cfg) elif isinstance(cfg, str): choice = cfg if choice in DATACLASS_REGISTRY: @@ -40,7 +48,7 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs) if choice is None: if required: - raise ValueError('{} is required!'.format(registry_name)) + raise ValueError("{} is required!".format(registry_name)) return None cls = REGISTRY[choice] diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 41f461f802..7575ba429e 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -9,6 +9,7 @@ import os from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent from omegaconf import DictConfig from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa @@ -21,9 +22,23 @@ def setup_task(cfg: DictConfig, **kwargs): - if isinstance(cfg, DictConfig): - return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs) - return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs) + task = None + task_name = getattr(cfg, "task", None) + + if isinstance(task_name, str): + # legacy tasks + task = TASK_REGISTRY[task_name] + else: + task_name = getattr(cfg, "_name", None) + + if task_name and task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = merge_with_parent(dc(), cfg) + task = TASK_REGISTRY[task_name] + + assert task is not None, f"Could not infer task type from {cfg}" + + return task.setup_task(cfg, **kwargs) def register_task(name, dataclass=None): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 6e85417ff5..79c225de6f 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -158,8 +158,8 @@ def setup_task(cls, args, **kwargs): dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) # upgrade old checkpoints - if hasattr(args, "exclude_self_target"): - args.self_target = not args.exclude_self_target + if getattr(args, "exclude_self_target", False): + args.self_target = False targets = [] if getattr(args, "self_target", False): @@ -176,7 +176,6 @@ def setup_task(cls, args, **kwargs): def build_model(self, args): model = super().build_model(args) - for target in self.targets: if target not in model.supported_targets: raise ValueError( diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py new file mode 100644 index 0000000000..24728c507f --- /dev/null +++ b/fairseq_cli/hydra_train.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hydra +from omegaconf import OmegaConf + +from fairseq.dataclass.initialize import hydra_init +from fairseq_cli.train import main as pre_main +from fairseq import distributed_utils +from fairseq.dataclass.configs import FairseqConfig + +import logging +import torch + + +logger = logging.getLogger(__name__) + + +@hydra.main(config_path="../config", config_name="config") +def hydra_main(cfg: FairseqConfig) -> None: + + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) + + OmegaConf.set_struct(cfg, True) + + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, pre_main) + else: + distributed_utils.call_main(cfg, pre_main) + + +if __name__ == "__main__": + try: + from hydra._internal.utils import get_args + + cfg_name = get_args().config_name or "config" + except: + logger.warning("Failed to get config name from hydra args") + cfg_name = "config" + + hydra_init(cfg_name) + hydra_main() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 9eeca18e92..e1af605348 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -48,8 +48,9 @@ def main(cfg: DictConfig) -> None: utils.import_user_module(cfg.common) - assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ - 'Must specify batch size either with --max-tokens or --batch-size' + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) @@ -64,22 +65,24 @@ def main(cfg: DictConfig) -> None: # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in cfg.dataset.valid_subset.split(','): + for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) + assert cfg.criterion, "Please specify criterion to train a model" + # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) - logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {})".format(criterion.__class__.__name__)) logger.info( - "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) + "num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) ) - logger.info("num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: @@ -97,11 +100,17 @@ def main(cfg: DictConfig) -> None: else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size)) - logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format( - cfg.dataset.max_tokens, - cfg.dataset.batch_size, - )) + logger.info( + "training on {} devices (GPUs/TPUs)".format( + cfg.distributed_training.distributed_world_size + ) + ) + logger.info( + "max tokens per GPU = {} and batch size per GPU = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + ) + ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator @@ -116,10 +125,7 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while ( - lr > cfg.optimization.min_lr - and epoch_itr.next_epoch_idx <= max_epoch - ): + while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -157,14 +163,20 @@ def is_better(a, b): else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= cfg.checkpoint.patience: - logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience)) + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + cfg.checkpoint.patience + ) + ) return True else: return False @metrics.aggregate("train") -def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: +def train( + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr +) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -185,7 +197,9 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -195,7 +209,7 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) trainer.begin_epoch(epoch_itr.epoch) - valid_subsets = cfg.dataset.valid_subset.split(',') + valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -233,7 +247,14 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) return valid_losses, should_stop -def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: +def validate_and_save( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + valid_subsets: List[str], + end_of_epoch: bool, +) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf do_save = ( @@ -268,14 +289,17 @@ def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask or num_updates >= max_update or ( cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours + and trainer.cumulative_training_time() / (60 * 60) + > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint( + cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + ) return valid_losses, should_stop @@ -285,7 +309,13 @@ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: return stats -def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: +def validate( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + subsets: List[str], +) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: @@ -308,7 +338,9 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -330,18 +362,23 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i return valid_losses -def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: +def get_valid_stats( + cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any] +) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], ) return stats -def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) @@ -355,5 +392,5 @@ def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] distributed_utils.call_main(cfg, main) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index 8de8e28ce0..ce4f1c055c 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -31,6 +31,9 @@ def setUp(self): self.cfg_dls = OmegaConf.create( { + "optimization": { + "lr": [0.1], + }, "optimizer": { "_name": "adam", "lr": [0.1], @@ -44,6 +47,7 @@ def setUp(self): "fp16_scale_tolerance": 1, "threshold_loss_scale": 1, "min_loss_scale": 1e-4, + "tpu": False, }, } ) From f57b14893837716bdaab4cb9a1430b19d4a6ccf7 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 5 Nov 2020 09:43:02 -0800 Subject: [PATCH 270/707] Require process group for all helpers in distributed_utils (#1395) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1395 Data parallel command: `python train.py --task dummy_lm --arch transformer_lm --tokens-per-sample 512 --max-sentences 8 --decoder-attention-heads 8 --dropout 0.0 --activation-dropout 0.0 --optimizer adam --lr 0.0001 --log-format simple --log-interval 1 --no-save --clip-norm 0.0` Data parallel before: ``` 2020-11-04 07:14:16 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 07:14:16 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 07:14:16 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 07:14:16 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 07:14:16 | INFO | fairseq.trainer | NOTE: your device may support faster training with --fp16 2020-11-04 07:14:16 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 07:14:16 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 07:14:21 | INFO | train_inner | epoch 001: 1 / 1563 loss=16.297, ppl=80495, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=2.501, train_wall=2, wall=5 2020-11-04 07:14:21 | INFO | train_inner | epoch 001: 2 / 1563 loss=15.399, ppl=43203.8, wps=101398, ups=3.09, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=2.101, train_wall=0, wall=6 2020-11-04 07:14:21 | INFO | train_inner | epoch 001: 3 / 1563 loss=14.742, ppl=27411.2, wps=217567, ups=6.63, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=1.888, train_wall=0, wall=6 2020-11-04 07:14:21 | INFO | train_inner | epoch 001: 4 / 1563 loss=14.206, ppl=18899.3, wps=219413, ups=6.69, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=1.91, train_wall=0, wall=6 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 5 / 1563 loss=13.697, ppl=13282.1, wps=219446, ups=6.69, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=1.98, train_wall=0, wall=6 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 6 / 1563 loss=13.179, ppl=9274.18, wps=220131, ups=6.71, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.08, train_wall=0, wall=6 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 7 / 1563 loss=12.634, ppl=6358.37, wps=220236, ups=6.72, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.195, train_wall=0, wall=6 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 8 / 1563 loss=12.056, ppl=4256.86, wps=220392, ups=6.72, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.259, train_wall=0, wall=6 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 9 / 1563 loss=11.453, ppl=2804.05, wps=225842, ups=6.89, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.287, train_wall=0, wall=7 2020-11-04 07:14:22 | INFO | train_inner | epoch 001: 10 / 1563 loss=10.842, ppl=1835, wps=238808, ups=7.28, wpb=32768, bsz=64, num_updates=10, lr=0.0001, gnorm=2.311, train_wall=0, wall=7 ``` Data parallel after: ``` 2020-11-04 07:14:47 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 07:14:47 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 07:14:47 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 07:14:47 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 07:14:47 | INFO | fairseq.trainer | NOTE: your device may support faster training with --fp16 2020-11-04 07:14:47 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 07:14:47 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 07:14:52 | INFO | train_inner | epoch 001: 1 / 1563 loss=16.297, ppl=80495, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=2.501, train_wall=2, wall=5 2020-11-04 07:14:52 | INFO | train_inner | epoch 001: 2 / 1563 loss=15.399, ppl=43203.8, wps=96089.4, ups=2.93, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=2.101, train_wall=0, wall=5 2020-11-04 07:14:52 | INFO | train_inner | epoch 001: 3 / 1563 loss=14.742, ppl=27411.2, wps=239285, ups=7.3, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=1.888, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 4 / 1563 loss=14.206, ppl=18899.3, wps=233039, ups=7.11, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=1.91, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 5 / 1563 loss=13.697, ppl=13282.1, wps=237484, ups=7.24, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=1.98, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 6 / 1563 loss=13.179, ppl=9274.18, wps=231683, ups=7.07, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.08, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 7 / 1563 loss=12.634, ppl=6358.37, wps=233804, ups=7.13, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.195, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 8 / 1563 loss=12.056, ppl=4256.86, wps=234025, ups=7.14, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.259, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 9 / 1563 loss=11.453, ppl=2804.05, wps=238426, ups=7.27, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.287, train_wall=0, wall=6 2020-11-04 07:14:53 | INFO | train_inner | epoch 001: 10 / 1563 loss=10.842, ppl=1835, wps=240069, ups=7.32, wpb=32768, bsz=64, num_updates=10, lr=0.0001, gnorm=2.311, train_wall=0, wall=6 ``` Model parallel command: `python train.py --task dummy_lm --arch transformer_lm_megatron --decoder-layers 2 --batch-size 2 --tokens-per-sample 512 --log-format simple --log-interval 1 --fp16 --optimizer adam --model-parallel-size 2 --share-decoder-input-output-embed --lr 0.0001` Model parallel before: ``` 2020-11-04 07:12:22 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 07:12:22 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 2 2020-11-04 07:12:22 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last-model_part-0.pt 2020-11-04 07:12:22 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 07:12:23 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 07:12:23 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 1 / 12500 loss=60.017, ppl=1.16627e+18, wps=0, ups=0, wpb=4096, bsz=8, num_updates=1, lr=0.0001, gnorm=8.531, loss_scale=128, train_wall=2, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 2 / 12500 loss=46.473, ppl=9.77028e+13, wps=48996.6, ups=11.95, wpb=4096, bsz=8, num_updates=2, lr=0.0001, gnorm=15.019, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 3 / 12500 loss=30.525, ppl=1.54543e+09, wps=58424.2, ups=14.25, wpb=4096, bsz=8, num_updates=3, lr=0.0001, gnorm=13.936, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 4 / 12500 loss=18.561, ppl=386799, wps=58399.5, ups=14.24, wpb=4096, bsz=8, num_updates=4, lr=0.0001, gnorm=7.251, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 5 / 12500 loss=15.145, ppl=36230, wps=58275.6, ups=14.21, wpb=4096, bsz=8, num_updates=5, lr=0.0001, gnorm=2.392, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 6 / 12500 loss=14.683, ppl=26304.2, wps=58704.8, ups=14.32, wpb=4096, bsz=8, num_updates=6, lr=0.0001, gnorm=2.487, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:28 | INFO | train_inner | epoch 001: 7 / 12500 loss=14.169, ppl=18418.9, wps=58449.2, ups=14.26, wpb=4096, bsz=8, num_updates=7, lr=0.0001, gnorm=2.45, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:29 | INFO | train_inner | epoch 001: 8 / 12500 loss=13.574, ppl=12197.4, wps=59106.5, ups=14.42, wpb=4096, bsz=8, num_updates=8, lr=0.0001, gnorm=2.393, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:29 | INFO | train_inner | epoch 001: 9 / 12500 loss=12.974, ppl=8047.87, wps=58619.6, ups=14.3, wpb=4096, bsz=8, num_updates=9, lr=0.0001, gnorm=2.317, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:12:29 | INFO | train_inner | epoch 001: 10 / 12500 loss=12.341, ppl=5187.55, wps=58166.5, ups=14.19, wpb=4096, bsz=8, num_updates=10, lr=0.0001, gnorm=2.213, loss_scale=128, train_wall=0, wall=6 ``` Model parallel after: ``` 2020-11-04 07:11:07 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 07:11:07 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 2 2020-11-04 07:11:07 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last-model_part-0.pt 2020-11-04 07:11:07 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 07:11:08 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 07:11:08 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 07:11:13 | INFO | train_inner | epoch 001: 1 / 12500 loss=60.017, ppl=1.16627e+18, wps=0, ups=0, wpb=4096, bsz=8, num_updates=1, lr=0.0001, gnorm=8.531, loss_scale=128, train_wall=2, wall=6 2020-11-04 07:11:13 | INFO | train_inner | epoch 001: 2 / 12500 loss=46.473, ppl=9.77028e+13, wps=47018.1, ups=11.47, wpb=4096, bsz=8, num_updates=2, lr=0.0001, gnorm=15.019, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:13 | INFO | train_inner | epoch 001: 3 / 12500 loss=30.525, ppl=1.54543e+09, wps=59292.6, ups=14.46, wpb=4096, bsz=8, num_updates=3, lr=0.0001, gnorm=13.936, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:13 | INFO | train_inner | epoch 001: 4 / 12500 loss=18.561, ppl=386799, wps=57708.9, ups=14.08, wpb=4096, bsz=8, num_updates=4, lr=0.0001, gnorm=7.251, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 5 / 12500 loss=15.145, ppl=36230, wps=57427.4, ups=14.01, wpb=4096, bsz=8, num_updates=5, lr=0.0001, gnorm=2.392, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 6 / 12500 loss=14.683, ppl=26304.2, wps=58730.2, ups=14.33, wpb=4096, bsz=8, num_updates=6, lr=0.0001, gnorm=2.487, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 7 / 12500 loss=14.169, ppl=18418.9, wps=59523.2, ups=14.52, wpb=4096, bsz=8, num_updates=7, lr=0.0001, gnorm=2.45, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 8 / 12500 loss=13.574, ppl=12197.4, wps=58945.2, ups=14.38, wpb=4096, bsz=8, num_updates=8, lr=0.0001, gnorm=2.393, loss_scale=128, train_wall=0, wall=6 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 9 / 12500 loss=12.974, ppl=8047.87, wps=59659.2, ups=14.55, wpb=4096, bsz=8, num_updates=9, lr=0.0001, gnorm=2.317, loss_scale=128, train_wall=0, wall=7 2020-11-04 07:11:14 | INFO | train_inner | epoch 001: 10 / 12500 loss=12.341, ppl=5187.55, wps=59681.4, ups=14.56, wpb=4096, bsz=8, num_updates=10, lr=0.0001, gnorm=2.213, loss_scale=128, train_wall=0, wall=7 ``` Test Plan: Imported from OSS Reviewed By: ngoyal2707 Differential Revision: D24728687 Pulled By: myleott fbshipit-source-id: 2d387d022ee889494f429b98df1942167896e306 --- .../multilingual/sampled_multi_dataset.py | 8 ++- fairseq/distributed_utils.py | 72 +++++++++++++------ fairseq/model_parallel/megatron_trainer.py | 12 ++-- fairseq/optim/fp16_optimizer.py | 5 +- fairseq/trainer.py | 7 +- fairseq_cli/validate.py | 1 + 6 files changed, 74 insertions(+), 31 deletions(-) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 3f544b099f..599f3a862b 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -160,9 +160,13 @@ def _sync_sample_ratios(self, ratios): ratios = torch.DoubleTensor(ratios) if torch.distributed.is_initialized(): if torch.cuda.is_available(): - distributed_utils.all_reduce(ratios.cuda()) + distributed_utils.all_reduce( + ratios.cuda(), group=distributed_utils.get_data_parallel_group() + ) else: - distributed_utils.all_reduce(ratios) + distributed_utils.all_reduce( + ratios, group=distributed_utils.get_data_parallel_group() + ) ret = ratios.cpu() ret = ret.numpy() return ret diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 3439508c94..9285f71e35 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -23,6 +23,11 @@ from omegaconf import open_dict +# Flag to indicate if we're using Megatron +# NOTE: this is a temporary hack until we move away from Megatron's model parallel init +_USE_MEGATRON = False + + logger = logging.getLogger(__name__) @@ -262,6 +267,8 @@ def distributed_init(cfg: FairseqConfig): "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) + global _USE_MEGATRON + _USE_MEGATRON = True initialize_model_parallel(cfg.common.model_parallel_size) model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() @@ -319,26 +326,49 @@ def call_main(cfg: FairseqConfig, main, **kwargs): main(cfg, **kwargs) -def get_rank(): - return dist.get_rank() +def get_rank(group): + return dist.get_rank(group=group) + +def get_world_size(group): + return dist.get_world_size(group=group) -def get_world_size(): - return dist.get_world_size() + +def get_global_group(): + if torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None -def get_default_group(): - return dist.group.WORLD +def get_data_parallel_group(): + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + return mpu.get_data_parallel_group() + else: + return get_global_group() + + +def get_model_parallel_group(): + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + return mpu.get_model_parallel_group() + else: + return None -def all_reduce(tensor, group=None): +def all_reduce(tensor, group): if isinstance(group, tuple) and group[0] == "tpu": import torch_xla.core.xla_model as xm return xm.all_reduce("sum", [tensor], groups=group[1]) else: - if group is None: - group = get_default_group() return dist.all_reduce(tensor, group=group) @@ -350,12 +380,14 @@ def all_gather_list(data, group=None, max_size=16384): Args: data (Any): data from the local worker to be gathered on other workers - group (optional): group of the collective + group: group of the collective max_size (int, optional): maximum size of the data to be gathered across workers """ - rank = get_rank() - world_size = get_world_size() + if group is None: + group = get_global_group() + rank = get_rank(group=group) + world_size = get_world_size(group=group) buffer_size = max_size * world_size if ( @@ -410,7 +442,7 @@ def all_gather_list(data, group=None, max_size=16384): ) -def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]: +def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: """ AllReduce a dictionary of values across workers. We separately reduce items that are already on the device and items on CPU for @@ -420,7 +452,7 @@ def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, An data (Mapping[str, Any]): dictionary of data to all-reduce, but cannot be a nested dictionary device (torch.device): device for the reduction - group (optional): group of the collective + group: group of the collective """ data_keys = list(data.keys()) @@ -463,7 +495,7 @@ def get_from_stack(key): def broadcast_object( obj: Any, src_rank: int, - group: object = dist.group.WORLD, + group: object, dist_device: Optional[torch.device] = None, ) -> Any: """ @@ -476,23 +508,23 @@ def broadcast_object( else: dist_device = torch.device("cpu") - if dist.get_rank() == src_rank: + if get_rank(group) == src_rank: # Emit data buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) - length_tensor = torch.LongTensor([len(data)]).to(dist_device) - data_send_tensor = torch.ByteTensor(data).to(dist_device) + length_tensor = torch.tensor([len(data)], dtype=torch.long, device=dist_device) + data_send_tensor = torch.tensor(data, dtype=torch.uint8, device=dist_device) dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) else: # Fetch from the source - length_tensor = torch.LongTensor([0]).to(dist_device) + length_tensor = torch.tensor([0], dtype=torch.long, device=dist_device) dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) data_recv_tensor = torch.empty( [int(length_tensor.item())], dtype=torch.uint8, device=dist_device ) dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) - obj = torch.load(buffer, map_location=dist_device) + obj = torch.load(buffer, map_location="cpu") return obj diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index a8c7bc9d98..b86b3d14de 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -13,10 +13,8 @@ try: from fairseq.model_parallel.megatron.mpu import ( - get_data_parallel_group, get_data_parallel_rank, get_data_parallel_world_size, - get_model_parallel_group, get_model_parallel_src_rank, get_cuda_rng_tracker, ) @@ -44,7 +42,7 @@ def data_parallel_world_size(self): @property def data_parallel_process_group(self): - return get_data_parallel_group() + return distributed_utils.get_data_parallel_group() @property def data_parallel_rank(self): @@ -57,7 +55,9 @@ def is_data_parallel_master(self): def clip_grad_norm(self, clip_norm): def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 - distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) + distributed_utils.all_reduce( + total_norm, group=distributed_utils.get_model_parallel_group() + ) total_norm = total_norm ** 0.5 return total_norm @@ -71,7 +71,7 @@ def save_checkpoint(self, filename, extra_state): extra_state['rng_tracker_states'] \ = get_cuda_rng_tracker().get_states() super().save_checkpoint(filename, extra_state) - + def load_checkpoint( self, filename, @@ -84,4 +84,4 @@ def load_checkpoint( if extra_state is not None and 'rng_tracker_states' in extra_state: get_cuda_rng_tracker().set_states( extra_state['rng_tracker_states']) - return extra_state \ No newline at end of file + return extra_state diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 8ef61a6a7e..2341f47077 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -127,7 +127,10 @@ def _sync_fp16_grads_to_fp32(self): if not p.requires_grad: continue if p.grad is not None: - p32.grad.data.copy_(p.grad.data) + if p32.grad is None: + p32.grad = p.grad.data.float() + else: + p32.grad.data.copy_(p.grad.data) else: p32.grad = torch.zeros_like(p.data, dtype=torch.float) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 657374aab7..c37ea5cbee 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -116,7 +116,9 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if self.cuda: self.cuda_env = utils.CudaEnvironment() if self.data_parallel_world_size > 1: - self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env) + self.cuda_env_arr = distributed_utils.all_gather_list( + self.cuda_env, group=distributed_utils.get_global_group() + ) else: self.cuda_env_arr = [self.cuda_env] if self.data_parallel_rank == 0: @@ -147,7 +149,7 @@ def data_parallel_process_group(self): if self.tpu: return ("tpu", None) else: - return None + return distributed_utils.get_data_parallel_group() @property def data_parallel_rank(self): @@ -325,6 +327,7 @@ def load_checkpoint( state, src_rank=0, group=group, + dist_device=self.device, ) if self.data_parallel_rank > 0: last_optim_state = state.get("last_optimizer_state", None) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index a1e577ed7a..36e8bd16ca 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -115,6 +115,7 @@ def main(cfg: DictConfig, override_args=None): log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, + group=distributed_utils.get_data_parallel_group(), ) log_outputs = list(chain.from_iterable(log_outputs)) From b4d57c6d49682094efe22fbe2c03fa2c4973869f Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 5 Nov 2020 15:17:32 -0800 Subject: [PATCH 271/707] Move TPU grad reductions out of Trainer into TPUDistributedDataParallel (#1397) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1397 Data parallel command: `python train.py ~/data/data-bin/wikitext-103-roberta-bpe-bin/ --task language_modeling --arch transformer_lm --batch-size 8 --tokens-per-sample 512 --log-format simple --log-interval 1 --fp16 --optimizer adam --share-decoder-input-output-embed --lr 0.0001` Data parallel before: ``` 2020-11-04 08:20:13 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:20:13 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:20:13 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 08:20:13 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:20:14 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:20:14 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:20:14 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:20:19 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 2 / 3587 loss=19.682, ppl=841142, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=13.17, loss_scale=64, train_wall=0, wall=5 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 3 / 3587 loss=16.721, ppl=108002, wps=160870, ups=4.91, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=4.507, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 4 / 3587 loss=16.07, ppl=68785.8, wps=517232, ups=15.77, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=2.737, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 5 / 3587 loss=15.714, ppl=53741.4, wps=537322, ups=16.38, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=2.542, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 6 / 3587 loss=15.441, ppl=44492.1, wps=540488, ups=16.48, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=2.485, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 7 / 3587 loss=15.199, ppl=37603.2, wps=543411, ups=16.57, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.382, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:19 | INFO | train_inner | epoch 001: 8 / 3587 loss=14.984, ppl=32414, wps=540359, ups=16.47, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.274, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:20 | INFO | train_inner | epoch 001: 9 / 3587 loss=14.7, ppl=26622.2, wps=533446, ups=16.26, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.16, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:20:20 | INFO | train_inner | epoch 001: 10 / 3587 loss=14.482, ppl=22875.4, wps=539734, ups=16.46, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.055, loss_scale=64, train_wall=0, wall=6 ``` Data parallel after: ``` 2020-11-04 08:14:02 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:14:02 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:14:02 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 08:14:02 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:14:03 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:14:03 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:14:03 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:14:08 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 2 / 3587 loss=19.682, ppl=841142, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=13.17, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 3 / 3587 loss=16.721, ppl=108002, wps=157099, ups=4.79, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=4.507, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 4 / 3587 loss=16.07, ppl=68785.8, wps=560049, ups=17.08, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=2.737, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 5 / 3587 loss=15.714, ppl=53741.4, wps=558507, ups=17.03, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=2.542, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 6 / 3587 loss=15.441, ppl=44492.1, wps=514194, ups=15.68, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=2.485, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:08 | INFO | train_inner | epoch 001: 7 / 3587 loss=15.199, ppl=37603.2, wps=552676, ups=16.85, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.382, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:09 | INFO | train_inner | epoch 001: 8 / 3587 loss=14.984, ppl=32414, wps=546402, ups=16.66, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.274, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:09 | INFO | train_inner | epoch 001: 9 / 3587 loss=14.7, ppl=26622.2, wps=508472, ups=15.5, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.16, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:09 | INFO | train_inner | epoch 001: 10 / 3587 loss=14.482, ppl=22875.4, wps=552493, ups=16.84, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.055, loss_scale=64, train_wall=0, wall=6 ``` Data parallel command (no_c10d): `python train.py ~/data/data-bin/wikitext-103-roberta-bpe-bin/ --task language_modeling --arch transformer_lm --batch-size 8 --tokens-per-sample 512 --log-format simple --log-interval 1 --fp16 --optimizer adam --share-decoder-input-output-embed --lr 0.0001 --dp-backend no_c10d` Data parallel before: ``` 2020-11-04 08:19:25 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:19:25 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:19:25 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 08:19:25 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:19:25 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:19:26 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:19:26 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:19:31 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:19:31 | INFO | train_inner | epoch 001: 2 / 3587 loss=19.682, ppl=841142, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=13.17, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 3 / 3587 loss=16.721, ppl=108001, wps=141659, ups=4.32, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=4.507, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 4 / 3587 loss=16.07, ppl=68785.9, wps=503762, ups=15.36, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=2.737, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 5 / 3587 loss=15.714, ppl=53741.5, wps=488599, ups=14.9, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=2.542, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 6 / 3587 loss=15.441, ppl=44492, wps=507855, ups=15.48, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=2.485, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 7 / 3587 loss=15.199, ppl=37603, wps=503270, ups=15.34, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.382, loss_scale=64, train_wall=0, wall=7 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 8 / 3587 loss=14.984, ppl=32414, wps=467778, ups=14.26, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.274, loss_scale=64, train_wall=0, wall=7 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 9 / 3587 loss=14.7, ppl=26622.2, wps=503800, ups=15.36, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.16, loss_scale=64, train_wall=0, wall=7 2020-11-04 08:19:32 | INFO | train_inner | epoch 001: 10 / 3587 loss=14.482, ppl=22875.3, wps=468486, ups=14.28, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.055, loss_scale=64, train_wall=0, wall=7 ``` Data parallel after: ``` 2020-11-04 08:14:50 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:14:50 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:14:50 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last.pt 2020-11-04 08:14:50 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:14:50 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:14:51 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:14:51 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:14:56 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 2 / 3587 loss=19.682, ppl=841142, wps=0, ups=0, wpb=32768, bsz=64, num_updates=1, lr=0.0001, gnorm=13.17, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 3 / 3587 loss=16.721, ppl=108001, wps=137677, ups=4.2, wpb=32768, bsz=64, num_updates=2, lr=0.0001, gnorm=4.507, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 4 / 3587 loss=16.07, ppl=68785.9, wps=519541, ups=15.84, wpb=32768, bsz=64, num_updates=3, lr=0.0001, gnorm=2.737, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 5 / 3587 loss=15.714, ppl=53741.5, wps=517063, ups=15.76, wpb=32768, bsz=64, num_updates=4, lr=0.0001, gnorm=2.542, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 6 / 3587 loss=15.441, ppl=44492, wps=490728, ups=14.95, wpb=32768, bsz=64, num_updates=5, lr=0.0001, gnorm=2.485, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 7 / 3587 loss=15.199, ppl=37603, wps=505262, ups=15.41, wpb=32768, bsz=64, num_updates=6, lr=0.0001, gnorm=2.382, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:56 | INFO | train_inner | epoch 001: 8 / 3587 loss=14.984, ppl=32414, wps=508874, ups=15.52, wpb=32768, bsz=64, num_updates=7, lr=0.0001, gnorm=2.274, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:57 | INFO | train_inner | epoch 001: 9 / 3587 loss=14.7, ppl=26622.2, wps=518028, ups=15.79, wpb=32768, bsz=64, num_updates=8, lr=0.0001, gnorm=2.16, loss_scale=64, train_wall=0, wall=6 2020-11-04 08:14:57 | INFO | train_inner | epoch 001: 10 / 3587 loss=14.482, ppl=22875.3, wps=515996, ups=15.73, wpb=32768, bsz=64, num_updates=9, lr=0.0001, gnorm=2.055, loss_scale=64, train_wall=0, wall=7 ``` Model parallel command: `python train.py ~/data/data-bin/wikitext-103-roberta-bpe-bin/ --task language_modeling --arch transformer_lm_megatron --decoder-layers 4 --batch-size 8 --tokens-per-sample 512 --log-format simple --log-interval 1 --fp16 --optimizer adam --model-parallel-size 2 --share-decoder-input-output-embed --lr 0.0001` Model parallel before: ``` 2020-11-04 08:18:38 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:18:38 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:18:38 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last-model_part-0.pt 2020-11-04 08:18:38 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:18:38 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:18:39 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:18:39 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:18:44 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:18:45 | INFO | train_inner | epoch 001: 2 / 7173 loss=55.997, ppl=7.19017e+16, wps=0, ups=0, wpb=16384, bsz=32, num_updates=1, lr=0.0001, gnorm=14.03, loss_scale=64, train_wall=1, wall=7 2020-11-04 08:18:45 | INFO | train_inner | epoch 001: 3 / 7173 loss=28.372, ppl=3.47501e+08, wps=48371.7, ups=2.95, wpb=16384, bsz=32, num_updates=2, lr=0.0001, gnorm=15.339, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:18:46 | INFO | train_inner | epoch 001: 4 / 7173 loss=15.855, ppl=59276.8, wps=72422.5, ups=4.42, wpb=16384, bsz=32, num_updates=3, lr=0.0001, gnorm=4.189, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:18:46 | INFO | train_inner | epoch 001: 5 / 7173 loss=14.713, ppl=26858.7, wps=72933.5, ups=4.45, wpb=16384, bsz=32, num_updates=4, lr=0.0001, gnorm=4.751, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:18:46 | INFO | train_inner | epoch 001: 6 / 7173 loss=13.901, ppl=15299.7, wps=71974.8, ups=4.39, wpb=16384, bsz=32, num_updates=5, lr=0.0001, gnorm=4.361, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:18:46 | INFO | train_inner | epoch 001: 7 / 7173 loss=13.312, ppl=10169.5, wps=72897.8, ups=4.45, wpb=16384, bsz=32, num_updates=6, lr=0.0001, gnorm=3.307, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:18:47 | INFO | train_inner | epoch 001: 8 / 7173 loss=12.914, ppl=7720.21, wps=73044.6, ups=4.46, wpb=16384, bsz=32, num_updates=7, lr=0.0001, gnorm=5.473, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:18:47 | INFO | train_inner | epoch 001: 9 / 7173 loss=12.56, ppl=6036.72, wps=73453.1, ups=4.48, wpb=16384, bsz=32, num_updates=8, lr=0.0001, gnorm=6.112, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:18:47 | INFO | train_inner | epoch 001: 10 / 7173 loss=12.116, ppl=4437.77, wps=73442.6, ups=4.48, wpb=16384, bsz=32, num_updates=9, lr=0.0001, gnorm=4.415, loss_scale=64, train_wall=0, wall=9 ``` Model parallel after: ``` 2020-11-04 08:12:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) 2020-11-04 08:12:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 2020-11-04 08:12:09 | INFO | fairseq.trainer | no existing checkpoint found checkpoints/checkpoint_last-model_part-0.pt 2020-11-04 08:12:09 | INFO | fairseq.trainer | loading train data for epoch 1 2020-11-04 08:12:09 | INFO | fairseq.data.data_utils | loaded 1801350 examples from: /private/home/myleott/data/data-bin/wikitext-103-roberta-bpe-bin/train 2020-11-04 08:12:10 | INFO | fairseq.optim.adam | using FusedAdam 2020-11-04 08:12:10 | INFO | fairseq.trainer | begin training epoch 1 2020-11-04 08:12:16 | INFO | fairseq.trainer | NOTE: overflow detected, setting loss scale to: 64.0 2020-11-04 08:12:17 | INFO | train_inner | epoch 001: 2 / 7173 loss=55.997, ppl=7.19017e+16, wps=0, ups=0, wpb=16384, bsz=32, num_updates=1, lr=0.0001, gnorm=14.03, loss_scale=64, train_wall=1, wall=8 2020-11-04 08:12:17 | INFO | train_inner | epoch 001: 3 / 7173 loss=28.372, ppl=3.47501e+08, wps=53097, ups=3.24, wpb=16384, bsz=32, num_updates=2, lr=0.0001, gnorm=15.339, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:12:17 | INFO | train_inner | epoch 001: 4 / 7173 loss=15.855, ppl=59276.8, wps=72355.5, ups=4.42, wpb=16384, bsz=32, num_updates=3, lr=0.0001, gnorm=4.189, loss_scale=64, train_wall=0, wall=8 2020-11-04 08:12:17 | INFO | train_inner | epoch 001: 5 / 7173 loss=14.713, ppl=26858.7, wps=70526.4, ups=4.3, wpb=16384, bsz=32, num_updates=4, lr=0.0001, gnorm=4.751, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:12:18 | INFO | train_inner | epoch 001: 6 / 7173 loss=13.901, ppl=15299.7, wps=73063.5, ups=4.46, wpb=16384, bsz=32, num_updates=5, lr=0.0001, gnorm=4.361, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:12:18 | INFO | train_inner | epoch 001: 7 / 7173 loss=13.312, ppl=10169.5, wps=73559.4, ups=4.49, wpb=16384, bsz=32, num_updates=6, lr=0.0001, gnorm=3.307, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:12:18 | INFO | train_inner | epoch 001: 8 / 7173 loss=12.914, ppl=7720.21, wps=72693.2, ups=4.44, wpb=16384, bsz=32, num_updates=7, lr=0.0001, gnorm=5.473, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:12:18 | INFO | train_inner | epoch 001: 9 / 7173 loss=12.56, ppl=6036.72, wps=73531.2, ups=4.49, wpb=16384, bsz=32, num_updates=8, lr=0.0001, gnorm=6.112, loss_scale=64, train_wall=0, wall=9 2020-11-04 08:12:19 | INFO | train_inner | epoch 001: 10 / 7173 loss=12.116, ppl=4437.77, wps=73187.6, ups=4.47, wpb=16384, bsz=32, num_updates=9, lr=0.0001, gnorm=4.415, loss_scale=64, train_wall=0, wall=10 ``` Test Plan: Imported from OSS Reviewed By: ngoyal2707 Differential Revision: D24729295 Pulled By: myleott fbshipit-source-id: beee8bdece3eaa0419a2e813990420411e507c75 --- README.md | 4 +- fairseq/checkpoint_utils.py | 8 +- fairseq/distributed_utils.py | 192 +++++++++++++++++--- fairseq/legacy_distributed_data_parallel.py | 12 +- fairseq/model_parallel/megatron_trainer.py | 16 -- fairseq/models/distributed_fairseq_model.py | 50 ++++- fairseq/optim/fairseq_optimizer.py | 5 + fairseq/optim/fp16_optimizer.py | 6 + fairseq/trainer.py | 64 +++---- fairseq/utils.py | 2 +- 10 files changed, 272 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 70e98fe395..13b8223959 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ We provide reference implementations of various sequence modeling papers: - Attention Is All You Need (Vaswani et al., 2017) - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md) + - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) @@ -171,7 +171,7 @@ We also have more detailed READMEs to reproduce results from specific papers: - [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) +- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) # Join the fairseq community diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index fdee84c181..25d3e1e705 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -13,6 +13,7 @@ from typing import Optional, Union import torch +from fairseq import utils from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -225,9 +226,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" with open(PathManager.get_local_path(path), "rb") as f: - state = torch.load( - f, map_location=lambda s, l: default_restore_location(s, "cpu") - ) + state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] @@ -385,6 +384,9 @@ def save_state( if not no_save_optimizer_state: state_dict["last_optimizer_state"] = optimizer.state_dict() + # keep everything on CPU + state_dict = utils.move_to_cpu(state_dict) + with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 9285f71e35..9059d8aa2b 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -14,7 +14,7 @@ import warnings from argparse import Namespace from collections import OrderedDict -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional import torch import torch.distributed as dist @@ -22,11 +22,19 @@ from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig from omegaconf import open_dict +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + # Flag to indicate if we're using Megatron # NOTE: this is a temporary hack until we move away from Megatron's model parallel init _USE_MEGATRON = False +# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops. +_USE_XLA = False + logger = logging.getLogger(__name__) @@ -241,9 +249,9 @@ def distributed_init(cfg: FairseqConfig): cfg.distributed_training.distributed_rank = torch.distributed.get_rank() else: - import torch_xla.core.xla_model as xm - assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + global _USE_XLA + _USE_XLA = True cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers @@ -257,7 +265,6 @@ def distributed_init(cfg: FairseqConfig): if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( - get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) @@ -273,6 +280,7 @@ def distributed_init(cfg: FairseqConfig): model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + return cfg.distributed_training.distributed_rank @@ -326,16 +334,55 @@ def call_main(cfg: FairseqConfig, main, **kwargs): main(cfg, **kwargs) +def use_xla(): + global _USE_XLA + return _USE_XLA + + +def new_groups(grouped_ranks: List[List[int]]): + if use_xla(): + return ("tpu", grouped_ranks) + else: + groups = [dist.new_group(g) for g in grouped_ranks] + my_group_idx = _find_my_group_index(grouped_ranks) + return groups[my_group_idx] + + +def _find_my_group_index(grouped_ranks): + my_rank = get_global_rank() + for i, group in enumerate(grouped_ranks): + if my_rank in group: + return i + raise RuntimeError + + +def _find_my_group(grouped_ranks): + index = _find_my_group_index(grouped_ranks) + return grouped_ranks[index] + + def get_rank(group): - return dist.get_rank(group=group) + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return my_group.index(get_global_rank()) + else: + return dist.get_rank(group=group) def get_world_size(group): - return dist.get_world_size(group=group) + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return len(my_group) + else: + return dist.get_world_size(group=group) def get_global_group(): - if torch.distributed.is_initialized(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): if not hasattr(get_global_group, "_global_group"): # ideally we could use torch.distributed.group.WORLD, but it seems # to cause random NCCL hangs in some cases @@ -345,7 +392,26 @@ def get_global_group(): return None +def get_global_rank(): + if use_xla(): + return xm.get_ordinal() + elif torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def get_global_world_size(): + if use_xla(): + return xm.xrt_world_size() + elif torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu @@ -354,6 +420,16 @@ def get_data_parallel_group(): return get_global_group() +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_rank(get_data_parallel_group()) + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_world_size(get_data_parallel_group()) + + def get_model_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: @@ -363,13 +439,83 @@ def get_model_parallel_group(): return None -def all_reduce(tensor, group): - if isinstance(group, tuple) and group[0] == "tpu": - import torch_xla.core.xla_model as xm +def get_model_parallel_rank(): + """Return my rank for the model parallel group.""" + return get_rank(get_model_parallel_group()) + + +def get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + return get_world_size(get_model_parallel_group()) - return xm.all_reduce("sum", [tensor], groups=group[1]) + +def all_reduce(tensor, group, op="sum"): + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + tensor = [tensor] # wrap in a list to make xm.all_reduce in-place + return xm.all_reduce(op, tensor, groups=group[1])[0] else: - return dist.all_reduce(tensor, group=group) + if op == "sum": + op = dist.ReduceOp.SUM + elif op == "max": + op = dist.ReduceOp.MAX + else: + raise NotImplementedError + dist.all_reduce(tensor, op=op, group=group) + return tensor + + +def broadcast(tensor, src, group): + if use_xla(): + # XLA doesn't support broadcast, hack it with all_reduce + if get_rank(group) != src: + tensor.zero_() + all_reduce(tensor, group) + else: + dist.broadcast(tensor, src=src, group=group) + + +def all_to_all(tensor, group): + """Perform an all-to-all operation on a 1D Tensor.""" + assert tensor.dim() == 1 + split_count = get_world_size(group=group) + assert tensor.numel() % split_count == 0 + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + return xm.all_to_all( + tensor, + split_dimension=0, + concat_dimension=0, + split_count=split_count, + groups=group[1], + ) + else: + output = torch.zeros_like(tensor) + dist.all_to_all_single(output, tensor, group=group) + return output + + +def all_gather(tensor, group, return_tensor=False): + """Perform an all-gather operation.""" + if use_xla(): + result = xm.all_gather(tensor, groups=group[1]) + world_size = get_world_size(group=group) + result = result.view(world_size, *tensor.size()) + if return_tensor: + return result + else: + return [result[i] for i in range(world_size)] + else: + world_size = get_world_size(group=group) + rank = get_rank(group=group) + tensor_list = [ + tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) + ] + dist.all_gather(tensor_list, tensor, group=group) + if return_tensor: + return torch.stack(tensor_list, dim=0) + else: + return tensor_list def all_gather_list(data, group=None, max_size=16384): @@ -497,6 +643,8 @@ def broadcast_object( src_rank: int, group: object, dist_device: Optional[torch.device] = None, + dist_length_dtype: Optional[torch.dtype] = torch.long, + dist_dtype: Optional[torch.dtype] = torch.uint8, ) -> Any: """ Either broadcast from master to the fleet (default), @@ -513,18 +661,20 @@ def broadcast_object( buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)], dtype=torch.long, device=dist_device) - data_send_tensor = torch.tensor(data, dtype=torch.uint8, device=dist_device) - dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) - dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + length_tensor = torch.tensor( + [len(data)], dtype=dist_length_dtype, device=dist_device + ) + broadcast(length_tensor, src=src_rank, group=group) + data_send_tensor = torch.tensor(data, dtype=dist_dtype, device=dist_device) + broadcast(data_send_tensor, src=src_rank, group=group) else: # Fetch from the source - length_tensor = torch.tensor([0], dtype=torch.long, device=dist_device) - dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) - data_recv_tensor = torch.empty( - [int(length_tensor.item())], dtype=torch.uint8, device=dist_device + length_tensor = torch.tensor([0], dtype=dist_length_dtype, device=dist_device) + broadcast(length_tensor, src=src_rank, group=group) + data_recv_tensor = torch.zeros( + [int(length_tensor.item())], dtype=dist_dtype, device=dist_device ) - dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + broadcast(data_recv_tensor, src=src_rank, group=group) buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) obj = torch.load(buffer, map_location="cpu") return obj diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py index 44f87c7c42..7e176eaf3d 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/legacy_distributed_data_parallel.py @@ -34,20 +34,18 @@ class LegacyDistributedDataParallel(nn.Module): Args: module (~torch.nn.Module): module to be parallelized - world_size (int): number of parallel workers - process_group (optional): the c10d process group to be used for - distributed data all-reduction. If None, the default process group - will be used. + process_group: the c10d process group to be used for distributed data + parallel all-reduction. buffer_size (int, optional): number of elements to buffer before performing all-reduce (default: 256M). """ - def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28): + def __init__(self, module, process_group, buffer_size=2 ** 28): super().__init__() self.module = module - self.world_size = world_size self.process_group = process_group + self.world_size = distributed_utils.get_world_size(self.process_group) # Never use a bigger buffer than the number of model params self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) @@ -84,7 +82,7 @@ def no_sync(self): def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) - def all_reduce(self): + def all_reduce_grads(self): """ This function must be called explicitly after backward to reduce gradients. There is no automatic hook like c10d. diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index b86b3d14de..1a6e844aee 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -36,22 +36,6 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): ) super().__init__(cfg, task, model, criterion, **kwargs) - @property - def data_parallel_world_size(self): - return get_data_parallel_world_size() - - @property - def data_parallel_process_group(self): - return distributed_utils.get_data_parallel_group() - - @property - def data_parallel_rank(self): - return get_data_parallel_rank() - - @property - def is_data_parallel_master(self): - return get_model_parallel_src_rank() == 0 - def clip_grad_norm(self, clip_norm): def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ece10c6333..b78a0125e3 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -5,7 +5,10 @@ import inspect +import torch import torch.nn as nn + +from fairseq import distributed_utils from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel @@ -16,7 +19,7 @@ _GOSSIP_DISABLED = True -def DistributedFairseqModel(args, model, process_group=None): +def DistributedFairseqModel(args, model, process_group): """ Wrap a *model* to support distributed data parallel training. @@ -28,10 +31,18 @@ def DistributedFairseqModel(args, model, process_group=None): Args: args (argparse.Namespace): fairseq args model (BaseFairseqModel): model to wrap + process_group: the c10d process group to be used for distributed data + parallel all-reduction. """ # determine which DDP class to extend assert isinstance(model, nn.Module) - if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d": + if args.tpu: + ddp_class = TPUDistributedDataParallel + init_kwargs = dict( + module=model, + process_group=process_group, + ) + elif args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d": ddp_class = nn.parallel.DistributedDataParallel init_kwargs = dict( module=model, @@ -50,7 +61,6 @@ def DistributedFairseqModel(args, model, process_group=None): ddp_class = LegacyDistributedDataParallel init_kwargs = dict( module=model, - world_size=args.distributed_world_size, buffer_size=2 ** 28, process_group=process_group, ) @@ -101,3 +111,37 @@ def __getattr__(self, name): return super().__getattr__(name) return _DistributedFairseqModel(**init_kwargs) + + +class TPUDistributedDataParallel(nn.Module): + + def __init__(self, module, process_group): + super().__init__() + self.module = module + self.process_group = process_group + self.world_size = distributed_utils.get_world_size(self.process_group) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + gradients = [] + for p in self.parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + if p.grad.requires_grad: + raise RuntimeError( + "TPUDistributedDataParallel only works with gradients that don't " + "require grad" + ) + gradients.append(p.grad) + + import torch_xla.core.xla_model as xm + xm.all_reduce( + 'sum', + gradients, + scale=1. / self.world_size, + groups=self.process_group[1], + ) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index e91e9d3204..f9864533b6 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -94,6 +94,11 @@ def backward(self, loss): """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" loss.backward() + def all_reduce_grads(self, module): + """Manually all-reduce gradients (if required).""" + if hasattr(module, "all_reduce_grads"): + module.all_reduce_grads() + def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" for p in self.params: diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 2341f47077..7fe3bdd5e5 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -304,6 +304,9 @@ def get_lr(self): def set_lr(self, lr): self.fp32_optimizer.set_lr(lr) + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -498,3 +501,6 @@ def get_lr(self): def set_lr(self, lr): self.wrapped_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.wrapped_optimizer.all_reduce_grads(module) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index c37ea5cbee..33207a4cc0 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -63,10 +63,6 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if self.tpu: - import torch_xla.core.xla_model as xm - - self._model = xm.send_cpu_data_to_device(self._model, self.device) if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() @@ -142,22 +138,25 @@ def reinitialize(self): @property def data_parallel_world_size(self): - return self.cfg.distributed_training.distributed_world_size + if self.cfg.distributed_training.distributed_world_size == 1: + return 1 + return distributed_utils.get_data_parallel_world_size() @property def data_parallel_process_group(self): - if self.tpu: - return ("tpu", None) - else: - return distributed_utils.get_data_parallel_group() + return distributed_utils.get_data_parallel_group() @property def data_parallel_rank(self): - return self.cfg.distributed_training.distributed_rank + if self.cfg.distributed_training.distributed_world_size == 1: + return 0 + return distributed_utils.get_data_parallel_rank() @property def is_data_parallel_master(self): - return distributed_utils.is_master(self.cfg.distributed_training) + # NOTE: this returns true for all model parallel replicas with data + # parallel rank 0 + return self.data_parallel_rank == 0 @property def criterion(self): @@ -166,7 +165,6 @@ def criterion(self): utils.has_parameters(self._criterion) and self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf - and not self.tpu ): self._wrapped_criterion = models.DistributedFairseqModel( self.cfg.distributed_training, @@ -183,7 +181,6 @@ def model(self): if ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf - and not self.tpu ): self._wrapped_model = models.DistributedFairseqModel( self.cfg.distributed_training, @@ -300,7 +297,12 @@ def load_checkpoint( bexists = PathManager.isfile(filename) if bexists: - if self.data_parallel_rank == 0: + if ( + self.data_parallel_rank == 0 + # TPUs don't support broadcast yet, so load checkpoints + # on every worker for now + or self.tpu + ): state = checkpoint_utils.load_checkpoint_to_cpu(filename) last_optim_state = state.get("last_optimizer_state", None) @@ -317,16 +319,15 @@ def load_checkpoint( last_optim_state = None state = None - if self.data_parallel_world_size > 1: - group = ( - self.data_parallel_process_group - if self.data_parallel_process_group is not None - else torch.distributed.group.WORLD - ) + if ( + self.data_parallel_world_size > 1 + # disable on TPUs until they support broadcast + and not self.tpu + ): state = distributed_utils.broadcast_object( state, src_rank=0, - group=group, + group=self.data_parallel_process_group, dist_device=self.device, ) if self.data_parallel_rank > 0: @@ -607,23 +608,17 @@ def maybe_no_sync(): total_train_time / self.data_parallel_world_size ) - if hasattr(self.model, "all_reduce"): - self.model.all_reduce() - overflow = False try: - if self.tpu and self.data_parallel_world_size > 1: - import torch_xla.core.xla_model as xm - - gradients = xm._fetch_gradients(self.optimizer.optimizer) - xm.all_reduce( - "sum", gradients, scale=1.0 / self.data_parallel_world_size - ) + with torch.autograd.profiler.record_function("reduce-grads"): + self.optimizer.all_reduce_grads(self.model) + if utils.has_parameters(self.criterion): + self.optimizer.all_reduce_grads(self.criterion) with torch.autograd.profiler.record_function("multiply-grads"): - # multiply gradients by (# GPUs / sample_size) since DDP - # already normalizes by the number of GPUs. Thus we get - # (sum_of_gradients / sample_size). + # multiply gradients by (data_parallel_size / sample_size) since + # DDP already normalizes by the number of data parallel workers. + # Thus we get (sum_of_gradients / sample_size) at the end. if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size @@ -683,6 +678,7 @@ def maybe_no_sync(): self.optimizer.optimizer ) + logging_output = None if ( not overflow or self.cfg.distributed_training.distributed_wrapper == "SlowMo" diff --git a/fairseq/utils.py b/fairseq/utils.py index 87c124b736..8e9119124d 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -641,7 +641,7 @@ def new_arange(x, *size): return torch.arange(size[-1], device=x.device).expand(*size).contiguous() -def get_tpu_device(args): +def get_tpu_device(): return xm.xla_device() From 83c39c41388f2e7ba37647d2e8a0cbc97f6f8032 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Thu, 5 Nov 2020 16:12:24 -0800 Subject: [PATCH 272/707] fix fd issues for ce training with extra splits Summary: Some folks are still reporiting errors like in P145050739, ``` ValueError: too many fds ``` This diff follows up with D24234470 (https://github.com/pytorch/fairseq/commit/a9baca376616bed56e5df5115d7adf8059c0d296), where we add support for `supports_fetch_outside_dataloader` to the rest of the datsets we use, including the sampling datasets enabled by --extra-splits. For why we need to add `supports_fetch_outside_dataloader` see D24234470 (https://github.com/pytorch/fairseq/commit/a9baca376616bed56e5df5115d7adf8059c0d296) Differential Revision: D24767506 fbshipit-source-id: 4e0252f70a9aa36155843677734f186fe03508c4 --- fairseq/data/multi_corpus_dataset.py | 7 +++++++ fairseq/data/multi_corpus_sampled_dataset.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index d2457666d6..9c7f1cb976 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -157,3 +157,10 @@ def set_epoch(self, epoch, **unused): @property def supports_prefetch(self): return False + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) diff --git a/fairseq/data/multi_corpus_sampled_dataset.py b/fairseq/data/multi_corpus_sampled_dataset.py index ad8e951cc9..e2e9fdf004 100644 --- a/fairseq/data/multi_corpus_sampled_dataset.py +++ b/fairseq/data/multi_corpus_sampled_dataset.py @@ -143,3 +143,10 @@ def prefetch(self, indices): dataset.prefetch( [self._map_index_to_dataset(key, index) for index in indices] ) + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) From 77c704dc866c1e85259153ec98917f5acc7c9d90 Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 6 Nov 2020 10:23:11 -0800 Subject: [PATCH 273/707] Misc fixes2 (#1402) Summary: Fixes #2855 Fixes #2847 Fixes #2841 Fixes #2783 make omegaconf strict warning not show up Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1402 Reviewed By: myleott Differential Revision: D24777582 Pulled By: alexeib fbshipit-source-id: 389e110c9de90c4a0744d01982f8071a7a867f09 --- examples/speech_recognition/w2l_decoder.py | 18 +++++++++++++----- fairseq/dataclass/utils.py | 14 +++++++++++--- fairseq/models/wav2vec/wav2vec2_asr.py | 2 +- fairseq_cli/score.py | 2 +- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index f760cd6df2..e2870df6a7 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -20,6 +20,8 @@ from examples.speech_recognition.data.replabels import unpack_replabels from fairseq import tasks from fairseq.utils import apply_to_sample +from omegaconf import open_dict +from fairseq.dataclass.utils import convert_namespace_to_omegaconf try: @@ -347,11 +349,17 @@ def __init__(self, args, tgt_dict): self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") - lm_args = checkpoint["args"] - lm_args.data = osp.dirname(args.kenlm_model) - print(lm_args) - task = tasks.setup_task(lm_args) - model = task.build_model(lm_args) + + if "cfg" in checkpoint and checkpoint["cfg"] is not None: + lm_args = checkpoint["cfg"] + else: + lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) + + with open_dict(lm_args.task): + lm_args.task.data = osp.dirname(args.kenlm_model) + + task = tasks.setup_task(lm_args.task) + model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 477a198d0f..d73977edaa 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -294,7 +294,7 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: # in case of "--editable" installs we need to go one dir up config_path = os.path.join("..", "..", "config") - with initialize(config_path=config_path, strict=True): + with initialize(config_path=config_path): composed_cfg = compose("config", overrides=overrides, strict=False) for k in deletes: composed_cfg[k] = None @@ -362,12 +362,20 @@ def populate_dataclass( def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): # this will be deprecated when we get rid of argparse and model_overrides logic + from fairseq.registry import REGISTRIES + with open_dict(cfg): for k in cfg.keys(): - if isinstance(cfg[k], DictConfig): + # "k in cfg" will return false if its a "mandatory value (e.g. ???)" + if k in cfg and isinstance(cfg[k], DictConfig): overwrite_args_by_name(cfg[k], overrides) elif k in overrides: - cfg[k] = overrides[k] + if k in REGISTRIES and overrides[k] in REGISTRIES[k]["dataclass_registry"]: + cfg[k] = DictConfig(REGISTRIES[k]["dataclass_registry"][overrides[k]]) + overwrite_args_by_name(cfg[k], overrides) + cfg[k]._name = overrides[k] + else: + cfg[k] = overrides[k] def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig): diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 1cbc6374fb..14fa8ea5ed 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -328,7 +328,7 @@ def __init__(self, args, tgt_dict=None): state = checkpoint_utils.load_checkpoint_to_cpu( args.w2v_path, arg_overrides ) - w2v_args = state.get("args", None) or state["cfg"].model + args.w2v_args = w2v_args = state.get("args", None) or state["cfg"].model else: state = None w2v_args = args.w2v_args diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index e06d67259d..0b207be959 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -58,7 +58,7 @@ def readlines(fd): def score(fdsys): with open(args.ref) as fdref: - print(sacrebleu.corpus_bleu(fdsys, [fdref])) + print(sacrebleu.corpus_bleu(fdsys, [fdref]).format()) elif args.sentence_bleu: From b7a2e00958f647497d2980fd7980ae9fcd8a513e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 6 Nov 2020 15:37:42 -0800 Subject: [PATCH 274/707] Avoid some device-to-host transfers (#1400) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1400 Reviewed By: msbaines Differential Revision: D24765749 Pulled By: myleott fbshipit-source-id: c242f59c88b0d8cb691948f0495af40ba415faff --- fairseq/optim/fp16_optimizer.py | 22 ++++++++++++++++++++-- fairseq/trainer.py | 4 +++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 7fe3bdd5e5..4457023527 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -162,7 +162,16 @@ def _sync_fp32_params_to_fp16(self): def _unscale_grads(self): self._sync_fp16_grads_to_fp32() - if self._multiply_factor != 1.0: + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): self.fp32_optimizer.multiply_grads(self._multiply_factor) self._multiply_factor = 1.0 @@ -370,7 +379,16 @@ def backward(self, loss): loss.backward() def _unscale_grads(self): - if self._multiply_factor != 1.0: + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): self.wrapped_optimizer.multiply_grads(self._multiply_factor) self._multiply_factor = 1.0 diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 33207a4cc0..19ca213d55 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1101,7 +1101,9 @@ def is_consistent(tensor): ) def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): - if grad_norm is not None: + if grad_norm is not None and ( + not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm) + ): metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.cfg.optimization.clip_norm > 0: From 09a5d864fc5a79a0ec6fdf09fa0825f197060683 Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 6 Nov 2020 22:51:29 -0800 Subject: [PATCH 275/707] move configs into fairseq dir (#1403) Summary: this way they get shipped together with fairseq package Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1403 Reviewed By: myleott Differential Revision: D24803076 Pulled By: alexeib fbshipit-source-id: a9aa6e47a8ef26fae4d54691f1616a721b8f6112 --- {config => fairseq/config}/config.yaml | 0 .../model/transformer_lm/transformer_lm_baevski_gbw.yaml | 0 .../transformer_lm/transformer_lm_baevski_wiki103.yaml | 0 .../config}/model/transformer_lm/transformer_lm_big.yaml | 0 .../config}/model/transformer_lm/transformer_lm_gbw.yaml | 0 .../config}/model/transformer_lm/transformer_lm_gpt.yaml | 0 .../model/transformer_lm/transformer_lm_gpt2_big.yaml | 0 .../model/transformer_lm/transformer_lm_gpt2_medium.yaml | 0 .../model/transformer_lm/transformer_lm_gpt2_small.yaml | 0 .../model/transformer_lm/transformer_lm_wiki103.yaml | 0 fairseq/dataclass/utils.py | 3 --- fairseq_cli/hydra_train.py | 3 ++- setup.py | 6 ++---- 13 files changed, 4 insertions(+), 8 deletions(-) rename {config => fairseq/config}/config.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_baevski_gbw.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_baevski_wiki103.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_big.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_gbw.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_gpt.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_gpt2_big.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_gpt2_medium.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_gpt2_small.yaml (100%) rename {config => fairseq/config}/model/transformer_lm/transformer_lm_wiki103.yaml (100%) diff --git a/config/config.yaml b/fairseq/config/config.yaml similarity index 100% rename from config/config.yaml rename to fairseq/config/config.yaml diff --git a/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml b/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_baevski_gbw.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml diff --git a/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml b/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml diff --git a/config/model/transformer_lm/transformer_lm_big.yaml b/fairseq/config/model/transformer_lm/transformer_lm_big.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_big.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_big.yaml diff --git a/config/model/transformer_lm/transformer_lm_gbw.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_gbw.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml diff --git a/config/model/transformer_lm/transformer_lm_gpt.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_gpt.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml diff --git a/config/model/transformer_lm/transformer_lm_gpt2_big.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_gpt2_big.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml diff --git a/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_gpt2_medium.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml diff --git a/config/model/transformer_lm/transformer_lm_gpt2_small.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_gpt2_small.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml diff --git a/config/model/transformer_lm/transformer_lm_wiki103.yaml b/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml similarity index 100% rename from config/model/transformer_lm/transformer_lm_wiki103.yaml rename to fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index d73977edaa..1efcc5dca9 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -290,9 +290,6 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: # configs will be in fairseq/config after installation config_path = os.path.join("..", "config") - if not os.path.exists(config_path): - # in case of "--editable" installs we need to go one dir up - config_path = os.path.join("..", "..", "config") with initialize(config_path=config_path): composed_cfg = compose("config", overrides=overrides, strict=False) diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 24728c507f..ffd3c5cd07 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -6,6 +6,7 @@ import hydra from omegaconf import OmegaConf +import os from fairseq.dataclass.initialize import hydra_init from fairseq_cli.train import main as pre_main @@ -19,7 +20,7 @@ logger = logging.getLogger(__name__) -@hydra.main(config_path="../config", config_name="config") +@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> None: cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) diff --git a/setup.py b/setup.py index 7b13f13e4c..572d2b50de 100644 --- a/setup.py +++ b/setup.py @@ -222,15 +222,13 @@ def get_files(path, relative_to="fairseq"): try: - # symlink config and examples into fairseq package so package_data accepts them + # symlink examples into fairseq package so package_data accepts them if "build_ext" not in sys.argv[1:]: - os.symlink(os.path.join("..", "config"), "fairseq/config") os.symlink(os.path.join("..", "examples"), "fairseq/examples") package_data = { - "fairseq": get_files("fairseq/config") + get_files("fairseq/examples"), + "fairseq": get_files("fairseq/examples"), } do_setup(package_data) finally: if "build_ext" not in sys.argv[1:]: - os.unlink("fairseq/config") os.unlink("fairseq/examples") From 50422496ac7c3ef2e7dd3818a5126ab9ab37ae29 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 7 Nov 2020 15:00:00 -0800 Subject: [PATCH 276/707] Automatically register components in ConfigStore via register_* functions (#1406) Summary: We can automatically register everything in ConfigStore on-the-fly. This avoids needing to worry about importing everything in the right order. It also better supports `--user-dir`, since those typically get imported later. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1406 Reviewed By: alexeib Differential Revision: D24814072 Pulled By: myleott fbshipit-source-id: 21cfc1a6c497fe98bf4429bfed138030d9999b6a --- fairseq/dataclass/initialize.py | 28 ---------------------------- fairseq/models/__init__.py | 6 ++++++ fairseq/optim/__init__.py | 4 +--- fairseq/registry.py | 10 +++++++++- fairseq/tasks/__init__.py | 6 ++++++ 5 files changed, 22 insertions(+), 32 deletions(-) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 24fedd52bf..7a1ebeff1c 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -9,32 +9,10 @@ from hydra.core.config_store import ConfigStore from fairseq.dataclass.configs import FairseqConfig -# the imports below are necessary so that "REGISTRIES" is correctly populated with all components -from fairseq.criterions import CRITERION_REGISTRY # noqa -from fairseq.optim import OPTIMIZER_REGISTRY # noqa -from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY # noqa -from fairseq.scoring import SCORER_REGISTRY # noqa -from fairseq.data.encoders import BPE_REGISTRY, TOKENIZER_REGISTRY # noqa - -from fairseq.models import MODEL_DATACLASS_REGISTRY -from fairseq.tasks import TASK_DATACLASS_REGISTRY -from fairseq.registry import REGISTRIES - logger = logging.getLogger(__name__) -def register_module_dataclass( - cs: ConfigStore, registry: Dict[str, Any], group: str -) -> None: - """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" - # note that if `group == model`, we register all model archs, not the model name. - for k, v in registry.items(): - node_ = v() - node_._name = k - cs.store(name=k, group=group, node=node_, provider="fairseq") - - def hydra_init(cfg_name="config") -> None: cs = ConfigStore.instance() @@ -47,9 +25,3 @@ def hydra_init(cfg_name="config") -> None: except BaseException: logger.error(f"{k} - {v}") raise - - register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") - - for k, v in REGISTRIES.items(): - register_module_dataclass(cs, v["dataclass_registry"], k) diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index e8af024795..5336b0a052 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -11,6 +11,7 @@ import fairseq from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore from omegaconf import DictConfig, OmegaConf from .composite_encoder import CompositeEncoder @@ -120,6 +121,11 @@ def register_model_cls(cls): cls.__dataclass = dataclass if dataclass is not None: MODEL_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="model", node=node, provider="fairseq") return cls return register_model_cls diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index d8e581729e..112c8ad10f 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -32,9 +32,7 @@ ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) -def build_optimizer( - cfg: DictConfig, params, *extra_args, **extra_kwargs -): +def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] params = list(filter(lambda p: p.requires_grad, params)) diff --git a/fairseq/registry.py b/fairseq/registry.py index 29631bb326..7a3dd1d1bf 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -8,6 +8,7 @@ from typing import Union from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import populate_dataclass, merge_with_parent +from hydra.core.config_store import ConfigStore from omegaconf import DictConfig REGISTRIES = {} @@ -82,9 +83,16 @@ def register_x_cls(cls): ) cls.__dataclass = dataclass - REGISTRY[name] = cls if cls.__dataclass is not None: DATACLASS_REGISTRY[name] = cls.__dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group=registry_name, node=node, provider="fairseq") + + REGISTRY[name] = cls + return cls return register_x_cls diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 7575ba429e..415f15e708 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -10,6 +10,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore from omegaconf import DictConfig from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa @@ -86,6 +87,11 @@ def register_task_cls(cls): if dataclass is not None: TASK_DATACLASS_REGISTRY[name] = dataclass + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="task", node=node, provider="fairseq") + return cls return register_task_cls From bd2e804b9c2ff1fae202c00e227f1afece12420b Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 7 Nov 2020 16:50:15 -0800 Subject: [PATCH 277/707] add and link hydra docs (#1405) Summary: updates hydra integration doc and links to it Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1405 Reviewed By: myleott Differential Revision: D24808779 Pulled By: alexeib fbshipit-source-id: a50160e196e469e30e39d6ee47440a569c0154bd --- README.md | 209 +++++++++-------- docs/hydra_integration.md | 285 ++++++++++++++++------- fairseq/data/encoders/moses_tokenizer.py | 2 +- 3 files changed, 318 insertions(+), 178 deletions(-) diff --git a/README.md b/README.md index 13b8223959..0648da15f7 100644 --- a/README.md +++ b/README.md @@ -13,100 +13,107 @@ Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. + We provide reference implementations of various sequence modeling papers:
List of implemented papers

-- **Convolutional Neural Networks (CNN)** - - [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - - [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) - - [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - - [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -- **LightConv and DynamicConv models** - - [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -- **Long Short-Term Memory (LSTM) networks** - - Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) -- **Transformer (self-attention) networks** - - Attention Is All You Need (Vaswani et al., 2017) - - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - - [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - - [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - - [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) -- **Non-autoregressive Transformers** - - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -- **Finetuning** - - [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)

### What's New: -- October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) -- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) -- October 2020: [Added CRISS models and code](examples/criss/README.md) -- September 2020: [Added Linformer code](examples/linformer/README.md) -- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) -- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) -- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) -- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) -- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) -- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -- April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* November 2020: Adopted [Hydra](https://github.com/facebookresearch/hydra) as a configuration framework; +[added documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) +* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) +* October 2020: [Added CRISS models and code](examples/criss/README.md) +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +
Previous updates

-- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -- February 2020: [mBART model and code released](examples/mbart/README.md) -- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) -- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -- November 2019: [CamemBERT model and code released](examples/camembert/README.md) -- November 2019: [BART model and code released](examples/bart/README.md) -- November 2019: [XLM-R models and code released](examples/xlmr/README.md) -- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -- August 2019: [WMT'19 models released](examples/wmt19/README.md) -- July 2019: fairseq relicensed under MIT license -- July 2019: [RoBERTa models and code released](examples/roberta/README.md) -- June 2019: [wav2vec models and code released](examples/wav2vec/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)

### Features: -- multi-GPU training on one machine or across multiple machines (data and model parallel) -- fast generation on both CPU and GPU with multiple search algorithms implemented: - - beam search - - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - - sampling (unconstrained, top-k and top-p/nucleus) - - lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) -- large mini-batch training even on a single GPU via delayed updates -- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers +* multi-GPU training on one machine or across multiple machines (data and model parallel) +* fast generation on both CPU and GPU with multiple search algorithms implemented: + + beam search + + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + + sampling (unconstrained, top-k and top-p/nucleus) + + lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) +* large mini-batch training even on a single GPU via delayed updates +* mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers +* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) with a convenient `torch.hub` interface: -```python + +``` python en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') en2de.translate('Hello world', beam=5) # 'Hallo Welt' ``` + See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. @@ -116,7 +123,8 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * **To install fairseq** and develop locally: -```bash + +``` bash git clone https://github.com/pytorch/fairseq cd fairseq pip install --editable ./ @@ -124,18 +132,20 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ ``` + * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: -```bash + +``` bash git clone https://github.com/NVIDIA/apex cd apex pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ --global-option="--fast_multihead_attn" ./ ``` -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with -`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`. +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` + as command line options to `nvidia-docker run` . # Getting Started @@ -148,30 +158,31 @@ types and tasks. We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, as well as example training and evaluation commands. -- [Translation](examples/translation/README.md): convolutional and transformer models are available -- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available +* [Translation](examples/translation/README.md): convolutional and transformer models are available +* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: -- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) -- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) -- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) -- [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) -- [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) -- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -- [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) -- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) -- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) -- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) -- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) -- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) -- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) -- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) -- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) -- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) + +* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) +* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) +* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) +* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) +* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) +* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) +* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) +* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) +* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) +* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) +* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) # Join the fairseq community @@ -188,7 +199,7 @@ The license applies to the pre-trained models as well. Please cite as: -```bibtex +``` bibtex @inproceedings{ott2019fairseq, title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 0973cd279e..f924de961b 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -1,111 +1,240 @@ +## Hydra +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of +research and other complex applications. The key feature is the ability to dynamically create a hierarchical +configuration by composition and override it through config files and the command line. The name Hydra comes from its +ability to run multiple similar jobs - much like a Hydra with multiple heads. + +## Motivation + +Until recently, all components in fairseq were configured through a shared "args" namespace that was created at +application startup. Components declared their own "add_args" method to update the argparse parser, hoping that +the names would not clash with arguments from other components. While this model works for smaller applications, +as fairseq grew and became integrated into other applications, this became problematic. +In order to determine how to configure each component, one needed to a) examine what args were added by this component, and +b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing +models involved sharing commands that often contained dozens of command line switches. + +The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time +in the future. + +New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this +component. The dataclass is registered along with the component, and fairseq takes care of constructing and +providing this configuration object to the component's constructor. Note that sharing parameters can optionally +still work, but one has to explicitly point to the "source of truth" (see inheritance example below). +These changes make components in fairseq +more independent and re-usable by other applications: all that is needed to create a component is to initialize its +dataclass and overwrite some of the defaults. + +While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still +fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through +hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run +an identically configured job. + +Additionally, Hydra has a rich and growing +[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as +hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library), +job launching across various platforms, and more. + +## Creating or migrating components + +In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same +file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be +present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added +to the FairseqConfig object. + +Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass +decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility). +Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as +data types for each field. + + Example: + + +``` python +from dataclasses import dataclass, field +from fairseq.dataclass import FairseqDataclass + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) +``` -## Hydra +### Inherting values -Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads. +Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to +know the initial learning rate value. One can declare a field that, by default, will +inherit its value from another config node in the same hierarchy: -## Train models with hydra interface +``` python +@dataclass +FairseqAdamConfig(FairseqDataclass): + ... + lr: List[float] = II("optimization.lr") + ... +``` -#### Provide parameters in `.yaml` files -For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below). +`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through +command line to achieve the same effect. Note that this assumes that there is an "optimization" config object +in the root config and it has a field called "lr". -- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example: +### Tasks and Models -``` -defaults: - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt -``` +Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes, +while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions. -- Provide generic parameters common across different jobs: `config.yaml` -- Provide task parameters: `config/task/language_modeling.yaml` -- Provide model parameters: `config/model/transformer_lm.yaml` -- Provide criterion parameters: `config/criterion/cross_entropy.yaml` -- Provide optimizer parameters: `config/optimizer/adam.yaml` -- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml` +Task example: -#### Command line overriding -`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command: +``` python +@dataclass +class LanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + ... +@register_task("language_modeling", dataclass=LanguageModelingConfig) +class LanguageModelingTask(LegacyFairseqTask): + ... + @classmethod + def setup_task(cls, cfg: LanguageModelingConfig): + ... ``` -# task.data is requested field marked by `???` in yaml -python fairseq_cli/train_hydra.py \ -task.data=/private/home/abaevski/data/wiki103 \ + +Model example: + +``` python +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + ... + +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) +class TransformerLanguageModel(FairseqLanguageModel): + ... + @classmethod + def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask): + ... ``` -Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits) +### Other components + +Other components work as before, but they now take their configuration dataclass as the only constructor argument: +``` python +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + ... + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) +class MosesTokenizer(object): + def __init__(self, cfg: MosesTokenizerConfig): + ... ``` -python fairseq_cli/train_hydra.py -task=language_modeling \ -task.data=/private/home/abaevski/data/wiki103 \ -task.tokens_per_sample=512 \ -task.sample_break_mode=none \ -model=transformer_lm \ -model.share_decoder_input_output_embed=true \ -model.dropout=0.1 \ -optimizer=adam \ -optimizer.adam_betas="'(0.9, 0.98)'" \ -optimizer.weight_decay=0.01 \ -lr_scheduler=inverse_sqrt \ -lr_scheduler.warmup_updates=4000 \ -lr_scheduler.warmup_init_lr=1e-07 \ -criterion=cross_entropy \ -common.fp16=true \ -common.log_format=json \ -common.log_interval=1 \ -dataset.max_tokens=1024 \ -dataset.num_workers=4 \ -optimization.update_freq=[16] \ -optimization.max_update=50000 \ -optimization.clip_norm=0.0 \ -optimization.lr=[0.0005] \ -checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ -checkpoint.save_interval_updates=10 + +Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in +fairseq/dataclass/configs.py: + +``` python +@dataclass +class FairseqConfig(object): + ... + my_new_registry: Any = None ``` -## Migrate existing/Creating new modules to hydra interface +## Training with hydra_train.py -In each of the modules we want to migrated/create with hydra interface, fundamentally we need to +To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the +hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py, +will remain supported for the foreseeable future but will be deprecated eventually. -- Provide a dataclass that layouts the parameters used in the module. +On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses +populated with their default values in the code. The default values are overwritten by values found in YAML files in +fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values +provided through command line arguments. -- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility. +Some of the most common use cases are shown below: -- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below. +### 1. Overwrite default values through command line: -#### Migrated examples: +```shell script +python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \ +model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 -- Task: `fairseq/tasks/language_modeling.py` +``` -- Model: `fairseq/models/transformer_lm.py` +Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` +over the default values in the dataclass. If you want to train a model without specifying a particular architecture +you can simply specify model=transformer_lm. This only works for migrated tasks and models. -- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py` +### 2. Replace bundled configs with an external config: -- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py` +```shell script +python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103 +``` + +where /path/to/external/configs/wiki103.yaml contains: + +``` yaml +# @package _group_ + +model: + _name: transformer_lm +distributed_training: + distributed_world_size: 1 +dataset: + batch_size: 2 +task: + _name: language_modeling + data: /path/to/data + add_bos_token: false + max_target_positions: 1024 +optimization: + max_update: 50000 + lr: [ 0.25 ] +criterion: cross_entropy +optimizer: adam +lr_scheduler: + _name: cosine +``` -- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py` +Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config). +Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields +(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your +top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this). -## Interpolate parameters across different places +### 3. Add an external config directory to Hydra search path: -## Support of legacy interface -If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned. +This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration. +```shell script +python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \ +task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \ +--config-dir /path/to/external/configs + +``` + +where /path/to/external/configs has the following structure: ``` -python fairseq_cli/train.py --task language_modeling \ -/private/home/abaevski/data/wiki103 \ ---save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ ---arch transformer_lm --share-decoder-input-output-embed \ ---dropout 0.1 \ ---optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \ ---lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ ---tokens-per-sample 512 --sample-break-mode none \ ---max-tokens 1024 --update-freq 16 \ ---fp16 \ ---max-update 50000 --log-format json --log-interval 1 --num-workers 4 \ ---save-interval-updates 10 +. ++-- model +| +-- transformer_lm +| | +-- 2_layers.yaml ``` + +and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add +other configs to configure other components as well. diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index fa004dd4af..e236dad167 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -24,7 +24,7 @@ class MosesTokenizerConfig(FairseqDataclass): @register_tokenizer("moses", dataclass=MosesTokenizerConfig) class MosesTokenizer(object): - def __init__(self, cfg): + def __init__(self, cfg: MosesTokenizerConfig): self.cfg = cfg try: From 108f7204f6ccddb676e6d52006da219ce96a02dc Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 7 Nov 2020 23:13:14 -0800 Subject: [PATCH 278/707] add local_rank alias (#1408) Summary: Fixes #2859 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1408 Reviewed By: myleott Differential Revision: D24817281 Pulled By: alexeib fbshipit-source-id: 4c1a3c7d6b3b940e1293d316253b57e101f3f862 --- fairseq/dataclass/configs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index a3c0d06a39..ec921a41d7 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -208,7 +208,10 @@ class DistributedTrainingConfig(FairseqDataclass): ) local_rank: int = field( default=0, - metadata={"help": "which GPU to use (usually configured automatically)"}, + metadata={ + "help": "which GPU to use (usually configured automatically)", + "argparse_alias": "--local_rank", + }, ) distributed_no_spawn: bool = field( default=False, From 18d3b5c8b0d71e0b828b5a0f5c54ee6769583669 Mon Sep 17 00:00:00 2001 From: UriSha Date: Mon, 9 Nov 2020 10:55:29 -0800 Subject: [PATCH 279/707] Update wikitext url (#2871) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Update WikiText-103 url ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2871 Reviewed By: myleott Differential Revision: D24835953 Pulled By: alexeib fbshipit-source-id: 890e911d528c04de0dc056e55866afb46a2bd87f --- examples/language_model/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language_model/README.md b/examples/language_model/README.md index dc84d8c761..e78ea48e08 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -5,7 +5,7 @@ Model | Description | Dataset | Download ---|---|---|--- `transformer_lm.gbw.adaptive_huge` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2) -`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2) +`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2) `transformer_lm.wmt19.en` | English LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz) `transformer_lm.wmt19.de` | German LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz) `transformer_lm.wmt19.ru` | Russian LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz) From d10fabd6971f51f59e3039accc248eae8945d6ff Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 9 Nov 2020 12:25:01 -0800 Subject: [PATCH 280/707] Make it easier to use non-FairseqDatasets (#1411) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1411 Test Plan: Imported from OSS Reviewed By: huihuifan Differential Revision: D24833475 Pulled By: myleott fbshipit-source-id: 5be599bd2b7d820a208321da53d594d5ae67bf2b --- fairseq/data/iterators.py | 18 +++++++++++---- fairseq/models/fairseq_model.py | 2 +- fairseq/models/huggingface/hf_gpt2.py | 33 ++++----------------------- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 15796234db..ef41fed739 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -164,7 +164,8 @@ def next_epoch_idx(self): def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): self.epoch = self.next_epoch_idx - self.dataset.set_epoch(self.epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) self._current_epoch_iterator = CountingIterator( iterable=ShardedIterator( iterable=self.dataset, @@ -225,7 +226,9 @@ class EpochBatchIterator(EpochBatchIterating): queue. Helps speeding up dataloading. When buffer_size is zero, the default torch.utils.data.DataLoader preloading is used. timeout (int, optional): if positive, the timeout value for collecting a batch - from workers. Should always be non-negative. (default: ``0``) + from workers. Should always be non-negative (default: ``0``). + disable_shuffling (bool, optional): force disable shuffling + (default: ``False``). """ def __init__( @@ -240,6 +243,7 @@ def __init__( epoch=1, buffer_size=0, timeout=0, + disable_shuffling=False, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -256,9 +260,10 @@ def __init__( # in a shared computing environment. self.buffer_size = min(buffer_size, 20) self.timeout = timeout + self.disable_shuffling = disable_shuffling self.epoch = max(epoch, 1) # we use 1-based indexing for epochs - self.shuffle = True + self.shuffle = not disable_shuffling self._cur_epoch_itr = None self._next_epoch_itr = None self._supports_prefetch = getattr(dataset, "supports_prefetch", False) @@ -279,7 +284,7 @@ def first_batch(self): "a larger dataset." ) - if self.dataset.supports_fetch_outside_dataloader: + if getattr(self.dataset, "supports_fetch_outside_dataloader", True): return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) else: return "DUMMY" @@ -311,8 +316,11 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): allocated to the same shards across epochs. Requires that :attr:`dataset` supports prefetching (default: False). """ + if self.disable_shuffling: + shuffle = False self.epoch = self.next_epoch_idx - self.dataset.set_epoch(self.epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) if self._next_epoch_itr is not None: self._cur_epoch_itr = self._next_epoch_itr self._next_epoch_itr = None diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 0c8d106be5..926d952f77 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -177,7 +177,7 @@ def make_generation_fast_(self, **kwargs): def apply_remove_weight_norm(module): try: nn.utils.remove_weight_norm(module) - except ValueError: # this module didn't have weight norm + except (AttributeError, ValueError): # this module didn't have weight norm return self.apply(apply_remove_weight_norm) diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py index a823453794..3a8eb78198 100644 --- a/fairseq/models/huggingface/hf_gpt2.py +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -17,20 +17,6 @@ ) -try: - # Prepend the transformers submodule to the path, so that - # it's prioritized over other installations. This allows - # making local changes in the submodule. - hf_path = os.path.join(os.path.dirname(__file__), "transformers", "src") - sys.path.insert(0, hf_path) - from transformers import GPT2Config, GPT2LMHeadModel - - sys.path.remove(hf_path) - has_hf = True -except ImportError: - has_hf = False - - logger = logging.getLogger(__name__) @@ -41,14 +27,6 @@ class HuggingFaceGPT2LanguageModel(FairseqLanguageModel): def __init__(self, decoder): super().__init__(decoder) - if not has_hf: - raise ImportError( - "\n\nPlease install huggingface/transformers with:" - "\n\n pip install transformers" - "\n\nOr to make local edits, install the submodule:" - "\n\n git submodule update --init " - "fairseq/models/huggingface/transformers" - ) @staticmethod def add_args(parser): @@ -76,17 +54,16 @@ def build_model(cls, args, task): class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder): def __init__(self, args, task): - super().__init__(task.target_dictionary) - - if not has_hf: + try: + from transformers import GPT2Config, GPT2LMHeadModel + except ImportError: raise ImportError( "\n\nPlease install huggingface/transformers with:" "\n\n pip install transformers" - "\n\nOr to make local edits, install the submodule:" - "\n\n git submodule update --init " - "fairseq/models/huggingface/transformers" ) + super().__init__(task.target_dictionary) + config = GPT2Config( vocab_size=len(task.target_dictionary), n_positions=args.max_target_positions + 1, From 5cfc50627788fb517e29ccc14ea8f3f12b8068a6 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 9 Nov 2020 12:25:01 -0800 Subject: [PATCH 281/707] Fix eval_lm.py to use current cfg for task, also fix --model-overrides (#1412) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1412 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D24833478 Pulled By: myleott fbshipit-source-id: 4d0720a875541c016a00b28a4f0a9ad77e77e7a8 --- fairseq_cli/eval_lm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 1197d6987b..e8fd98c325 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -14,7 +14,7 @@ from argparse import Namespace import torch -from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar @@ -63,7 +63,7 @@ def __str__(self): ) -def main(cfg: DictConfig, override_args=None, **unused_kwargs): +def main(cfg: DictConfig, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) @@ -75,12 +75,6 @@ def main(cfg: DictConfig, override_args=None, **unused_kwargs): if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) - if override_args is not None: - overrides = vars(override_args) - overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) - else: - overrides = None - logger.info(cfg) # Load ensemble @@ -89,12 +83,17 @@ def main(cfg: DictConfig, override_args=None, **unused_kwargs): # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + # Initialize the task using the current *cfg* + task = tasks.setup_task(cfg.task) + + # Initialize the model (but not the task) using the checkpoint's *cfg* models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], - arg_overrides=overrides, + arg_overrides=eval(cfg.common_eval.model_overrides), suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, + task=task, ) # Load dataset splits @@ -193,7 +192,7 @@ def main(cfg: DictConfig, override_args=None, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if cfg.task.add_bos_token: + if getattr(cfg.task, "add_bos_token", False): assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] From c19dfe26160c6cee768b31eb6bb149781d9c6eac Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 9 Nov 2020 12:25:01 -0800 Subject: [PATCH 282/707] Make activation checkpointing interface less restrictive (#1413) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1413 Test Plan: Imported from OSS Reviewed By: ngoyal2707 Differential Revision: D24833476 Pulled By: myleott fbshipit-source-id: 380ea7e05c7b188086b2b10c15120ea6636e0a3e --- fairseq/modules/checkpoint_activations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index a4341fe742..1f99c24ca1 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -88,7 +88,6 @@ def split_non_tensors( """ if isinstance(mixed, torch.Tensor): return (mixed,), None - assert isinstance(mixed, tuple) tensors = [] packed_non_tensors = {"is_tensor": [], "objects": []} for o in mixed: @@ -151,7 +150,6 @@ def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): if isinstance(outputs, torch.Tensor): return outputs else: - assert isinstance(outputs, tuple) # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. From 74a59ada8c882b6a43eac190cb0608b3258ce165 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 9 Nov 2020 12:25:01 -0800 Subject: [PATCH 283/707] Upgrade DummyLMTask to Hydra (#1415) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1415 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D24833480 Pulled By: myleott fbshipit-source-id: 007623168467d18166b20ef99f54388eb9d8008a --- fairseq/benchmark/dummy_lm.py | 79 ++++++++++++++++------------------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index f3146b3581..d917e28837 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -4,82 +4,75 @@ # LICENSE file in the root directory of this source tree. import logging +from dataclasses import dataclass, field +from typing import Optional import numpy as np import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II logger = logging.getLogger(__name__) -@register_task("dummy_lm") -class DummyLMTask(LegacyFairseqTask): - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("--dict-size", default=49996, type=int) - parser.add_argument("--dataset-size", default=100000, type=int) - parser.add_argument( - "--tokens-per-sample", - default=512, - type=int, - help="max number of total tokens over all segments " - "per sample for BERT dataset", - ) - parser.add_argument("--add-bos-token", action="store_true", help="unused") - parser.add_argument( - "--max-target-positions", - default=None, - help="max number of tokens in the target sequence", - ) +@dataclass +class DummyLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, metadata={"help": "max sequence length"} + ) + add_bos_token: bool = False + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + - def __init__(self, args, dictionary): - super().__init__(args) - self.dictionary = dictionary - self.seed = args.seed +@register_task("dummy_lm", dataclass=DummyLMConfig) +class DummyLMTask(FairseqTask): - dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + def __init__(self, cfg: DummyLMConfig): + super().__init__(cfg) - seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 + # load dictionary + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + logger.info("dictionary: {} types".format(len(self.dictionary))) + + seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1 self.dummy_src = seq[:-1] self.dummy_tgt = seq[1:] - @classmethod - def setup_task(cls, args, **kwargs): - """Setup the task. """ - dictionary = Dictionary() - for i in range(args.dict_size): - dictionary.add_symbol("word{}".format(i)) - logger.info("dictionary: {} types".format(len(dictionary))) - return cls(args, dictionary) - def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ - if self.args.batch_size is not None: - bsz = self.args.batch_size + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size else: - bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) self.datasets[split] = DummyDataset( { "id": 1, "net_input": { "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), "src_lengths": torch.full( - (bsz,), self.args.tokens_per_sample, dtype=torch.long + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long ), }, "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), "nsentences": bsz, - "ntokens": bsz * self.args.tokens_per_sample, + "ntokens": bsz * self.cfg.tokens_per_sample, }, - num_items=self.args.dataset_size, - item_size=self.args.tokens_per_sample, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, ) @property From b418e46c8b4fedeaf80ef41b4235af33d496ddd4 Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 9 Nov 2020 15:44:32 -0800 Subject: [PATCH 284/707] migrate audio_pretraining task to hydra (#1407) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1407 Reviewed By: myleott Differential Revision: D24821909 Pulled By: alexeib fbshipit-source-id: a58afdd17afab00062bef43cadae380998b23f29 --- fairseq/dataclass/utils.py | 12 +- fairseq/models/__init__.py | 2 +- fairseq/models/wav2vec/wav2vec2_asr.py | 22 ++- fairseq/tasks/audio_pretraining.py | 213 ++++++++++++++----------- fairseq/tasks/fairseq_task.py | 7 +- fairseq/tasks/language_modeling.py | 1 - 6 files changed, 151 insertions(+), 106 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 1efcc5dca9..5f4d200dfe 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -50,6 +50,10 @@ def argparse_name(name: str): def interpret_dc_type(field_type): if isinstance(field_type, str): raise RuntimeError("field should be a type") + + if field_type == Any: + return str + typestring = str(field_type) if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): return field_type.__args__[0] @@ -127,8 +131,13 @@ def get_kwargs_from_dc( for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) + field_type = dataclass_instance._get_type(k) if field_name is None: continue + elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): + gen_parser_from_dataclass(parser, field_type(), delete_default) + continue + kwargs = get_kwargs_from_dc(dataclass_instance, k) field_args = [field_name] @@ -197,13 +206,14 @@ def get_default(f): t_args = v.type.__args__ if len(t_args) == 1: val = list(map(t_args[0], val)) - if val is None: overrides.append("{}.{}=null".format(sub_node, k)) elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) elif isinstance(val, str): overrides.append("{}.{}='{}'".format(sub_node, k, val)) + elif isinstance(val, FairseqDataclass): + overrides += _override_attr(f"{sub_node}.{k}", type(val), args) else: overrides.append("{}.{}={}".format(sub_node, k, val)) return overrides diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 5336b0a052..d76e391499 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -52,7 +52,7 @@ ] -def build_model(cfg: DictConfig, task): +def build_model(cfg: FairseqDataclass, task): model = None model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 14fa8ea5ed..f62ec633b4 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import contextlib import copy import math @@ -12,6 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models import ( BaseFairseqModel, FairseqEncoder, @@ -328,18 +330,24 @@ def __init__(self, args, tgt_dict=None): state = checkpoint_utils.load_checkpoint_to_cpu( args.w2v_path, arg_overrides ) - args.w2v_args = w2v_args = state.get("args", None) or state["cfg"].model + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + args.w2v_args = w2v_args else: state = None w2v_args = args.w2v_args + if isinstance(w2v_args, Namespace): + args.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) assert ( - args.normalize == w2v_args.normalize - ), "Fine-tuning works best when data normalization is the same" + args.normalize == w2v_args.task.normalize + ), "Fine-tuning works best when data normalization is the same. " \ + "Please check that --normalize is set or unset for both" - w2v_args.data = args.data - task = tasks.setup_task(w2v_args) - model = task.build_model(w2v_args) + w2v_args.task.data = args.data + task = tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) if state is not None and not args.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) @@ -348,7 +356,7 @@ def __init__(self, args, tgt_dict=None): super().__init__(task.source_dictionary) - d = w2v_args.encoder_embed_dim + d = w2v_args.model.encoder_embed_dim self.w2v_model = model diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 90eb7ca2d6..d1b6bf1c14 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -10,10 +10,16 @@ import sys import torch +from dataclasses import dataclass, field +from typing import Optional, Any +from omegaconf import MISSING + from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders from fairseq.data.data_utils import post_process +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig -from . import LegacyFairseqTask, register_task +from . import FairseqTask, register_task from .. import utils from ..logging import metrics @@ -28,87 +34,94 @@ def __call__(self, label): ) -@register_task("audio_pretraining") -class AudioPretrainingTask(LegacyFairseqTask): +@dataclass +class AudioPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + labels: Optional[str] = field( + default=None, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to crop to for batching"} + ) + + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default="space", + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + + +@register_task("audio_pretraining", dataclass=AudioPretrainingConfig) +class AudioPretrainingTask(FairseqTask): """""" - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("data", help="path to data directory") - parser.add_argument( - "--sample-rate", - default=16000, - type=int, - help="target sample rate. audio files will be up/down sampled to this rate", - ) - parser.add_argument( - "--normalize", - action="store_true", - help="if set, normalizes input to have 0 mean and unit variance", - ) - parser.add_argument( - "--max-sample-size", - default=None, - type=int, - help="max sample size to crop to for batching. default = min sample length", - ) - parser.add_argument( - "--min-sample-size", - default=None, - type=int, - help="min sample size to crop to for batching. default = same as --max-sample-size", - ) - - parser.add_argument( - "--enable-padding", - action="store_true", - help="pad shorter samples instead of cropping", - ) - - parser.add_argument( - "--labels", - type=str, - default=None, - help="extension of the label file to load, if any", - ) - - # Options for reporting WER metrics during validation. Only applicable to - # Seq2Seq models during fine-tuning - parser.add_argument( - "--eval-wer", - action="store_true", - help="compute WER for Seq2Seq models", - ) - parser.add_argument( - "--eval-wer-remove-bpe", - default="letter", - help="remove BPE tokens before scoring (can be sentencepiece, letter, and more)", - ) + cfg: AudioPretrainingConfig - def __init__(self, args, source_dictionary=None, target_dictionary=None): - super().__init__(args) + def __init__( + self, + cfg: AudioPretrainingConfig, + source_dictionary=None, + target_dictionary=None, + ): + super().__init__(cfg) self._target_dictionary = target_dictionary self._source_dictionary = source_dictionary - self.is_ctc = args.criterion == "ctc" - if getattr(self.args, "eval_wer", False): - assert args.labels is not None, "eval_wer can only be set during fine-tuning" + if cfg.eval_wer: + assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (omegaconf.DictConfig): parsed command-line arguments + cfg (AudioPretrainingConfig): configuration of this task """ - if args.labels: - dict_path = os.path.join(args.data, f"dict.{args.labels}.txt") + if cfg.labels: + dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") target_dictionary = Dictionary.load(dict_path) else: target_dictionary = None - return cls(args, target_dictionary=target_dictionary) + return cls(cfg, target_dictionary=target_dictionary) def load_dataset(self, split, **kwargs): """Load a given dataset split. @@ -116,19 +129,19 @@ def load_dataset(self, split, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - manifest = os.path.join(self.args.data, "{}.tsv".format(split)) + manifest = os.path.join(self.cfg.data, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=self.args.sample_rate, - max_sample_size=self.args.max_sample_size, - min_sample_size=self.args.max_sample_size, - min_length=self.args.min_sample_size, - pad=self.args.labels is not None or self.args.enable_padding, - normalize=self.args.normalize, + sample_rate=self.cfg.sample_rate, + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.max_sample_size, + min_length=self.cfg.min_sample_size, + pad=self.cfg.labels is not None or self.cfg.enable_padding, + normalize=self.cfg.normalize, ) - if self.args.labels: - label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}") + if self.cfg.labels: + label_path = os.path.join(self.cfg.data, f"{split}.{self.cfg.labels}") labels = [] with open(label_path, "r") as f: for line in f: @@ -143,7 +156,7 @@ def load_dataset(self, split, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=not self.is_ctc, + add_to_input=self.cfg.autoregressive, ) @property @@ -173,7 +186,7 @@ def filter_indices_by_size( def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - if getattr(self.args, "eval_wer", False) and not self.is_ctc: + if self.cfg.eval_wer and self.cfg.autoregressive: metrics = self._inference_with_wer(self.sequence_generator, sample, model) logging_output["_num_char_errors"] = metrics["num_char_errors"] logging_output["_num_chars"] = metrics["num_chars"] @@ -181,19 +194,23 @@ def valid_step(self, sample, model, criterion): logging_output["_num_words"] = metrics["num_words"] return loss, sample_size, logging_output - def build_model(self, args): - model = super().build_model(args) + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) - if getattr(args, 'eval_wer', False) and not self.is_ctc: - self.sequence_generator = self.build_generator([model], args, ) - self.tokenizer = encoders.build_tokenizer(args) + if self.cfg.eval_wer and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) return model def _inference_with_wer(self, generator, sample, model): def decode(toks, escape_unk=True): s = self.target_dictionary.string( toks.int().cpu(), - self.args.eval_wer_remove_bpe, + self.cfg.eval_wer_post_process, escape_unk=escape_unk, extra_symbols_to_ignore={generator.eos}, ) @@ -210,15 +227,15 @@ def decode(toks, escape_unk=True): utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), escape_unk=True, ) - hyp = post_process(hyp, self.args.eval_wer_remove_bpe).strip("_") - ref = post_process(ref, self.args.eval_wer_remove_bpe).strip("_") + hyp = post_process(hyp, self.cfg.eval_wer_post_process).strip("_") + ref = post_process(ref, self.cfg.eval_wer_post_process).strip("_") num_char_errors += editdistance.eval(hyp, ref) num_chars += len(ref) hyp_words = hyp.split("_") ref_words = ref.split("_") num_word_errors += editdistance.eval(hyp_words, ref_words) num_words += len(ref_words) - + return { "num_char_errors": num_char_errors, "num_chars": num_chars, @@ -229,10 +246,14 @@ def decode(toks, escape_unk=True): def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) - zero = torch.scalar_tensor(0.) - num_char_errors = sum(log.get("_num_char_errors", zero) for log in logging_outputs) + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) - num_word_errors = sum(log.get("_num_word_errors", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) num_words = sum(log.get("_num_words", zero) for log in logging_outputs) metrics.log_scalar("_num_char_errors", num_char_errors) metrics.log_scalar("_num_chars", num_chars) @@ -241,11 +262,17 @@ def reduce_metrics(self, logging_outputs, criterion): if num_words > 0: metrics.log_derived( "uer", - lambda meters: meters["_num_char_errors"].sum * 100.0 / meters["_num_chars"].sum - if meters["_num_chars"].sum > 0 else float("nan") + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), ) metrics.log_derived( "wer", - lambda meters: meters["_num_word_errors"].sum * 100.0 / meters["_num_words"].sum - if meters["_num_words"].sum > 0 else float("nan") + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), ) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 3cdb64cfae..c47f9c4200 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -11,6 +11,7 @@ import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from omegaconf import DictConfig @@ -40,7 +41,7 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() - def __init__(self, cfg: DictConfig, **kwargs): + def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg self.datasets = {} self.dataset_to_epoch_iter = {} @@ -255,13 +256,13 @@ def get_batch_iterator( return epoch_iter - def build_model(self, cfg: DictConfig): + def build_model(self, cfg: FairseqDataclass): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: - cfg (omegaconf.DictConfig): configuration object + cfg (FairseqDataclass): configuration object Returns: a :class:`~fairseq.models.BaseFairseqModel` instance diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 79c225de6f..e0bf1f9b2b 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -39,7 +39,6 @@ @dataclass class LanguageModelingConfig(FairseqDataclass): - # TODO common var add to parent data: Optional[str] = field( default=None, metadata={"help": "path to data directory"} ) From 4ea1c1eee077cbf85b1110e6f25d691e53270a7b Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 9 Nov 2020 15:44:32 -0800 Subject: [PATCH 285/707] migrate wav2vec2 model (#1409) Summary: see title also includes some minor bug fixes Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1409 Reviewed By: myleott Differential Revision: D24822219 Pulled By: alexeib fbshipit-source-id: b18f9a8af42ced37880c23dd6ad1ec4df3dfc040 --- fairseq/checkpoint_utils.py | 2 - fairseq/dataclass/utils.py | 21 +- fairseq/models/__init__.py | 24 +- fairseq/models/transformer_lm.py | 1 - fairseq/models/wav2vec/wav2vec2.py | 629 +++++++++++------------------ 5 files changed, 260 insertions(+), 417 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 25d3e1e705..3038a1ebcc 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -13,7 +13,6 @@ from typing import Optional, Union import torch -from fairseq import utils from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -22,7 +21,6 @@ from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder from omegaconf import DictConfig, open_dict -from torch.serialization import default_restore_location logger = logging.getLogger(__name__) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 5f4d200dfe..e817005bf3 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -86,9 +86,10 @@ def get_kwargs_from_dc( kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices - if (isinstance(inter_type, type) and issubclass(inter_type, List)) or ( - "List" in str(inter_type) - ): + if ( + isinstance(inter_type, type) + and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) + ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): @@ -96,7 +97,9 @@ def get_kwargs_from_dc( elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: - raise NotImplementedError() + raise NotImplementedError( + "parsing of type " + str(inter_type) + " is not implemented" + ) if field_default is not MISSING: kwargs["default"] = ( ",".join(map(str, field_default)) @@ -216,6 +219,7 @@ def get_default(f): overrides += _override_attr(f"{sub_node}.{k}", type(val), args) else: overrides.append("{}.{}={}".format(sub_node, k, val)) + return overrides @@ -377,8 +381,13 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): if k in cfg and isinstance(cfg[k], DictConfig): overwrite_args_by_name(cfg[k], overrides) elif k in overrides: - if k in REGISTRIES and overrides[k] in REGISTRIES[k]["dataclass_registry"]: - cfg[k] = DictConfig(REGISTRIES[k]["dataclass_registry"][overrides[k]]) + if ( + k in REGISTRIES + and overrides[k] in REGISTRIES[k]["dataclass_registry"] + ): + cfg[k] = DictConfig( + REGISTRIES[k]["dataclass_registry"][overrides[k]] + ) overwrite_args_by_name(cfg[k], overrides) cfg[k]._name = overrides[k] else: diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index d76e391499..b987966749 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -126,6 +126,11 @@ def register_model_cls(cls): node = dataclass() node._name = name cs.store(name=name, group="model", node=node, provider="fairseq") + + @register_model_architecture(name, name) + def noop(_): + pass + return cls return register_model_cls @@ -155,15 +160,6 @@ def lstm_luong_wmt_en_de(cfg): arch_name (str): the name of the model architecture (``--arch``) """ - def arch_override_from_yaml(args, arch): - root_dir = os.path.dirname(os.path.dirname(fairseq.__file__)) - yaml_path = os.path.join(root_dir, "config/model/{}.yaml".format(arch)) - if not os.path.exists(yaml_path): - raise RuntimeError(f"yaml file {yaml_path} does not exist!") - arch_cfg = OmegaConf.load(yaml_path) - for k, v in arch_cfg.items(): - setattr(args, k, getattr(args, k, v)) - def register_model_arch_fn(fn): if model_name not in MODEL_REGISTRY: raise ValueError( @@ -182,15 +178,7 @@ def register_model_arch_fn(fn): ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) - if type(fn) is type and issubclass(fn, BaseFairseqModel): - # for model classes migrated with hydra - # in this case, we are using this decorator directly on model class since - # we do not need arch overriding functions. - ARCH_CONFIG_REGISTRY[arch_name] = lambda args: arch_override_from_yaml( - args, arch=arch_name - ) - else: - ARCH_CONFIG_REGISTRY[arch_name] = fn + ARCH_CONFIG_REGISTRY[arch_name] = fn return fn return register_model_arch_fn diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 9467b25efd..35bfa6eb6f 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -249,7 +249,6 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): return embed_tokens -@register_model_architecture("transformer_lm", "transformer_lm") def base_lm_architecture(args): # backward compatibility for older model checkpoints if hasattr(args, "no_tie_adaptive_proj"): diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 6a0f787601..6ad59085f0 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import math +from dataclasses import dataclass, field from typing import List, Tuple import numpy as np @@ -13,6 +13,7 @@ import torch.nn.functional as F from fairseq import utils from fairseq.data.data_utils import compute_mask_indices +from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.modules import ( Fp32GroupNorm, @@ -28,333 +29,256 @@ from fairseq.utils import buffered_arange -@register_model("wav2vec2") +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) + + +@dataclass +class Wav2Vec2Config(FairseqDataclass): + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group norm with d " + "groups in the first conv block, whereas layer_norm has layer norms in " + "every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + + # dropouts + dropout: float = field( + default=0.1, metadata={"help": "dropout probability for the transformer"} + ) + attention_dropout: float = field( + default=0.1, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN"} + ) + encoder_layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many dimensions." + "set to encoder_embed_dim is <= 0" + }, + ) + layer_norm_first: bool = field( + default=False, metadata={"help": "apply layernorm first in the transformer"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5), (512, 8, 4)] + [(512, 4, 2)] * 3 + [(512, 1, 1)]", + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + quantize_targets: bool = field( + default=False, metadata={"help": "use quantized targets"} + ) + quantize_input: bool = field( + default=False, metadata={"help": "use quantized inputs"} + ) + same_quantizer: bool = field( + default=False, metadata={"help": "use same quantizer for inputs and targets"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, metadata={"help": "multiply feature extractor var grads by this"} + ) + latent_vars: int = field( + default=320, + metadata={"help": "number of latent variables V in each group of the codebook"}, + ) + latent_groups: int = field( + default=2, + metadata={"help": "number of groups G of latent variables in the codebook"}, + ) + latent_dim: int = field( + default=0, + metadata={ + "help": "if > 0, uses this dimensionality for latent variables. " + "otherwise uses final_dim / latent_groups" + }, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, metadata={"help": "probability of replacing a token with mask"} + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # negative selection + num_negatives: int = field( + default=100, + metadata={"help": "number of negative examples from the same sample"}, + ) + negatives_from_everywhere: bool = field( + default=False, + metadata={"help": "sample negatives from everywhere, not just masked states"}, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "number of negative examples from the any sample"} + ) + codebook_negatives: int = field( + default=0, metadata={"help": "number of negative examples codebook"} + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling. " + "can be tuple of 3 values (start, end, decay)" + }, + ) + + +@register_model("wav2vec2", dataclass=Wav2Vec2Config) class Wav2Vec2Model(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - - parser.add_argument( - "--extractor-mode", - choices=["default", "layer_norm"], - help="mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", - ) - - parser.add_argument( - "--encoder-layers", - type=int, - metavar="L", - help="num encoder layers in the transformer", - ) - parser.add_argument( - "--encoder-embed-dim", - type=int, - metavar="H", - help="encoder embedding dimension", - ) - parser.add_argument( - "--encoder-ffn-embed-dim", - type=int, - metavar="F", - help="encoder embedding dimension for FFN", - ) - parser.add_argument( - "--encoder-attention-heads", - type=int, - metavar="A", - help="num encoder attention heads", - ) - parser.add_argument( - "--activation-fn", - choices=utils.get_available_activation_fns(), - help="activation function to use", - ) - - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout probability for the transformer", - ) - - parser.add_argument( - "--attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights", - ) - - parser.add_argument( - "--activation-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN", - ) - - parser.add_argument( - "--final-dim", - type=int, - metavar="D", - help="project final representations and targets to this many dimensions", - ) - - parser.add_argument( - "--layer-norm-first", - action="store_true", - help="apply layernorm first in the transformer", - ) - - parser.add_argument( - "--encoder-layerdrop", - type=float, - help="probability of dropping a tarnsformer layer", - ) - - parser.add_argument( - "--conv-feature-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - - parser.add_argument( - "--logit-temp", type=float, help="temperature to divide logits by" - ) - - parser.add_argument( - "--quantize-targets", action="store_true", help="use quantized targets" - ) - - parser.add_argument( - "--quantize-input", action="store_true", help="use quantized inputs" - ) - - parser.add_argument( - "--same-quantizer", - action="store_true", - help="use same quantizer for inputs and targets", - ) - - parser.add_argument( - "--feature-grad-mult", - type=float, - help="multiply feature extractor var grads by this", - ) - - parser.add_argument( - "--latent-vars", - type=int, - metavar="N", - help="number of latent variables V in each group of the codebook", - ) - - parser.add_argument( - "--latent-groups", - type=int, - metavar="N", - help="number of groups G of latent variables in the codebook", - ) - - parser.add_argument( - "--latent-dim", - type=int, - metavar="N", - help="if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", - ) - - parser.add_argument("--mask-length", type=int, help="mask length") - - parser.add_argument( - "--mask-prob", type=float, help="probability of replacing a token with mask" - ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", - ) - - parser.add_argument( - "--mask-other", - type=float, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", - ) - - parser.add_argument( - "--no-mask-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-min-space", - type=int, - help="min space between spans (if no overlap is enabled)", - ) - - parser.add_argument( - "--mask-channel-length", - type=int, - help="repeat the mask indices multiple times", - ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - help="probability of replacing a token with mask", - ) - - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", - ) - - parser.add_argument( - "--mask-channel-other", - type=float, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", - ) - - parser.add_argument( - "--no-mask-channel-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-channel-min-space", - type=int, - help="min space between spans (if no overlap is enabled)", - ) - - parser.add_argument( - "--dropout-input", - type=float, - metavar="D", - help="dropout to apply to the input (after feat extr)", - ) - - parser.add_argument( - "--dropout-features", - type=float, - metavar="D", - help="dropout to apply to the features (after feat extr)", - ) - - parser.add_argument( - "--num-negatives", type=int, metavar="N", help="number of negative examples" - ) - - parser.add_argument( - "--negatives-from-everywhere", - action="store_true", - help="sample negatives from everywhere, not just masked states", - ) - - parser.add_argument( - "--cross-sample-negatives", - type=int, - metavar="N", - help="num of cross sampled negatives", - ) - - parser.add_argument( - "--codebook-negatives", - type=int, - metavar="N", - help="num of codebook sampled negatives", - ) - - parser.add_argument( - "--conv-pos", - type=int, - metavar="N", - help="number of filters for convolutional positional embeddings", - ) - - parser.add_argument( - "--conv-pos-groups", - type=int, - metavar="N", - help="number of groups for convolutional positional embedding", - ) - - parser.add_argument( - "--latent-temp", - type=str, - metavar="D", - help="temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", - ) - - parser.add_argument( - "--target-glu", action="store_true", help="adds projection + glu to targets" - ) - - parser.add_argument( - "--conv-bias", action="store_true", help="include bias in conv encoder" - ) - - def __init__(self, args): + def __init__(self, cfg: Wav2Vec2Config): super().__init__() - self.args = args + self.cfg = cfg - feature_enc_layers = eval(args.conv_feature_layers) + feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, - mode=args.extractor_mode, - conv_bias=args.conv_bias, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, ) self.post_extract_proj = ( - nn.Linear(self.embed, args.encoder_embed_dim) - if self.embed != args.encoder_embed_dim and not args.quantize_input + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None ) - self.mask_prob = args.mask_prob - self.mask_selection = args.mask_selection - self.mask_other = args.mask_other - self.mask_length = args.mask_length - self.no_mask_overlap = args.no_mask_overlap - self.mask_min_space = args.mask_min_space + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space - self.mask_channel_prob = args.mask_channel_prob - self.mask_channel_selection = args.mask_channel_selection - self.mask_channel_other = args.mask_channel_other - self.mask_channel_length = args.mask_channel_length - self.no_mask_channel_overlap = args.no_mask_channel_overlap - self.mask_channel_min_space = args.mask_channel_min_space + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space - self.dropout_input = nn.Dropout(args.dropout_input) - self.dropout_features = nn.Dropout(args.dropout_features) + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) - self.feature_grad_mult = args.feature_grad_mult + self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None - self.n_negatives = args.num_negatives - self.cross_sample_negatives = args.cross_sample_negatives - self.codebook_negatives = args.codebook_negatives - self.negatives_from_everywhere = args.negatives_from_everywhere + self.n_negatives = cfg.num_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere - self.logit_temp = args.logit_temp + self.logit_temp = cfg.logit_temp - final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - if args.quantize_targets: - vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim + if cfg.quantize_targets: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, - num_vars=args.latent_vars, - temp=eval(args.latent_temp), - groups=args.latent_groups, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, @@ -363,39 +287,37 @@ def __init__(self, args): else: self.project_q = nn.Linear(self.embed, final_dim) - if args.quantize_input: - if args.same_quantizer and self.quantizer is not None: + if cfg.quantize_input: + if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: - vq_dim = ( - args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim - ) + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, - num_vars=args.latent_vars, - temp=eval(args.latent_temp), - groups=args.latent_groups, + num_vars=cfg.latent_vars, + temp=eval(cfg.latent_temp), + groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) - self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( - torch.FloatTensor(args.encoder_embed_dim).uniform_() + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() ) - self.encoder = TransformerEncoder(args) + self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None - if args.target_glu: + if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU() ) - self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -403,13 +325,10 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict @classmethod - def build_model(cls, args, task=None): + def build_model(cls, cfg: Wav2Vec2Config, task=None): """Build a new model instance.""" - # make sure all arguments are present - base_architecture(args) - - return cls(args) + return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape @@ -957,73 +876,3 @@ def forward( x = self.final_layer_norm(x) return x, attn - - -@register_model_architecture("wav2vec2", "wav2vec2") -def base_architecture(args): - args.extractor_mode = getattr(args, "extractor_mode", "default") - - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - - args.final_dim = getattr(args, "final_dim", 0) - - args.layer_norm_first = getattr(args, "layer_norm_first", False) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - - conv_feature_layers = "[(512, 10, 5)]" - conv_feature_layers += " + [(512, 8, 4)]" - conv_feature_layers += " + [(512, 4, 2)] * 3" - conv_feature_layers += " + [(512, 1, 1)]" - args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) - - args.logit_temp = getattr(args, "logit_temp", 0.1) - - args.quantize_targets = getattr(args, "quantize_targets", False) - args.quantize_input = getattr(args, "quantize_input", False) - args.same_quantizer = getattr(args, "same_quantizer", False) - - args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) - - args.latent_vars = getattr(args, "latent_vars", 320) - args.latent_groups = getattr(args, "latent_groups", 2) - args.latent_dim = getattr(args, "latent_dim", 0) - - args.mask_length = getattr(args, "mask_length", 10) - args.mask_prob = getattr(args, "mask_prob", 0.65) - args.mask_selection = getattr(args, "mask_selection", "static") - args.mask_other = getattr(args, "mask_other", 0) - args.no_mask_overlap = getattr(args, "no_mask_overlap", False) - args.mask_min_space = getattr(args, "mask_min_space", 1) - - args.mask_channel_length = getattr(args, "mask_channel_length", 10) - args.mask_channel_prob = getattr(args, "mask_channel_prob", 0) - args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") - args.mask_channel_other = getattr(args, "mask_channel_other", 0) - args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) - args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) - - args.dropout_input = getattr(args, "dropout_input", 0) - args.dropout_features = getattr(args, "dropout_features", 0) - - args.num_negatives = getattr(args, "num_negatives", 100) - args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False) - args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) - args.codebook_negatives = getattr(args, "codebook_negatives", 0) - - args.conv_pos = getattr(args, "conv_pos", 128) - args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) - - args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)") - - args.target_glu = getattr(args, "target_glu", False) - - args.conv_bias = getattr(args, "conv_bias", False) From 6815772651fd639ed16360074aa23e238b29c6ce Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 9 Nov 2020 19:21:13 -0800 Subject: [PATCH 286/707] fix wav2vec inference (#1418) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1418 Reviewed By: michaelauli Differential Revision: D24847525 Pulled By: alexeib fbshipit-source-id: e9f5d562ad2ac2904a65852cb9a05af775bebab0 --- fairseq/dataclass/utils.py | 41 ++++++++++++++++++++---------- fairseq/models/wav2vec/wav2vec2.py | 2 +- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index e817005bf3..d73af240b9 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -5,6 +5,7 @@ import ast import os +import re from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum @@ -30,13 +31,25 @@ def eval_str_list(x, x_type=float): return [x_type(x)] +def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError("field should be a type") + + if field_type == Any: + return str + + typestring = str(field_type) + if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): + return field_type.__args__[0] + return field_type + + def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, ) -> None: """convert a dataclass instance to tailing parser arguments""" - import re def argparse_name(name: str): if name == "data": @@ -47,18 +60,6 @@ def argparse_name(name: str): return None return "--" + name.replace("_", "-") - def interpret_dc_type(field_type): - if isinstance(field_type, str): - raise RuntimeError("field should be a type") - - if field_type == Any: - return str - - typestring = str(field_type) - if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): - return field_type.__args__[0] - return field_type - def get_kwargs_from_dc( dataclass_instance: FairseqDataclass, k: str ) -> Dict[str, Any]: @@ -204,11 +205,25 @@ def get_default(f): val = get_default(v) if not hasattr(args, k) else getattr(args, k) + field_type = interpret_dc_type(v.type) + if ( + isinstance(val, str) + and not val.startswith("${") # not interpolation + and field_type != str + and not issubclass(field_type, Enum) # not choices enum + ): + # upgrade old models that stored complex parameters as string + val = ast.literal_eval(val) + + if isinstance(val, tuple): + val = list(val) + if getattr(v.type, "__origin__", None) is List: # if type is int but val is float, then we will crash later - try to convert here t_args = v.type.__args__ if len(t_args) == 1: val = list(map(t_args[0], val)) + if val is None: overrides.append("{}.{}=null".format(sub_node, k)) elif val == "": diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 6ad59085f0..e6fecdd4fe 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -296,7 +296,7 @@ def __init__(self, cfg: Wav2Vec2Config): self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, num_vars=cfg.latent_vars, - temp=eval(cfg.latent_temp), + temp=cfg.latent_temp, groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, From a66cc28b14ca8ff95e3aadfd3ca77bfb5b00c136 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 11 Nov 2020 10:06:10 -0800 Subject: [PATCH 287/707] fix more bugs incl generating from w2v models (#1419) Summary: fixes several bugs: - populating dataclasses from arg objects - generating from w2v seq2seq models -> fix post processing, and make sure that generate uses the "task" args saved in the model that contain important info about dataset (e.g. whether to normalize it or not) - use task's config object if it exists (so any new fields are picked up) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1419 Reviewed By: myleott Differential Revision: D24853592 Pulled By: alexeib fbshipit-source-id: 463762fa4c0de30e5bcbfca51df84714e4d1f464 --- fairseq/dataclass/utils.py | 12 ++++++++---- fairseq/models/__init__.py | 9 +++++---- fairseq/registry.py | 2 +- fairseq/sequence_generator.py | 2 +- fairseq/tasks/__init__.py | 8 +++++--- fairseq/tasks/audio_pretraining.py | 17 +++++++---------- fairseq_cli/eval_lm.py | 4 ++-- 7 files changed, 29 insertions(+), 25 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index d73af240b9..f8ed8f667f 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -210,7 +210,7 @@ def get_default(f): isinstance(val, str) and not val.startswith("${") # not interpolation and field_type != str - and not issubclass(field_type, Enum) # not choices enum + and inspect.isclass(field_type) and not issubclass(field_type, Enum) # not choices enum ): # upgrade old models that stored complex parameters as string val = ast.literal_eval(val) @@ -229,6 +229,7 @@ def get_default(f): elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) elif isinstance(val, str): + val = val.replace("'", r"\'") overrides.append("{}.{}='{}'".format(sub_node, k, val)) elif isinstance(val, FairseqDataclass): overrides += _override_attr(f"{sub_node}.{k}", type(val), args) @@ -373,7 +374,7 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: def populate_dataclass( - args: Namespace, dataclass: FairseqDataclass + dataclass: FairseqDataclass, args: Namespace, ) -> FairseqDataclass: for k in dataclass.__dataclass_fields__.keys(): if k.startswith("_"): @@ -382,7 +383,7 @@ def populate_dataclass( if hasattr(args, k): setattr(dataclass, k, getattr(args, k)) - return dataclass + return dataclass def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): @@ -395,6 +396,9 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): # "k in cfg" will return false if its a "mandatory value (e.g. ???)" if k in cfg and isinstance(cfg[k], DictConfig): overwrite_args_by_name(cfg[k], overrides) + elif k in cfg and isinstance(cfg[k], Namespace): + for override_key, val in overrides.items(): + setattr(cfg[k], override_key, val) elif k in overrides: if ( k in REGISTRIES @@ -409,7 +413,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): cfg[k] = overrides[k] -def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig): +def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): dc_instance = DictConfig(dc) dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"] cfg = OmegaConf.merge(dc_instance, cfg) diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index b987966749..600ca27c6a 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -8,11 +8,9 @@ import importlib import os -import fairseq from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent +from fairseq.dataclass.utils import merge_with_parent, populate_dataclass from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig, OmegaConf from .composite_encoder import CompositeEncoder from .distributed_fairseq_model import DistributedFairseqModel @@ -78,7 +76,10 @@ def build_model(cfg: FairseqDataclass, task): if model_type in MODEL_DATACLASS_REGISTRY: # set defaults from dataclass. note that arch name and model name can be the same dc = MODEL_DATACLASS_REGISTRY[model_type] - cfg = merge_with_parent(dc(), cfg) + if isinstance(cfg, argparse.Namespace): + cfg = populate_dataclass(dc(), cfg) + else: + cfg = merge_with_parent(dc(), cfg) assert model is not None, f"Could not infer model type from {cfg}" diff --git a/fairseq/registry.py b/fairseq/registry.py index 7a3dd1d1bf..3fbaeac301 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -45,7 +45,7 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs) else: choice = getattr(cfg, registry_name, None) if choice in DATACLASS_REGISTRY: - cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) + cfg = populate_dataclass(DATACLASS_REGISTRY[choice](), cfg) if choice is None: if required: diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 9c5423e2b1..603c5b6821 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -420,7 +420,7 @@ def _generate( break if self.search.stop_on_max_len and step >= max_len: break - assert step < max_len + assert step < max_len, f"{step} < {max_len}" # Remove finalized sentences (ones for which {beam_size} # finished hypotheses have been generated) from the batch. diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 415f15e708..0e55d093b1 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -9,9 +9,8 @@ import os from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent +from fairseq.dataclass.utils import merge_with_parent, populate_dataclass from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa @@ -22,13 +21,16 @@ TASK_CLASS_NAMES = set() -def setup_task(cfg: DictConfig, **kwargs): +def setup_task(cfg: FairseqDataclass, **kwargs): task = None task_name = getattr(cfg, "task", None) if isinstance(task_name, str): # legacy tasks task = TASK_REGISTRY[task_name] + if task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = populate_dataclass(dc(), cfg) else: task_name = getattr(cfg, "_name", None) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index d1b6bf1c14..0f891f1199 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -71,7 +71,7 @@ class AudioPretrainingConfig(FairseqDataclass): metadata={"help": "beam search config for evaluating wer during training"}, ) eval_wer_tokenizer: Any = field( - default="space", + default=None, metadata={"help": "tokenizer config for evaluating wer during training"}, ) eval_wer_post_process: str = field( @@ -185,7 +185,6 @@ def filter_indices_by_size( def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - if self.cfg.eval_wer and self.cfg.autoregressive: metrics = self._inference_with_wer(self.sequence_generator, sample, model) logging_output["_num_char_errors"] = metrics["num_char_errors"] @@ -204,15 +203,16 @@ def build_model(self, model_cfg: FairseqDataclass): ) if self.cfg.eval_wer_tokenizer: self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None return model def _inference_with_wer(self, generator, sample, model): - def decode(toks, escape_unk=True): + def decode(toks): s = self.target_dictionary.string( toks.int().cpu(), self.cfg.eval_wer_post_process, - escape_unk=escape_unk, - extra_symbols_to_ignore={generator.eos}, + escape_unk=True, ) if self.tokenizer: s = self.tokenizer.decode(s) @@ -225,14 +225,11 @@ def decode(toks, escape_unk=True): hyp = decode(gen_out[i][0]["tokens"]) ref = decode( utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), - escape_unk=True, ) - hyp = post_process(hyp, self.cfg.eval_wer_post_process).strip("_") - ref = post_process(ref, self.cfg.eval_wer_post_process).strip("_") num_char_errors += editdistance.eval(hyp, ref) num_chars += len(ref) - hyp_words = hyp.split("_") - ref_words = ref.split("_") + hyp_words = hyp.split() + ref_words = ref.split() num_word_errors += editdistance.eval(hyp_words, ref_words) num_words += len(ref_words) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index e8fd98c325..d962a8145b 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -255,10 +255,10 @@ def main(cfg: DictConfig, **unused_kwargs): wps_meter.update(sample["ntokens"]) progress.log({"wps": round(wps_meter.avg)}) - avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 + avg_nll_loss = -score_sum / count / math.log(2) if count > 0 else 0 # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( - gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg + gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 ) ) logger.info( From 11ea91a33a5788b3b9e7a02cab4bcb158cac8778 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 11 Nov 2020 10:14:04 -0800 Subject: [PATCH 288/707] load dataset with saved task config (optionally) (#1423) Summary: this adds an argument to load_dataset that provides task configuration from the checkpoint. different tasks can decide what to do with it afterwards. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1423 Reviewed By: myleott Differential Revision: D24875706 Pulled By: alexeib fbshipit-source-id: 5bb1e2b7495520c456024dc7b0751b65cb05b473 --- examples/speech_recognition/infer.py | 73 ++++++---------------------- fairseq/checkpoint_utils.py | 10 ++-- fairseq/tasks/audio_pretraining.py | 28 ++++++----- fairseq/tasks/fairseq_task.py | 11 ++++- fairseq_cli/generate.py | 7 ++- fairseq_cli/validate.py | 8 +-- 6 files changed, 57 insertions(+), 80 deletions(-) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 68889463f4..ddd3fd6340 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -8,6 +8,7 @@ Run inference for pre-processed data with a trained model. """ +import ast import logging import math import os @@ -18,7 +19,6 @@ import torch from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data.data_utils import post_process -from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging.meters import StopwatchMeter, TimeMeter @@ -178,53 +178,6 @@ def get_res_file(file_prefix): } -def load_models_and_criterions( - filenames, data_path, arg_overrides=None, task=None, model_state=None -): - models = [] - criterions = [] - - if arg_overrides is None: - arg_overrides = {} - - arg_overrides["wer_args"] = None - arg_overrides["data"] = data_path - - if filenames is None: - assert model_state is not None - filenames = [0] - else: - filenames = filenames.split(":") - - for filename in filenames: - if model_state is None: - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) - else: - state = model_state - - if "cfg" in state: - cfg = state["cfg"] - else: - cfg = convert_namespace_to_omegaconf(state["args"]) - - if task is None: - if hasattr(cfg.task, 'data'): - cfg.task.data = data_path - task = tasks.setup_task(cfg.task) - - model = task.build_model(cfg.model) - model.load_state_dict(state["model"], strict=True) - models.append(model) - - criterion = task.build_criterion(cfg.criterion) - if "criterion" in state: - criterion.load_state_dict(state["criterion"], strict=True) - criterions.append(criterion) - return models, criterions, task - - def optimize_models(args, use_cuda, models): """Optimize ensemble for generation""" for model in models: @@ -266,23 +219,26 @@ def main(args, task=None, model_state=None): logger.info("| decoding with criterion {}".format(args.criterion)) + task = tasks.setup_task(args) + # Load ensemble if args.load_emissions: models, criterions = [], [] - task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) - models, criterions, task = load_models_and_criterions( - args.path, - data_path=args.data, - arg_overrides=eval(args.model_overrides), # noqa + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(args.path), + arg_overrides=ast.literal_eval(args.model_overrides), task=task, - model_state=model_state, + suffix=args.checkpoint_suffix, + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, + state=model_state, ) optimize_models(args, use_cuda, models) + task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) - # Load dataset splits - task.load_dataset(args.gen_subset) # Set dictionary tgt_dict = task.target_dictionary @@ -295,8 +251,9 @@ def main(args, task=None, model_state=None): # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": - trans = criterions[0].asg.trans.data - args.asg_transitions = torch.flatten(trans).tolist() + raise NotImplementedError("asg_loss is currently not supported") + # trans = criterions[0].asg.trans.data + # args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 3038a1ebcc..5a0dc099b2 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -239,7 +239,7 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): def load_model_ensemble( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None ): """Loads an ensemble of models. @@ -259,12 +259,13 @@ def load_model_ensemble( strict, suffix, num_shards, + state, ) return ensemble, args def load_model_ensemble_and_task( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None ): from fairseq import tasks @@ -272,8 +273,10 @@ def load_model_ensemble_and_task( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] + cfg = None for filename in filenames: orig_filename = filename + assert num_shards > 0 for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") @@ -282,7 +285,8 @@ def load_model_ensemble_and_task( if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) - state = load_checkpoint_to_cpu(filename, arg_overrides) + if state is None: + state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 0f891f1199..90e667c80d 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -10,6 +10,7 @@ import sys import torch +from argparse import Namespace from dataclasses import dataclass, field from typing import Optional, Any from omegaconf import MISSING @@ -123,25 +124,28 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): return cls(cfg, target_dictionary=target_dictionary) - def load_dataset(self, split, **kwargs): - """Load a given dataset split. + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg - Args: - split (str): name of the split (e.g., train, valid, test) - """ - manifest = os.path.join(self.cfg.data, "{}.tsv".format(split)) + # upgrade old task + if isinstance(task_cfg, Namespace): + if not hasattr(task_cfg, "autoregressive"): + task_cfg.autoregressive = not task_cfg.criterion == 'ctc' + + manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=self.cfg.sample_rate, + sample_rate=task_cfg.sample_rate, max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, - pad=self.cfg.labels is not None or self.cfg.enable_padding, - normalize=self.cfg.normalize, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, ) - if self.cfg.labels: - label_path = os.path.join(self.cfg.data, f"{split}.{self.cfg.labels}") + if task_cfg.labels: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") labels = [] with open(label_path, "r") as f: for line in f: @@ -156,7 +160,7 @@ def load_dataset(self, split, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=self.cfg.autoregressive, + add_to_input=task_cfg.autoregressive, ) @property diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index c47f9c4200..f62973a534 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -91,11 +91,20 @@ def setup_task(cls, cfg: DictConfig, **kwargs): def has_sharded_data(self, split): return os.pathsep in getattr(self.cfg, "data", "") - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset( + self, + split: str, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets """ raise NotImplementedError diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 021f819ed7..6be8150cda 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -81,7 +81,7 @@ def _main(cfg: DictConfig, output_file): # Load dataset splits task = tasks.setup_task(cfg.task) - task.load_dataset(cfg.dataset.gen_subset) + # Set dictionaries try: @@ -94,7 +94,7 @@ def _main(cfg: DictConfig, output_file): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, _model_args = checkpoint_utils.load_model_ensemble( + models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, @@ -103,6 +103,9 @@ def _main(cfg: DictConfig, output_file): num_shards=cfg.checkpoint.checkpoint_shard_count, ) + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 36e8bd16ca..f6f0c9265c 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -51,7 +51,7 @@ def main(cfg: DictConfig, override_args=None): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, @@ -66,15 +66,15 @@ def main(cfg: DictConfig, override_args=None): model.cuda() # Print args - logger.info(model_args) + logger.info(saved_cfg) # Build criterion - criterion = task.build_criterion(model_args.criterion) + criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: - task.load_dataset(subset, combine=False, epoch=1) + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) From e607911dde205e2188d3e62dcde592a6d84b4c46 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 11 Nov 2020 17:32:30 -0800 Subject: [PATCH 289/707] fix passing task config in validate.py (#1426) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1426 Reviewed By: aconneau Differential Revision: D24895299 Pulled By: alexeib fbshipit-source-id: 7af96952b857fa4616cdafd0268d8ab6cb94c61d --- fairseq_cli/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index f6f0c9265c..c69bb94142 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -74,7 +74,7 @@ def main(cfg: DictConfig, override_args=None): for subset in cfg.dataset.valid_subset.split(","): try: - task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg) + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) From b55053373fb8678361a45a1d2c1b462befd9ab1a Mon Sep 17 00:00:00 2001 From: Angela Fan Date: Fri, 13 Nov 2020 09:49:40 -0800 Subject: [PATCH 290/707] update m2m readme (#2890) Summary: adding smaller models Pull Request resolved: https://github.com/pytorch/fairseq/pull/2890 Reviewed By: ngoyal2707 Differential Revision: D24935146 Pulled By: huihuifan fbshipit-source-id: 2ba8e4083b9805d336154e3cc0d6d7bed71cca04 --- examples/m2m_100/README.md | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index 0bacd4c8b1..f1b465c7b9 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -14,8 +14,8 @@ sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en # WAT -wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip -unzip wat2019.my-en.zip +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip +unzip wat2020.my-en.zip # FLORES # download from: https://github.com/facebookresearch/flores @@ -116,7 +116,22 @@ If you use any of the resources listed here, please cite: ## Trained Models -More models coming up soon. +### 418M and 1.2B Model +We include the last checkpoint for both of these models. + +```bash +wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt +wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs_small_models.txt + +# 418M parameter model +wget https://dl.fbaipublicfiles.com/m2m_100/418M_last_checkpoint.pt + +# 1.2B parameter model +wget https://dl.fbaipublicfiles.com/m2m_100/1.2B_last_checkpoint.pt + +# Generation: +fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models --decoder-langtok --encoder-langtok src --gen-subset test > gen_out +``` ### 12B Model 12B parameter model trained on many-to-many training data for 100 languages. We include the last checkpoint, average of last 5 checkpoints, average of last 10 checkpoints. There isn't a universally best choice out of these three, but all three versions are pretty close in accuracy. You can either sweep over the 3 checkpoints on a dev test and use the best performing checkpoint for final testing. Or the last checkpoint can be a good default choice. From dc6a84f1433beaf3f6332ea181231055249be684 Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Fri, 13 Nov 2020 12:22:46 -0800 Subject: [PATCH 291/707] Make BART models compatiable with JIT Summary: Bart models are not compatible with JIT. This diff makes minor changes to enable its compabilitity Reviewed By: myleott Differential Revision: D24824963 fbshipit-source-id: 41cbcc46c14b0439f5763478b8efe98e5516dc95 --- fairseq/models/bart/model.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index e105d6fc46..44f03b0162 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -6,6 +6,7 @@ BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension """ +from typing import Optional import logging @@ -24,6 +25,8 @@ @register_model("bart") class BARTModel(TransformerModel): + __jit_unused_properties__ = ["supported_targets"] + @classmethod def hub_models(cls): return { @@ -41,6 +44,8 @@ def __init__(self, args, encoder, decoder): self.apply(init_bert_params) self.classification_heads = nn.ModuleDict() + if hasattr(self.encoder, "dictionary"): + self.eos: int = self.encoder.dictionary.eos() @staticmethod def add_args(parser): @@ -71,10 +76,12 @@ def forward( src_tokens, src_lengths, prev_output_tokens, - features_only=False, - classification_head_name=None, - token_embeddings=None, - **kwargs, + features_only: bool = False, + classification_head_name: Optional[str] = None, + token_embeddings: Optional[torch.Tensor] = None, + return_all_hiddens: bool = True, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, ): if classification_head_name is not None: features_only = True @@ -83,22 +90,27 @@ def forward( src_tokens, src_lengths=src_lengths, token_embeddings=token_embeddings, - **kwargs, + return_all_hiddens=return_all_hiddens ) x, extra = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, - **kwargs, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, ) - + eos: int = self.eos if classification_head_name is not None: sentence_representation = x[ - src_tokens.eq(self.encoder.dictionary.eos()), : + src_tokens.eq(eos), : ].view(x.size(0), -1, x.size(-1))[:, -1, :] - x = self.classification_heads[classification_head_name]( - sentence_representation - ) + for k, head in self.classification_heads.items(): + # for torch script only supports iteration + if k == classification_head_name: + x = head(sentence_representation) + break return x, extra @classmethod From 3c5647cebf454c07b52a0fb899c920789381ebda Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Fri, 13 Nov 2020 16:11:44 -0800 Subject: [PATCH 292/707] add grad_norm infinity check Summary: add grad_norm check for fp32 cases and single node training as well. Triggers nan detector when the grad_norm check fails, should help debug nan/inf cases. also fixing a bug (i think) in the original check_grad_norm() where [float('inf'), float('inf')] can pass the check. Reviewed By: myleott Differential Revision: D24849271 fbshipit-source-id: 2382342cd549717f3ff178b9aa29933f486327c8 --- fairseq/trainer.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 19ca213d55..d7ba0be874 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -632,12 +632,16 @@ def maybe_no_sync(): grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers - if ( - not self.cfg.optimization.use_bmuf - and self.cfg.distributed_training.distributed_wrapper != "SlowMo" - and not self.tpu - ): - self._check_grad_norms(grad_norm) + # on tpu check tensor is slow + if not self.tpu: + if ( + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" + ): + self._check_grad_norms(grad_norm) + if not torch.isfinite(grad_norm).all(): + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step @@ -1078,7 +1082,7 @@ def _check_grad_norms(self, grad_norm): def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( - not torch.isfinite(tensor).any() + torch.isfinite(tensor).all() or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) @@ -1090,7 +1094,8 @@ def is_consistent(tensor): error_detail = "grad_norm across the workers:\n{}\n".format( pretty_detail ) - raise RuntimeError( + # use FloatingPointError to trigger NanDetector + raise FloatingPointError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d. " "Or are you mixing up different generation of GPUs in training?" From b987d30c69d773ccbf7d432e1cd0878aa9e50196 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Sat, 14 Nov 2020 08:45:03 -0800 Subject: [PATCH 293/707] Delete duplicate code in RoBERTa model (#2891) Summary: Simple fix. Lines 498 and 499 are the same. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2891 Reviewed By: alexeib Differential Revision: D24953450 Pulled By: myleott fbshipit-source-id: 7745d066ed1e431edc39e99dd72ec8937235f752 --- fairseq/models/roberta/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 0f6efe5b33..96a7b9c8a2 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -496,7 +496,6 @@ def base_architecture(args): args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False From 0d03fbedcf79b63901b8718b4c61c525464cb198 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 14 Nov 2020 08:46:07 -0800 Subject: [PATCH 294/707] deprecation warning fixes (#2881) Summary: ## What does this PR do? Fixes: - 2x `DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working` - 1x `/fairseq/optim/adam.py:98: DeprecationWarning: invalid escape sequence \:` This is with py38. Pull Request resolved: https://github.com/pytorch/fairseq/pull/2881 Reviewed By: alexeib Differential Revision: D24959633 Pulled By: myleott fbshipit-source-id: ac563e194d5f07e3817de55729b0448366a6dc23 --- fairseq/optim/adam.py | 4 ++-- fairseq/optim/nag.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 9b8ddffd7e..1a4f213707 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,7 +5,7 @@ import logging import math -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List @@ -95,7 +95,7 @@ def average_params(self): class Adam(torch.optim.Optimizer): - """Implements Adam algorithm. + r"""Implements Adam algorithm. This implementation is modified from torch.optim.Adam based on: `Fixed Weight Decay Regularization in Adam` diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 3982a8271d..c612d812c9 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List From 0a848245f3e00ee39a68fddf54f738de11dd8cc8 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 15 Nov 2020 19:46:48 -0800 Subject: [PATCH 295/707] Add Truncated BPTT example + TransformerXL (#1410) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1410 Test Plan: - reproduced Transformer-XL results (see README) - added integration test Reviewed By: jingfeidu Differential Revision: D24928966 Pulled By: myleott fbshipit-source-id: 86376c17ab24d37e72e7c097b6dcec71b1a087a7 --- README.md | 23 +- examples/__init__.py | 5 +- examples/criss/mining/mine.py | 9 +- examples/truncated_bptt/README.md | 70 +++++ examples/truncated_bptt/__init__.py | 6 + .../truncated_bptt/transformer_xl_model.py | 146 +++++++++++ .../truncated_bptt/truncated_bptt_lm_task.py | 246 ++++++++++++++++++ fairseq/tasks/fairseq_task.py | 6 + tests/test_binaries.py | 86 +++++- 9 files changed, 576 insertions(+), 21 deletions(-) create mode 100644 examples/truncated_bptt/README.md create mode 100644 examples/truncated_bptt/__init__.py create mode 100644 examples/truncated_bptt/transformer_xl_model.py create mode 100644 examples/truncated_bptt/truncated_bptt_lm_task.py diff --git a/README.md b/README.md index 0648da15f7..3ae332b350 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ We provide reference implementations of various sequence modeling papers: + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](examples/truncated_bptt/README.md) + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) @@ -59,8 +60,9 @@ We provide reference implementations of various sequence modeling papers: ### What's New: -* November 2020: Adopted [Hydra](https://github.com/facebookresearch/hydra) as a configuration framework; -[added documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework + * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) * October 2020: [Added CRISS models and code](examples/criss/README.md) @@ -69,13 +71,13 @@ We provide reference implementations of various sequence modeling papers: * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) + +
Previous updates

+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) * April 2020: [Quant-Noise code released](examples/quant_noise/README.md) * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) - -

Previous updates

- * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) * February 2020: [mBART model and code released](examples/mbart/README.md) * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) @@ -99,10 +101,10 @@ We provide reference implementations of various sequence modeling papers: + beam search + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + sampling (unconstrained, top-k and top-p/nucleus) - + lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) -* large mini-batch training even on a single GPU via delayed updates -* mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -* extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers + + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) +* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU +* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) @@ -131,6 +133,9 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ + +# to install the latest stable release (0.10.0) +# pip install fairseq==0.10.0 ``` * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: diff --git a/examples/__init__.py b/examples/__init__.py index 80d95f5fe7..44bb24ae61 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -3,4 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.version import __version__ # noqa +try: + from fairseq.version import __version__ # noqa +except ImportError: + pass diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py index c86f73ae87..c872da196f 100644 --- a/examples/criss/mining/mine.py +++ b/examples/criss/mining/mine.py @@ -7,7 +7,12 @@ import glob from subprocess import check_call -import faiss +try: + import faiss + + has_faiss = True +except ImportError: + has_faiss = False import numpy as np @@ -40,6 +45,8 @@ def load_batch(emb_file, dim): def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"): + if not has_faiss: + raise ImportError("Please install Faiss") sims = [] inds = [] xfrom = 0 diff --git a/examples/truncated_bptt/README.md b/examples/truncated_bptt/README.md new file mode 100644 index 0000000000..f5c6447f1c --- /dev/null +++ b/examples/truncated_bptt/README.md @@ -0,0 +1,70 @@ +# Truncated Backpropagation Through Time (BPTT) + +Truncated BPTT is a useful technique for training language models on very long +sequences. Typically a long sequences is split into chunks and a language model +is trained over the chunks sequentially. The LM may condition on previous +chunks, but gradients only flow through the current chunk. This technique was +the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a +Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved +state-of-the-art language modeling results at the time of publication. + +It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since +we need to iterate over the data sequentially and disable any batch shuffling +logic. The code provided in this example illustrates how to implement Truncated +BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate +over the data sequentially. Crucially, this example supports batching and +multi-GPU (data parallel) training. + +##### 0. Setup + +First, see the general [language modeling README](README.md) for instructions on +preprocessing the WikiText-103 data. + +##### 1. Train a Transformer-XL model on WikiText-103 + +We will train a 16-layer Transformer-XL model following the [hyperparameters +used in the original +paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). + +The following command assumes 4 GPUs, so that the total batch size is 60 +sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/truncated_bptt \ + data-bin/wikitext-103/ \ + --task truncated_bptt_lm --tokens-per-sample 150 \ + --batch-size 15 --max-update 200000 \ + --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \ + --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \ + --optimizer adam --clip-norm 0.25 \ + --lr-scheduler cosine --warmup-updates 0 --lr 0.0 --max-lr 0.00025 \ + --log-format json --log-interval 25 \ + --fp16 +``` + +If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients +and simulate training on 4 GPUs. + +##### 2. Evaluate + +```bash +fairseq-eval-lm data-bin/wikitext-103/ \ + --path checkpoints/checkpoint_best.pt \ + --user-dir examples/truncated_bptt/ \ + --task truncated_bptt_lm \ + --batch-size 1 --required-batch-size-multiple 1 \ + --model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \ + --tokens-per-sample 64 +# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537 +# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s) +# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70 +# Compare to 24.0 test perplexity from the paper +``` + +*Note:* During training the model saw 150 tokens of context +(``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``). +During evaluation we measure perplexity on sequences of 64 tokens +(``--tokens-per-sample=64``) and increase the memory length +(``--model-overrides='{"mem_len":640}'``). These settings match the evaluation +settings from [the original +paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). diff --git a/examples/truncated_bptt/__init__.py b/examples/truncated_bptt/__init__.py new file mode 100644 index 0000000000..eee484d427 --- /dev/null +++ b/examples/truncated_bptt/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import transformer_xl_model, truncated_bptt_lm_task # noqa diff --git a/examples/truncated_bptt/transformer_xl_model.py b/examples/truncated_bptt/transformer_xl_model.py new file mode 100644 index 0000000000..7466c951ab --- /dev/null +++ b/examples/truncated_bptt/transformer_xl_model.py @@ -0,0 +1,146 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class TransformerXLConfig(FairseqDataclass): + # defaults come from the original Transformer-XL code + cutoffs: List[int] = field(default_factory=lambda: [20000, 40000, 200000]) + d_model: int = 500 + n_head: int = 10 + d_head: int = 50 + d_inner: int = 1000 + div_val: int = 1 + n_layer: int = 12 + mem_len: int = 0 + clamp_len: int = -1 + same_length: bool = False + dropout: float = 0.0 + dropatt: float = 0.0 + checkpoint_activations: bool = False + max_target_positions: int = II("task.max_target_positions") + + +@register_model("transformer_xl", dataclass=TransformerXLConfig) +class TransformerXLLanguageModel(FairseqLanguageModel): + @classmethod + def build_model(cls, cfg: TransformerXLConfig, task): + return cls(TransformerXLDecoder(cfg, task)) + + +class TransformerXLDecoder(FairseqIncrementalDecoder): + def __init__(self, cfg, task): + from transformers.configuration_transfo_xl import TransfoXLConfig + from transformers.modeling_transfo_xl import TransfoXLLMHeadModel + + super().__init__(task.target_dictionary) + self.cfg = cfg + + # remove any cutoffs larger than the vocab size + cutoffs = [ + cutoff for cutoff in cfg.cutoffs if cutoff < len(task.target_dictionary) + ] + + config = TransfoXLConfig( + vocab_size=len(task.target_dictionary), + cutoffs=cutoffs, + d_model=cfg.d_model, + d_embed=cfg.d_model, + n_head=cfg.n_head, + d_head=cfg.d_head, + d_inner=cfg.d_inner, + div_val=cfg.div_val, + n_layer=cfg.n_layer, + mem_len=cfg.mem_len, + clamp_len=cfg.clamp_len, + same_length=cfg.same_length, + dropout=cfg.dropout, + dropatt=cfg.dropatt, + ) + logger.info(config) + self.model = TransfoXLLMHeadModel(config) + + # Workaround a bug in huggingface's ``ProjectedAdaptiveLogSoftmax`` + # which adds ``None`` values to an ``nn.ParameterList``, which is not + # supported in PyTorch. Instead we can replace this with an + # ``nn.ModuleList``, which does support ``None`` values. + try: + if all(p is None for p in self.model.crit.out_projs._parameters.values()): + self.model.crit.out_projs = torch.nn.ModuleList( + [None] * len(self.model.crit.out_projs._parameters) + ) + except Exception: + pass + + if cfg.checkpoint_activations: + for i in range(len(self.model.transformer.layers)): + self.model.transformer.layers[i] = checkpoint_wrapper( + self.model.transformer.layers[i] + ) + + self._mems = None + + def forward( + self, + src_tokens, + src_lengths=None, # unused + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + if incremental_state is not None: # used during inference + mems = self.get_incremental_state(incremental_state, "mems") + src_tokens = src_tokens[:, -1:] # only keep the most recent token + else: + mems = self._mems + + output = self.model( + input_ids=src_tokens, + mems=mems, + return_dict=False, + ) + + if len(output) >= 2: + if incremental_state is not None: + self.set_incremental_state(incremental_state, "mems", output[1]) + else: + self._mems = output[1] + + return (output[0],) + + def max_positions(self): + return self.cfg.max_target_positions + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + new_order: torch.Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + mems = self.get_incremental_state(incremental_state, "mems") + if mems is not None: + new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] + self.set_incremental_state(incremental_state, "mems", new_mems) diff --git a/examples/truncated_bptt/truncated_bptt_lm_task.py b/examples/truncated_bptt/truncated_bptt_lm_task.py new file mode 100644 index 0000000000..5f81ec4b84 --- /dev/null +++ b/examples/truncated_bptt/truncated_bptt_lm_task.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import torch +from fairseq import distributed_utils as dist_utils, utils +from fairseq.data import Dictionary, TokenBlockDataset, data_utils, iterators +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class TruncatedBPTTLMConfig(FairseqDataclass): + data: str = field(default="???", metadata={"help": "path to data directory"}) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sequence"}, + ) + batch_size: int = II("dataset.batch_size") + # Some models use *max_target_positions* to know how many positional + # embeddings to learn. We use II(...) to make it default to + # *tokens_per_sample*, but in principle there could be more positional + # embeddings than tokens in a single batch. This may also be irrelevant for + # custom model implementations. + max_target_positions: int = II("task.tokens_per_sample") + # these will be populated automatically if not provided + data_parallel_rank: Optional[int] = None + data_parallel_size: Optional[int] = None + + +@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig) +class TruncatedBPTTLMTask(FairseqTask): + def __init__(self, cfg: TruncatedBPTTLMConfig): + super().__init__(cfg) + + if cfg.data_parallel_rank is None or cfg.data_parallel_size is None: + if torch.distributed.is_initialized(): + cfg.data_parallel_rank = dist_utils.get_data_parallel_rank() + cfg.data_parallel_size = dist_utils.get_data_parallel_world_size() + else: + cfg.data_parallel_rank = 0 + cfg.data_parallel_size = 1 + + # load the dictionary + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(self.dictionary))) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test)""" + + # support sharded datasets + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + # each element of *data* will be a tensorized line from the original + # text dataset, similar to ``open(split_path).readlines()`` + data = data_utils.load_indexed_dataset( + split_path, self.dictionary, combine=combine + ) + if data is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + # this is similar to ``data.view(-1).split(tokens_per_sample)`` + data = TokenBlockDataset( + data, + data.sizes, + block_size=self.cfg.tokens_per_sample, + pad=None, # unused + eos=None, # unused + break_mode="none", + ) + + self.datasets[split] = TruncatedBPTTDataset( + data=data, + bsz_per_shard=self.cfg.batch_size, + shard_id=self.cfg.data_parallel_rank, + num_shards=self.cfg.data_parallel_size, + ) + + def dataset(self, split): + return self.datasets[split] + + def get_batch_iterator( + self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs + ): + return iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=self._collate_fn, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + # we don't use the batching functionality from EpochBatchIterator; + # instead every item in *dataset* is a whole batch + batch_sampler=[[i] for i in range(len(dataset))], + disable_shuffling=True, + ) + + def _collate_fn(self, items: List[List[torch.Tensor]]): + # we don't use fairseq's batching functionality, so we expect a single + # Tensor of type List[torch.Tensor] + assert len(items) == 1 + + # item will have shape B x T (the last batch may have length < T) + id, item = items[0] + item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad()) + B, T = item.size() + + # shift item one position over and append a padding token for the target + target = torch.nn.functional.pad( + item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad() + ) + + # fairseq expects batches to have the following structure + return { + "id": torch.tensor([id]*item.size(0)), + "net_input": { + "src_tokens": item, + }, + "target": target, + "nsentences": item.size(0), + "ntokens": item.numel(), + } + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + eos = self.source_dictionary.eos() + dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=eos, + break_mode="eos", + ) + + class Dataset(torch.utils.data.Dataset): + def __getitem__(self, i): + item = dataset[i] + if item[-1] == eos: + # remove eos to support generating with a prefix + item = item[:-1] + return (i, [item]) + + def __len__(self): + return len(dataset) + + return Dataset() + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if constraints is not None: + raise NotImplementedError + + # SequenceGenerator doesn't use *src_tokens* directly, we need to + # pass the *prefix_tokens* argument instead. + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + + # begin generation with the end-of-sentence token + bos_token = self.source_dictionary.eos() + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class TruncatedBPTTDataset(torch.utils.data.Dataset): + def __init__( + self, + data: List[torch.Tensor], # ordered list of items + bsz_per_shard, # number of items processed per GPUs per forward + shard_id, # current GPU ID + num_shards, # number of GPUs + ): + super().__init__() + self.data = data + + def batchify(data, bsz): + # Work out how cleanly we can divide the dataset into bsz parts. + nbatch = data.size(0) // bsz + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, nbatch * bsz) + # Evenly divide the data across the bsz batches. + data = data.view(bsz, -1).contiguous() + return data + + # total number of sequences processed by all GPUs in each forward pass + global_batch_size = bsz_per_shard * num_shards + + """ + With a 16 item dataset, bsz_per_shard=2 and num_shards=3, + *indices* might look like: + + indices = [[0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11]] + + The size of the TruncatedBPTTDataset instance will be 2, + and shard 1 will see items: + + [(0, [data[4], data[6]]), + (1, [data[5], data[7]])] + """ + indices = batchify(torch.arange(len(data)), global_batch_size) + assert indices.size(0) == global_batch_size + + self.my_indices = indices[ + shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard + ] + assert self.my_indices.size(0) == bsz_per_shard + + def __len__(self): + return self.my_indices.size(1) + + def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]: + return (i, [self.data[idx] for idx in self.my_indices[:, i]]) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index f62973a534..d34f09d1d7 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,6 +7,7 @@ import os import warnings from argparse import Namespace +from typing import List import torch from fairseq import metrics, search, tokenizer, utils @@ -437,6 +438,11 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + raise NotImplementedError + def inference_step( self, generator, models, sample, prefix_tokens=None, constraints=None ): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index dae38dda0c..4d3393ae40 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -25,6 +25,14 @@ ) +try: + import transformers # noqa + + has_hf_transformers = True +except ImportError: + has_hf_transformers = False + + class TestTranslation(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -963,6 +971,36 @@ def test_transformer_lm(self): ], ) + def test_transformer_lm_with_adaptive_softmax(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm_with_adaptive_softmax") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + [ + "--add-bos-token", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + ], + run_validation=True, + ) + eval_lm_main(data_dir) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + def test_lightconv_lm(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_lightconv_lm") as data_dir: @@ -1035,6 +1073,35 @@ def test_lstm_lm_residuals(self): ], ) + @unittest.skipIf(not has_hf_transformers, "skip test if transformers is missing") + def test_transformer_xl_bptt_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_xl_bptt_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + task_flags = [ + "--user-dir", + "examples/truncated_bptt", + "--task", + "truncated_bptt_lm", + "--batch-size", + "2", + "--tokens-per-sample", + "50", + ] + train_language_model( + data_dir=data_dir, + arch="transformer_xl", + extra_flags=task_flags + [ + "--n-layer", + "2", + ], + task="truncated_bptt_lm", + run_validation=True, + extra_valid_flags=task_flags, + ) + eval_lm_main(data_dir, extra_flags=task_flags) + class TestMaskedLanguageModel(unittest.TestCase): def setUp(self): @@ -1478,13 +1545,15 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): train.main(train_args) -def train_language_model(data_dir, arch, extra_flags=None, run_validation=False): +def train_language_model( + data_dir, arch, extra_flags=None, run_validation=False, extra_valid_flags=None, task="language_modeling" +): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", - "language_modeling", + task, data_dir, "--arch", arch, @@ -1492,10 +1561,6 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) "adam", "--lr", "0.0001", - "--criterion", - "adaptive_loss", - "--adaptive-softmax-cutoff", - "5,10,15", "--max-tokens", "500", "--tokens-per-sample", @@ -1523,7 +1588,7 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) validate_parser, [ "--task", - "language_modeling", + task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), @@ -1534,12 +1599,13 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) "--no-progress-bar", "--num-workers", "0", - ], + ] + + (extra_valid_flags or []), ) validate.main(validate_args) -def eval_lm_main(data_dir): +def eval_lm_main(data_dir, extra_flags=None): eval_lm_parser = options.get_eval_lm_parser() eval_lm_args = options.parse_args_and_arch( eval_lm_parser, @@ -1550,7 +1616,7 @@ def eval_lm_main(data_dir): "--no-progress-bar", "--num-workers", "0", - ], + ] + (extra_flags or []), ) eval_lm.main(eval_lm_args) From dc1eaf3dde83494037a8727de5897a43c46e0b46 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 16 Nov 2020 09:10:56 -0800 Subject: [PATCH 296/707] Remove unused hf/transformers submodule (#1435) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1435 Reviewed By: huihuifan Differential Revision: D24973816 Pulled By: myleott fbshipit-source-id: 1565dfc3f7e8db65ded4af92d1afd7aff8d19294 --- .gitmodules | 4 ---- fairseq/models/huggingface/transformers | 1 - 2 files changed, 5 deletions(-) delete mode 160000 fairseq/models/huggingface/transformers diff --git a/.gitmodules b/.gitmodules index df0d3d3071..07a55d45d4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ -[submodule "fairseq/models/huggingface/transformers"] - path = fairseq/models/huggingface/transformers - url = https://github.com/myleott/transformers.git - branch = fairseq [submodule "fairseq/model_parallel/megatron"] path = fairseq/model_parallel/megatron url = https://github.com/ngoyal2707/Megatron-LM diff --git a/fairseq/models/huggingface/transformers b/fairseq/models/huggingface/transformers deleted file mode 160000 index 839f8a563c..0000000000 --- a/fairseq/models/huggingface/transformers +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 839f8a563cefcb7f2048b310024c217e7829a198 From 52d774cda9e8926ab42210da05715789fc567d8e Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 16 Nov 2020 11:08:13 -0800 Subject: [PATCH 297/707] fix gumbel temp arg (#1438) Summary: Fix #2897 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1438 Reviewed By: myleott Differential Revision: D24992106 Pulled By: alexeib fbshipit-source-id: 0cb15c2e865c3e8f7950e8f5e6c54c5000637af2 --- fairseq/modules/gumbel_vector_quantizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 47657bb0ab..7113438888 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -73,7 +73,10 @@ def block(input_dim, output_dim): nn.init.normal_(self.weight_proj.weight, mean=0, std=1) nn.init.zeros_(self.weight_proj.bias) - assert len(temp) == 3, temp + if isinstance(temp, str): + import ast + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" self.max_temp, self.min_temp, self.temp_decay = temp self.curr_temp = self.max_temp From add65adcc53a927f99a717d90a9672765237d937 Mon Sep 17 00:00:00 2001 From: Juan Miguel Pino Date: Mon, 16 Nov 2020 12:41:26 -0800 Subject: [PATCH 298/707] Replace encoder output type (#1281) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1281 The PyTorch Mobile lite interpreter does not support NamedTuple creation in forward. One workaround is to replace NamedTuple with a custom class that inherits nn.Module. This class could be initialized in `__init__` and updated in forward. However, lite interpreter does not support list construction with custom classes. So the final solution is to replace the NamedTuple with a dictionary. We cannot have mixed value types in that dictionary, otherwise, this breaks TorchScript export. So the type is List[Tensor] and an empty list corresponds to having a value of None. Reviewed By: myleott Differential Revision: D23752010 fbshipit-source-id: 0b152a534a165ce4f84bd4f580d7f29145cfd264 --- .../pointer_generator_src/transformer_pg.py | 18 +-- .../mean_pool_gating_network.py | 10 +- fairseq/models/nat/levenshtein_transformer.py | 13 +- .../nat/nonautoregressive_transformer.py | 34 +++-- fairseq/models/transformer.py | 116 +++++++++--------- fairseq/sequence_generator.py | 9 +- 6 files changed, 114 insertions(+), 86 deletions(-) diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index 079fdda581..fb40a80836 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -185,14 +185,14 @@ def forward(self, src_tokens, src_lengths, **kwargs): `(batch, src_len)` """ encoder_out = super().forward(src_tokens, src_lengths, **kwargs) - return EncoderOut( - encoder_out=encoder_out.encoder_out, # T x B x C - encoder_padding_mask=encoder_out.encoder_padding_mask, # B x T - encoder_embedding=encoder_out.encoder_embedding, # B x T x C - encoder_states=encoder_out.encoder_states, # List[T x B x C] - src_tokens=src_tokens, # B x T - src_lengths=None, - ) + return { + "encoder_out": encoder_out["encoder_out"], # T x B x C + "encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T + "encoder_embedding": encoder_out["encoder_embedding"], # B x T x C + "encoder_states": encoder_out["encoder_states"], # List[T x B x C] + "src_tokens": [src_tokens], # B x T + "src_lengths": [], + } class TransformerPointerGeneratorDecoder(TransformerDecoder): @@ -284,7 +284,7 @@ def forward( predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) p_gens = torch.sigmoid(p_gens) - x = self.output_layer(x, extra["attn"][0], encoder_out.src_tokens, p_gens) + x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens) return x, extra def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): diff --git a/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py index 484b6ac912..efc7ae40bf 100644 --- a/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py +++ b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py @@ -26,15 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None): def forward(self, encoder_out): if not ( - hasattr(encoder_out, "encoder_out") - and hasattr(encoder_out, "encoder_padding_mask") - and encoder_out.encoder_out.size(2) == self.embed_dim + "encoder_out" in encoder_out + and "encoder_padding_mask" in encoder_out + and encoder_out["encoder_out"][0].size(2) == self.embed_dim ): raise ValueError("Unexpected format for encoder_out") # mean pooling over time - encoder_padding_mask = encoder_out.encoder_padding_mask # B x T - encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C + encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T + encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C if encoder_padding_mask is not None: encoder_out = encoder_out.clone() # required because of transpose above encoder_out[encoder_padding_mask] = 0 diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index f7a3f003ca..17f1ee99be 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -256,7 +256,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() - ).type_as(encoder_out.encoder_out) + ).type_as(encoder_out["encoder_out"][0]) return DecoderOut( output_tokens=initial_output_tokens, @@ -357,8 +357,15 @@ def extract_features( for _, layer in enumerate(layers[:early_exit]): x, attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py index 735297fc29..d114202d25 100644 --- a/fairseq/models/nat/nonautoregressive_transformer.py +++ b/fairseq/models/nat/nonautoregressive_transformer.py @@ -163,7 +163,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() - ).type_as(encoder_out.encoder_out) + ).type_as(encoder_out["encoder_out"][0]) return DecoderOut( output_tokens=initial_output_tokens, @@ -233,8 +233,11 @@ def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused): @ensemble_decoder def forward_length(self, normalize, encoder_out): - enc_feats = encoder_out.encoder_out # T x B x C - src_masks = encoder_out.encoder_padding_mask # B x T or None + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None enc_feats = _mean_pooling(enc_feats, src_masks) if self.sg_length_pred: enc_feats = enc_feats.detach() @@ -264,8 +267,11 @@ def extract_features( """ # embedding if embedding_copy: - src_embd = encoder_out.encoder_embedding - src_mask = encoder_out.encoder_padding_mask + src_embd = encoder_out["encoder_embedding"][0] + if len(encoder_out["encoder_padding_mask"]) > 0: + src_mask = encoder_out["encoder_padding_mask"][0] + else: + src_mask = None src_mask = ( ~src_mask if src_mask is not None @@ -297,8 +303,15 @@ def extract_features( x, attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) @@ -353,8 +366,11 @@ def forward_copying_source(self, src_embeds, src_masks, tgt_masks): return copied_embedding def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None): - enc_feats = encoder_out.encoder_out # T x B x C - src_masks = encoder_out.encoder_padding_mask # B x T or None + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None if self.pred_length_offset: if src_masks is None: src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 7614c33f74..70920ed779 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -16,7 +16,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.modules import ( AdaptiveSoftmax, FairseqDropout, @@ -425,7 +424,7 @@ def forward( # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) - encoder_states = [] if return_all_hiddens else None + encoder_states = [] # encoder layers for layer in self.layers: @@ -437,17 +436,21 @@ def forward( if self.layer_norm is not None: x = self.layer_norm(x) - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask, # B x T - encoder_embedding=encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=None, - src_lengths=None, - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } @torch.jit.export - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. @@ -458,50 +461,46 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - """ - Since encoder_padding_mask and encoder_embedding are both of type - Optional[Tensor] in EncoderOut, they need to be copied as local - variables for Torchscript Optional refinement - """ - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] - new_encoder_out = ( - encoder_out.encoder_out - if encoder_out.encoder_out is None - else encoder_out.encoder_out.index_select(1, new_order) - ) - new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(0, new_order) - ) - new_encoder_embedding = ( - encoder_embedding - if encoder_embedding is None - else encoder_embedding.index_select(0, new_order) - ) - src_tokens = encoder_out.src_tokens - if src_tokens is not None: - src_tokens = src_tokens.index_select(0, new_order) + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] - src_lengths = encoder_out.src_lengths - if src_lengths is not None: - src_lengths = src_lengths.index_select(0, new_order) + if len(encoder_out["src_lengths"]) == 0: + src_lengths = [] + else: + src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] - encoder_states = encoder_out.encoder_states - if encoder_states is not None: + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) - return EncoderOut( - encoder_out=new_encoder_out, # T x B x C - encoder_padding_mask=new_encoder_padding_mask, # B x T - encoder_embedding=new_encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=src_tokens, # B x T - src_lengths=src_lengths, # B x 1 - ) + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "src_lengths": src_lengths, # B x 1 + } def max_positions(self): """Maximum input length supported by the encoder.""" @@ -664,7 +663,7 @@ def build_decoder_layer(self, args, no_encoder_attn=False): def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, @@ -706,7 +705,7 @@ def forward( def extract_features( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -723,14 +722,14 @@ def extract_features( """ A scriptable subclass of this class has an extract_features method and calls - super().extract_features, but super() is not supported in torchscript. Aa copy of + super().extract_features, but super() is not supported in torchscript. A copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -807,8 +806,15 @@ def extract_features_scriptable( x, layer_attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 603c5b6821..47a20296cf 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -11,7 +11,6 @@ from fairseq import search, utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from fairseq.models.fairseq_encoder import EncoderOut from torch import Tensor @@ -806,13 +805,13 @@ def forward_encoder(self, net_input: Dict[str, Tensor]): def forward_decoder( self, tokens, - encoder_outs: List[EncoderOut], + encoder_outs: List[Dict[str, List[Tensor]]], incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, ): log_probs = [] avg_attn: Optional[Tensor] = None - encoder_out: Optional[EncoderOut] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] @@ -868,7 +867,7 @@ def forward_decoder( return avg_probs, avg_attn @torch.jit.export - def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): + def reorder_encoder_out(self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order): """ Reorder encoder output according to *new_order*. @@ -879,7 +878,7 @@ def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_orde Returns: *encoder_out* rearranged according to *new_order* """ - new_outs: List[EncoderOut] = [] + new_outs: List[Dict[str, List[Tensor]]] = [] if not self.has_encoder(): return new_outs for i, model in enumerate(self.models): From d7dd683b3bebcd3b3249db44d7faa7b670e44b8f Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Mon, 16 Nov 2020 14:03:25 -0800 Subject: [PATCH 299/707] Add option to skip virtual epoch Summary: The current translation_multi_simple_epoch will add extrac layer of virtual epoch abstracts to load part of data and start training earlier. However, for smaller dataset this is not necessary. This diff makes it skip virtual epoch layer if --virtual-epoch-size is not specified. Reviewed By: pipibjc Differential Revision: D24962835 fbshipit-source-id: 7de4293a6996ed075a1ed0c1ff2de94c8ae3df14 --- .../multilingual/multilingual_data_manager.py | 37 ++++++++++++- .../tasks/translation_multi_simple_epoch.py | 18 ++++--- tests/test_binaries.py | 52 +++++++++++++++++++ 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 8c14f4e3ad..21fb23c047 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -236,7 +236,7 @@ def add_args(parser): ) parser.add_argument( "--virtual-epoch-size", - default=1000000, + default=None, type=int, help="virtual epoch size to speed up data loading", ) @@ -1040,3 +1040,38 @@ def load_sampled_multi_epoch_dataset( ) else: return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_sampled_multi_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiDataset( + OrderedDict(datasets), + epoch=epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + if self.args.virtual_epoch_size is None: + return self.load_sampled_multi_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) + else: + return self.load_sampled_multi_epoch_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index d871502a2c..34af9bf4a3 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -138,12 +138,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): """ if split in self.datasets: dataset = self.datasets[split] - if self.has_sharded_data(split) and dataset.load_next_shard: - shard_epoch = dataset.shard_epoch - else: - # no need to load next shard so skip loading - # also this avoid always loading from beginning of the data - return + if self.has_sharded_data(split): + if self.args.virtual_epoch_size is not None: + if dataset.load_next_shard: + shard_epoch = dataset.shard_epoch + else: + # no need to load next shard so skip loading + # also this avoid always loading from beginning of the data + return + else: + shard_epoch = epoch else: # estimate the shard epoch from virtual data size and virtual epoch size shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) @@ -153,7 +157,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): del self.datasets[split] logger.info("old dataset deleted manually") logger.info(f"mem usage: {data_utils.get_mem_usage()}") - self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( + self.datasets[split] = self.data_manager.load_dataset( split, self.training, epoch=epoch, diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 4d3393ae40..6dd95cb4a5 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -433,6 +433,58 @@ def test_translation_multi_simple_epoch(self): + dec_ltok_flag, ) + def test_translation_multi_simple_epoch_no_vepoch(self): + # test with all combinations of encoder/decoder lang tokens + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, extra_flags=[] + ) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + def test_translation_multi_simple_epoch_dicts(self): # test with all combinations of encoder/decoder lang tokens with contextlib.redirect_stdout(StringIO()): From 0e13e2fddedbe569f33167c3aa090cc1aa28a499 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 17 Nov 2020 12:52:02 -0800 Subject: [PATCH 300/707] Wav2vec hydra (#1439) Summary: convert wav2vec 1.0 model to hydra Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1439 Reviewed By: myleott Differential Revision: D25010596 Pulled By: alexeib fbshipit-source-id: eb3ae81e7dad4789b217fca9bb4c6413835d75ab --- .../model/wav2vec/vq_wav2vec_gumbel.yaml | 5 + fairseq/criterions/wav2vec_criterion.py | 39 +- fairseq/models/wav2vec/wav2vec.py | 491 +++++++----------- fairseq/models/wav2vec/wav2vec2.py | 2 +- 4 files changed, 223 insertions(+), 314 deletions(-) create mode 100644 fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml diff --git a/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml new file mode 100644 index 0000000000..ee1329bf46 --- /dev/null +++ b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml @@ -0,0 +1,5 @@ +# @package _group_ +activation: gelu +vq_type: gumbel +vq_depth: 2 +combine_groups: true diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 6ac7557dcc..3a58390088 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -4,33 +4,42 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field +from typing import List, Optional import torch import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass from fairseq.logging.meters import safe_round -@register_criterion("wav2vec") +@dataclass +class Wav2VecCriterionConfig(FairseqDataclass): + infonce: bool = field( + default=False, + metadata={ + "help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)" + }, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): super().__init__(task) self.infonce = infonce - self.loss_weights = None if loss_weights is None else eval(loss_weights) - self.log_keys = [] if log_keys is None else eval(log_keys) - - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--infonce', action='store_true', - help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)') - parser.add_argument('--loss-weights', type=str, default=None, - help='weights for additional loss terms (not first one)') - parser.add_argument('--log-keys', type=str, default=None, - help='output keys to log') - # fmt: on + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys def forward(self, model, sample, reduce=True, log_pred=False): """Compute the loss for the given sample. diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py index 772995b526..83b6461129 100644 --- a/fairseq/models/wav2vec/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -3,14 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field import logging import math +from typing import Optional, Tuple +from omegaconf import II import sys import torch import torch.nn as nn import torch.nn.functional as F -from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model from fairseq.modules import ( Fp32GroupNorm, Fp32LayerNorm, @@ -18,264 +22,208 @@ KmeansVectorQuantizer, TransposeLast, ) +from fairseq.tasks import FairseqTask from fairseq.utils import buffered_arange logger = logging.getLogger(__name__) -@register_model("wav2vec") -class Wav2VecModel(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - parser.add_argument( - "--prediction-steps", - type=int, - metavar="N", - help="number of steps ahead to predict", - ) - parser.add_argument( - "--sample-distance", - type=int, - metavar="N", - help="sample distance from target. does not work properly with cross-sampling", - ) - parser.add_argument( - "--cross-sample-negatives", - type=int, - metavar="N", - help="num of cross sampled negatives", - ) - parser.add_argument( - "--num-negatives", type=int, metavar="N", help="number of negative examples" - ) - parser.add_argument( - "--conv-feature-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - parser.add_argument( - "--conv-aggregator-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout to apply within the model", - ) - parser.add_argument( - "--dropout-features", - type=float, - metavar="D", - help="dropout to apply to the features", - ) - parser.add_argument( - "--dropout-agg", - type=float, - metavar="D", - help="dropout to apply after aggregation step", - ) - parser.add_argument( - "--encoder", type=str, choices=["cnn"], help="type of encoder to use" - ) - parser.add_argument( - "--aggregator", - type=str, - choices=["cnn", "gru"], - help="type of aggregator to use", - ) - parser.add_argument( - "--gru-dim", type=int, metavar="N", help="GRU dimensionality" - ) - - parser.add_argument( - "--no-conv-bias", - action="store_true", - help="if set, does not learn bias for conv layers", - ) - parser.add_argument( - "--agg-zero-pad", - action="store_true", - help="if set, zero pads in aggregator instead of repl pad", - ) +AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) +PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) +ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) +VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) - parser.add_argument( - "--skip-connections-feat", - action="store_true", - help="if set, adds skip connections to the feature extractor", - ) - parser.add_argument( - "--skip-connections-agg", - action="store_true", - help="if set, adds skip connections to the aggregator", - ) - parser.add_argument( - "--residual-scale", - type=float, - metavar="D", - help="scales residual by sqrt(value)", - ) - - parser.add_argument( - "--log-compression", - action="store_true", - help="if set, adds a log compression to feature extractor", - ) - - parser.add_argument( - "--balanced-classes", - action="store_true", - help="if set, loss is scaled to balance for number of negatives", - ) - parser.add_argument( - "--project-features", - choices=["none", "same", "new"], - help="if not none, features are projected using the (same or new) aggregator", - ) - - parser.add_argument( - "--non-affine-group-norm", - action="store_true", - help="if set, group norm is not affine", - ) - - parser.add_argument( - "--offset", - help="if set, introduces an offset from target to predictions. " - 'if set to "auto", it is computed automatically from the receptive field', - ) - - parser.add_argument( - "--activation", - type=str, - choices=["relu", "gelu"], - help="which activation function to use", - ) +@dataclass +class Wav2VecConfig(FairseqDataclass): + prediction_steps: int = field( + default=12, metadata={"help": "number of steps ahead to predict"} + ) + sample_distance: Optional[int] = field( + default=None, + metadata={ + "help": "sample distance from target. does not work properly with cross-sampling" + }, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "num of cross sampled negatives"} + ) + num_negatives: int = field( + default=10, metadata={"help": "num of cross sampled negatives"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", + metadata={ + "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" + }, + ) + conv_aggregator_layers: str = field( + default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", + metadata={ + "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" + }, + ) + dropout: float = field( + default=0.0, metadata={"help": "dropout to apply within the model"} + ) + dropout_features: float = field( + default=0.0, metadata={"help": "dropout to apply to the features"} + ) + dropout_agg: float = field( + default=0.0, metadata={"help": "dropout to apply after aggregation step"} + ) + aggregator: AGGREGATOR_CHOICES = field( + default="cnn", metadata={"help": "type of aggregator to use"} + ) + gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) + no_conv_bias: bool = field( + default=False, metadata={"help": "if set, does not learn bias for conv layers"} + ) + agg_zero_pad: bool = field( + default=False, + metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, + ) + skip_connections_feat: bool = field( + default=False, + metadata={"help": "if set, adds skip connections to the feature extractor"}, + ) + skip_connections_agg: bool = field( + default=True, + metadata={"help": "if set, adds skip connections to the aggregator"}, + ) + residual_scale: float = field( + default=0.5, metadata={"help": "scales residual by sqrt(value)"} + ) + log_compression: bool = field( + default=True, + metadata={"help": "if set, adds a log compression to feature extractor"}, + ) + balanced_classes: bool = field( + default=False, + metadata={"help": "if set, loss is scaled to balance for number of negatives"}, + ) + project_features: PROJECT_FEATURES_CHOICES = field( + default="none", + metadata={ + "help": "if not none, features are projected using the (same or new) aggregator" + }, + ) + non_affine_group_norm: bool = field( + default=False, metadata={"help": "if set, group norm is not affine"} + ) + offset: str = field( + default="auto", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + activation: ACTIVATION_CHOICES = field( + default="relu", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + vq_type: VQ_TYPE_CHOICES = field( + default="none", metadata={"help": "which type of quantizer to use"} + ) + vq_vars: int = field( + default=320, + metadata={"help": "project to this many vector quantized variables per group"}, + ) + vq_groups: int = field( + default=2, metadata={"help": "number of groups of latent variables"} + ) + vq_dim: int = field( + default=0, + metadata={ + "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" + }, + ) + vq_depth: int = field( + default=1, metadata={"help": "number of layers for vq weight projection"} + ) + combine_groups: bool = field( + default=False, metadata={"help": "if set, variables are shared among groups"} + ) + vq_temp: Tuple[float, float, float] = field( + default=(2.0, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" + }, + ) + vq_gamma: float = field( + default=0.25, + metadata={"help": "gamma parameter for kmeans style vector quantization"}, + ) + infonce: bool = II("criterion.infonce") - parser.add_argument( - "--vq-type", - type=str, - choices=["none", "gumbel", "kmeans"], - help="which type of quantizer to use", - ) - parser.add_argument( - "--vq-vars", - type=int, - metavar="N", - help="if set, project to this many vector quantized variables per group", - ) - parser.add_argument( - "--vq-groups", - type=int, - metavar="N", - help="number of groups of latent variables", - ) - parser.add_argument( - "--vq-dim", - type=int, - metavar="N", - help="uses this dimensionality for quantized vectors", - ) - parser.add_argument( - "--vq-depth", - type=int, - metavar="N", - help="number of layers for vq weight projection", - ) - parser.add_argument( - "--combine-groups", - action="store_true", - help="if set, variables are shared among groups", - ) - parser.add_argument( - "--vq-temp", - type=str, - metavar="TEMP", - help="temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)", - ) - parser.add_argument( - "--vq-gamma", - type=float, - metavar="D", - help="gamma parameter for kmeans style vector quantization", - ) +@register_model("wav2vec", dataclass=Wav2VecConfig) +class Wav2VecModel(BaseFairseqModel): @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): """Build a new model instance.""" - # make sure all arguments are present in older models - base_wav2vec_architecture(args) - - model = Wav2VecModel(args) + model = Wav2VecModel(cfg) logger.info(model) return model - def __init__(self, args): + def __init__(self, cfg: Wav2VecConfig): super().__init__() - self.prediction_steps = args.prediction_steps - offset = args.offset + self.prediction_steps = cfg.prediction_steps + offset = cfg.offset - if args.activation == "relu": + if cfg.activation == "relu": activation = nn.ReLU() - elif args.activation == "gelu": + elif cfg.activation == "gelu": activation = nn.GELU() else: - raise Exception("unknown activation " + args.activation) - - if args.encoder == "cnn": - feature_enc_layers = eval(args.conv_feature_layers) - self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - log_compression=args.log_compression, - skip_connections=args.skip_connections_feat, - residual_scale=args.residual_scale, - non_affine_group_norm=args.non_affine_group_norm, - activation=activation, - ) - embed = feature_enc_layers[-1][0] - else: - raise Exception("unknown encoder type " + args.encoder) + raise Exception("unknown activation " + cfg.activation) + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + log_compression=cfg.log_compression, + skip_connections=cfg.skip_connections_feat, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + activation=activation, + ) + embed = feature_enc_layers[-1][0] self.vector_quantizer = None - if args.vq_type == "gumbel": + if cfg.vq_type == "gumbel": self.vector_quantizer = GumbelVectorQuantizer( dim=embed, - num_vars=args.vq_vars, - temp=eval(args.vq_temp), - groups=args.vq_groups, - combine_groups=args.combine_groups, - vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + num_vars=cfg.vq_vars, + temp=cfg.vq_temp, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, activation=activation, - weight_proj_depth=args.vq_depth, + weight_proj_depth=cfg.vq_depth, weight_proj_factor=2, ) - elif args.vq_type == "kmeans": + elif cfg.vq_type == "kmeans": self.vector_quantizer = KmeansVectorQuantizer( dim=embed, - num_vars=args.vq_vars, - groups=args.vq_groups, - combine_groups=args.combine_groups, - vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + num_vars=cfg.vq_vars, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, - gamma=args.vq_gamma, + gamma=cfg.vq_gamma, ) else: assert ( - args.vq_type == "none" or args.vq_type is None + cfg.vq_type == "none" or cfg.vq_type is None ), "Unknown quantizer type" - if args.offset == "auto": - assert args.encoder == "cnn" + if cfg.offset == "auto": jin = 0 rin = 0 for _, k, stride in feature_enc_layers: @@ -291,34 +239,34 @@ def __init__(self, args): offset = int(offset) def make_aggregator(): - if args.aggregator == "cnn": - agg_layers = eval(args.conv_aggregator_layers) + if cfg.aggregator == "cnn": + agg_layers = eval(cfg.conv_aggregator_layers) agg_dim = agg_layers[-1][0] feature_aggregator = ConvAggegator( conv_layers=agg_layers, embed=embed, - dropout=args.dropout, - skip_connections=args.skip_connections_agg, - residual_scale=args.residual_scale, - non_affine_group_norm=args.non_affine_group_norm, - conv_bias=not args.no_conv_bias, - zero_pad=args.agg_zero_pad, + dropout=cfg.dropout, + skip_connections=cfg.skip_connections_agg, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + conv_bias=not cfg.no_conv_bias, + zero_pad=cfg.agg_zero_pad, activation=activation, ) - elif args.aggregator == "gru": - agg_dim = args.gru_dim + elif cfg.aggregator == "gru": + agg_dim = cfg.gru_dim feature_aggregator = nn.Sequential( TransposeLast(), nn.GRU( input_size=embed, hidden_size=agg_dim, num_layers=1, - dropout=args.dropout, + dropout=cfg.dropout, ), TransposeLast(deconstruct_idx=0), ) else: - raise Exception("unknown aggregator type " + args.aggregator) + raise Exception("unknown aggregator type " + cfg.aggregator) return feature_aggregator, agg_dim @@ -327,24 +275,24 @@ def make_aggregator(): self.wav2vec_predictions = Wav2VecPredictionsModel( in_dim=agg_dim, out_dim=embed, - prediction_steps=args.prediction_steps, - n_negatives=args.num_negatives, - cross_sample_negatives=args.cross_sample_negatives, - sample_distance=args.sample_distance, - dropout=args.dropout, + prediction_steps=cfg.prediction_steps, + n_negatives=cfg.num_negatives, + cross_sample_negatives=cfg.cross_sample_negatives, + sample_distance=cfg.sample_distance, + dropout=cfg.dropout, offset=offset, - balanced_classes=args.balanced_classes, - infonce=args.infonce, + balanced_classes=cfg.balanced_classes, + infonce=cfg.infonce, ) - self.dropout_feats = nn.Dropout(p=args.dropout_features) - self.dropout_agg = nn.Dropout(p=args.dropout_agg) + self.dropout_feats = nn.Dropout(p=cfg.dropout_features) + self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) - if args.project_features == "none": + if cfg.project_features == "none": self.project_features = None - elif args.project_features == "same": + elif cfg.project_features == "same": self.project_features = self.feature_aggregator - elif args.project_features == "new": + elif cfg.project_features == "new": self.project_features, _ = make_aggregator() def forward(self, source): @@ -680,56 +628,3 @@ def forward(self, x, y): labels = (labels, weights) return predictions, labels - - -@register_model_architecture("wav2vec", "wav2vec") -def base_wav2vec_architecture(args): - conv_feature_layers = "[(512, 10, 5)]" - conv_feature_layers += " + [(512, 8, 4)]" - conv_feature_layers += " + [(512, 4, 2)] * 3" - args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) - - args.conv_aggregator_layers = getattr( - args, "conv_aggregator_layers", "[(512, 3, 1)] * 9" - ) - - args.prediction_steps = getattr(args, "prediction_steps", 12) - args.num_negatives = getattr(args, "num_negatives", 1) - args.sample_distance = getattr(args, "sample_distance", None) - args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) - - args.dropout = getattr(args, "dropout", 0.0) - args.dropout_features = getattr(args, "dropout_features", 0.0) - args.dropout_agg = getattr(args, "dropout_agg", 0.0) - args.encoder = getattr(args, "encoder", "cnn") - args.aggregator = getattr(args, "aggregator", "cnn") - - args.skip_connections_feat = getattr(args, "skip_connections_feat", False) - args.skip_connections_agg = getattr(args, "skip_connections_agg", False) - args.residual_scale = getattr(args, "residual_scale", 0.5) - - args.gru_dim = getattr(args, "gru_dim", 512) - - args.no_conv_bias = getattr(args, "no_conv_bias", False) - args.agg_zero_pad = getattr(args, "agg_zero_pad", False) - - args.log_compression = getattr(args, "log_compression", False) - - args.balanced_classes = getattr(args, "balanced_classes", False) - args.infonce = getattr(args, "infonce", False) - args.project_features = getattr(args, "project_features", "none") - - args.non_affine_group_norm = getattr(args, "non_affine_group_norm", False) - - args.offset = getattr(args, "offset", "auto") - - args.activation = getattr(args, "activation", "relu") - - args.vq_type = getattr(args, "vq_type", "none") - args.vq_vars = getattr(args, "vq_vars", 320) - args.vq_groups = getattr(args, "vq_groups", 2) - args.vq_dim = getattr(args, "vq_dim", 0) - args.vq_depth = getattr(args, "vq_depth", 1) - args.combine_groups = getattr(args, "combine_groups", False) - args.vq_temp = getattr(args, "vq_temp", "(2.0, 0.5, 0.999995)") - args.vq_gamma = getattr(args, "vq_gamma", 0.25) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index e6fecdd4fe..a00dc4d915 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -14,7 +14,7 @@ from fairseq import utils from fairseq.data.data_utils import compute_mask_indices from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.models import BaseFairseqModel, register_model from fairseq.modules import ( Fp32GroupNorm, Fp32LayerNorm, From 43e379e59068fa472c847400b6c653c88b7ffd95 Mon Sep 17 00:00:00 2001 From: Shruti Bhosale Date: Tue, 17 Nov 2020 14:14:12 -0800 Subject: [PATCH 301/707] Fast Noisy Channel Online Decoding for Neural Machine Translation (#1436) Summary: This PR adds logic to generate translations using noisy channel decoding (i.e. with a channel model `P(source|target)` and language model `P(target)`, in addition to a direct model `P(target|source)` It also includes additional logic to make noisy channel decoding very fast, without much loss in accuracy. Most of the logic resides within `examples/fast_noisy_channel` - - `noisy_channel_translation.py`: Fairseq Task for noisy channel translation - `noisy_channel_sequence_generator.py`: Sequence Generator for noisy channel decoding - this contains the main logic for scoring the direct, channel and LM models at each step of beam search - `noisy_channel_beam_search.py`: A variant of beam search that chooses the top-K candidates based on the combined scores from the direct, channel and LM models TODO: add an integration test to ensure changes in the core fairseq files don't break the logic in fast_noisy_channel Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1436 Reviewed By: myleott, edunov Differential Revision: D24986498 Pulled By: shruti-bh fbshipit-source-id: 2ae3b7d68fe4a1cfb61493c363134ab7a16c8647 --- examples/fast_noisy_channel/README.md | 345 +++++++ examples/fast_noisy_channel/__init__.py | 8 + .../noisy_channel_beam_search.py | 71 ++ .../noisy_channel_sequence_generator.py | 842 ++++++++++++++++++ .../noisy_channel_translation.py | 127 +++ fairseq/data/dictionary.py | 3 +- 6 files changed, 1395 insertions(+), 1 deletion(-) create mode 100644 examples/fast_noisy_channel/README.md create mode 100644 examples/fast_noisy_channel/__init__.py create mode 100644 examples/fast_noisy_channel/noisy_channel_beam_search.py create mode 100644 examples/fast_noisy_channel/noisy_channel_sequence_generator.py create mode 100644 examples/fast_noisy_channel/noisy_channel_translation.py diff --git a/examples/fast_noisy_channel/README.md b/examples/fast_noisy_channel/README.md new file mode 100644 index 0000000000..a04151a796 --- /dev/null +++ b/examples/fast_noisy_channel/README.md @@ -0,0 +1,345 @@ +# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling + +## Introduction +- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical. +- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy. +- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy. + +## Noisy Channel Modeling + +[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`. +```P(y|x) = P(x|y) * P(y) / P(x)``` +- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model** +- `P(y)` is a **language model** over the target `y` +- `P(x)` is generally not modeled since it is constant for all `y`. + +We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`. + +During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores. + +```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )``` +- `t` - Target Prefix Length +- `s` - Source Length +- `λ1` - Channel Model Weight +- `λ2` - Language Model Weight + +The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search. + +This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source. + +### Training Translation Models and Language Models + +For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/translation) + +For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) + +### Generation with Language Model for German-English translation with fairseq + +Here are instructions to generate using a direct model and a target-side language model. + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt + +k2=10 +lenpen=0.16 +lm_wt=0.14 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --k2 ${k2} \ + --combine-method lm_only \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --gen-subset valid \ + --remove-bpe \ + --fp16 \ + --batch-size 10 +``` +### Noisy Channel Generation for German-English translation with fairseq + +Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling). + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +ch_model=en_de.big.seed4.pt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model} + +k2=10 +lenpen=0.21 +lm_wt=0.50 +bw_wt=0.30 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --channel-model ${ch_model} \ + --k2 ${k2} \ + --combine-method noisy_channel \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --ch-wt ${bw_wt} \ + --gen-subset test \ + --remove-bpe \ + --fp16 \ + --batch-size 1 +``` +## Fast Noisy Channel Modeling + +[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding - +- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`) + - This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model. + - Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose. +- Smaller output vocabulary size for the channel model (~30,000 -> ~1000) + - The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known. + - This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500` + - This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary + - This reduces the memory consumption needed to store channel model scores significantly +- Smaller number of candidates (`k2`) scored per beam + - This is specified by reducing the argument `--k2` + + +### Fast Noisy Channel Generation for German-English translation with fairseq + +Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`. + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +small_ch_model=en_de.base_1_1.seed4.pt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model} + +k2=3 +lenpen=0.23 +lm_wt=0.58 +bw_wt=0.26 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --channel-model ${small_ch_model} \ + --k2 ${k2} \ + --combine-method noisy_channel \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --ch-wt ${bw_wt} \ + --gen-subset test \ + --remove-bpe \ + --fp16 \ + --batch-size 50 \ + --channel-scoring-type src_vocab --top-k-vocab 500 +``` + +## Test Data Preprocessing + +For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script - + +```sh +FAIRSEQ=/path/to/fairseq +cd $FAIRSEQ +SCRIPTS=$FAIRSEQ/mosesdecoder/scripts +if [ ! -d "${SCRIPTS}" ]; then + echo 'Cloning Moses github repository (for tokenization scripts)...' + git clone https://github.com/moses-smt/mosesdecoder.git +fi +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl + +s=de +t=en +test=wmt18 + +mkdir -p data_dir + +# Tokenization +if [ $s == "ro" ] ; then + # Note: Get normalise-romanian.py and remove-diacritics.py from + # https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess + sacrebleu -t $test -l $s-$t --echo src | \ + $NORMALIZE -l $s | \ + python normalise-romanian.py | \ + python remove-diacritics.py | \ + $TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s +else + sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s +fi + +sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t + + +# Applying BPE +src_bpe_code=/path/to/source/language/bpe/code +tgt_bpe_code=/path/to/target/language/bpe/code +src_dict=/path/to/source/language/dict +tgt_dict=/path/to/target/language/dict + +FASTBPE=$FAIRSEQ/fastBPE +if [ ! -d "${FASTBPE}" ] ; then + git clone https://github.com/glample/fastBPE.git + # Follow compilation instructions at https://github.com/glample/fastBPE + g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast +fi + +${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code} +${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code} + +fairseq-preprocess -s $s -t $t \ + --testpref data_dir/bpe.$test.$s-$t \ + --destdir data_dir/binarized \ + --srcdict ${src_dict} \ + --tgtdict ${tgt_dict} +``` + +## Calculating BLEU + +```sh +DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl +cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t +``` + + +## Romanian-English Translation + +The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c)) + +The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling. + +### BPE Codes and Dictionary + +We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target. +||Path| +|----------|------| +| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) | +| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) | + +### Direct Models +For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture. + +| Seed | Model | +|----|----| +| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt) +| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt) +| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt) + +### Channel Models +For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/). +The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5. +| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 | +|----|----|----|----|----|----|----| +| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | +| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) | + +### Language Model +The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization. +| | Path | +|----|----| +| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) | +| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict) + +## German-English Translation + +### BPE Codes and Dictionaries + +| | Path| +|----------|------| +| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) | +| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K) +| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) | +| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) | + +### Direct Models +We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs. +We use the Transformer-Big architecture for the direct model. + +| Seed | Model | +|:----:|----| +| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt) +| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt) +| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt) + +### Channel Models + +We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs. + +| Model Size | Seed 4 | Seed 5 | Seed 6 | +|----|----|----|----| +| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) | +| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) | +| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) | +| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) | +| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) | +| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) | +| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) | +| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) | +| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) | +| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) | +| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) | +| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) | + +### Language Model +The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization. +| | Path | +|----|----| +| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) | +| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/) + + +## Citation + +```bibtex +@inproceedings{bhosale2020language, + title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling}, + author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli}, + booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)}, + year={2020}, +} + +@inproceedings{yee2019simple, + title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation}, + author={Yee, Kyra and Dauphin, Yann and Auli, Michael}, + booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, + pages={5700--5705}, + year={2019} +} +``` diff --git a/examples/fast_noisy_channel/__init__.py b/examples/fast_noisy_channel/__init__.py new file mode 100644 index 0000000000..9b248c3a24 --- /dev/null +++ b/examples/fast_noisy_channel/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import noisy_channel_translation # noqa +from . import noisy_channel_sequence_generator # noqa +from . import noisy_channel_beam_search # noqa diff --git a/examples/fast_noisy_channel/noisy_channel_beam_search.py b/examples/fast_noisy_channel/noisy_channel_beam_search.py new file mode 100644 index 0000000000..23869ebcd0 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_beam_search.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.search import Search + + +class NoisyChannelBeamSearch(Search): + + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.fw_scores_buf = None + self.lm_scores_buf = None + + def _init_buffers(self, t): + # super()._init_buffers(t) + if self.fw_scores_buf is None: + self.scores_buf = t.new() + self.indices_buf = torch.LongTensor().to(device=t.device) + self.beams_buf = torch.LongTensor().to(device=t.device) + self.fw_scores_buf = t.new() + self.lm_scores_buf = t.new() + + def combine_fw_bw(self, combine_method, fw_cum, bw, step): + if combine_method == "noisy_channel": + fw_norm = fw_cum.div(step + 1) + lprobs = bw + fw_norm + elif combine_method == "lm_only": + lprobs = bw + fw_cum + + return lprobs + + def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method): + self._init_buffers(fw_lprobs) + bsz, beam_size, vocab_size = fw_lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous() + bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous() + # nothing to add since we are at the first step + fw_lprobs_cum = fw_lprobs + + else: + # make probs contain cumulative scores for each hypothesis + raw_scores = (scores[:, :, step - 1].unsqueeze(-1)) + fw_lprobs_cum = (fw_lprobs.add(raw_scores)) + + combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step) + + # choose the top k according to the combined noisy channel model score + torch.topk( + combined_lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + out=(self.scores_buf, self.indices_buf), + ) + # save corresponding fw and lm scores + self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf) + self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf) + # Project back into relative indices and beams + self.beams_buf = self.indices_buf // vocab_size + self.indices_buf.fmod_(vocab_size) + return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf diff --git a/examples/fast_noisy_channel/noisy_channel_sequence_generator.py b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py new file mode 100644 index 0000000000..ea8fae98e8 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py @@ -0,0 +1,842 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import math +import numpy as np + +import torch +import torch.nn.functional as F +from torch import Tensor + +from .noisy_channel_beam_search import NoisyChannelBeamSearch +from fairseq.sequence_generator import EnsembleModel + + +class NoisyChannelSequenceGenerator(object): + def __init__( + self, + combine_method, + tgt_dict, + src_dict=None, + beam_size=1, + max_len_a=0, + max_len_b=200, + min_len=1, + len_penalty=1.0, + unk_penalty=0.0, + retain_dropout=False, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + normalize_scores=True, + channel_models=None, + k2=10, + ch_weight=1.0, + channel_scoring_type='log_norm', + top_k_vocab=0, + lm_models=None, + lm_dict=None, + lm_weight=1.0, + normalize_lm_scores_by_tgt_len=False, + ): + """Generates translations of a given source sentence, + using beam search with noisy channel decoding. + + Args: + combine_method (string, optional): Method to combine direct, LM and + channel model scores (default: None) + tgt_dict (~fairseq.data.Dictionary): target dictionary + src_dict (~fairseq.data.Dictionary): source dictionary + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + retain_dropout (bool, optional): use dropout when generating + (default: False) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + no_repeat_ngram_size (int, optional): Size of n-grams that we avoid + repeating in the generation (default: 0) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + channel_models (List[~fairseq.models.FairseqModel]): ensemble of models + translating from the target to the source + k2 (int, optional): Top K2 candidates to score per beam at each step (default:10) + ch_weight (int, optional): Weight associated with the channel model score + assuming that the direct model score has weight 1.0 (default: 1.0) + channel_scoring_type (str, optional): String specifying how to score + the channel model (default: 'log_norm') + top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or + `'src_vocab_batched'`, then this parameter specifies the number of + most frequent tokens to include in the channel model output vocabulary, + in addition to the source tokens in the input batch (default: 0) + lm_models (List[~fairseq.models.FairseqModel]): ensemble of models + generating text in the target language + lm_dict (~fairseq.data.Dictionary): LM Model dictionary + lm_weight (int, optional): Weight associated with the LM model score + assuming that the direct model score has weight 1.0 (default: 1.0) + normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores + by the target length? By default, we normalize the combination of + LM and channel model scores by the source length + """ + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.retain_dropout = retain_dropout + self.temperature = temperature + self.match_source_len = match_source_len + self.no_repeat_ngram_size = no_repeat_ngram_size + self.channel_models = channel_models + self.src_dict = src_dict + self.tgt_dict = tgt_dict + self.combine_method = combine_method + self.k2 = k2 + self.ch_weight = ch_weight + self.channel_scoring_type = channel_scoring_type + self.top_k_vocab = top_k_vocab + self.lm_models = lm_models + self.lm_dict = lm_dict + self.lm_weight = lm_weight + self.log_softmax_fn = torch.nn.LogSoftmax(dim=1) + self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len + + self.share_tgt_dict = (self.lm_dict == self.tgt_dict) + self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict) + + self.ch_scoring_bsz = 3072 + + assert temperature > 0, '--temperature must be greater than 0' + + self.search = NoisyChannelBeamSearch(tgt_dict) + + @torch.no_grad() + def generate( + self, + models, + sample, + prefix_tokens=None, + bos_token=None, + **kwargs + ): + """Generate a batch of translations. + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + """ + model = EnsembleModel(models) + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(model.models_size) + ], + ) + if not self.retain_dropout: + model.eval() + + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + src_tokens = encoder_input['src_tokens'] + src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + input_size = src_tokens.size() + # batch dimension goes first followed by source lengths + bsz = input_size[0] + src_len = input_size[1] + beam_size = self.beam_size + + if self.match_source_len: + max_len = src_lengths_no_eos.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + # exclude the EOS marker + model.max_decoder_positions() - 1, + ) + + # compute the encoder output for each beam + encoder_outs = model.forward_encoder(encoder_input) + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) + + src_lengths = encoder_input['src_lengths'] + # initialize buffers + scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) + lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0) + + scores_buf = scores.clone() + tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) + tokens_buf = tokens.clone() + tokens[:, 0] = self.eos if bos_token is None else bos_token + + # reorder source tokens so they may be used as a reference in generating P(S|T) + src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index) + + src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len) + src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1) + + attn, attn_buf = None, None + nonpad_idxs = None + + # The cands_to_ignore indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then the cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask + + # list of completed sentences + finalized = [[] for i in range(bsz)] + finished = [False for i in range(bsz)] + num_remaining_sent = bsz + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) + cand_offsets = torch.arange(0, cand_size).type_as(tokens) + + # helper function for allocating buffers on the fly + buffers = {} + + def buffer(name, type_of=tokens): # noqa + if name not in buffers: + buffers[name] = type_of.new() + return buffers[name] + + def is_finished(sent, step, unfin_idx): + """ + Check whether we've finished generation for a given sentence, by + comparing the worst score among finalized hypotheses to the best + possible score among unfinalized hypotheses. + """ + assert len(finalized[sent]) <= beam_size + if len(finalized[sent]) == beam_size: + return True + return False + + def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores): + """ + Finalize the given hypotheses at this step, while keeping the total + number of finalized hypotheses per sentence <= beam_size. + + Note: the input must be in the desired finalization order, so that + hypotheses that appear earlier in the input are preferred to those + that appear later. + + Args: + step: current time step + bbsz_idx: A vector of indices in the range [0, bsz*beam_size), + indicating which hypotheses to finalize + eos_scores: A vector of the same size as bbsz_idx containing + fw scores for each hypothesis + combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing + combined noisy channel scores for each hypothesis + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors + tokens_clone = tokens.index_select(0, bbsz_idx) + tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS + assert not tokens_clone.eq(self.eos).any() + tokens_clone[:, step] = self.eos + attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty + + cum_unfin = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + sents_seen = set() + for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())): + unfin_idx = idx // beam_size + sent = unfin_idx + cum_unfin[unfin_idx] + + sents_seen.add((sent, unfin_idx)) + + if self.match_source_len and step > src_lengths_no_eos[unfin_idx]: + score = -math.inf + + def get_hypo(): + + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i][nonpad_idxs[sent]] + _, alignment = hypo_attn.max(dim=0) + else: + hypo_attn = None + alignment = None + + return { + 'tokens': tokens_clone[i], + 'score': score, + 'attention': hypo_attn, # src_len x tgt_len + 'alignment': alignment, + 'positional_scores': pos_scores[i], + } + + if len(finalized[sent]) < beam_size: + finalized[sent].append(get_hypo()) + + newly_finished = [] + for sent, unfin_idx in sents_seen: + # check termination conditions for this sentence + if not finished[sent] and is_finished(sent, step, unfin_idx): + finished[sent] = True + newly_finished.append(unfin_idx) + return newly_finished + + def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k): + """Rescore the top k hypothesis from each beam using noisy channel modeling + Returns: + new_fw_lprobs: the direct model probabilities after pruning the top k + new_ch_lm_lprobs: the combined channel and language model probabilities + new_lm_lprobs: the language model probabilities after pruning the top k + """ + with torch.no_grad(): + lprobs_size = lprobs.size() + if prefix_tokens is not None and step < prefix_tokens.size(1): + probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] + cand_scores = torch.gather( + probs_slice, dim=1, + index=prefix_tokens[:, step].view(-1, 1).data + ).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1) + cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1) + + # need to calculate and save fw and lm probs for prefix tokens + fw_top_k = cand_scores + fw_top_k_idx = cand_indices + k = 1 + else: + # take the top k best words for every sentence in batch*beam + fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k) + eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0] + ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0) + src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype) + + if self.combine_method != "lm_only": + temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1) + not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index + cur_tgt_size = step+2 + + # add eos to all candidate sentences except those that already end in eos + eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1) + eos_tokens[eos_idx] = self.tgt_dict.pad_index + + if step == 0: + channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1) + else: + # move eos from beginning to end of target sentence + channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1) + + ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size)) + ch_input_lengths[eos_idx] = cur_tgt_size-1 + if self.channel_scoring_type == "unnormalized": + ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths) + ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) + del ch_encoder_output + ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:]) + ch_intermed_scores = ch_intermed_scores.float() + ch_intermed_scores *= not_padding.float() + ch_scores = torch.sum(ch_intermed_scores, dim=1) + elif self.channel_scoring_type == "k2_separate": + for k_idx in range(k): + k_eos_tokens = eos_tokens[k_idx::k, :] + if step == 0: + k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1) + else: + # move eos from beginning to end of target sentence + k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1) + k_ch_input_lengths = ch_input_lengths[k_idx::k] + k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens) + k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True) + k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2) + k_ch_intermed_scores *= not_padding.float() + ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1) + elif self.channel_scoring_type == "src_vocab": + ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths) + ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) + + del ch_encoder_output + ch_lprobs = normalized_scores_with_batch_vocab( + channel_model.decoder, + ch_decoder_output, src_tokens, k, bsz, beam_size, + self.src_dict.pad_index, top_k=self.top_k_vocab) + ch_scores = torch.sum(ch_lprobs, dim=1) + elif self.channel_scoring_type == "src_vocab_batched": + ch_bsz_size = temp_src_tokens_full.shape[0] + ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz)) + for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)): + end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size) + temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :] + channel_input_batch = channel_input[start_idx:end_idx, :] + ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx] + ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch) + ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True) + ch_lprobs_list[i] = normalized_scores_with_batch_vocab( + channel_model.decoder, + ch_decoder_output_batch, src_tokens, k, bsz, beam_size, + self.src_dict.pad_index, top_k=self.top_k_vocab, + start_idx=start_idx, end_idx=end_idx) + ch_lprobs = torch.cat(ch_lprobs_list, dim=0) + ch_scores = torch.sum(ch_lprobs, dim=1) + else: + ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full) + ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True) + ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1) + ch_intermed_scores *= not_padding.float() + ch_scores = torch.sum(ch_intermed_scores, dim=1) + + else: + cur_tgt_size = 0 + ch_scores = ch_scores.view(bsz*beam_size, k) + expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten() + + if self.share_tgt_dict: + lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k) + else: + new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm) + new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm) + lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k) + + lm_scores.add_(expanded_lm_prefix_scores) + ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size) + # initialize all as min value + new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_fw_lprobs[:, self.pad] = -math.inf + new_ch_lm_lprobs[:, self.pad] = -math.inf + new_lm_lprobs[:, self.pad] = -math.inf + + new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k) + new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores) + new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k)) + return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs + + def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size): + if self.channel_scoring_type == "unnormalized": + ch_scores = self.log_softmax_fn( + ch_scores.view(-1, self.beam_size * self.k2) + ).view(ch_scores.shape) + ch_scores = ch_scores * self.ch_weight + lm_scores1 = lm_scores1 * self.lm_weight + + if combine_type == "lm_only": + # log P(T|S) + log P(T) + ch_scores = lm_scores1.view(ch_scores.size()) + elif combine_type == "noisy_channel": + # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T) + if self.normalize_lm_scores_by_tgt_len: + ch_scores.div_(src_size) + lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size) + ch_scores.add_(lm_scores_norm) + # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T) + else: + ch_scores.add_(lm_scores1.view(ch_scores.size())) + ch_scores.div_(src_size) + + return ch_scores + + if self.channel_models is not None: + channel_model = self.channel_models[0] # assume only one channel_model model + else: + channel_model = None + + lm = EnsembleModel(self.lm_models) + lm_incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(lm.models_size) + ], + ) + + reorder_state = None + batch_idxs = None + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) + reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) + model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state) + + lm.reorder_incremental_state(lm_incremental_states, reorder_state) + + fw_lprobs, avg_attn_scores = model.forward_decoder( + tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature, + ) + + fw_lprobs[:, self.pad] = -math.inf # never select pad + fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2) + + # handle min and max length constraints + if step >= max_len: + fw_lprobs[:, :self.eos] = -math.inf + fw_lprobs[:, self.eos + 1:] = -math.inf + elif step < self.min_len: + fw_lprobs[:, self.eos] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if prefix_tokens is not None and step < prefix_tokens.size(1): + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_mask = prefix_toks.ne(self.pad) + + prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + fw_lprobs[prefix_mask] = -math.inf + fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs + ) + + prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + ch_lm_lprobs[prefix_mask] = -math.inf + ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs + ) + + prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + lm_lprobs[prefix_mask] = -math.inf + lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs + ) + + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + def replicate_first_beam(tensor, mask): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = replicate_first_beam(tokens, eos_mask_batch_dim) + scores = replicate_first_beam(scores, eos_mask_batch_dim) + + fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim) + ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim) + lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim) + + if self.no_repeat_ngram_size > 0: + # for each beam and batch sentence, generate a list of previous ngrams + gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] + for bbsz_idx in range(bsz * beam_size): + gen_tokens = tokens[bbsz_idx].tolist() + for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]): + gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ + gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] + + # Record attention scores + if avg_attn_scores is not None: + if attn is None: + attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) + attn_buf = attn.clone() + nonpad_idxs = src_tokens.ne(self.pad) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(fw_lprobs) + scores_buf = scores_buf.type_as(fw_lprobs) + + self.search.set_src_lengths(src_lengths_no_eos) + + if self.no_repeat_ngram_size > 0: + def calculate_banned_tokens(bbsz_idx): + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) + return gen_ngrams[bbsz_idx].get(ngram_index, []) + + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)] + else: + banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] + + for bbsz_idx in range(bsz * beam_size): + fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf + + combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step( + step, + fw_lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size), + lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos (except for candidates to be ignored) + eos_mask = cand_indices.eq(self.eos) + eos_mask[:, :beam_size] &= ~cands_to_ignore + + # only consider eos when it's among the top beam_size indices + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = set() + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + combined_noisy_channel_eos_scores = torch.masked_select( + combined_noisy_channel_scores[:, :beam_size], + mask=eos_mask[:, :beam_size], + ) + + # finalize hypo using channel model score + finalized_sents = finalize_hypos( + step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores) + + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = cand_indices.new_ones(bsz) + batch_mask[cand_indices.new(finalized_sents)] = 0 + batch_idxs = torch.nonzero(batch_mask).squeeze(-1) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs] + + fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs] + cand_indices = cand_indices[batch_idxs] + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths_no_eos = src_lengths_no_eos[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + scores_buf.resize_as_(scores) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens_buf.resize_as_(tokens) + src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze() + + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) + attn_buf.resize_as_(attn) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos or + # ignored hypos and values < cand_size indicate candidate + # active hypos. After this, the min values per row are the top + # candidate active hypos. + eos_mask[:, :beam_size] |= cands_to_ignore + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just the hypos + # with the smallest values in active_mask + active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore') + torch.topk( + active_mask, k=beam_size, dim=1, largest=False, + out=(new_cands_to_ignore, active_hypos) + ) + + # update cands_to_ignore to ignore any finalized hypos + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + assert (~cands_to_ignore).any(dim=1).all() + + active_bbsz_idx = buffer('active_bbsz_idx') + torch.gather( + cand_bbsz_idx, dim=1, index=active_hypos, + out=active_bbsz_idx, + ) + active_scores = torch.gather( + fw_lprobs_top_k, dim=1, index=active_hypos, + out=scores[:, step].view(bsz, beam_size), + ) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + torch.index_select( + tokens[:, :step + 1], dim=0, index=active_bbsz_idx, + out=tokens_buf[:, :step + 1], + ) + torch.gather( + cand_indices, dim=1, index=active_hypos, + out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], + ) + if step > 0: + torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx, + out=scores_buf[:, :step], + ) + torch.gather( + fw_lprobs_top_k, dim=1, index=active_hypos, + out=scores_buf.view(bsz, beam_size, -1)[:, :, step], + ) + torch.gather( + lm_lprobs_top_k, dim=1, index=active_hypos, + out=lm_prefix_scores.view(bsz, beam_size) + ) + + # copy attention for active hypotheses + if attn is not None: + torch.index_select( + attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, + out=attn_buf[:, :, :step + 2], + ) + + # swap buffers + tokens, tokens_buf = tokens_buf, tokens + scores, scores_buf = scores_buf, scores + if attn is not None: + attn, attn_buf = attn_buf, attn + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) + + return finalized + + +def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k): + with torch.no_grad(): + lm_lprobs, avg_attn_scores = model.forward_decoder( + input_tokens, encoder_outs=None, incremental_states=incremental_states, + ) + + lm_lprobs_size = lm_lprobs.size(0) + probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1) + + return probs_next_wrd + + +def make_dict2dict(old_dict, new_dict): + dict2dict_map = {} + for sym in old_dict.symbols: + dict2dict_map[old_dict.index(sym)] = new_dict.index(sym) + return dict2dict_map + + +def dict2dict(tokens, dict2dict_map): + if tokens.device == torch.device('cpu'): + tokens_tmp = tokens + else: + tokens_tmp = tokens.cpu() + return tokens_tmp.map_( + tokens_tmp, + lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)] + ).to(tokens.device) + + +def reorder_tokens(tokens, lengths, eos): + # reorder source tokens so they may be used as reference for P(S|T) + return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0) + + +def reorder_all_tokens(tokens, lengths, eos): + # used to reorder src tokens from [ .. ] to [ ...] + # so source tokens can be used to predict P(S|T) + return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)]) + + +def normalized_scores_with_batch_vocab( + model_decoder, features, target_ids, k, bsz, beam_size, + pad_idx, top_k=0, vocab_size_meter=None, start_idx=None, + end_idx=None, **kwargs): + """ + Get normalized probabilities (or log probs) from a net's output + w.r.t. vocab consisting of target IDs in the batch + """ + if model_decoder.adaptive_softmax is None: + weight = model_decoder.output_projection.weight + vocab_ids = torch.unique( + torch.cat( + (torch.unique(target_ids), torch.arange(top_k, device=target_ids.device)) + ) + ) + id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids)))) + mapped_target_ids = target_ids.cpu().apply_( + lambda x, id_map=id_map: id_map[x] + ).to(target_ids.device) + expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1) + if start_idx is not None and end_idx is not None: + expanded_target_ids = expanded_target_ids[start_idx:end_idx, :] + logits = F.linear(features, weight[vocab_ids, :]) + log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32) + intermed_scores = torch.gather( + log_softmax[:, :-1, :], + 2, + expanded_target_ids[:, 1:].unsqueeze(2), + ).squeeze() + not_padding = expanded_target_ids[:, 1:] != pad_idx + intermed_scores *= not_padding.float() + return intermed_scores + else: + raise ValueError("adaptive softmax doesn't work with " + + "`normalized_scores_with_batch_vocab()`") diff --git a/examples/fast_noisy_channel/noisy_channel_translation.py b/examples/fast_noisy_channel/noisy_channel_translation.py new file mode 100644 index 0000000000..b74bdfd456 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_translation.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.tasks.translation import TranslationTask +from fairseq.tasks.language_modeling import LanguageModelingTask +from fairseq import checkpoint_utils +import argparse +from fairseq.tasks import register_task +import torch + + +@register_task("noisy_channel_translation") +class NoisyChannelTranslation(TranslationTask): + """ + Rescore the top k candidates from each beam using noisy channel modeling + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + TranslationTask.add_args(parser) + # fmt: off + parser.add_argument('--channel-model', metavar='FILE', + help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.') + parser.add_argument('--combine-method', default='lm_only', + choices=['lm_only', 'noisy_channel'], + help="""method for combining direct and channel model scores. + lm_only: decode with P(T|S)P(T) + noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""") + parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False, + help='normalize lm score by target length instead of source length') + parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'], + help="Normalize bw scores with log softmax or return bw scores without log softmax") + parser.add_argument('--top-k-vocab', default=0, type=int, + help='top k vocab IDs to use with `src_vocab` in channel model scoring') + parser.add_argument('--k2', default=50, type=int, + help='the top k2 candidates to rescore with the noisy channel model for each beam') + parser.add_argument('--ch-wt', default=1, type=float, + help='weight for the channel model') + parser.add_argument('--lm-model', metavar='FILE', + help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side') + parser.add_argument('--lm-data', metavar='FILE', + help='path to lm model training data for target language, used to properly load LM with correct dictionary') + parser.add_argument('--lm-wt', default=1, type=float, + help='the weight of the lm in joint decoding') + # fmt: on + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if getattr(args, "score_reference", False): + raise NotImplementedError() + else: + from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator + use_cuda = torch.cuda.is_available() and not self.args.cpu + assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!' + assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs' + if self.args.channel_model is not None: + import copy + ch_args_task = copy.deepcopy(self.args) + tmp = ch_args_task.source_lang + ch_args_task.source_lang = ch_args_task.target_lang + ch_args_task.target_lang = tmp + ch_args_task._name = 'translation' + channel_task = TranslationTask.setup_task(ch_args_task) + + arg_dict = {} + arg_dict['task'] = 'language_modeling' + arg_dict['sample_break_mode'] = 'eos' + arg_dict['data'] = self.args.lm_data + arg_dict['output_dictionary_size'] = -1 + lm_args = argparse.Namespace(**arg_dict) + lm_task = LanguageModelingTask.setup_task(lm_args) + lm_dict = lm_task.output_dictionary + + if self.args.channel_model is not None: + channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task) + + for model in channel_models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if self.args.fp16: + model.half() + if use_cuda: + model.cuda() + else: + channel_models = None + + lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task) + + for model in lm_models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if self.args.fp16: + model.half() + if use_cuda: + model.cuda() + return NoisyChannelSequenceGenerator( + combine_method=self.args.combine_method, + tgt_dict=self.target_dictionary, + src_dict=self.source_dictionary, + beam_size=getattr(args, 'beam', 5), + max_len_a=getattr(args, 'max_len_a', 0), + max_len_b=getattr(args, 'max_len_b', 200), + min_len=getattr(args, 'min_len', 1), + len_penalty=getattr(args, 'lenpen', 1), + unk_penalty=getattr(args, 'unkpen', 0), + temperature=getattr(args, 'temperature', 1.), + match_source_len=getattr(args, 'match_source_len', False), + no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), + normalize_scores=(not getattr(args, 'unnormalized', False)), + channel_models=channel_models, + k2=getattr(self.args, 'k2', 50), + ch_weight=getattr(self.args, 'ch_wt', 1), + channel_scoring_type=self.args.channel_scoring_type, + top_k_vocab=self.args.top_k_vocab, + lm_models=lm_models, + lm_dict=lm_dict, + lm_weight=getattr(self.args, 'lm_wt', 1), + normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False), + ) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index e2df08e092..efb5f1542c 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -69,6 +69,7 @@ def string( escape_unk=False, extra_symbols_to_ignore=None, unk_string=None, + include_eos=False, ): """Helper for converting a tensor of token indices to a string. @@ -76,7 +77,7 @@ def string( """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( - self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore) + self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore, include_eos=include_eos) for t in tensor ) From 265791b727b664d4d7da3abd918a3f6fb70d7337 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 17 Nov 2020 17:07:23 -0800 Subject: [PATCH 302/707] fix loading ensembles (#1442) Summary: fixes loading ensembles. previous change used the state of the first model for all models in the ensemble Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1442 Reviewed By: chtran Differential Revision: D25035706 Pulled By: alexeib fbshipit-source-id: 9029999be0f1703efb1df20bec2890de59449f1f --- fairseq/checkpoint_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 5a0dc099b2..2bb055056e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -267,6 +267,8 @@ def load_model_ensemble( def load_model_ensemble_and_task( filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None ): + assert state is None or len(filenames) == 1 + from fairseq import tasks assert not ( @@ -303,6 +305,10 @@ def load_model_ensemble_and_task( model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) + + # reset state so it gets loaded for the next model in ensemble + state = None + ensemble.append(model) return ensemble, cfg, task From e931009a91c430a66583e80a91d1de9cea656bd2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 18 Nov 2020 10:31:19 -0800 Subject: [PATCH 303/707] Fix boundary condition in token_block_utils_fast.pyx (#1445) Summary: In cases where the item size in the underlying dataset is 0, it's possible that `remaining` is initialized to 0. We can update the assert to reflect this. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1445 Reviewed By: alexeib Differential Revision: D25054723 Pulled By: myleott fbshipit-source-id: 1bb73cce34e973f407436c442b698ce706d97359 --- fairseq/data/token_block_utils_fast.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx index 5a2f16ec34..08af4f3061 100644 --- a/fairseq/data/token_block_utils_fast.pyx +++ b/fairseq/data/token_block_utils_fast.pyx @@ -170,7 +170,7 @@ cdef class DatasetSearcher(object): self.current_offset += to_consume self.current_i += to_consume else: - assert remaining > 0 + assert remaining >= 0 self.current_i += remaining self.current_index += 1 self.current_offset = 0 From 41a61bd4e2835c7bed25cc9f52fe65714379322e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 18 Nov 2020 14:30:02 -0800 Subject: [PATCH 304/707] Add GitHub Action to build Python wheels (+ minor cleanup in build scripts) (#1447) Summary: Here's an example run in a forked repo: https://github.com/fairseq/fairseq/runs/1419699104 We can upload the wheels to PyPI to make `pip install fairseq` easier for folks. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1447 Reviewed By: lematt1991 Differential Revision: D25060753 Pulled By: myleott fbshipit-source-id: 9fdc28cc7c8a172daac668dd09684ec43e2ff11a --- .github/workflows/build.yml | 14 +++++++--- .github/workflows/build_wheels.yml | 41 ++++++++++++++++++++++++++++++ fairseq/tasks/audio_pretraining.py | 3 ++- setup.py | 19 +++++++++----- 4 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/build_wheels.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6ae8093a8a..a2d44dd57f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,26 +19,32 @@ jobs: runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Conditionally install pytorch if: matrix.platform == 'windows-latest' run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html + - name: Install locally run: | python -m pip install --upgrade pip + git submodule update --init --recursive python setup.py build_ext --inplace python -m pip install --editable . + - name: Lint with flake8 run: | pip install flake8 # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron + - name: Run tests run: | python setup.py test diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml new file mode 100644 index 0000000000..7261708596 --- /dev/null +++ b/.github/workflows/build_wheels.yml @@ -0,0 +1,41 @@ +name: build_wheels + +on: + push: + branches: + - v[0-9]+.[0-9]+.[x0-9]+ + tags: + - v* + +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v2 + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.7' + + - name: Install cibuildwheel + run: | + python -m pip install cibuildwheel + + - name: Build wheels for CPython + run: | + python -m cibuildwheel --output-dir dist + env: + CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64" + CIBW_MANYLINUX_X86_64_IMAGE: manylinux1 + CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install . + + - uses: actions/upload-artifact@v2 + with: + name: wheels + path: ./dist/*.whl diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 90e667c80d..a2f7edc34d 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,7 +5,6 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import editdistance import os import sys import torch @@ -212,6 +211,8 @@ def build_model(self, model_cfg: FairseqDataclass): return model def _inference_with_wer(self, generator, sample, model): + import editdistance + def decode(toks): s = self.target_dictionary.string( toks.int().cpu(), diff --git a/setup.py b/setup.py index 572d2b50de..2aae720d7e 100644 --- a/setup.py +++ b/setup.py @@ -132,7 +132,7 @@ def include_dirs(self, dirs): # use CPU build of PyTorch dependency_links = [ - "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" + "https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" ] else: dependency_links = [] @@ -149,6 +149,11 @@ def include_dirs(self, dirs): ) +extra_packages = [] +if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")): + extra_packages.append("fairseq.model_parallel.megatron.mpu") + + def do_setup(package_data): setup( name="fairseq", @@ -172,7 +177,6 @@ def do_setup(package_data): "cffi", "cython", "dataclasses", - "editdistance", "hydra-core", "numpy", "regex", @@ -190,7 +194,7 @@ def do_setup(package_data): "tests", "tests.*", ] - ), + ) + extra_packages, package_data=package_data, ext_modules=extensions, test_suite="tests", @@ -223,12 +227,13 @@ def get_files(path, relative_to="fairseq"): try: # symlink examples into fairseq package so package_data accepts them - if "build_ext" not in sys.argv[1:]: - os.symlink(os.path.join("..", "examples"), "fairseq/examples") + fairseq_examples = os.path.join("fairseq", "examples") + if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): + os.symlink(os.path.join("..", "examples"), fairseq_examples) package_data = { "fairseq": get_files("fairseq/examples"), } do_setup(package_data) finally: - if "build_ext" not in sys.argv[1:]: - os.unlink("fairseq/examples") + if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + os.unlink(fairseq_examples) From 6d2cf0ddf64040543c346b3866eb636d14522dde Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 18 Nov 2020 22:26:34 -0800 Subject: [PATCH 305/707] convert wav2vec2 asr to hydra (#1444) Summary: this completes wav2vec migration to hydra Test Plan: - training wav2vec2 models - training wav2vec2 ctc models - training wav2vec2 seq2seq models - infer.py eval of ctc models - generate.py eval of seq2seq models Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1444 Reviewed By: myleott Differential Revision: D25040041 Pulled By: alexeib fbshipit-source-id: 2aac3b9c659667f7e696628a4b016ee863da68cf --- examples/wav2vec/README.md | 62 +-- .../wav2vec/config/finetuning/base_100h.yaml | 59 ++ .../wav2vec/config/finetuning/base_10h.yaml | 64 +++ .../wav2vec/config/finetuning/base_10m.yaml | 64 +++ .../wav2vec/config/finetuning/base_1h.yaml | 64 +++ .../wav2vec/config/finetuning/base_960h.yaml | 58 ++ .../wav2vec/config/finetuning/vox_100h.yaml | 59 ++ .../wav2vec/config/finetuning/vox_10h.yaml | 64 +++ .../wav2vec/config/finetuning/vox_10m.yaml | 64 +++ .../wav2vec/config/finetuning/vox_1h.yaml | 64 +++ .../wav2vec/config/finetuning/vox_960h.yaml | 58 ++ .../wav2vec2_base_librispeech.yaml | 55 ++ .../pretraining/wav2vec2_large_librivox.yaml | 69 +++ .../config/model/wav2vec2/wav2vec2_base.yaml | 8 + .../config/model/wav2vec2/wav2vec2_large.yaml | 20 + fairseq/criterions/ctc.py | 106 ++-- fairseq/criterions/fairseq_criterion.py | 6 +- fairseq/dataclass/configs.py | 2 +- fairseq/dataclass/utils.py | 23 +- fairseq/models/wav2vec/wav2vec2.py | 4 +- fairseq/models/wav2vec/wav2vec2_asr.py | 518 ++++++++---------- .../lr_scheduler/polynomial_decay_schedule.py | 85 +-- .../lr_scheduler/tri_stage_lr_scheduler.py | 128 ++--- 23 files changed, 1209 insertions(+), 495 deletions(-) create mode 100644 examples/wav2vec/config/finetuning/base_100h.yaml create mode 100644 examples/wav2vec/config/finetuning/base_10h.yaml create mode 100644 examples/wav2vec/config/finetuning/base_10m.yaml create mode 100644 examples/wav2vec/config/finetuning/base_1h.yaml create mode 100644 examples/wav2vec/config/finetuning/base_960h.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_100h.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10h.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10m.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_960h.yaml create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml create mode 100644 fairseq/config/model/wav2vec2/wav2vec2_base.yaml create mode 100644 fairseq/config/model/wav2vec2/wav2vec2_large.yaml diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 1da42f388a..442a92553a 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -53,44 +53,27 @@ separately pre-processed manifest file. This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper -Note that this was tested with pytorch 1.4.0 and the input is expected to be single channel, sampled at 16 kHz +Note that the input is expected to be single channel, sampled at 16 kHz ```shell script -$ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \ ---save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ ---log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ ---conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \ ---latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \ ---adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 400000 \ ---lr 0.0005 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ ---encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ ---loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 \ ---max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ ---max-tokens 1400000 --max-update 400000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ +--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech ``` -Note: you can simulate 64 GPUs by using k GPUs and setting --update-freq 64/k +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k ### Train a wav2vec 2.0 large model: This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper ```shell script -$ python train.py --distributed-world-size 128 --distributed-port $PORT /manifest/path \ ---save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ ---log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ ---conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 768 --latent-vars 320 \ ---latent-groups 2 --latent-temp '(2.0,0.1,0.999995)' --infonce --optimizer adam \ ---adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 600000 \ ---lr 0.0003 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ ---encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.03 \ ---loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --encoder-layers 24 --encoder-embed-dim 1024 \ ---encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --num-negatives 100 --cross-sample-negatives 0 \ ---max-sample-size 320000 --min-sample-size 32000 --dropout 0.0 --attention-dropout 0.1 --weight-decay 0.01 \ ---max-tokens 1200000 --max-update 600000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ +--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_large_librivox ``` -Note: you can simulate 128 GPUs by using k GPUs and setting --update-freq 128/k +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k ### Fine-tune a pre-trained model with CTC: @@ -105,28 +88,19 @@ $ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $sp Fine-tuning on 100h of Librispeech with letter targets: ```shell script -valid_subset=dev_other -python train.py --distributed-world-size 24 --distributed-port $PORT /path/to/training_data --save-dir /model/path --fp16 \ ---wer-args '("/path/to/lm/4-gram.bin","/path/to/lexicon",2,-1)' \ ---post-process letter --valid-subset $valid_subset --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 \ ---max-update 80000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /path/to/pretrained/model \ ---labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ ---mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 --zero-infinity \ ---feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 --optimizer adam \ ---adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 --hold-steps 32000 \ ---decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 --activation-dropout 0.1 --criterion ctc \ ---attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json --log-interval 500 --ddp-backend no_c10d +python fairseq_cli/hydra_train.py distributed_training.distributed_port=$PORT task.data=/path/to/data \ +model.w2v_path=/path/to/model.pt --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \ +--config-name base_100h ``` -Note: you can simulate 24 GPUs by using k GPUs and setting --update-freq 24/k +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the --config-name parameter. -Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). -Alternatively, simply omit the --wer-args flag. - -For hyper-parameters to fine-tune other Librispeech splits (10 minutes, 1 hour, etc) please refer to the table in Appendix B in the wav2vec 2.0 paper. -The main changes to make are adjusting --max-update, and then adjusting --warmup-steps, --hold-steps, and --decay steps so that they use 0.1/0.4/0.5 of max-update respectively. You then need to adjust --mask-prob and --mask-channel-prob. This should be set to the mask-length * x where x is the number in the table and mask-length is what you use for --mask-length (10 in this example. Use --mask-channel-length value for --mask-channel-prob). +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before --config-path) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k -For example, for 10 hours, we see in the paper that timestep mask prob should be 0.065, so we set --mask-prob to 10* 0.065 = 0.65. channel mask prob is 0.004, so we set it to 64 * 0.004 = 0.256. then we set --max-updates to 20000 and change --warmup-steps to 20000 * 0.1 = 2000, --hold-steps to 8000 and --decay-steps to 10000. +Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. ### Evaluating a CTC model: diff --git a/examples/wav2vec/config/finetuning/base_100h.yaml b/examples/wav2vec/config/finetuning/base_100h.yaml new file mode 100644 index 0000000000..7d1664a184 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_100h.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 + diff --git a/examples/wav2vec/config/finetuning/base_10h.yaml b/examples/wav2vec/config/finetuning/base_10h.yaml new file mode 100644 index 0000000000..31125947c0 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_10h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.05 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_10m.yaml b/examples/wav2vec/config/finetuning/base_10m.yaml new file mode 100644 index 0000000000..2235504489 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_10m.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_1h.yaml b/examples/wav2vec/config/finetuning/base_1h.yaml new file mode 100644 index 0000000000..2235504489 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_1h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml new file mode 100644 index 0000000000..d742c94abf --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_960h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 320000 + lr: [0.00001] + sentence_avg: true + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.1 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 + diff --git a/examples/wav2vec/config/finetuning/vox_100h.yaml b/examples/wav2vec/config/finetuning/vox_100h.yaml new file mode 100644 index 0000000000..8885c78470 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_100h.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_10h.yaml b/examples/wav2vec/config/finetuning/vox_10h.yaml new file mode 100644 index 0000000000..c0957c0058 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_10m.yaml b/examples/wav2vec/config/finetuning/vox_10m.yaml new file mode 100644 index 0000000000..0d567552d7 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10m.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_1h.yaml b/examples/wav2vec/config/finetuning/vox_1h.yaml new file mode 100644 index 0000000000..10c45a52d8 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.0003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_960h.yaml b/examples/wav2vec/config/finetuning/vox_960h.yaml new file mode 100644 index 0000000000..6212a2e738 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_960h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 24 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 320000 + lr: [0.00003] + sentence_avg: true + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml new file mode 100644 index 0000000000..e2c2b7b0b3 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml @@ -0,0 +1,55 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + +dataset: + num_workers: 6 + max_tokens: 1400000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 64 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 10] + +optimization: + max_update: 400000 + lr: [0.0005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + final_dim: 256 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + feature_grad_mult: 0.1 diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml new file mode 100644 index 0000000000..0c911b7491 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml @@ -0,0 +1,69 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 768 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + + feature_grad_mult: 1.0 + diff --git a/fairseq/config/model/wav2vec2/wav2vec2_base.yaml b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml new file mode 100644 index 0000000000..ce65499b80 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +quantize_targets: true +final_dim: 256 +encoder_layerdrop: 0.05 +dropout_input: 0.1 +dropout_features: 0.1 +feature_grad_mult: 0.1 diff --git a/fairseq/config/model/wav2vec2/wav2vec2_large.yaml b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml new file mode 100644 index 0000000000..5846f75243 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml @@ -0,0 +1,20 @@ +# @package _group_ + +quantize_targets: true +extractor_mode: layer_norm +layer_norm_first: true +final_dim: 768 +latent_temp: [2.0,0.1,0.999995] +encoder_layerdrop: 0.0 +dropout_input: 0.0 +dropout_features: 0.0 +dropout: 0.0 +attention_dropout: 0.0 +conv_bias: true + +encoder_layers: 24 +encoder_embed_dim: 1024 +encoder_ffn_embed_dim: 4096 +encoder_attention_heads: 16 + +feature_grad_mult: 1.0 diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 6b77ce47eb..deab4f2650 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -6,39 +6,92 @@ import math from argparse import Namespace +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional, Tuple import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass from fairseq.data.data_utils import post_process +from fairseq.tasks import FairseqTask from fairseq.logging.meters import safe_round -@register_criterion("ctc") -class CtcCriterion(LegacyFairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) +@dataclass +class CtcCriterionConfig(FairseqDataclass): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + post_process: str = field( + default="letter", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + wer_kenlm_model: Optional[str] = field( + default=None, + metadata={ + "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)" + }, + ) + wer_lexicon: Optional[str] = field( + default=None, + metadata={"help": "lexicon to use with wer_kenlm_model"}, + ) + wer_lm_weight: float = field( + default=2.0, + metadata={"help": "lm weight to use with wer_kenlm_model"}, + ) + wer_word_score: float = field( + default=-1.0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + + wer_args: Optional[str] = field( + default=None, + metadata={ + "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)" + }, + ) + + +@register_criterion("ctc", dataclass=CtcCriterionConfig) +class CtcCriterion(FairseqCriterion): + def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): + super().__init__(task) self.blank_idx = task.target_dictionary.bos() self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = args.post_process if args.post_process else "letter" + self.post_process = cfg.post_process - if args.wer_args is not None: - from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + if cfg.wer_args is not None: + ( + cfg.wer_kenlm_model, + cfg.wer_lexicon, + cfg.wer_lm_weight, + cfg.wer_word_score, + ) = eval(cfg.wer_args) - wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) + if cfg.wer_kenlm_model is not None: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder dec_args = Namespace() dec_args.nbest = 1 dec_args.criterion = "ctc" - dec_args.kenlm_model = wer_compute_kenlm - dec_args.lexicon = wer_lexicon + dec_args.kenlm_model = cfg.wer_kenlm_model + dec_args.lexicon = cfg.wer_lexicon dec_args.beam = 50 dec_args.beam_size_token = min(50, len(task.target_dictionary)) dec_args.beam_threshold = min(50, len(task.target_dictionary)) - dec_args.lm_weight = lm_w - dec_args.word_score = ws_w + dec_args.lm_weight = cfg.wer_lm_weight + dec_args.word_score = cfg.wer_word_score dec_args.unk_weight = -math.inf dec_args.sil_weight = 0 @@ -46,31 +99,8 @@ def __init__(self, args, task): else: self.w2l_decoder = None - self.zero_infinity = args.zero_infinity - self.sentence_avg = args.sentence_avg - - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - parser.add_argument( - "--zero-infinity", action="store_true", help="zero inf loss" - ) - try: - parser.add_argument( - "--post-process", - "--remove-bpe", - default="letter", - help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)", - ) - except: - pass # this option might have been added from eval args - parser.add_argument( - "--wer-args", - type=str, - default=None, - help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \ - path to lexicon, lm score, word score", - ) + self.zero_infinity = cfg.zero_infinity + self.sentence_avg = cfg.sentence_avg def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index b2eda1a7e4..ff4beb0250 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List from fairseq import metrics, utils +from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass -from omegaconf import DictConfig from torch.nn.modules.loss import _Loss @@ -28,7 +28,7 @@ def add_args(cls, parser): gen_parser_from_dataclass(parser, dc()) @classmethod - def build_criterion(cls, cfg: DictConfig, task): + def build_criterion(cls, cfg: FairseqDataclass, task): """Construct a criterion from command-line args.""" # arguments in the __init__. init_args = {} @@ -46,6 +46,8 @@ def build_criterion(cls, cfg: DictConfig, task): if p.name == "task": init_args["task"] = task + elif p.name == "cfg": + init_args["cfg"] = cfg elif hasattr(cfg, p.name): init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index ec921a41d7..28dc8905c7 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -874,7 +874,7 @@ class InteractiveConfig(FairseqDataclass): @dataclass -class FairseqConfig(object): +class FairseqConfig(FairseqDataclass): common: CommonConfig = CommonConfig() common_eval: CommonEvalConfig = CommonEvalConfig() distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f8ed8f667f..a3a6c43281 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -5,6 +5,7 @@ import ast import os +import logging import re from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING @@ -15,8 +16,11 @@ from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig from hydra.experimental import compose, initialize +from hydra.core.global_hydra import GlobalHydra from omegaconf import DictConfig, OmegaConf, open_dict +logger = logging.getLogger(__name__) + def eval_str_list(x, x_type=float): if x is None: @@ -210,7 +214,8 @@ def get_default(f): isinstance(val, str) and not val.startswith("${") # not interpolation and field_type != str - and inspect.isclass(field_type) and not issubclass(field_type, Enum) # not choices enum + and inspect.isclass(field_type) + and not issubclass(field_type, Enum) # not choices enum ): # upgrade old models that stored complex parameters as string val = ast.literal_eval(val) @@ -233,6 +238,10 @@ def get_default(f): overrides.append("{}.{}='{}'".format(sub_node, k, val)) elif isinstance(val, FairseqDataclass): overrides += _override_attr(f"{sub_node}.{k}", type(val), args) + elif isinstance(val, Namespace): + sub_overrides, _ = override_module_args(val) + for so in sub_overrides: + overrides.append(f"{sub_node}.{k}.{so}") else: overrides.append("{}.{}={}".format(sub_node, k, val)) @@ -321,8 +330,15 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: # configs will be in fairseq/config after installation config_path = os.path.join("..", "config") + GlobalHydra.instance().clear() + with initialize(config_path=config_path): - composed_cfg = compose("config", overrides=overrides, strict=False) + try: + composed_cfg = compose("config", overrides=overrides, strict=False) + except: + logger.error("Error when composing. Overrides: " + str(overrides)) + raise + for k in deletes: composed_cfg[k] = None @@ -374,7 +390,8 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: def populate_dataclass( - dataclass: FairseqDataclass, args: Namespace, + dataclass: FairseqDataclass, + args: Namespace, ) -> FairseqDataclass: for k in dataclass.__dataclass_fields__.keys(): if k.startswith("_"): diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index a00dc4d915..783ebcfe6b 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -92,7 +92,7 @@ class Wav2Vec2Config(FairseqDataclass): default=False, metadata={"help": "apply layernorm first in the transformer"} ) conv_feature_layers: str = field( - default="[(512, 10, 5), (512, 8, 4)] + [(512, 4, 2)] * 3 + [(512, 1, 1)]", + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", metadata={ "help": "string describing convolutional feature extraction layers in form of a python list that contains " "[(dim, kernel_size, stride), ...]" @@ -147,7 +147,7 @@ class Wav2Vec2Config(FairseqDataclass): default=0, metadata={ "help": "secondary mask argument (used for more complex distributions), " - "see help in compute_mask_indicesh" + "see help in compute_mask_indices" }, ) no_mask_overlap: bool = field( diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index f62ec633b4..790b0a8ad1 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -7,166 +7,145 @@ import contextlib import copy import math - import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from dataclasses import dataclass, field +from omegaconf import MISSING, II, open_dict +from typing import Any + from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks import FairseqTask from fairseq.models import ( BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, FairseqIncrementalDecoder, register_model, - register_model_architecture, ) +from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer -def add_common_args(parser): - parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") - parser.add_argument( - "--no-pretrained-weights", - action="store_true", - help="if true, does not load pretrained weights", +@dataclass +class Wav2Vec2AsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to wav2vec 2.0 model"} ) - parser.add_argument( - "--dropout-input", - type=float, - metavar="D", - help="dropout to apply to the input (after feat extr)", + no_pretrained_weights: bool = field( + default=False, metadata={"help": "if true, does not load pretrained weights"} ) - parser.add_argument( - "--final-dropout", - type=float, - metavar="D", - help="dropout after transformer and before final projection", + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, ) - parser.add_argument( - "--apply-mask", action="store_true", help="apply masking during fine-tuning" + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, ) - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout probability inside wav2vec 2.0 model", + dropout: float = field( + default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"} ) - parser.add_argument( - "--attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights inside wav2vec 2.0 model", + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside wav2vec 2.0 model" + }, ) - parser.add_argument( - "--activation-dropout", - "--relu-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN inside wav2vec 2.0 model", + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" + }, ) - parser.add_argument( - "--mask-length", type=int, help="repeat the mask indices multiple times" + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} ) - - parser.add_argument( - "--mask-prob", type=float, help="probability of replacing a token with mask" + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask (normalized by length)" + }, ) - - parser.add_argument( - "--mask-other", - type=float, - help="stdev of the mask length in case of 'normal' selection strategy", + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} ) - - parser.add_argument( - "--no-mask-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - help="probability of replacing a token with mask", + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} ) - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} ) - - parser.add_argument( - "--mask-channel-other", - type=float, - help="stdev of the mask length in case of 'normal' selection strategy", + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} ) - - parser.add_argument( - "--no-mask-channel-overlap", - action="store_true", - help="whether to allow masks to overlap", + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, ) - - parser.add_argument( - "--freeze-finetune-updates", + mask_channel_other: float = field( default=0, - type=int, - help="dont finetune wav2vec for this many updates", + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, ) - - parser.add_argument( - "--feature-grad-mult", - default=None, - type=float, - help="reset feature grad mult in wav2vec 2.0 to this", + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} ) - - parser.add_argument( - "--layerdrop", - default=0.0, - type=float, - help="probability of dropping a layer in wav2vec 2.0", + freeze_finetune_updates: int = field( + default=0, metadata={"help": "dont finetune wav2vec for this many updates"} + ) + feature_grad_mult: float = field( + default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"} ) + layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + # this holds the loaded wav2vec args + w2v_args: Any = None -@register_model("wav2vec_ctc") -class Wav2VecCtc(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - add_common_args(parser) +@dataclass +class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): + pass - def __init__(self, w2v_encoder, args): + +@register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) +class Wav2VecCtc(BaseFairseqModel): + def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): super().__init__() + self.cfg = cfg self.w2v_encoder = w2v_encoder - self.args = args def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) return state_dict @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): """Build a new model instance.""" - base_architecture(args) - w2v_encoder = Wav2VecEncoder(args, task.target_dictionary) - return cls(w2v_encoder, args) + w2v_encoder = Wav2VecEncoder(cfg, task.target_dictionary) + return cls(cfg, w2v_encoder) def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" @@ -181,96 +160,67 @@ def forward(self, **kwargs): x = self.w2v_encoder(**kwargs) return x - # def max_positions(self): - # return None +@dataclass +class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) -@register_model("wav2vec_seq2seq") -class TransformerModel(FairseqEncoderDecoderModel): - def __init__(self, args, encoder, decoder): - super().__init__(encoder, decoder) - - @staticmethod - def add_args(parser): - add_common_args(parser) - - parser.add_argument( - "--decoder-embed-dim", - type=int, - metavar="N", - help="decoder embedding dimension", - ) - parser.add_argument( - "--decoder-ffn-embed-dim", - type=int, - metavar="N", - help="decoder embedding dimension for FFN", - ) - parser.add_argument( - "--decoder-layers", type=int, metavar="N", help="num decoder layers" - ) - parser.add_argument( - "--decoder-layerdrop", - type=float, - metavar="D", - help="decoder layerdrop chance", - ) - parser.add_argument( - "--decoder-attention-heads", - type=int, - metavar="N", - help="num decoder attention heads", - ) - parser.add_argument( - "--decoder-learned-pos", - action="store_true", - help="use learned positional embeddings in the decoder", - ) - parser.add_argument( - "--decoder-normalize-before", - action="store_true", - help="apply layernorm before each decoder block", - ) - parser.add_argument( - "--no-token-positional-embeddings", - default=False, - action="store_true", - help="if set, disables positional embeddings (outside self attention)", - ) - - parser.add_argument( - "--decoder-dropout", - type=float, - metavar="D", - help="dropout probability in the decoder", - ) - parser.add_argument( - "--decoder-attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights inside the decoder", - ) - parser.add_argument( - "--decoder-activation-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN inside the decoder", - ) - # fmt: on +@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) +class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" - # make sure all arguments are present in older models - base_architecture(args) - - if not hasattr(args, "max_source_positions"): - args.max_source_positions = 2048 - if not hasattr(args, "max_target_positions"): - args.max_target_positions = 2048 - src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): @@ -279,19 +229,20 @@ def build_embedding(dictionary, embed_dim): emb = Embedding(num_embeddings, embed_dim, padding_idx) return emb - decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + + encoder = cls.build_encoder(cfg) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) - encoder = cls.build_encoder(args) - decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - return TransformerModel(args, encoder, decoder) + return Wav2Vec2Seq2SeqModel(encoder, decoder) @classmethod - def build_encoder(cls, args): - return Wav2VecEncoder(args) + def build_encoder(cls, cfg: Wav2Vec2AsrConfig): + return Wav2VecEncoder(cfg) @classmethod - def build_decoder(cls, args, tgt_dict, embed_tokens): - return TransformerDecoder(args, tgt_dict, embed_tokens) + def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens): + return TransformerDecoder(cfg, tgt_dict, embed_tokens) def forward(self, **kwargs): encoder_out = self.encoder(tbc=False, **kwargs) @@ -304,52 +255,50 @@ def upgrade_state_dict_named(self, state_dict, name): class Wav2VecEncoder(FairseqEncoder): - def __init__(self, args, tgt_dict=None): - self.apply_mask = args.apply_mask + def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): + self.apply_mask = cfg.apply_mask arg_overrides = { - "dropout": args.dropout, - "activation_dropout": args.activation_dropout, - "dropout_input": args.dropout_input, - "attention_dropout": args.attention_dropout, - "mask_length": args.mask_length, - "mask_prob": args.mask_prob, - "mask_selection": args.mask_selection, - "mask_other": args.mask_other, - "no_mask_overlap": args.no_mask_overlap, - "mask_channel_length": args.mask_channel_length, - "mask_channel_prob": args.mask_channel_prob, - "mask_channel_selection": args.mask_channel_selection, - "mask_channel_other": args.mask_channel_other, - "no_mask_channel_overlap": args.no_mask_channel_overlap, - "encoder_layerdrop": args.layerdrop, - "feature_grad_mult": args.feature_grad_mult, + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, } - if getattr(args, "w2v_args", None) is None: - state = checkpoint_utils.load_checkpoint_to_cpu( - args.w2v_path, arg_overrides - ) + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) - args.w2v_args = w2v_args + cfg.w2v_args = w2v_args else: state = None - w2v_args = args.w2v_args + w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): - args.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) - assert ( - args.normalize == w2v_args.task.normalize - ), "Fine-tuning works best when data normalization is the same. " \ - "Please check that --normalize is set or unset for both" + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for both pre-training and here" + ) - w2v_args.task.data = args.data + w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) - if state is not None and not args.no_pretrained_weights: + if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() @@ -360,14 +309,14 @@ def __init__(self, args, tgt_dict=None): self.w2v_model = model - self.final_dropout = nn.Dropout(args.final_dropout) - self.freeze_finetune_updates = args.freeze_finetune_updates + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) - elif getattr(args, "decoder_embed_dim", d) != d: - self.proj = Linear(d, args.decoder_embed_dim) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None @@ -436,21 +385,26 @@ class TransformerDecoder(FairseqIncrementalDecoder): (default: False). """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + def __init__( + self, + cfg: Wav2Vec2Seq2SeqConfig, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): super().__init__(dictionary) - self.dropout = args.decoder_dropout - self.share_input_output_embed = args.share_decoder_input_output_embed + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim - embed_dim = args.decoder_embed_dim - self.output_embed_dim = args.decoder_embed_dim - args.encoder_embed_dim = embed_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim - self.layerdrop = args.decoder_layerdrop + self.layerdrop = cfg.decoder_layerdrop padding_idx = embed_tokens.padding_idx - self.max_target_positions = args.max_target_positions + self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim @@ -463,25 +417,31 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.embed_positions = ( PositionalEmbedding( - args.max_target_positions, + cfg.max_target_positions, embed_dim, padding_idx, - learned=args.decoder_learned_pos, + learned=cfg.decoder_learned_pos, ) - if not args.no_token_positional_embeddings + if not cfg.no_token_positional_embeddings else None ) - args = copy.deepcopy(args) - args.dropout = args.decoder_dropout - args.attention_dropout = args.decoder_attention_dropout - args.activation_dropout = args.decoder_activation_dropout + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) self.layers = nn.ModuleList([]) self.layers.extend( [ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) ] ) @@ -491,9 +451,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) - if args.decoder_normalize_before and not getattr( - args, "no_decoder_final_norm", False - ): + if transformer_cfg.decoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None @@ -633,51 +591,3 @@ def Linear(in_features, out_features, bias=True): if bias: nn.init.constant_(m.bias, 0.0) return m - - -@register_model_architecture("wav2vec_ctc", "wav2vec_ctc") -def base_architecture(args): - args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) - args.dropout_input = getattr(args, "dropout_input", 0) - args.final_dropout = getattr(args, "final_dropout", 0) - args.apply_mask = getattr(args, "apply_mask", False) - args.dropout = getattr(args, "dropout", 0) - args.attention_dropout = getattr(args, "attention_dropout", 0) - args.activation_dropout = getattr(args, "activation_dropout", 0) - - args.mask_length = getattr(args, "mask_length", 10) - args.mask_prob = getattr(args, "mask_prob", 0.5) - args.mask_selection = getattr(args, "mask_selection", "static") - args.mask_other = getattr(args, "mask_other", 0) - args.no_mask_overlap = getattr(args, "no_mask_overlap", False) - args.mask_channel_length = getattr(args, "mask_channel_length", 10) - args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) - args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") - args.mask_channel_other = getattr(args, "mask_channel_other", 0) - args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) - - args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) - args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) - args.layerdrop = getattr(args, "layerdrop", 0.0) - - -@register_model_architecture("wav2vec_seq2seq", "wav2vec_seq2seq") -def seq2seq_architecture(args): - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) - args.decoder_layers = getattr(args, "decoder_layers", 10) - args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.decoder_dropout = getattr(args, "decoder_dropout", 0) - args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) - args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) - args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", False - ) - - base_architecture(args) diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index 63adc740a9..be9c9aec1d 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -3,53 +3,60 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from . import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("polynomial_decay") -class PolynomialDecaySchedule(LegacyFairseqLRScheduler): + +@dataclass +class PolynomialDecayScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + end_learning_rate: float = field( + default=0.0, + metadata={"help": "learning rate to decay to"}, + ) + power: float = field( + default=1.0, + metadata={"help": "decay exponent"}, + ) + total_num_update: float = II("optimization.max_update") + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayScheduleConfig) +class PolynomialDecaySchedule(FairseqLRScheduler): """Decay the LR on a fixed schedule.""" - def __init__(self, args, optimizer): - super().__init__(args, optimizer) + cfg: PolynomialDecayScheduleConfig - # set defaults - args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 + def __init__(self, cfg: PolynomialDecayScheduleConfig, optimizer): + super().__init__(cfg, optimizer) - self.lr = args.lr[0] - if args.warmup_updates > 0: - self.warmup_factor = 1.0 / args.warmup_updates + assert cfg.total_num_update > 0 + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates else: self.warmup_factor = 1 - self.end_learning_rate = args.end_learning_rate - self.total_num_update = args.total_num_update - self.power = args.power + self.end_learning_rate = cfg.end_learning_rate + self.total_num_update = cfg.total_num_update + self.power = cfg.power self.optimizer.set_lr(self.warmup_factor * self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - parser.add_argument( - "--force-anneal", - "--fa", - type=int, - metavar="N", - help="force annealing at specified epoch", - ) - parser.add_argument( - "--warmup-updates", - default=0, - type=int, - metavar="N", - help="warmup the learning rate linearly for the first N updates", - ) - parser.add_argument("--end-learning-rate", default=0.0, type=float) - parser.add_argument("--power", default=1.0, type=float) - parser.add_argument("--total-num-update", default=1000000, type=int) - def get_next_lr(self, epoch): - lrs = self.args.lr - if self.args.force_anneal is None or epoch < self.args.force_anneal: + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: # use fixed LR schedule next_lr = lrs[min(epoch, len(lrs) - 1)] else: @@ -65,13 +72,13 @@ def step_begin_epoch(self, epoch): def step_update(self, num_updates): """Update the learning rate after each update.""" - if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: - self.warmup_factor = num_updates / float(self.args.warmup_updates) + if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates: + self.warmup_factor = num_updates / float(self.cfg.warmup_updates) lr = self.warmup_factor * self.lr elif num_updates >= self.total_num_update: lr = self.end_learning_rate else: - warmup = self.args.warmup_updates + warmup = self.cfg.warmup_updates lr_range = self.lr - self.end_learning_rate pct_remaining = 1 - (num_updates - warmup) / ( self.total_num_update - warmup diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index c573237f11..f0576d2d5f 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -5,11 +5,47 @@ import math -from . import LegacyFairseqLRScheduler, register_lr_scheduler - - -@register_lr_scheduler("tri_stage") -class TriStageLRSchedule(LegacyFairseqLRScheduler): +from dataclasses import dataclass, field +from typing import Optional, List, Tuple +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from . import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriStageLRScheduleConfig(FairseqDataclass): + warmup_steps: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + hold_steps: int = field( + default=0, + metadata={"help": "steps in hold stage"}, + ) + decay_steps: int = field( + default=0, + metadata={"help": "steps in decay stages"}, + ) + phase_ratio: Optional[Tuple[float, float, float]] = field( + default=None, + metadata={"help": "if set, automatically sets warmup/hold/decay steps to the ratio specified here " + "from max_updates. the ratios must add up to 1.0"}, + ) + init_lr_scale: float = field( + default=0.01, + metadata={"help": "initial learning rate scale during warmup phase"}, + ) + final_lr_scale: float = field( + default=0.01, + metadata={"help": "final learning rate scale"}, + ) + max_update: float = II("optimization.max_update") + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) +class TriStageLRScheduleConfig(FairseqLRScheduler): """Tristage learning rate schedulr Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf @@ -29,92 +65,60 @@ class TriStageLRSchedule(LegacyFairseqLRScheduler): During warmup:: - init_lr = args.init_lr_scale * args.lr - lrs = torch.linspace(init_lr, args.lr, args.warmup_steps) + init_lr = cfg.init_lr_scale * cfg.lr + lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps) lr = lrs[update_num] During hold:: - lr = args.lr + lr = cfg.lr During decay:: - decay_factor = - math.log(args.final_lr_scale) / args.decay_steps - lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps + lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) After that:: - lr = args.lr * args.final_lr_scale + lr = cfg.lr * cfg.final_lr_scale """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with tri-stage lr." " Consider --lr-scheduler=fixed instead." ) # calculate LR at each point - self.peak_lr = args.lr[0] - self.init_lr = args.init_lr_scale * args.lr[0] - self.final_lr = args.final_lr_scale * args.lr[0] - - # remember the steps at each stage - self.warmup_steps = args.warmup_steps - self.hold_steps = args.hold_steps - self.decay_steps = args.decay_steps + self.peak_lr = cfg.lr[0] + self.init_lr = cfg.init_lr_scale * cfg.lr[0] + self.final_lr = cfg.final_lr_scale * cfg.lr[0] + + if cfg.phase_ratio is not None: + assert sum(cfg.phase_ratio) == 1, 'phase ratios must add up to 1' + self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) + self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) + self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2]) + else: + self.warmup_steps = cfg.warmup_steps + self.hold_steps = cfg.hold_steps + self.decay_steps = cfg.decay_steps + + assert self.warmup_steps + self.hold_steps + self.decay_steps > 0, "please specify steps or phase_ratio" self.warmup_rate = ( (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0 ) - self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps + self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps # initial learning rate self.lr = self.init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument( - '--warmup-steps', - default=4000, - type=int, - metavar='N', - help='warmup the learning rate linearly for the first N updates' - ) - parser.add_argument( - '--hold-steps', - default=20000, - type=int, - metavar='N', - help='steps in hold stage.' - ) - parser.add_argument( - '--decay-steps', - default=60000, - type=int, - metavar='N', - help='steps in decay stages' - ) - parser.add_argument( - '--init-lr-scale', - default=0.01, - type=float, - help=""" - initial learning rate scale during warmup phase; default is 0.01""") - parser.add_argument( - '--final-lr-scale', - default=0.01, - type=float, - help="final learning rate scale; default to 0.01" - ) - # fmt: on - def _decide_stage(self, update_step): """ return stage, and the corresponding steps within the current stage From 6e280bff2aecf5e7fe92d46aa2507c6cae036fe3 Mon Sep 17 00:00:00 2001 From: Suyoun Kim Date: Thu, 19 Nov 2020 14:56:54 -0800 Subject: [PATCH 306/707] set lr for each epoch by using input dictionary parameter Summary: [Manual LR Scheduler] set lr for each epoch by using input dictionary parameter Differential Revision: D25047764 fbshipit-source-id: 4b8ccfa1b1f5db99d73fdb478caa2c6ea8d80a50 --- .../optim/lr_scheduler/manual_lr_scheduler.py | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 fairseq/optim/lr_scheduler/manual_lr_scheduler.py diff --git a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py new file mode 100644 index 0000000000..7e06ec55c8 --- /dev/null +++ b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler +import logging +import ast + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_lr_scheduler("manual") +class ManualSchedule(LegacyFairseqLRScheduler): + """Decay the LR on a manual schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + self.epoch2lr = self.parse_manuallr_args(args.epoch2lr) + self.update2lr = self.parse_manuallr_args(args.update2lr) + logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr)) + logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr)) + + if 1 in self.epoch2lr: + self.lr = self.epoch2lr[1] + elif 1 in self.update2lr: + self.lr = self.update2lr[1] + else: + self.lr = args.lr[0] + + def parse_manuallr_args(self, lr_args_str): + lr_dict = ast.literal_eval(lr_args_str) + if not isinstance(lr_dict, dict): + raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") + + lr_args = {} + logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict)) + for key, val in lr_dict.items(): + if "," in key: + for k in key.split(","): + lr_args[int(k)] = float(val) + elif "-" in key: + s = int(key.split("-")[0]) + e = int(key.split("-")[1]) + for k in range(s, e + 1, 1): + lr_args[k] = float(val) + else: + lr_args[int(key)] = float(val) + + return lr_args + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + "--epoch2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each epoch manually", + ) + parser.add_argument( + "--update2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each update manually", + ) + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + manual_keys = [k for k in self.epoch2lr if k <= epoch] + if manual_keys: + manual_lr = self.epoch2lr[max(manual_keys)] + else: + logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}".format(epoch, self.epoch2lr)) + manual_lr = self.optimizer.get_lr() + return manual_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + manual_keys = [k for k in self.update2lr if k <= num_updates] + if manual_keys: + manual_lr = self.update2lr[max(manual_keys)] + else: + logger.warning("epoch={} does not exist in manual lr input update2lr={}".format(num_updates, self.update2lr)) + manual_lr = self.optimizer.get_lr() + + self.optimizer.set_lr(manual_lr) + return self.optimizer.get_lr() From 7171cdec5b15b983764dacbc618c572546a8a692 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 19 Nov 2020 15:02:18 -0800 Subject: [PATCH 307/707] fix interpolated fields not being added to argparse (#1450) Summary: we were skipping fields that are interpolated from being added to argparse, but now we add them if they have their own help. this affected --total-num-updates for polynomial decay Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1450 Reviewed By: huihuifan Differential Revision: D25098158 Pulled By: alexeib fbshipit-source-id: 105bc67cb5ddfc86475f3b50c5d1b5cc00330d85 --- fairseq/dataclass/utils.py | 6 +++++- fairseq/optim/lr_scheduler/polynomial_decay_schedule.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index a3a6c43281..15b87b9e4d 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -157,7 +157,11 @@ def get_kwargs_from_dc( if isinstance(kwargs["default"], str) and kwargs["default"].startswith( "${" ): - continue + if kwargs["help"] is None: + # this is a field with a name that will be added elsewhere + continue + else: + del kwargs["default"] if delete_default: del kwargs["default"] try: diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index be9c9aec1d..2f6ae25c88 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -29,7 +29,10 @@ class PolynomialDecayScheduleConfig(FairseqDataclass): default=1.0, metadata={"help": "decay exponent"}, ) - total_num_update: float = II("optimization.max_update") + total_num_update: float = field( + default=II("optimization.max_update"), + metadata={"help": "total number of updates over which to decay learning rate"}, + ) lr: List[float] = II("optimization.lr") From 40fbb3744304de0eaa164fc84dd736d9a202a427 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 05:59:25 -0800 Subject: [PATCH 308/707] Migrate remaining LR schedulers (#1448) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1448 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25092150 Pulled By: myleott fbshipit-source-id: fd066a0eba388bb0c344082a8fa1132974d53d40 --- fairseq/dataclass/utils.py | 6 +- .../optim/lr_scheduler/cosine_lr_scheduler.py | 35 +++--- .../lr_scheduler/fairseq_lr_scheduler.py | 3 +- fairseq/optim/lr_scheduler/fixed_schedule.py | 65 ++++++----- .../inverse_square_root_schedule.py | 35 +++--- .../lr_scheduler/polynomial_decay_schedule.py | 12 +- .../lr_scheduler/reduce_lr_on_plateau.py | 110 +++++++++++------- .../lr_scheduler/tri_stage_lr_scheduler.py | 22 ++-- .../lr_scheduler/triangular_lr_scheduler.py | 61 +++++----- 9 files changed, 191 insertions(+), 158 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 15b87b9e4d..694d878308 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -227,7 +227,11 @@ def get_default(f): if isinstance(val, tuple): val = list(val) - if getattr(v.type, "__origin__", None) is List: + if ( + getattr(v.type, "__origin__", None) is List + # skip interpolation + and not (isinstance(val, str) and val.startswith("${")) + ): # if type is int but val is float, then we will crash later - try to convert here t_args = v.type.__args__ if len(t_args) == 1: diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 646ac66be9..ef8645cd58 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -8,14 +8,14 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass import FairseqDataclass -from omegaconf import II, DictConfig +from omegaconf import II -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass -class CosineConfig(FairseqDataclass): +class CosineLRScheduleConfig(FairseqDataclass): warmup_updates: int = field( default=0, metadata={"help": "warmup the learning rate linearly for the first N updates"}, @@ -23,11 +23,11 @@ class CosineConfig(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) max_lr: float = field( - default=1.0, metadata={"help": "max learning rate, must be more than args.lr"} + default=1.0, metadata={"help": "max learning rate, must be more than cfg.lr"} ) t_mult: float = field( default=1.0, metadata={"help": "factor to grow the length of each period"} @@ -38,13 +38,12 @@ class CosineConfig(FairseqDataclass): lr_shrink: float = field( default=0.1, metadata={"help": "shrink factor for annealing"} ) - # TODO common var for parent class lr: List[float] = II("optimization.lr") max_update: int = II("optimization.max_update") -@register_lr_scheduler("cosine", dataclass=CosineConfig) -class CosineSchedule(FairseqLRScheduler): +@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig) +class CosineLRSchedule(FairseqLRScheduler): """Assign LR based on a cyclical schedule that follows the cosine function. See https://arxiv.org/pdf/1608.03983.pdf for details. @@ -55,7 +54,7 @@ class CosineSchedule(FairseqLRScheduler): During warmup:: - lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) lr = lrs[update_num] After warmup:: @@ -67,9 +66,7 @@ class CosineSchedule(FairseqLRScheduler): after every iteration. """ - def __init__( - self, cfg: DictConfig, fairseq_optimizer - ): + def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): super().__init__(cfg, fairseq_optimizer) if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( @@ -78,11 +75,7 @@ def __init__( ) warmup_end_lr = cfg.max_lr - lr = ( - cfg.lr[0] - if isinstance(cfg.lr, Collection) - else cfg.lr - ) + lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr if cfg.warmup_init_lr < 0: cfg.warmup_init_lr = lr @@ -100,10 +93,8 @@ def __init__( self.period = cfg.max_update - cfg.warmup_updates if cfg.warmup_updates > 0: - # linearly warmup for the first args.warmup_updates - self.lr_step = ( - warmup_end_lr - cfg.warmup_init_lr - ) / cfg.warmup_updates + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates else: self.lr_step = 1 diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index d0ac115829..dd75dc5e30 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -6,8 +6,7 @@ from argparse import Namespace from fairseq.dataclass.utils import gen_parser_from_dataclass - -from .. import FairseqOptimizer +from fairseq.optim import FairseqOptimizer class FairseqLRScheduler(object): diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index e91ba86f8c..d0e7e14b7e 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -3,37 +3,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("fixed") -class FixedSchedule(LegacyFairseqLRScheduler): - """Decay the LR on a fixed schedule.""" - def __init__(self, args, optimizer): - super().__init__(args, optimizer) +@dataclass +class FixedLRScheduleConfig(FairseqDataclass): + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + lr_shrink: float = field( + default=0.1, + metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + lr: List[float] = II("optimization.lr") + - # set defaults - args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 +@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig) +class FixedLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" - self.lr = args.lr[0] - if args.warmup_updates > 0: - self.warmup_factor = 1.0 / args.warmup_updates + def __init__(self, cfg: FixedLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates else: self.warmup_factor = 1 - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', - help='force annealing at specified epoch (epochs start at 1)') - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing, lr_new = (lr * lr_shrink)') - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - # fmt: on - def state_dict(self): return {"lr": self.lr} @@ -42,14 +49,14 @@ def load_state_dict(self, state_dict): self.lr = state_dict["lr"] def get_next_lr(self, epoch): - lrs = self.args.lr - if self.args.force_anneal is None or epoch < self.args.force_anneal: + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: # use fixed LR schedule next_lr = lrs[min(epoch - 1, len(lrs) - 1)] else: # annneal based on lr_shrink - next_lr = lrs[-1] * self.args.lr_shrink ** ( - epoch + 1 - self.args.force_anneal + next_lr = lrs[-1] * self.cfg.lr_shrink ** ( + epoch + 1 - self.cfg.force_anneal ) return next_lr @@ -61,8 +68,8 @@ def step_begin_epoch(self, epoch): def step_update(self, num_updates): """Update the learning rate after each update.""" - if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: - self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) + if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates) self.optimizer.set_lr(self.warmup_factor * self.lr) else: self.optimizer.set_lr(self.lr) diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index c42e090677..d9321577bb 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -7,14 +7,14 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass import FairseqDataclass -from omegaconf import II, DictConfig +from omegaconf import II -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass -class InverseSquareRootScheduleConfig(FairseqDataclass): +class InverseSquareRootLRScheduleConfig(FairseqDataclass): warmup_updates: int = field( default=4000, metadata={"help": "warmup the learning rate linearly for the first N updates"}, @@ -22,14 +22,13 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) - # TODO common vars at parent class lr: List[float] = II("optimization.lr") -@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig) class InverseSquareRootSchedule(FairseqLRScheduler): """Decay the LR based on the inverse square root of the update number. @@ -40,36 +39,28 @@ class InverseSquareRootSchedule(FairseqLRScheduler): During warmup:: - lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) lr = lrs[update_num] After warmup:: - decay_factor = args.lr * sqrt(args.warmup_updates) + decay_factor = cfg.lr * sqrt(cfg.warmup_updates) lr = decay_factor / sqrt(update_num) """ - def __init__(self, cfg: DictConfig, optimizer): + def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer): super().__init__(cfg, optimizer) if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with inverse_sqrt." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = ( - cfg.lr[0] - if isinstance(cfg.lr, Collection) - else cfg.lr - ) + warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr if cfg.warmup_init_lr < 0: - cfg.warmup_init_lr = ( - 0 if cfg.warmup_updates > 0 else warmup_end_lr - ) + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr - # linearly warmup for the first args.warmup_updates - self.lr_step = ( - warmup_end_lr - cfg.warmup_init_lr - ) / cfg.warmup_updates + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates # then, decay prop. to the inverse square root of the update number self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index 2f6ae25c88..b8109a7c1e 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -8,11 +8,11 @@ from omegaconf import II from fairseq.dataclass import FairseqDataclass -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass -class PolynomialDecayScheduleConfig(FairseqDataclass): +class PolynomialDecayLRScheduleConfig(FairseqDataclass): warmup_updates: int = field( default=0, metadata={"help": "warmup the learning rate linearly for the first N updates"}, @@ -36,13 +36,11 @@ class PolynomialDecayScheduleConfig(FairseqDataclass): lr: List[float] = II("optimization.lr") -@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayScheduleConfig) -class PolynomialDecaySchedule(FairseqLRScheduler): +@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig) +class PolynomialDecayLRSchedule(FairseqLRScheduler): """Decay the LR on a fixed schedule.""" - cfg: PolynomialDecayScheduleConfig - - def __init__(self, cfg: PolynomialDecayScheduleConfig, optimizer): + def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer): super().__init__(cfg, optimizer) assert cfg.total_num_update > 0 diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 82bb36efe9..6e29ba79b6 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -3,13 +3,59 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import List + import torch.optim.lr_scheduler +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + -from . import LegacyFairseqLRScheduler, register_lr_scheduler +@dataclass +class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass): + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + lr_threshold: float = field( + default=1e-4, + metadata={ + "help": ( + "threshold for measuring the new optimum, to only focus on " + "significant changes" + ) + }, + ) + lr_patience: int = field( + default=0, + metadata={ + "help": ( + "number of epochs with no improvement after which learning rate will " + "be reduced" + ) + }, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II( + "checkpoint.maximize_best_checkpoint_metric" + ) -@register_lr_scheduler("reduce_lr_on_plateau") -class ReduceLROnPlateau(LegacyFairseqLRScheduler): +@register_lr_scheduler( + "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig +) +class ReduceLROnPlateauLRSchedule(FairseqLRScheduler): """ Decay the LR by a factor every time the validation loss plateaus. Also comes with optional warmup phase, where we linearly increase @@ -21,61 +67,43 @@ class ReduceLROnPlateau(LegacyFairseqLRScheduler): During warmup:: lrs = torch.linspace( - args.warmup_init_lr, args.lr, args.warmup_updates + cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates ) lr = lrs[update_num] """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." " Consider --lr-scheduler=fixed instead." ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, - patience=args.lr_patience, - factor=args.lr_shrink, - mode="max" if args.maximize_best_checkpoint_metric else "min", - threshold=args.lr_threshold, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, ) - warmup_end_lr = args.lr[0] - # if no warm up, sets initial lr to be args.lr[0] - if args.warmup_init_lr < 0: - args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + warmup_end_lr = cfg.lr[0] + # if no warm up, sets initial lr to be cfg.lr[0] + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr - # linearly warmup for the first args.warmup_updates - if args.warmup_updates > 0: - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + # linearly warmup for the first cfg.warmup_updates + if cfg.warmup_updates > 0: + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates # this flag is either set from arg when no warm up, or set by # step_update() when warmup finishes - self.warmup_end = True if args.warmup_updates <= 0 else False + self.warmup_end = True if cfg.warmup_updates <= 0 else False # initial learning rate # this self.lr is used only during init and/or warm up period - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing, lr_new = (lr * lr_shrink)') - parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', - help='threshold for measuring the new optimum, ' - 'to only focus on significant changes') - parser.add_argument('--lr-patience', default=0, type=int, - help='number of epochs with no improvement after which ' - 'learning rate will be reduced') - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', - help='initial learning rate during warmup phase; default is args.lr') - # fmt: on - def state_dict(self): """Return the LR scheduler state dict.""" return { @@ -104,9 +132,9 @@ def step_update(self, num_updates): """ Update the learning rate after each update.""" # if there is warmup - if self.args.warmup_updates > 0: - if num_updates <= self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if self.cfg.warmup_updates > 0: + if num_updates <= self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step self.optimizer.set_lr(self.lr) else: if self.warmup_end is False: diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index f0576d2d5f..403de77c80 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -4,13 +4,12 @@ # LICENSE file in the root directory of this source tree. import math - from dataclasses import dataclass, field from typing import Optional, List, Tuple from omegaconf import II from fairseq.dataclass import FairseqDataclass -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass @@ -29,8 +28,12 @@ class TriStageLRScheduleConfig(FairseqDataclass): ) phase_ratio: Optional[Tuple[float, float, float]] = field( default=None, - metadata={"help": "if set, automatically sets warmup/hold/decay steps to the ratio specified here " - "from max_updates. the ratios must add up to 1.0"}, + metadata={ + "help": ( + "if set, automatically sets warmup/hold/decay steps to the ratio " + "specified here from max_updates. the ratios must add up to 1.0" + ) + }, ) init_lr_scale: float = field( default=0.01, @@ -42,7 +45,7 @@ class TriStageLRScheduleConfig(FairseqDataclass): ) max_update: float = II("optimization.max_update") lr: List[float] = II("optimization.lr") - + @register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) class TriStageLRScheduleConfig(FairseqLRScheduler): @@ -90,6 +93,7 @@ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): "Cannot use a fixed learning rate schedule with tri-stage lr." " Consider --lr-scheduler=fixed instead." ) + assert cfg.max_update > 0 # calculate LR at each point self.peak_lr = cfg.lr[0] @@ -97,7 +101,7 @@ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): self.final_lr = cfg.final_lr_scale * cfg.lr[0] if cfg.phase_ratio is not None: - assert sum(cfg.phase_ratio) == 1, 'phase ratios must add up to 1' + assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1" self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2]) @@ -105,8 +109,10 @@ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): self.warmup_steps = cfg.warmup_steps self.hold_steps = cfg.hold_steps self.decay_steps = cfg.decay_steps - - assert self.warmup_steps + self.hold_steps + self.decay_steps > 0, "please specify steps or phase_ratio" + + assert ( + self.warmup_steps + self.hold_steps + self.decay_steps > 0 + ), "please specify steps or phase_ratio" self.warmup_rate = ( (self.peak_lr - self.init_lr) / self.warmup_steps diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index 0f3193f2b8..bfe2a0d381 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -4,52 +4,61 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field +from typing import List -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("triangular") -class TriangularSchedule(LegacyFairseqLRScheduler): + +@dataclass +class TriangularLRScheduleConfig(FairseqDataclass): + max_lr: float = field( + default="???", metadata={"help": "max learning rate, must be more than cfg.lr"} + ) + lr_period_updates: float = field( + default=5000, + metadata={"help": "initial number of updates per period (cycle length)"}, + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + shrink_min: bool = field( + default=False, metadata={"help": "if set, also shrinks min lr"} + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig) +class TriangularLRSchedule(FairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. See https://arxiv.org/pdf/1506.01186.pdf for details. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: TriangularLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with triangular." " Consider --lr-scheduler=fixed instead." ) - lr = args.lr[0] + lr = cfg.lr[0] - assert args.max_lr > lr, "max_lr must be more than lr" + assert cfg.max_lr > lr, "max_lr must be more than lr" self.min_lr = lr - self.max_lr = args.max_lr - self.stepsize = args.lr_period_updates // 2 - self.lr_shrink = args.lr_shrink - self.shrink_min = args.shrink_min + self.max_lr = cfg.max_lr + self.stepsize = cfg.lr_period_updates // 2 + self.lr_shrink = cfg.lr_shrink + self.shrink_min = cfg.shrink_min # initial learning rate self.lr = self.min_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--max-lr', required=True, type=float, metavar='LR', - help='max learning rate, must be more than args.lr') - parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', - help='initial number of updates per period (cycle length)') - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing') - parser.add_argument('--shrink-min', action='store_true', - help='if set, also shrinks min lr') - # fmt: on - def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" super().step(epoch, val_loss) From 3b77a6160097408e01883c69e6f8fed017266311 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 05:59:25 -0800 Subject: [PATCH 309/707] Add fairseq-hydra-train and update docs (#1449) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1449 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25094525 Pulled By: myleott fbshipit-source-id: 430387d11196d3292933bb168cf09ea16ebc0d3b --- docs/hydra_integration.md | 226 ++++++++++++++++++------------- examples/wav2vec/README.md | 21 ++- fairseq/config/config.yaml | 6 +- fairseq/dataclass/configs.py | 6 + fairseq/modules/cross_entropy.py | 6 +- fairseq_cli/hydra_train.py | 34 ++++- setup.py | 17 ++- 7 files changed, 203 insertions(+), 113 deletions(-) diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index f924de961b..8e4082cb24 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -1,57 +1,70 @@ ## Hydra -[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of -research and other complex applications. The key feature is the ability to dynamically create a hierarchical -configuration by composition and override it through config files and the command line. The name Hydra comes from its -ability to run multiple similar jobs - much like a Hydra with multiple heads. +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python +framework that simplifies the development of research and other complex +applications. The key feature is the ability to dynamically create a +hierarchical configuration by composition and override it through config files +and the command line. The name Hydra comes from its ability to run multiple +similar jobs - much like a Hydra with multiple heads. ## Motivation -Until recently, all components in fairseq were configured through a shared "args" namespace that was created at -application startup. Components declared their own "add_args" method to update the argparse parser, hoping that -the names would not clash with arguments from other components. While this model works for smaller applications, -as fairseq grew and became integrated into other applications, this became problematic. -In order to determine how to configure each component, one needed to a) examine what args were added by this component, and -b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing -models involved sharing commands that often contained dozens of command line switches. - -The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time -in the future. - -New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this -component. The dataclass is registered along with the component, and fairseq takes care of constructing and -providing this configuration object to the component's constructor. Note that sharing parameters can optionally -still work, but one has to explicitly point to the "source of truth" (see inheritance example below). -These changes make components in fairseq -more independent and re-usable by other applications: all that is needed to create a component is to initialize its -dataclass and overwrite some of the defaults. - -While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still -fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through -hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run -an identically configured job. - -Additionally, Hydra has a rich and growing -[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as -hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library), -job launching across various platforms, and more. +Until recently, all components in fairseq were configured through a shared +`args` namespace that was created at application startup. Components declared +their own `add_args` method to update the argparse parser, hoping that the names +would not clash with arguments from other components. While this model works for +smaller applications, as fairseq grew and became integrated into other +applications, this became problematic. In order to determine how to configure +each component, one needed to a) examine what args were added by this component, +and b) read the code to figure out what shared arguments it is using that were +added in other places. Reproducing models involved sharing commands that often +contained dozens of command line switches. + +The model described above is still supported by fairseq for backward +compatibility, but will be deprecated some time in the future. + +New components in fairseq should now create a dataclass that encapsulates all +parameters required to configure this component. The dataclass is registered +along with the component, and fairseq takes care of constructing and providing +this configuration object to the component's constructor. Note that sharing +parameters can optionally still work, but one has to explicitly point to the +"source of truth" (see inheritance example below). These changes make components +in fairseq more independent and re-usable by other applications: all that is +needed to create a component is to initialize its dataclass and overwrite some +of the defaults. + +While configuring fairseq through command line (using either the legacy argparse +based or the new Hydra based entry points) is still fully supported, you can now +take advantage of configuring fairseq completely or piece-by-piece through +hierarchical YAML configuration files. These files can also be shipped as +examples that others can use to run an identically configured job. + +Additionally, Hydra has a rich and growing [library of +plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that +provide functionality such as hyperparameter sweeping (including using bayesian +optimization through the [Ax](https://github.com/facebook/Ax) library), job +launching across various platforms, and more. ## Creating or migrating components -In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same -file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be -present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added -to the FairseqConfig object. - -Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass -decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility). -Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as +In general, each new (or updated) component should provide a companion +[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are +typically located in the same file as the component and are passed as arguments +to the `register_*()` functions. Top-level configs that should be present in +every fairseq application are placed in the +[global](fairseq/dataclass/configs.py) config file and added to the +`FairseqConfig` object. + +Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These +classes are decorated with a `@dataclass` decorator, and typically inherit from +`FairseqDataclass` (which adds some functionality for backward compatibility). +Each field must have a type, and generally has metadata (such as a help string) +and a default value. Only primitive types or other config objects are allowed as data types for each field. - Example: - +#### Example: -``` python +```python from dataclasses import dataclass, field from fairseq.dataclass import FairseqDataclass @@ -71,11 +84,12 @@ class InteractiveConfig(FairseqDataclass): ### Inherting values -Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to -know the initial learning rate value. One can declare a field that, by default, will -inherit its value from another config node in the same hierarchy: +Some components require sharing a value. For example, a learning rate scheduler +and an optimizer may both need to know the initial learning rate value. One can +declare a field that, by default, will inherit its value from another config +node in the same hierarchy: -``` python +```python @dataclass FairseqAdamConfig(FairseqDataclass): ... @@ -83,18 +97,21 @@ FairseqAdamConfig(FairseqDataclass): ... ``` -`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through -command line to achieve the same effect. Note that this assumes that there is an "optimization" config object -in the root config and it has a field called "lr". +`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is +the value one can use in a YAML config file or through command line to achieve +the same effect. Note that this assumes that there is an "optimization" config +object in the root config and it has a field called "lr". ### Tasks and Models -Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes, -while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions. +Creating Tasks and Models works same as before, except that legacy +implementations now inherit from `LegacyFairseq*` base classes, while new +components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass +to the `register_*()` functions. -Task example: +#### Task example: -``` python +```python @dataclass class LanguageModelingConfig(FairseqDataclass): data: Optional[str] = field( @@ -110,9 +127,9 @@ class LanguageModelingTask(LegacyFairseqTask): ... ``` -Model example: +#### Model example: -``` python +```python @dataclass class TransformerLanguageModelConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( @@ -131,9 +148,10 @@ class TransformerLanguageModel(FairseqLanguageModel): ### Other components -Other components work as before, but they now take their configuration dataclass as the only constructor argument: +Other components work as before, but they now take their configuration dataclass +as the only constructor argument: -``` python +```python @dataclass class MosesTokenizerConfig(FairseqDataclass): source_lang: str = field(default="en", metadata={"help": "source language"}) @@ -145,50 +163,61 @@ class MosesTokenizer(object): ... ``` -Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in -fairseq/dataclass/configs.py: +Note that if you are adding a new registry for a new set of components, you need +to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`: -``` python +```python @dataclass class FairseqConfig(object): ... my_new_registry: Any = None ``` -## Training with hydra_train.py +## Training with `fairseq-hydra-train` -To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the -hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py, -will remain supported for the foreseeable future but will be deprecated eventually. +To fully take advantage of configuration flexibility offered by Hydra, you may +want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI +tools such as `fairseq-train` will remain supported for the foreseeable future +but will be deprecated eventually. -On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses -populated with their default values in the code. The default values are overwritten by values found in YAML files in -fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values -provided through command line arguments. +On startup, Hydra will create a configuration object that contains a hierarchy +of all the necessary dataclasses populated with their default values in the +code. The default values are overwritten by values found in YAML files in +`fairseq/config` directory (which currently sets minimal defaults) and then +further overwritten by values provided through command line arguments. Some of the most common use cases are shown below: -### 1. Overwrite default values through command line: +### 1. Override default values through command line: ```shell script -python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \ -model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 - +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=data-bin \ + model=transformer_lm/transformer_lm_gpt \ + task=language_modeling \ + optimization.max_update=5000 ``` -Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` -over the default values in the dataclass. If you want to train a model without specifying a particular architecture -you can simply specify model=transformer_lm. This only works for migrated tasks and models. +Note that along with explicitly providing values for parameters such as +`dataset.batch_size`, this also tells Hydra to overlay configuration found in +`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default +values in the dataclass. If you want to train a model without specifying a +particular architecture you can simply specify `model=transformer_lm`. This only +works for migrated tasks and models. ### 2. Replace bundled configs with an external config: ```shell script -python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103 +$ fairseq-hydra-train \ + --config-path /path/to/external/configs \ + --config-name wiki103 ``` -where /path/to/external/configs/wiki103.yaml contains: +where `/path/to/external/configs/wiki103.yaml` contains: -``` yaml +```yaml # @package _group_ model: @@ -211,24 +240,38 @@ lr_scheduler: _name: cosine ``` -Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config). +Note that here bundled configs from `fairseq/config` directory are not used, +however the defaults from each dataclass will still be used (unless overwritten +by your external config). -Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields -(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your -top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this). +Additionally you can choose to break up your configs by creating a directory +structure in the same location as your main config file, with the names of the +top-level fields (such as "model", "dataset", etc), and placing config files +with meaningful names that would populate that specific section of your +top-level config file (for example, you might have +`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You +can then specify the correct configuration via command line, defaults in the +main config, or even launch all of them as a sweep (see Hydra documentation on +how to do this). ### 3. Add an external config directory to Hydra search path: -This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration. +This allows combining default configuration (including using any bundled config +files), while specifying your own config files for some parts of the +configuration. ```shell script -python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \ -task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \ ---config-dir /path/to/external/configs - +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=/path/to/data/ \ + model=transformer_lm/2_layers \ + task=language_modeling \ + optimization.max_update=5000 \ + --config-dir /path/to/external/configs ``` -where /path/to/external/configs has the following structure: +where `/path/to/external/configs` has the following structure: ``` . +-- model @@ -236,5 +279,6 @@ where /path/to/external/configs has the following structure: | | +-- 2_layers.yaml ``` -and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add -other configs to configure other components as well. +and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with +`decoder_layers` set to 2. You can add other configs to configure other +components as well. diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 442a92553a..fdbf844ec7 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -56,8 +56,10 @@ This configuration was used for the base model trained on the Librispeech datase Note that the input is expected to be single channel, sampled at 16 kHz ```shell script -$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ ---config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_base_librispeech ``` Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path) @@ -68,8 +70,10 @@ Note: you can simulate 64 GPUs by using k GPUs and adding command line parameter This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper ```shell script -$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ ---config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_large_librivox +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox ``` Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path) @@ -88,9 +92,12 @@ $ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $sp Fine-tuning on 100h of Librispeech with letter targets: ```shell script -python fairseq_cli/hydra_train.py distributed_training.distributed_port=$PORT task.data=/path/to/data \ -model.w2v_path=/path/to/model.pt --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \ ---config-name base_100h +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h ``` There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index 039609aece..9621baa5e9 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -1,10 +1,10 @@ # @package _group_ defaults: - - task: language_modeling + - task: null - model: null - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: cosine + - optimizer: null + - lr_scheduler: fixed - bpe: null - tokenizer: null - scoring: null diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 28dc8905c7..36d88d83f7 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -173,6 +173,12 @@ class CommonConfig(FairseqDataclass): profile: bool = field( default=False, metadata={"help": "enable autograd profiler emit_nvtx"} ) + reset_logging: bool = field( + default=True, + metadata={ + "help": "when using Hydra, reset the logging at the beginning of training" + }, + ) @dataclass diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py index 0d2beb44bb..6f33c24cb5 100644 --- a/fairseq/modules/cross_entropy.py +++ b/fairseq/modules/cross_entropy.py @@ -26,12 +26,14 @@ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): import xentropy_cuda from apex.contrib import xentropy - logger.info("using fused cross entropy") - def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): if logits.device == torch.device("cpu"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction) else: + if not getattr(cross_entropy, "_has_logged_once", False): + logger.info("using fused cross entropy") + cross_entropy._has_logged_once = True + half_to_float = logits.dtype == torch.half losses = xentropy.SoftmaxCrossEntropyLoss.apply( logits, diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index ffd3c5cd07..b092ce14ee 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -4,29 +4,32 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import hydra -from omegaconf import OmegaConf +import logging import os +import sys from fairseq.dataclass.initialize import hydra_init from fairseq_cli.train import main as pre_main from fairseq import distributed_utils from fairseq.dataclass.configs import FairseqConfig -import logging +import hydra import torch +from omegaconf import OmegaConf -logger = logging.getLogger(__name__) +logger = logging.getLogger("fairseq_cli.hydra_train") @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> None: - cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) + if cfg.common.reset_logging: + reset_logging() # Hydra hijacks logging, fix that + if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -35,7 +38,22 @@ def hydra_main(cfg: FairseqConfig) -> None: distributed_utils.call_main(cfg, pre_main) -if __name__ == "__main__": +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +def cli_main(): try: from hydra._internal.utils import get_args @@ -46,3 +64,7 @@ def hydra_main(cfg: FairseqConfig) -> None: hydra_init(cfg_name) hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/setup.py b/setup.py index 2aae720d7e..6bc450a7fa 100644 --- a/setup.py +++ b/setup.py @@ -22,14 +22,18 @@ def write_version_py(): # append latest commit hash to version string try: - sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + sha = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .decode("ascii") + .strip() + ) version += "+" + sha[:7] except Exception: pass # write version info to fairseq/version.py with open(os.path.join("fairseq", "version.py"), "w") as f: - f.write("__version__ = \"{}\"\n".format(version)) + f.write('__version__ = "{}"\n'.format(version)) return version @@ -194,7 +198,8 @@ def do_setup(package_data): "tests", "tests.*", ] - ) + extra_packages, + ) + + extra_packages, package_data=package_data, ext_modules=extensions, test_suite="tests", @@ -202,6 +207,7 @@ def do_setup(package_data): "console_scripts": [ "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-hydra-train = fairseq_cli.hydra_train:cli_main", "fairseq-interactive = fairseq_cli.interactive:cli_main", "fairseq-preprocess = fairseq_cli.preprocess:cli_main", "fairseq-score = fairseq_cli.score:cli_main", @@ -230,8 +236,11 @@ def get_files(path, relative_to="fairseq"): fairseq_examples = os.path.join("fairseq", "examples") if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): os.symlink(os.path.join("..", "examples"), fairseq_examples) + package_data = { - "fairseq": get_files("fairseq/examples"), + "fairseq": ( + get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) + ) } do_setup(package_data) finally: From 94f59bb67bf48d2913dc223fc20b8e94c7ed1bab Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 12:40:49 -0800 Subject: [PATCH 310/707] Remove unused train_masked_language_model helper (#1452) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1452 Test Plan: Imported from OSS Reviewed By: lematt1991 Differential Revision: D25108462 Pulled By: myleott fbshipit-source-id: 3c17a9937a4c3edb69f64130dfd866c5f42a4aaf --- tests/test_binaries.py | 66 ------------------------------------------ 1 file changed, 66 deletions(-) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 6dd95cb4a5..a53d84118b 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1673,71 +1673,5 @@ def eval_lm_main(data_dir, extra_flags=None): eval_lm.main(eval_lm_args) -def train_masked_language_model(data_dir, arch, extra_args=()): - train_parser = options.get_training_parser() - # TODO: langs should be in and out right? - train_args = options.parse_args_and_arch( - train_parser, - [ - "--task", - "cross_lingual_lm", - data_dir, - "--arch", - arch, - # Optimizer args - "--optimizer", - "adam", - "--lr-scheduler", - "reduce_lr_on_plateau", - "--lr-shrink", - "0.5", - "--lr", - "0.0001", - "--min-lr", - "1e-09", - # dropout, attention args - "--dropout", - "0.1", - "--attention-dropout", - "0.1", - # MLM args - "--criterion", - "masked_lm_loss", - "--masked-lm-only", - "--monolingual-langs", - "in,out", - "--num-segment", - "5", - # Transformer args: use a small transformer model for fast training - "--encoder-layers", - "1", - "--encoder-embed-dim", - "32", - "--encoder-attention-heads", - "1", - "--encoder-ffn-embed-dim", - "32", - # Other training args - "--max-tokens", - "500", - "--tokens-per-sample", - "500", - "--save-dir", - data_dir, - "--max-epoch", - "1", - "--no-progress-bar", - "--distributed-world-size", - "1", - "--dataset-impl", - "raw", - "--num-workers", - "0", - ] - + list(extra_args), - ) - train.main(train_args) - - if __name__ == "__main__": unittest.main() From fa113ff1dee60bf493b6e61820303ab9d72cabcb Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 12:40:49 -0800 Subject: [PATCH 311/707] Add test for activation checkpointing (#1453) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1453 Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D25108463 Pulled By: myleott fbshipit-source-id: 3cebce9be7fe503401eabba3f483c26847e7a3c0 --- fairseq/modules/checkpoint_activations.py | 5 +++-- tests/test_binaries.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index 1f99c24ca1..e0e5679c5a 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Tuple, Union import torch +import torch.utils.checkpoint as checkpoint from fairseq import utils @@ -133,7 +134,7 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation - torch.utils.checkpoint.check_backward_validity(args) + checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys @@ -165,7 +166,7 @@ def backward(ctx, *args): ) tensor_inputs = ctx.saved_tensors - tensor_inputs = torch.utils.checkpoint.detach_variable(tensor_inputs) + tensor_inputs = checkpoint.detach_variable(tensor_inputs) inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. diff --git a/tests/test_binaries.py b/tests/test_binaries.py index a53d84118b..8235702383 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -249,6 +249,29 @@ def test_transformer(self): ) generate_main(data_dir) + def test_transformer_with_activation_checkpointing(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--checkpoint-activations", + ], + run_validation=True, + ) + generate_main(data_dir) + def test_multilingual_transformer(self): # test with all combinations of encoder/decoder lang tokens encoder_langtok_flags = [ From d464af2feb4f5f7149e10241cf1d064071f404e1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 12:40:49 -0800 Subject: [PATCH 312/707] Fix NAT code (#1454) Summary: D23752010 (https://github.com/pytorch/fairseq/commit/add65adcc53a927f99a717d90a9672765237d937) broke some GPU-only tests for NAT. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1454 Test Plan: Imported from OSS Reviewed By: jmp84 Differential Revision: D25108461 Pulled By: myleott fbshipit-source-id: f32b890221578c421944d6f9a49f06ef1dc075c6 --- fairseq/models/nat/fairseq_nat_model.py | 33 ++++++++++++------- fairseq/models/nat/levenshtein_transformer.py | 8 ++--- .../models/nat/nonautoregressive_ensembles.py | 9 +++-- tests/gpu/test_binaries_gpu.py | 26 ++++++++++----- tests/utils.py | 6 ++-- 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py index 1dbc29d0f4..b09394112f 100644 --- a/fairseq/models/nat/fairseq_nat_model.py +++ b/fairseq/models/nat/fairseq_nat_model.py @@ -18,18 +18,23 @@ def ensemble_encoder(func): def wrapper(self, *args, **kwargs): if self.ensemble_models is None or len(self.ensemble_models) == 1: return func(self, *args, **kwargs) - encoder_outs = [func(model, *args, **kwargs) for model in self.ensemble_models] - _encoder_out = encoder_outs[0] + encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models] + _encoder_out = encoder_outs[0].copy() def stack(key): - outs = [getattr(e, key) for e in encoder_outs] - return torch.stack(outs, -1) if outs[0] is not None else None + outs = [e[key][0] for e in encoder_outs] + return [torch.stack(outs, -1) if outs[0] is not None else None] - return _encoder_out._replace( - encoder_out=stack("encoder_out"), - encoder_embedding=stack("encoder_embedding"), - encoder_states=stack("encoder_states"), - ) + _encoder_out["encoder_out"] = stack("encoder_out") + _encoder_out["encoder_embedding"] = stack("encoder_embedding") + + num_layers = len(_encoder_out["encoder_states"]) + if num_layers > 0: + _encoder_out["encoder_states"] = [ + torch.stack([e["encoder_states"][i] for e in encoder_outs], -1) + for i in range(num_layers) + ] + return _encoder_out return wrapper @@ -41,12 +46,18 @@ def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs ) + def _replace(encoder_out, new_val): + new_encoder_out = encoder_out.copy() + new_encoder_out["encoder_out"] = [new_val] + return new_encoder_out + action_outs = [ func( model, normalize=normalize, - encoder_out=encoder_out._replace( - encoder_out=encoder_out.encoder_out[:, :, :, i] + encoder_out=_replace( + encoder_out, + encoder_out["encoder_out"][0][:, :, :, i] ), *args, **kwargs diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index 17f1ee99be..9377c3c7f5 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -149,11 +149,11 @@ def forward_decoder( if max_ratio is None: max_lens = torch.zeros_like(output_tokens).fill_(255) else: - if encoder_out.encoder_padding_mask is None: - max_src_len = encoder_out.encoder_out.size(0) - src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len) + if not encoder_out["encoder_padding_mask"]: + max_src_len = encoder_out["encoder_out"].size(0) + src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len) else: - src_lens = (~encoder_out.encoder_padding_mask).sum(1) + src_lens = (~encoder_out["encoder_padding_mask"][0]).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() # delete words diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py index 46bb8aac43..705a04fb49 100644 --- a/fairseq/models/nat/nonautoregressive_ensembles.py +++ b/fairseq/models/nat/nonautoregressive_ensembles.py @@ -83,14 +83,13 @@ def forward_decoder( if max_ratio is None: max_lens = output_tokens.new().fill_(255) else: - if encoder_outs[0].encoder_padding_mask is None: + if not encoder_outs[0]["encoder_padding_mask"]: src_lens = ( - encoder_outs[0] - .encoder_out.new(bsz) - .fill_(encoder_outs[0].encoder_out.size(1)) + encoder_outs[0]["encoder_out"][0].new(bsz) + .fill_(encoder_outs[0]["encoder_out"][0].size(1)) ) else: - src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1) + src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() # delete words diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 2ac60a0934..5690e73752 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -93,17 +93,25 @@ def test_levenshtein_transformer(self): ], task="translation_lev", ) + gen_config = [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ] + # non-ensemble generation + generate_main(data_dir, gen_config) + # ensemble generation generate_main( data_dir, - [ - "--task", - "translation_lev", - "--iter-decode-max-iter", - "9", - "--iter-decode-eos-penalty", - "0", - "--print-step", - ], + gen_config, + path=os.pathsep.join([ + os.path.join(data_dir, "checkpoint_last.pt"), + os.path.join(data_dir, "checkpoint_last.pt"), + ]), ) diff --git a/tests/utils.py b/tests/utils.py index a145aa587d..178df5763e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -345,18 +345,20 @@ def train_translation_model( validate.main(validate_args) -def generate_main(data_dir, extra_flags=None): +def generate_main(data_dir, extra_flags=None, path=None): if extra_flags is None: extra_flags = [ "--print-alignment", ] + if path is None: + path = os.path.join(data_dir, "checkpoint_last.pt") generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, [ data_dir, "--path", - os.path.join(data_dir, "checkpoint_last.pt"), + path, "--beam", "3", "--batch-size", From e419db74e2ae557c60b6cbf304ae9f8cc812d9dd Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Fri, 20 Nov 2020 14:45:04 -0800 Subject: [PATCH 313/707] Add extra logging before/after checkpointing Summary: Makes it easier for ppl to notice if things break in the middle of writing checkpoint (ex: OOMing) (Also helps provide timing stats for how long it took to write checkpoints) Reviewed By: donhusa Differential Revision: D25120107 fbshipit-source-id: 35a7e9b7fe22a1ffa25fb8b461e7b7bef09fa063 --- fairseq/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d7ba0be874..e9ac3c6bb1 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -178,10 +178,7 @@ def criterion(self): @property def model(self): if self._wrapped_model is None: - if ( - self.data_parallel_world_size > 1 - and not self.cfg.optimization.use_bmuf - ): + if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf: self._wrapped_model = models.DistributedFairseqModel( self.cfg.distributed_training, self._model, @@ -266,6 +263,10 @@ def consolidate_optimizer(self): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint + logger.info( + f"Preparing to save checkpoint to {filename} after " + f"{self.get_num_updates()} updates" + ) extra_state["metrics"] = metrics.state_dict() extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( @@ -279,6 +280,7 @@ def save_checkpoint(self, filename, extra_state): self._optim_history, extra_state, ) + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, From 521fccf93c821cabb3686b768f9d9152486b5bd6 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 20 Nov 2020 19:16:50 -0800 Subject: [PATCH 314/707] Fix torch.distributed.launch (fixes #2924) (#1456) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1456 Reviewed By: alexeib Differential Revision: D25133448 Pulled By: myleott fbshipit-source-id: 8a7573b69c471b237fffdfc7874f9f6b51143f5a --- fairseq/dataclass/configs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 36d88d83f7..2cbfc9560a 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -209,10 +209,6 @@ class DistributedTrainingConfig(FairseqDataclass): }, ) device_id: int = field( - default=0, - metadata={"help": "which GPU to use (usually configured automatically)"}, - ) - local_rank: int = field( default=0, metadata={ "help": "which GPU to use (usually configured automatically)", From 2b854d8517c3398e576ff4bf8cc5f59c90af74c4 Mon Sep 17 00:00:00 2001 From: Kritika Singh Date: Sat, 21 Nov 2020 12:02:02 -0800 Subject: [PATCH 315/707] wav2vec in PySpeech Summary: Define `audio_pretraining_fb` task similar to [`audio_pretraining`](https://www.internalfb.com/intern/codesearch/path/fbsource/fbcode/deeplearning/projects/fairseq-py/fairseq/tasks/audio_pretraining.py) in fairseq. I was earlier trying to stay as close to the fairseq task as possible but sub-classing `FbSpeechRecognitionTask` from pyspeech made decoding much easier. The main differences between the fairseq and pyspeech end-to-end processes are: a. Inputs to the training `fairseq`: train.tsv, train.labels.txt, valid.tsv, valid.labels.txt, dict.labels.txt `pyspeech`: data.json. An example is: ```{ "fairseq_dict": "spm_char_32_fairseq.dict", "train": { "handles_file": "train_dataaug_full.tsv", "transforms": [ ["RawAudioDatasetTransform", {}], ["SentencePieceEncodeTransform", {"model_path": "spm_char_32.model", "append_eos": false}] ] }, "valid": { "handles_file": "valid_pages_10s_full.tsv", "transforms": [ ["RawAudioDatasetTransform", {}], ["SentencePieceEncodeTransform", {"model_path": "spm_char_32.model", "append_eos": false}] ] } } ``` The handles files follow the format `{handle}\t{length}\t{optional_ref}` b. Encoding the reference into target units: fairseq process was to use a dictionary file to specify the target units, an offline script [libri_labels.py](https://www.internalfb.com/intern/codesearch/path/fbsource/fbcode/deeplearning/projects/fairseq-py/examples/wav2vec/libri_labels.py) to split the reference and LabelEncoder to get the target unit sequence. This is replaced by [`SentencepieceEncodeTransform`](https://www.internalfb.com/intern/codesearch/path/fbsource/fbcode/deeplearning/projects/pyspeech/pyspeech/data/transforms/sentence_piece_encode_transform.py) such that when you pass a sentencepiece model and dictionary (from sentencepiece library) to the training, the transformations happen on the fly. c. FairseqDataset class: I use `FbEverstoreDataset` in place of `EverstoreAudioDataset` (from fairseq) to be able to use the existing pre-processing transforms like `SentencepieceEncodeTransform` and easier integration with CTC decoder. This required copying over the audio processing and collater code pieces from RawAudioDataset to new PySpeechTransform and collator classes. Reviewed By: alexeib Differential Revision: D24265820 fbshipit-source-id: 68e2fef38a0cc1cf316410d83ed405d62a810578 --- fairseq/criterions/ctc.py | 2 +- fairseq/tasks/audio_pretraining.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index deab4f2650..b218175f21 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -66,7 +66,7 @@ class CtcCriterionConfig(FairseqDataclass): class CtcCriterion(FairseqCriterion): def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): super().__init__(task) - self.blank_idx = task.target_dictionary.bos() + self.blank_idx = task.target_dictionary.index(task.blank_symbol) self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() self.post_process = cfg.post_process diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index a2f7edc34d..6ea40a813f 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -106,6 +106,7 @@ def __init__( self._source_dictionary = source_dictionary if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" + self.blank_symbol = "" @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): From b1b02a828fec708aea1718e5336dd941a24f4276 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 21 Nov 2020 18:39:00 -0800 Subject: [PATCH 316/707] Fixed min-lr / max-lr swap on cosine schedule. (#2916) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? The cosine learning rate scheduler was implemented incorrectly. It annealed to the learning rate (`--lr`) instead of the minimum learning rate (`--min-lr`). This implementation is consistent with the PyTorch [CosineAnnealingLR](https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L461). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2916 Reviewed By: alexeib Differential Revision: D25146468 Pulled By: myleott fbshipit-source-id: 8704b6954dd40692eb930b882fecfa799ea98b00 --- fairseq/optim/lr_scheduler/cosine_lr_scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index ef8645cd58..d73c7cc7ed 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -79,8 +79,10 @@ def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): if cfg.warmup_init_lr < 0: cfg.warmup_init_lr = lr - self.min_lr = lr - self.max_lr = cfg.max_lr + # default min_lr=-1 -> cosine anneale to lr=0.0 + # otherwise pick min_lr from config + self.min_lr = cfg.min_lr if cfg.min_lr > 0.0 else 0.0 + self.max_lr = lr assert self.max_lr > self.min_lr, "max_lr must be more than lr" self.t_mult = cfg.t_mult From 158bd0321c4b915e4bddf738f5cb9d72d192f969 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 21 Nov 2020 18:43:09 -0800 Subject: [PATCH 317/707] Add .github/stale.yml (#2932) Summary: Mostly copied from https://github.com/facebook/react/blob/master/.github/stale.yml Pull Request resolved: https://github.com/pytorch/fairseq/pull/2932 Reviewed By: alexeib Differential Revision: D25146465 Pulled By: myleott fbshipit-source-id: c11d695dcbd2f18609c04af2e520317977797e0f --- .github/stale.yml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/stale.yml diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000000..b12867dab0 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,30 @@ +# Configuration for probot-stale - https://github.com/probot/stale +# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 90 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - bug +# Label to use when marking an issue as stale +staleLabel: stale +issues: + # Comment to post when marking an issue as stale. + markComment: > + This issue has been automatically marked as stale. + **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment! + # Comment to post when closing a stale issue. + closeComment: > + Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you! +pulls: + # Comment to post when marking a pull request as stale. + markComment: > + This pull request has been automatically marked as stale. + **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated. + # Comment to post when closing a stale pull request. + closeComment: > + Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you! + From b889b52ae9b91a0114112d00735df56c1aa36fad Mon Sep 17 00:00:00 2001 From: Xu Song Date: Sat, 21 Nov 2020 19:29:07 -0800 Subject: [PATCH 318/707] Update hub_utils.py (#2910) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? has no attribute 'arg' Pull Request resolved: https://github.com/pytorch/fairseq/pull/2910 Reviewed By: alexeib Differential Revision: D25146481 Pulled By: myleott fbshipit-source-id: 11912bb2bcacd1d2f91da47bb0d868da90b38f17 --- fairseq/hub_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 3be7078b7a..1819a9460a 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -180,7 +180,7 @@ def generate( if verbose: def getarg(name, default): - return getattr(gen_args, name, getattr(self.args, name, default)) + return getattr(gen_args, name, getattr(self.cfg, name, default)) for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): src_str_with_unk = self.string(source_tokens) From 8e328aec86a2afee83c77095d0b5bb2c449ed5c4 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 21 Nov 2020 22:50:04 -0800 Subject: [PATCH 319/707] minor fixes (#1457) Summary: - some minor fixes a) secondary loss logging in wav2vec criterion b) ability to have nested Dict[...] inside config objects c) remove debug param Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1457 Reviewed By: myleott Differential Revision: D25145151 Pulled By: alexeib fbshipit-source-id: 21e6422b91151b00b929447f0c73deced56450cb --- fairseq/criterions/ctc.py | 5 ++++- fairseq/criterions/wav2vec_criterion.py | 15 ++++++--------- fairseq/dataclass/utils.py | 3 ++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index b218175f21..0e4e3577d2 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -118,7 +118,10 @@ def forward(self, model, sample, reduce=True): sample["target"] != self.eos_idx ) targets_flat = sample["target"].masked_select(pad_mask) - target_lengths = sample["target_lengths"] + if "target_lengths" in sample: + target_lengths = sample["target_lengths"] + else: + target_lengths = pad_mask.sum(-1) with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss( diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 3a58390088..8a1c348a58 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -41,7 +41,7 @@ def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): self.loss_weights = loss_weights self.log_keys = [] if log_keys is None else log_keys - def forward(self, model, sample, reduce=True, log_pred=False): + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -125,9 +125,6 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output["correct"] = corr logging_output["count"] = count - if log_pred: - logging_output["logits"] = logits.cpu().numpy() - logging_output["target"] = target.cpu().numpy() return loss, sample_size, logging_output @staticmethod @@ -175,13 +172,13 @@ def reduce_metrics(logging_outputs) -> None: for k in logging_outputs[0]: if k not in builtin_keys: - val = sum(log.get(k, 0) for log in logging_outputs) / len( - logging_outputs - ) + val = sum(log.get(k, 0) for log in logging_outputs) if k.startswith("loss"): - metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) + metrics.log_scalar( + k, val / sample_size / math.log(2), sample_size, round=3 + ) else: - metrics.log_scalar(k, val, round=3) + metrics.log_scalar(k, val / len(logging_outputs), round=3) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 694d878308..9bf4f7d09f 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -441,6 +441,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): dc_instance = DictConfig(dc) dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"] - cfg = OmegaConf.merge(dc_instance, cfg) + with open_dict(dc_instance): + cfg = OmegaConf.merge(dc_instance, cfg) OmegaConf.set_struct(cfg, True) return cfg From 7cdef0a9bce575738ffb7b3c5fcad07181f149ac Mon Sep 17 00:00:00 2001 From: tuanh208 Date: Sat, 21 Nov 2020 23:00:33 -0800 Subject: [PATCH 320/707] Adding --mask-multiple-length and --mask-stdev options to masked_lm task (#2846) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: NOTE: this implements span masking for RoBERTa as described in vq-wav2vec paper # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Adding --mask-multiple-length and --mask-stdev options to masked_lm task, allowing to mask sequences of multiple lengths when training a masked language model. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Yes � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2846 Reviewed By: myleott Differential Revision: D25007978 Pulled By: alexeib fbshipit-source-id: a8b3bcb260c8308641362c8c59706f08142e6be9 --- fairseq/data/mask_tokens_dataset.py | 43 +++++++++++++++++++++++++++-- fairseq/tasks/masked_lm.py | 11 ++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 8ea86245f7..b239013c80 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -39,6 +39,10 @@ class MaskTokensDataset(BaseWrapperDataset): over vocab indices, indicating whether it is the beginning of a word. We will extend any mask to encompass the whole word. bpe: BPE to use for whole-word masking. + mask_multiple_length : repeat each mask index multiple times. Default + value is 1. + mask_stdev : standard deviation of masks distribution in case of + multiple masking. Default value is 0. """ @classmethod @@ -63,11 +67,15 @@ def __init__( random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, mask_whole_words: torch.Tensor = None, + mask_multiple_length: int = 1, + mask_stdev: float = 0.0, ): assert 0.0 < mask_prob < 1.0 assert 0.0 <= random_token_prob <= 1.0 assert 0.0 <= leave_unmasked_prob <= 1.0 assert random_token_prob + leave_unmasked_prob <= 1.0 + assert mask_multiple_length >= 1 + assert mask_stdev >= 0.0 self.dataset = dataset self.vocab = vocab @@ -79,6 +87,8 @@ def __init__( self.leave_unmasked_prob = leave_unmasked_prob self.random_token_prob = random_token_prob self.mask_whole_words = mask_whole_words + self.mask_multiple_length = mask_multiple_length + self.mask_stdev = mask_stdev if random_token_prob > 0.0: if freq_weighted_replacement: @@ -122,10 +132,39 @@ def __getitem__(self, index: int): mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding - self.mask_prob * sz + self.mask_prob * sz / float(self.mask_multiple_length) + np.random.rand() ) - mask[np.random.choice(sz, num_mask, replace=False)] = True + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = np.random.choice(sz, num_mask, replace=False) + if self.mask_stdev > 0.0: + lengths = np.random.normal( + self.mask_multiple_length, self.mask_stdev, size=num_mask + ) + lengths = [max(0, int(round(x))) for x in lengths] + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ], + dtype=np.int64, + ) + else: + mask_idc = np.concatenate( + [mask_idc + i for i in range(self.mask_multiple_length)] + ) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print( + "Assigning mask indexes {} to mask {} failed!".format( + mask_idc, mask + ) + ) + raise if self.return_masked_tokens: # exit early if we're just returning the masked tokens diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 56086f5e81..70208bc4d5 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -88,6 +88,15 @@ def add_args(parser): action="store_true", help="mask whole words; you may also want to set --bpe", ) + parser.add_argument( + "--mask-multiple-length", + default=1, + type=int, + help="repeat the mask indices multiple times", + ) + parser.add_argument( + "--mask-stdev", default=0.0, type=float, help="stdev of the mask length" + ) parser.add_argument( "--shorten-method", default="none", @@ -180,6 +189,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, + mask_multiple_length=self.args.mask_multiple_length, + mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed + epoch): From 1519c80aef44849713581afdd83475e832f0dc1c Mon Sep 17 00:00:00 2001 From: alexeib Date: Sun, 22 Nov 2020 09:33:02 -0800 Subject: [PATCH 321/707] workaround hydra + submit it not supporting custom enums (#1458) Summary: this allows using submitit launcher with hydra to launch fairseq jobs something like this now works: ``` python fairseq_cli/hydra_train.py --multirun hydra/launcher=submitit_slurm hydra.launcher.cpus_per_task=80 hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=1 hydra.launcher.nodes=2 hydra.launcher.partition=dev hydra.launcher.mem_gb=400 distributed_training.distributed_world_size=16 distributed_training.distributed_port=33333 +optimization.update_freq='[2]' --config-path /private/home/abaevski/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech ``` (note that one has to specify distributed_port for this to work) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1458 Reviewed By: myleott Differential Revision: D25150369 Pulled By: alexeib fbshipit-source-id: 63b74a437fb92afff8b0faa579d07f4539a2f1d8 --- fairseq/dataclass/constants.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index fad04f3482..858f77a863 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -3,11 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, EnumMeta from typing import List -class StrEnum(Enum): +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): def __str__(self): return self.value From 74fc8ccce1cfb7367cb57e6605e125030f2b0d31 Mon Sep 17 00:00:00 2001 From: alexeib Date: Sun, 22 Nov 2020 13:51:53 -0800 Subject: [PATCH 322/707] default max_tokens_valid and batch_size_valid correctly (#1459) Summary: max_tokens_valid and batch_size_valid were not getting defaulted properly, leading to ooms Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1459 Reviewed By: myleott Differential Revision: D25151434 Pulled By: alexeib fbshipit-source-id: 0dc0f099973e6abc8ba9b20516da26b4fb2e0e33 --- fairseq/dataclass/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 2cbfc9560a..d9ceb2a10b 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -402,14 +402,14 @@ class DatasetConfig(FairseqDataclass): default=False, metadata={"help": "disable validation"} ) max_tokens_valid: Optional[int] = field( - default=None, + default=II("dataset.max_tokens"), metadata={ "help": "maximum number of tokens in a validation batch" " (defaults to --max-tokens)" }, ) batch_size_valid: Optional[int] = field( - default=None, + default=II("dataset.batch_size"), metadata={ "help": "batch size of the validation batch (defaults to --batch-size)", "argparse_alias": "--max-sentences-valid", From 168480c9f11e2db0c1b0a40eb0a901133e05cb4a Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Mon, 23 Nov 2020 14:41:25 -0800 Subject: [PATCH 323/707] add model criterion Summary: - add model criterion that allows one to define any kind of loss(es) within your model and then just have this criterion do the logging Reviewed By: myleott Differential Revision: D25145814 fbshipit-source-id: bb0f01935b96d5c77f8adad40e931689ce6e3391 --- fairseq/criterions/model_criterion.py | 134 ++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 fairseq/criterions/model_criterion.py diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py new file mode 100644 index 0000000000..da37f899ea --- /dev/null +++ b/fairseq/criterions/model_criterion.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import logging +from typing import Dict, List + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCriterionConfig(FairseqDataclass): + loss_weights: Dict[str, float] = field( + default_factory=dict, + metadata={"help": "weights for the loss terms"}, + ) + log_keys: List[str] = field( + default_factory=list, + metadata={"help": "additional output keys to log"}, + ) + + +@register_criterion("model", dataclass=ModelCriterionConfig) +class ModelCriterion(FairseqCriterion): + """ + This criterion relies on the model to supply losses. + The losses should be a dictionary of name -> scalar returned by + the model either by including it in the net_output dict or by + implementing a get_losses(net_output, sample) method. The final loss is + a scaled sum of all losses according to weights in loss_weights. + If no weights are provided, then all losses are scaled by 1.0. + + The losses will be automatically logged. Additional keys from + net_output dict can be logged via the log_keys parameter. + """ + + def __init__(self, task, loss_weights=None, log_keys=None): + super().__init__(task) + self.loss_weights = loss_weights + self.log_keys = log_keys + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + + sample_size = net_output["sample_size"] + scaled_losses = {} + + if hasattr(model, "get_losses"): + losses = model.get_losses(net_output, sample) + elif isinstance(net_output, dict) and "losses" in net_output: + losses = net_output["losses"] + else: + raise Exception("Could not retrieve losses") + + for lk, p in losses.items(): + try: + coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk] + except KeyError: + logger.error( + f"weight for loss {lk} is not in loss_weights ({self.loss_weights})" + ) + raise + if coef != 0 and p is not None: + scaled_losses[lk] = coef * p.float() + + loss = sum(scaled_losses.values()) + if reduce and loss.numel() > 1: + loss = loss.sum() + + logging_output = { + "loss": loss.data, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + if len(scaled_losses) > 1: + for lk, l in scaled_losses.items(): + logging_output[f"loss_{lk}"] = l.item() + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "correct", + "count", + } + + for k in logging_outputs[0]: + if k not in builtin_keys: + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss_"): + metrics.log_scalar(k, val / sample_size, sample_size, round=3) + else: + metrics.log_scalar(k, val / len(logging_outputs), round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True From f13f2990935fce56f62a40b1243cac0ee4668433 Mon Sep 17 00:00:00 2001 From: alexeib Date: Mon, 23 Nov 2020 19:07:19 -0800 Subject: [PATCH 324/707] fix issubclass() call on python 3.7+ (#1462) Summary: Fixes #2897 Also updates readmes to use --config-dir instead of --config-path for hydra runs, and adds __init__.py to config dir Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1462 Reviewed By: myleott Differential Revision: D25163789 Pulled By: alexeib fbshipit-source-id: f45f432174771c5c458480f984aedf12130b8522 --- docs/hydra_integration.md | 2 +- examples/wav2vec/README.md | 18 +++++++++--------- examples/wav2vec/wav2vec_manifest.py | 3 +++ fairseq/config/__init__.py | 4 ++++ fairseq/config/config.yaml | 5 +++++ fairseq/dataclass/utils.py | 3 +-- 6 files changed, 23 insertions(+), 12 deletions(-) create mode 100644 fairseq/config/__init__.py diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 8e4082cb24..04c797fe50 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -211,7 +211,7 @@ works for migrated tasks and models. ```shell script $ fairseq-hydra-train \ - --config-path /path/to/external/configs \ + --config-dir /path/to/external/configs \ --config-name wiki103 ``` diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index fdbf844ec7..10d231ed69 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -58,12 +58,12 @@ Note that the input is expected to be single channel, sampled at 16 kHz ```shell script $ fairseq-hydra-train \ task.data=/path/to/data \ - --config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ --config-name wav2vec2_base_librispeech ``` -Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k ### Train a wav2vec 2.0 large model: @@ -72,12 +72,12 @@ This configuration was used for the large model trained on the Libri-light datas ```shell script $ fairseq-hydra-train \ task.data=/path/to/data \ - --config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ --config-name wav2vec2_large_librivox ``` -Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k ### Fine-tune a pre-trained model with CTC: @@ -96,14 +96,14 @@ $ fairseq-hydra-train \ distributed_training.distributed_port=$PORT \ task.data=/path/to/data \ model.w2v_path=/path/to/model.pt \ - --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ --config-name base_100h ``` There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. -You can specify the right config via the --config-name parameter. +You can specify the right config via the `--config-name` parameter. -Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before --config-path) +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) `distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). diff --git a/examples/wav2vec/wav2vec_manifest.py b/examples/wav2vec/wav2vec_manifest.py index 1d27f58afc..5417084554 100644 --- a/examples/wav2vec/wav2vec_manifest.py +++ b/examples/wav2vec/wav2vec_manifest.py @@ -47,6 +47,9 @@ def get_parser(): def main(args): assert args.valid_percent >= 0 and args.valid_percent <= 1.0 + if not os.path.exists(args.dest): + os.makedirs(args.dest) + dir_path = os.path.realpath(args.root) search_path = os.path.join(dir_path, "**/*." + args.ext) rand = random.Random(args.seed) diff --git a/fairseq/config/__init__.py b/fairseq/config/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/fairseq/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index 9621baa5e9..e20d914b9b 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -1,4 +1,9 @@ # @package _group_ + +hydra: + run: + dir: . + defaults: - task: null - model: null diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 9bf4f7d09f..beae592d1a 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -218,8 +218,7 @@ def get_default(f): isinstance(val, str) and not val.startswith("${") # not interpolation and field_type != str - and inspect.isclass(field_type) - and not issubclass(field_type, Enum) # not choices enum + and (not inspect.isclass(field_type) or not issubclass(field_type, Enum)) # not choices enum ): # upgrade old models that stored complex parameters as string val = ast.literal_eval(val) From dea66cc294a18dd4d9e59aa0af8d51f951e83884 Mon Sep 17 00:00:00 2001 From: Chau Tran Date: Wed, 25 Nov 2020 23:04:24 -0800 Subject: [PATCH 325/707] Fix mbart checkpoint Summary: activation_fn was hardcoded to 'gelu' when mbart was trained, and activation_fn was not saved to the checkpoint. The default on master is 'relu', so the old version would use 'relu'. I manually updated the activation_fn in the checkpoint to 'gelu' and uploaded to https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz Reviewed By: myleott Differential Revision: D25163364 fbshipit-source-id: 365ebbd39ebb341c92b1c9ad71c8fbb2edffb7e6 --- examples/mbart/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 510edeff64..217c48867c 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -9,7 +9,7 @@ MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scal Model | Description | # params | Download ---|---|---|--- -`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz) +`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz) `mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz) ## Results From 4ba5b4be98399c3c20d6e3c9b7750461afb83730 Mon Sep 17 00:00:00 2001 From: "Yuan Shangguan (June)" Date: Mon, 30 Nov 2020 11:38:53 -0800 Subject: [PATCH 326/707] Minor changes to manual LR scheduler Summary: 1. Update to make the learning rate at the beginning of training to be reasonable. Otherwise, job train_loss explodes like this f235003551 or f234959697 or f234961844. 2. Trivial update to avoid printing out massive update2lr dictionary that takes over the entire logging page, and makes logging difficult to see. Example: {F347075703} Differential Revision: D25146200 fbshipit-source-id: 071a591cf823e8c74a0380ec6850dc6b34d82ffc --- fairseq/optim/lr_scheduler/manual_lr_scheduler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py index 7e06ec55c8..0269a1e285 100644 --- a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -29,9 +29,10 @@ def __init__(self, args, optimizer): self.lr = self.update2lr[1] else: self.lr = args.lr[0] + self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. def parse_manuallr_args(self, lr_args_str): - lr_dict = ast.literal_eval(lr_args_str) + lr_dict = ast.literal_eval(lr_args_str.replace(' ', '')) if not isinstance(lr_dict, dict): raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") @@ -83,7 +84,9 @@ def get_next_lr(self, epoch): if manual_keys: manual_lr = self.epoch2lr[max(manual_keys)] else: - logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}".format(epoch, self.epoch2lr)) + logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( + epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)] + )) manual_lr = self.optimizer.get_lr() return manual_lr @@ -99,7 +102,8 @@ def step_update(self, num_updates): if manual_keys: manual_lr = self.update2lr[max(manual_keys)] else: - logger.warning("epoch={} does not exist in manual lr input update2lr={}".format(num_updates, self.update2lr)) + logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format( + num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)])) manual_lr = self.optimizer.get_lr() self.optimizer.set_lr(manual_lr) From 9cf0bd96d645df23dfbacf6ee28e3ddb441e8717 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 30 Nov 2020 14:19:25 -0800 Subject: [PATCH 327/707] Add/fix tests (#1468) Summary: - add test for loading ensemble checkpoints (and confirmed it fails if I revert: https://github.com/pytorch/fairseq/commit/265791b727b664d4d7da3abd918a3f6fb70d7337) - add test for LayerDrop (and fix it) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1468 Reviewed By: alexeib Differential Revision: D25223272 Pulled By: myleott fbshipit-source-id: 3f06f753605af251567c70d2961f5506ea423499 --- fairseq/checkpoint_utils.py | 24 +++++++-- tests/test_binaries.py | 32 ++++++++++++ tests/test_checkpoint_utils.py | 89 ++++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 4 deletions(-) create mode 100644 tests/test_checkpoint_utils.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 2bb055056e..235c660a5e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -5,6 +5,7 @@ import ast import collections +import contextlib import logging import os import re @@ -239,7 +240,13 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): def load_model_ensemble( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None + filenames, + arg_overrides=None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, ): """Loads an ensemble of models. @@ -265,7 +272,13 @@ def load_model_ensemble( def load_model_ensemble_and_task( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None + filenames, + arg_overrides=None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, ): assert state is None or len(filenames) == 1 @@ -563,8 +576,11 @@ def create_pruning_pass(layers_to_keep, layer_name): # Since layers are now pruned, *_layers_to_keep are no longer needed. # This is more of "It would make it work fix" rather than a proper fix. - - with open_dict(model_cfg): + if isinstance(model_cfg, DictConfig): + context = open_dict(model_cfg) + else: + context = contextlib.ExitStack() + with context: if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None if hasattr(model_cfg, "decoder_layers_to_keep"): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 8235702383..cad6f1eba4 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -925,6 +925,38 @@ def test_alignment_full_context(self): ) generate_main(data_dir) + def test_transformer_layerdrop(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_layerdrop") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--encoder-layerdrop", + "0.01", + "--decoder-layerdrop", + "0.01", + ], + ) + generate_main(data_dir) + generate_main( + data_dir, + [ + "--model-overrides", + "{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}" + ], + ) + class TestStories(unittest.TestCase): def setUp(self): diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py new file mode 100644 index 0000000000..e3c685deec --- /dev/null +++ b/tests/test_checkpoint_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from io import StringIO + +from fairseq import checkpoint_utils + +from tests.utils import ( + create_dummy_data, + preprocess_translation_data, + train_translation_model, +) + + +class TestCheckpointUtils(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @contextlib.contextmanager + def _train_transformer(self, seed, extra_args=None): + if extra_args is None: + extra_args = [] + with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--seed", + str(seed), + ] + + extra_args, + ) + yield os.path.join(data_dir, "checkpoint_last.pt") + + def test_load_model_ensemble_and_task(self): + with contextlib.redirect_stdout(StringIO()): + with self._train_transformer(seed=123) as model1: + with self._train_transformer(seed=456) as model2: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model1, model2] + ) + self.assertEqual(len(ensemble), 2) + + # after Transformer has been migrated to Hydra, this will probably + # become cfg.common.seed + self.assertEqual(ensemble[0].args.seed, 123) + self.assertEqual(ensemble[1].args.seed, 456) + + # the task from the first model should be returned + self.assertEqual(task.args.seed, 123) + + def test_prune_state_dict(self): + with contextlib.redirect_stdout(StringIO()): + extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"] + with self._train_transformer(seed=1, extra_args=extra_args) as model: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model], + arg_overrides={ + "encoder_layers_to_keep": "0,2", + "decoder_layers_to_keep": "1", + }, + ) + self.assertEqual(len(ensemble), 1) + self.assertEqual(len(ensemble[0].encoder.layers), 2) + self.assertEqual(len(ensemble[0].decoder.layers), 1) + + +if __name__ == "__main__": + unittest.main() From f732b403ec15244c41a24b9e28d6c5a411a511df Mon Sep 17 00:00:00 2001 From: "Yuan Shangguan (June)" Date: Mon, 30 Nov 2020 14:45:20 -0800 Subject: [PATCH 328/707] Max_update is not backward compatible. Fix in this diff. Summary: max_update assertion in tri_stage_lr is introduced in D25040041 (https://github.com/pytorch/fairseq/commit/6d2cf0ddf64040543c346b3866eb636d14522dde). It requires max-update to be defined, and breaks the backward compatibility of existing recipes. Since max-update is ONLY used when phase-ratio is defined. We recommend this change to keep it from breaking existing model recipes. Reviewed By: myleott Differential Revision: D25204247 fbshipit-source-id: 01f6f2f0935dfaff9f23501158af608e5d507145 --- fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index 403de77c80..4d5547c39b 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -48,7 +48,7 @@ class TriStageLRScheduleConfig(FairseqDataclass): @register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) -class TriStageLRScheduleConfig(FairseqLRScheduler): +class TriStageLRSchedule(FairseqLRScheduler): """Tristage learning rate schedulr Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf @@ -93,7 +93,6 @@ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): "Cannot use a fixed learning rate schedule with tri-stage lr." " Consider --lr-scheduler=fixed instead." ) - assert cfg.max_update > 0 # calculate LR at each point self.peak_lr = cfg.lr[0] @@ -101,6 +100,7 @@ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): self.final_lr = cfg.final_lr_scale * cfg.lr[0] if cfg.phase_ratio is not None: + assert cfg.max_update > 0 assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1" self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) From 65d88f150c54f9549de0b565411684b52f4e2b50 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Tue, 1 Dec 2020 11:16:33 -0800 Subject: [PATCH 329/707] more accurate nan detection Summary: I've been debugging some nan issues related to fp16. This diff adds some improvements to make it more accurate: 1. Zero grad before doing nan detection. This enables the gradients/grad norms calculated to be accurate 2. Account for update frequency when doing nan detection. Without this, infs/nans can go undetected. Reviewed By: myleott Differential Revision: D25225729 fbshipit-source-id: 4ffd1dcdf4a643459b814e24e74776b144a068a8 --- fairseq/trainer.py | 50 ++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e9ac3c6bb1..4583abd133 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -508,16 +508,7 @@ def train_step(self, samples, raise_oom=False): # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): - sample = self._prepare_sample(sample) - if sample is None: - # when sample is None, run forward/backward on a dummy batch - # and ignore the resulting gradients - sample = self._prepare_sample(self._dummy_batch) - is_dummy_batch = True - else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample - is_dummy_batch = False + sample, is_dummy_batch = self._prepare_sample(sample) def maybe_no_sync(): """ @@ -652,15 +643,18 @@ def maybe_no_sync(): except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails + self.zero_grad() with NanDetector(self.get_model()): - self.task.train_step( - sample, - self.model, - self.criterion, - self.optimizer, - self.get_num_updates(), - ignore_grad=False, - ) + for _, sample in enumerate(samples): + sample, _ = self._prepare_sample(sample) + self.task.train_step( + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + ) raise except OverflowError as e: overflow = True @@ -775,14 +769,7 @@ def valid_step(self, sample, raise_oom=False): self.model.eval() self.criterion.eval() - sample = self._prepare_sample(sample) - if sample is None: - sample = self._prepare_sample(self._dummy_batch) - is_dummy_batch = True - else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample - is_dummy_batch = False + sample, is_dummy_batch = self._prepare_sample(sample) try: _loss, sample_size, logging_output = self.task.valid_step( @@ -932,7 +919,11 @@ def _prepare_sample(self, sample): ) if sample is None or len(sample) == 0: - return None + assert ( + self._dummy_batch is not None and len(self._dummy_batch) > 0 + ), "Invalid dummy batch: {}".format(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch) + return sample, True if self.cuda: if self.pipeline_model_parallel: @@ -959,7 +950,10 @@ def apply_bfloat16(t): if self.cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) - return sample + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + + return sample, False def _set_seed(self): # Set seed based on args.seed and the update number so that we get From 0db28cdd0e50cad9c36e5e47ffceff40beaf6f60 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 1 Dec 2020 17:44:23 -0800 Subject: [PATCH 330/707] fix generation config being properly passed (#1465) Summary: fixes #2961 Also fixes model criterion logging with world size > 1 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1465 Reviewed By: myleott Differential Revision: D25198394 Pulled By: alexeib fbshipit-source-id: fa52011a4d56eb41fe4bd59f9bd565632b87fba5 --- fairseq/criterions/model_criterion.py | 9 +++++---- fairseq/hub_utils.py | 2 +- fairseq/logging/progress_bar.py | 2 ++ fairseq_cli/interactive.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py index da37f899ea..c4f2c0b354 100644 --- a/fairseq/criterions/model_criterion.py +++ b/fairseq/criterions/model_criterion.py @@ -10,6 +10,7 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from fairseq.distributed_utils import get_data_parallel_world_size logger = logging.getLogger(__name__) @@ -83,7 +84,7 @@ def forward(self, model, sample, reduce=True): for lk in self.log_keys: if lk in net_output: - logging_output[lk] = float((net_output[lk])) + logging_output[lk] = float(net_output[lk]) if len(scaled_losses) > 1: for lk, l in scaled_losses.items(): @@ -112,17 +113,17 @@ def reduce_metrics(logging_outputs) -> None: "ntokens", "nsentences", "sample_size", - "correct", - "count", } + world_size = get_data_parallel_world_size() + for k in logging_outputs[0]: if k not in builtin_keys: val = sum(log.get(k, 0) for log in logging_outputs) if k.startswith("loss_"): metrics.log_scalar(k, val / sample_size, sample_size, round=3) else: - metrics.log_scalar(k, val / len(logging_outputs), round=3) + metrics.log_scalar(k, val / world_size, round=3) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 1819a9460a..b716884c78 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -157,7 +157,7 @@ def generate( )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.cfg) + gen_args = copy.copy(self.cfg.generation) with open_dict(gen_args): gen_args.beam = beam for k, v in kwargs.items(): diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 3183d2f476..2b3873794e 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -356,6 +356,8 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): writer.add_scalar(key, stats[key].val, step) elif isinstance(stats[key], Number): writer.add_scalar(key, stats[key], step) + elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: + writer.add_scalar(key, stats[key].item(), step) writer.flush() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 530830d6b0..4785855985 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -173,7 +173,7 @@ def main(cfg: FairseqConfig): model.prepare_for_inference_(cfg) # Initialize generator - generator = task.build_generator(models, cfg.task) + generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) From ffa158ff0cf2aa6c104ae844bfde361f125478f6 Mon Sep 17 00:00:00 2001 From: Sathish Indurthi Date: Thu, 3 Dec 2020 07:45:54 -0800 Subject: [PATCH 331/707] fix for MMA criterion initialization (#2911) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [X ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ X] Did you make sure to update the docs? N/A - [ X] Did you write any new necessary tests? N/A ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2122. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2911 Reviewed By: alexeib Differential Revision: D25146472 Pulled By: myleott fbshipit-source-id: 9cf02a9be679c2e1725dd3ae83aafef31900e640 --- ...moothed_cross_entropy_latency_augmented.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py index b3c8f6d53f..761cfe61a1 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -14,15 +14,30 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion ): - def __init__(self, args, task): - super().__init__(args, task) - self.eps = args.label_smoothing - self.latency_weight_avg = args.latency_weight_avg - self.latency_weight_avg_type = args.latency_weight_avg_type - self.latency_weight_var = args.latency_weight_var - self.latency_weight_var_type = args.latency_weight_var_type - self.mass_preservation = args.mass_preservation - self.average_method = args.average_method + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + latency_weight_avg, + latency_weight_avg_type, + latency_weight_var, + latency_weight_var_type, + mass_preservation, + average_method, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + self.eps = label_smoothing + self.latency_weight_avg = latency_weight_avg + self.latency_weight_avg_type = latency_weight_avg_type + self.latency_weight_var = latency_weight_var + self.latency_weight_var_type = latency_weight_var_type + self.mass_preservation = mass_preservation + self.average_method = average_method self.latency_train = LatencyTraining( self.latency_weight_avg, self.latency_weight_var, From d7e571c557e0f7833fa244ecb5cd0458ba28670c Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 3 Dec 2020 13:22:14 -0800 Subject: [PATCH 332/707] Small fixes for TPU (and support them in sweep) (#1433) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1433 Test Plan: Imported from OSS Reviewed By: huihuifan Differential Revision: D24959540 Pulled By: myleott fbshipit-source-id: a3b247be4ae7f4e09e571f972451e1e4ce76d5c5 --- fairseq/trainer.py | 7 +++++-- fairseq/utils.py | 2 +- fairseq_cli/train.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4583abd133..04db13dce0 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -910,7 +910,7 @@ def _local_cumulative_training_time(self): """Aggregate training time in seconds.""" return time.time() - self._start_time + self._previous_training_time - def _prepare_sample(self, sample): + def _prepare_sample(self, sample, is_dummy=False): if sample == "DUMMY": raise Exception( "Trying to use an uninitialized 'dummy' batch. This usually indicates " @@ -922,7 +922,7 @@ def _prepare_sample(self, sample): assert ( self._dummy_batch is not None and len(self._dummy_batch) > 0 ), "Invalid dummy batch: {}".format(self._dummy_batch) - sample, _ = self._prepare_sample(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) return sample, True if self.cuda: @@ -933,6 +933,9 @@ def _prepare_sample(self, sample): ) else: sample = utils.move_to_cuda(sample) + elif self.tpu and is_dummy: + # the dummy batch may not be on the appropriate device + sample = utils.move_to_cuda(sample, device=self.device) def apply_half(t): if t.dtype is torch.float32: diff --git a/fairseq/utils.py b/fairseq/utils.py index 8e9119124d..4046f6696c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -106,7 +106,7 @@ def move_to_cuda(sample, device=None): def _move_to_cuda(tensor): # non_blocking is ignored if tensor is not pinned, so we can always set # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) - return tensor.cuda(device=device, non_blocking=True) + return tensor.to(device=device, non_blocking=True) return apply_to_sample(_move_to_cuda, sample) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index e1af605348..7739759693 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -189,7 +189,7 @@ def train( else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(cfg.common, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, From fa802c1034c4e6a38c80e7ab545b445aabd2d314 Mon Sep 17 00:00:00 2001 From: wei zhao Date: Thu, 3 Dec 2020 14:19:16 -0800 Subject: [PATCH 333/707] Fix another link to mbart checkpoint (#2976) Summary: Forget to update another model url after the fix https://github.com/pytorch/fairseq/commit/dea66cc294a18dd4d9e59aa0af8d51f951e83884. chtran Pull Request resolved: https://github.com/pytorch/fairseq/pull/2976 Reviewed By: tangyuq Differential Revision: D25257003 Pulled By: chtran fbshipit-source-id: 2fdb30547ed1fc82ff5cfa038a3b6d8fb9dc60ba --- examples/mbart/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 217c48867c..fa520a6825 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -26,7 +26,7 @@ Model | en-ro | ro-en ## BPE data # download model -wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz +wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz tar -xzvf mbart.CC25.tar.gz # bpe data install SPM [here](https://github.com/google/sentencepiece) From 793ec2b19d63b70717a84293b45e583f6c0b9dd5 Mon Sep 17 00:00:00 2001 From: Arthur Guo Date: Thu, 3 Dec 2020 15:32:33 -0800 Subject: [PATCH 334/707] Enable JIT on LAS Model Differential Revision: D25290353 fbshipit-source-id: 18ce98d32e49e9cebe1aed14302613a00e8c3c99 --- fairseq/modules/linearized_convolution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py index 09a8f201c0..b36cea91fa 100644 --- a/fairseq/modules/linearized_convolution.py +++ b/fairseq/modules/linearized_convolution.py @@ -38,6 +38,7 @@ def upgrade_state_dict_named(self, state_dict, name): if prefix + "_linearized_weight" in state_dict: del state_dict[prefix + "_linearized_weight"] + @torch.jit.ignore def forward(self, input, incremental_state=None): """ Args: From bc4ebcafb4f1535c528aea589d14db56a13bd763 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 3 Dec 2020 18:17:09 -0800 Subject: [PATCH 335/707] Fix tests (#1482) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1482 Reviewed By: michaelauli Differential Revision: D25318618 Pulled By: myleott fbshipit-source-id: bed171ffe5ca10e8359be96a15d0fe9bb1a630ea --- setup.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6bc450a7fa..0a4be4b0dd 100644 --- a/setup.py +++ b/setup.py @@ -174,15 +174,17 @@ def do_setup(package_data): long_description_content_type="text/markdown", setup_requires=[ "cython", - "numpy", + 'numpy<1.20.0; python_version<"3.7"', + 'numpy; python_version>="3.7"', "setuptools>=18.0", ], install_requires=[ "cffi", "cython", - "dataclasses", + 'dataclasses; python_version<"3.7"', "hydra-core", - "numpy", + 'numpy<1.20.0; python_version<"3.7"', + 'numpy; python_version>="3.7"', "regex", "sacrebleu>=1.4.12", "torch", From a700e14ea3cbf3e99ca729caa30b5b1c0305c4c4 Mon Sep 17 00:00:00 2001 From: Robert Verkuil Date: Fri, 4 Dec 2020 04:43:09 -0800 Subject: [PATCH 336/707] Increase plasma reconnect attempts (#1480) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Avoids run failures due to insufficient time allowed for plasma client -> server connection. I've been seeing many failures in my recent runs, which use data sharding, due to plasma connection errors: [train.stderr.33224323.txt](https://github.com/fairinternal/fairseq-py/files/5637976/train.stderr.33224323.txt) [train.stderr.33245371.txt](https://github.com/fairinternal/fairseq-py/files/5637980/train.stderr.33245371.txt) [train.stderr.33261567.txt](https://github.com/fairinternal/fairseq-py/files/5637981/train.stderr.33261567.txt) [train.stderr.33261589.txt](https://github.com/fairinternal/fairseq-py/files/5637982/train.stderr.33261589.txt) Currently plasma can attempt client -> server connection 20 times, for a total retry time of ≈8-10s. When sharding, epoch_itr's intentionally are *not* cached. Therefore, dataset-related plasma arrays must be remade. This makes plasma connect errors a persistent concern over training. Worse with many gpus+workers. For a single training run that I inspected, the number of retries needed to connect increases, and eventually exceeds 20. ![image](https://user-images.githubusercontent.com/4042063/101068717-10853100-3567-11eb-89ca-7d18a7ef0405.png) Best solution would be to increase the retry interval. Looking at plasma [source](https://github.com/apache/arrow/blob/016f76c8c02e769d58b3e785a87674d98ce83367/python/pyarrow/_plasma.pyx#L848-L868), this doesn't look possible. However, we can increase num_retries. Hopefully this isn't too annoying for the cluster file system? ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1480 Reviewed By: myleott Differential Revision: D25313114 Pulled By: joshim5 fbshipit-source-id: ad50c3b29e0698bf197e24f6392bda73b407a548 --- fairseq/data/plasma_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index 2b12646783..f4bb6472d7 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -60,7 +60,7 @@ def start_server(self): def client(self): if self._client is None: assert self.path is not None - self._client = self.plasma.connect(self.path) + self._client = self.plasma.connect(self.path, num_retries=200) return self._client def __getstate__(self): From 8be488ff6b7fc49346c94323085e71e72b6583ea Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 4 Dec 2020 05:43:50 -0800 Subject: [PATCH 337/707] Add --load-checkpoint-on-all-dp-ranks (#1478) Summary: A recent commit [1] made it so that checkpoints are loaded on rank 0 and then broadcast to other workers. This is valuable to reduce I/O and especially helpful when training with optimizer state sharding, so we don't need to load redundant optimizer state on every worker. This diff adds a new option (`--load-checkpoint-on-all-dp-ranks`) that optionally reverts to the old behavior. [1] https://github.com/pytorch/fairseq/commit/ea4ccd94de131d6b39163836418696369dd1d034 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1478 Test Plan: Imported from OSS Reviewed By: theweiho Differential Revision: D25291595 Pulled By: myleott fbshipit-source-id: 57b521c5e6a48f08140f9527162072ea1d4066db --- fairseq/dataclass/configs.py | 7 +++++++ fairseq/trainer.py | 19 +++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index d9ceb2a10b..3ff177d969 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -583,6 +583,13 @@ class CheckpointConfig(FairseqDataclass): "the checkpoint" }, ) + load_checkpoint_on_all_dp_ranks: bool = field( + default=False, + metadata={ + "help": "load checkpoints on all data parallel devices " + "(default: only load on rank 0 and broadcast to other devices)" + }, + ) model_parallel_size: int = II("common.model_parallel_size") distributed_rank: int = II("distributed_training.distributed_rank") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 04db13dce0..94684f051b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -299,12 +299,14 @@ def load_checkpoint( bexists = PathManager.isfile(filename) if bexists: - if ( - self.data_parallel_rank == 0 + load_on_all_ranks = ( + self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu - ): + ) + + if load_on_all_ranks or self.data_parallel_rank == 0: state = checkpoint_utils.load_checkpoint_to_cpu(filename) last_optim_state = state.get("last_optimizer_state", None) @@ -312,7 +314,8 @@ def load_checkpoint( # state. Later we will broadcast sharded states to each rank # to avoid memory from exploding. if ( - self.cfg.distributed_training.zero_sharding == "os" + not load_on_all_ranks + and self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state and self.data_parallel_world_size > 1 ): @@ -321,11 +324,7 @@ def load_checkpoint( last_optim_state = None state = None - if ( - self.data_parallel_world_size > 1 - # disable on TPUs until they support broadcast - and not self.tpu - ): + if self.data_parallel_world_size > 1 and not load_on_all_ranks: state = distributed_utils.broadcast_object( state, src_rank=0, @@ -368,7 +367,7 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if self.data_parallel_world_size > 1: + if not load_on_all_ranks and self.data_parallel_world_size > 1: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) From bb039fa2063dca1b388d6be2f64052b07fb556a2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 4 Dec 2020 10:57:35 -0800 Subject: [PATCH 338/707] Improve performance of distributed_utils.broadcast_object (#1479) Summary: This diff dramatically speeds up and reduces memory usage of distributed_utils.broadcast_object. In particular, rather than pickling the whole state dict (and broadcasting it), we only pickle the non-tensors and broadcast the tensors directly. This improves speed (since pickling is expensive) and also saves RAM, since pickle duplicates the data and by separating out the tensors, we're only left with duplicate copies of non-tensor data in the state dict, which is quite small. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1479 Test Plan: Imported from OSS Differential Revision: D25291594 Pulled By: myleott fbshipit-source-id: 521102fae75a3bc71dcd5ac2bf238f7eb534a3d1 --- fairseq/distributed_utils.py | 126 +++++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 21 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 9059d8aa2b..fa70607fbc 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -14,6 +14,7 @@ import warnings from argparse import Namespace from collections import OrderedDict +from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional import torch @@ -637,44 +638,127 @@ def get_from_stack(key): return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) -# From fairscale/optim/utils.py +def broadcast_tensors( + tensors: Optional[List[torch.Tensor]], + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + Broadcasts a list of tensors without other (non-src) ranks needing to know + the dtypes/shapes of the tensors. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + # share metadata first to simplify transfer + is_src_rank = (get_rank(group) == src_rank) + if is_src_rank: + metadata = [ + {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors + ] + metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + else: + metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + out_tensors = [] + for i, meta in enumerate(metadata): + if is_src_rank: + tensor = tensors[i] + broadcast(tensors[i].to(dist_device), src=src_rank, group=group) + else: + tensor = torch.zeros( + [meta["size"].numel()], dtype=meta["dtype"], device=dist_device + ) + broadcast(tensor, src=src_rank, group=group) + tensor = tensor.view(meta["size"]).to(meta["device"]) + out_tensors.append(tensor) + return out_tensors + + def broadcast_object( obj: Any, src_rank: int, group: object, dist_device: Optional[torch.device] = None, - dist_length_dtype: Optional[torch.dtype] = torch.long, - dist_dtype: Optional[torch.dtype] = torch.uint8, ) -> Any: - """ - Either broadcast from master to the fleet (default), - or use the src setting as the original rank. - """ + """Broadcast an arbitrary Python object to other workers.""" if dist_device is None: if torch.distributed.get_backend(group) == "nccl": dist_device = torch.device("cuda") else: dist_device = torch.device("cpu") + if get_rank(group) == src_rank: + # split the tensors from the non-tensors so we can broadcast them + # directly, avoiding unnecessary serialization/deserialization + tensors = [] + obj = _split_tensors_from_obj(obj, tensors) + obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + else: + obj = _broadcast_object_slow(None, src_rank, group, dist_device) + tensors = broadcast_tensors(None, src_rank, group, dist_device) + return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, src_rank: int, group: object, dist_device: torch.device, +) -> Any: if get_rank(group) == src_rank: # Emit data buffer = io.BytesIO() torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor( - [len(data)], dtype=dist_length_dtype, device=dist_device - ) - broadcast(length_tensor, src=src_rank, group=group) - data_send_tensor = torch.tensor(data, dtype=dist_dtype, device=dist_device) - broadcast(data_send_tensor, src=src_rank, group=group) + buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) + length = torch.LongTensor([len(buffer)]).to(dist_device) + broadcast(length, src=src_rank, group=group) + broadcast(buffer, src=src_rank, group=group) else: # Fetch from the source - length_tensor = torch.tensor([0], dtype=dist_length_dtype, device=dist_device) - broadcast(length_tensor, src=src_rank, group=group) - data_recv_tensor = torch.zeros( - [int(length_tensor.item())], dtype=dist_dtype, device=dist_device - ) - broadcast(data_recv_tensor, src=src_rank, group=group) - buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + length = torch.LongTensor([0]).to(dist_device) + broadcast(length, src=src_rank, group=group) + buffer = torch.ByteTensor(int(length.item())).to(dist_device) + broadcast(buffer, src=src_rank, group=group) + buffer = io.BytesIO(buffer.cpu().numpy()) obj = torch.load(buffer, map_location="cpu") return obj + + +@dataclass(frozen=True) +class _TensorPlaceholder: + index: int + + +def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if torch.is_tensor(obj): + placeholder = _TensorPlaceholder(index=len(tensors)) + tensors.append(obj) + return placeholder + elif isinstance(obj, dict): + return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors_from_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_split_tensors_from_obj(v, tensors) for v in obj} + else: + return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if isinstance(obj, _TensorPlaceholder): + return tensors[obj.index] + elif isinstance(obj, dict): + return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_put_tensors_in_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_put_tensors_in_obj(v, tensors) for v in obj} + else: + return obj From 6f47704d4deb99061e7562710e1dbd0253b04ea4 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 4 Dec 2020 10:57:35 -0800 Subject: [PATCH 339/707] Add distributed tests (#1481) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1481 Test Plan: Imported from OSS Reviewed By: theweiho Differential Revision: D25313776 Pulled By: myleott fbshipit-source-id: 755bf4b77b2a7a3aee56e2344246ff2087a3af77 --- tests/distributed/__init__.py | 0 tests/distributed/test_distributed_utils.py | 69 +++++++++++++++++++++ tests/distributed/utils.py | 61 ++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 tests/distributed/__init__.py create mode 100644 tests/distributed/test_distributed_utils.py create mode 100644 tests/distributed/utils.py diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/distributed/test_distributed_utils.py b/tests/distributed/test_distributed_utils.py new file mode 100644 index 0000000000..161ee85eaa --- /dev/null +++ b/tests/distributed/test_distributed_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import sys +import unittest + +import torch + +from fairseq import distributed_utils as dist_utils + +from .utils import objects_are_equal, spawn_and_init + + +class TestDistributedUtils(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + def test_broadcast_object_python(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + "hello world", + ), + world_size=2, + ) + + def test_broadcast_object_tensor(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + torch.rand(5), + ), + world_size=2, + ) + + def test_broadcast_object_complex(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + { + "a": "1", + "b": [2, torch.rand(2, 3), 3], + "c": (torch.rand(2, 3), 4), + "d": {5, torch.rand(5)}, + "e": torch.rand(5), + "f": torch.rand(5).int().cuda(), + }, + ), + world_size=2, + ) + + @staticmethod + def _test_broadcast_object(ref_obj, rank, group): + obj = dist_utils.broadcast_object( + ref_obj if rank == 0 else None, src_rank=0, group=group + ) + assert objects_are_equal(ref_obj, obj) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py new file mode 100644 index 0000000000..d2b3ddb1ff --- /dev/null +++ b/tests/distributed/utils.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import tempfile + +import torch + + +def spawn_and_init(fn, world_size, args=None): + if args is None: + args = () + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + torch.multiprocessing.spawn( + fn=functools.partial(init_and_run, fn, args), + args=(world_size, tmp_file.name,), + nprocs=world_size, + ) + + +def distributed_init(rank, world_size, tmp_file): + torch.distributed.init_process_group( + backend="nccl", + init_method="file://{}".format(tmp_file), + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + + +def init_and_run(fn, args, rank, world_size, tmp_file): + distributed_init(rank, world_size, tmp_file) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +def objects_are_equal(a, b) -> bool: + if type(a) is not type(b): + return False + if isinstance(a, dict): + if set(a.keys()) != set(b.keys()): + return False + for k in a.keys(): + if not objects_are_equal(a[k], b[k]): + return False + return True + elif isinstance(a, (list, tuple, set)): + if len(a) != len(b): + return False + return all(objects_are_equal(x, y) for x, y in zip(a, b)) + elif torch.is_tensor(a): + return ( + a.size() == b.size() + and a.dtype == b.dtype + and a.device == b.device + and torch.all(a == b) + ) + else: + return a == b From ba79f7b781929e04c827c6dda9048e4e6e86ba6a Mon Sep 17 00:00:00 2001 From: Alex Conneau Date: Fri, 4 Dec 2020 13:45:22 -0800 Subject: [PATCH 340/707] Adding XLSR-53 model to wav2vec README (#1483) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1483 Reviewed By: alexeib Differential Revision: D25326050 Pulled By: aconneau fbshipit-source-id: 7428244d328ea0bbbbaaf23f715cd6d44d329b94 --- examples/wav2vec/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 10d231ed69..52dce362ab 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -2,6 +2,8 @@ wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). +We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). + We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). ## Pre-trained models @@ -26,6 +28,21 @@ Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https:// \* updated (Oct. 24, 2020) +We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: + +Model | Architecture | Hours | Languages | Datasets | Model +|---|---|---|---|---|--- +XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) + +The XLSR model uses the following datasets for multilingual pretraining: + +* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* + +* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). + +* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* + + ## Training a new model with the CLI tools Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) From 579002a0ea7d3161660a7beeb8c56b1f0fc26b63 Mon Sep 17 00:00:00 2001 From: Rui Hou Date: Fri, 4 Dec 2020 13:55:37 -0800 Subject: [PATCH 341/707] Add a dummy first_batch to EpochBatchIterating Summary: Add a dummy first_batch to EpochBatchIterating Reviewed By: myleott Differential Revision: D25337800 fbshipit-source-id: c387c4c39533c161cb160a84ad4f99e71c66f73e --- fairseq/data/iterators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index ef41fed739..0f55026ef8 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -138,6 +138,10 @@ def load_state_dict(self, state_dict): """Copies the state of the iterator from the given *state_dict*.""" raise NotImplementedError + @property + def first_batch(self): + return "DUMMY" + class StreamingEpochBatchIterator(EpochBatchIterating): def __init__( From d5218f88275fd57825819c6dab523e30a41b6866 Mon Sep 17 00:00:00 2001 From: dingjiajia Date: Fri, 4 Dec 2020 17:09:08 -0800 Subject: [PATCH 342/707] update sequence generator device change (#2989) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes Sequence Generator Reproduce the bug: 1. Background: **fairseq.utils.move_to_cuda** function support to change device. 2. When you change the file **fairseq_cli/generate.py** suppose `gpu_id = 1` - line 134 change from `model.cuda()` to `model.cuda(gpu_id)` - line 189 change from `sample = utils.move_to_cuda(sample) if use_cuda else sample` to `sample = utils.move_to_cuda(sample, gpu_id) if use_cuda else sample` 2. you will get an error ``` fairseq/sequence_generator.py, line 382, in generate cand_bbsz_idx = cand_beams.add(bbsz_offsets) RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:1 and input b is on cuda:0 ``` The reason of this bug is that `bbsz_offsets` is not assigned to proper cuda device. Therefore, we need to change the line 281 and line 281 of the file **fairseq/sequence_generator.py** The test file for this fix is the file **tests/test_sequence_generator.py** ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � very happy {emoji:1f642} Pull Request resolved: https://github.com/pytorch/fairseq/pull/2989 Reviewed By: alexeib Differential Revision: D25344265 Pulled By: myleott fbshipit-source-id: 8cb8b389e59a9aa67aec84dbdadcfa2c08c9648f --- fairseq/sequence_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 47a20296cf..afc1500090 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -278,8 +278,8 @@ def _generate( cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes - bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) - cand_offsets = torch.arange(0, cand_size).type_as(tokens) + bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens).to(src_tokens.device) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) reorder_state: Optional[Tensor] = None batch_idxs: Optional[Tensor] = None From ba4f54267af5c3f67f1b76a6e804b6ab593d1d39 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Fri, 4 Dec 2020 17:34:08 -0800 Subject: [PATCH 343/707] composite optimizer Summary: this adds a composite optimizer and pass through learning rate scheduler that allows fairseq models to have separate optimizers (that can optionally have separate lr schedulers) for different parameters. to use this, you add a "param_group" field to the parameters you wish to be optimized separately (the rest of the params get automatically placed into a "default" group), then specify a composite optimizer with nested optimizers (and, optionally, lr schedulers) for each group name (see example below). for fp16 training this requires setting fp16_no_flatten_grads to true one possible area for future improvement is to automatically create param groups based on module names, but this is to be discussed for example, i can modify wav2vec2 model and add ```python for p in self.pos_conv.parameters(): p.param_group = "pos_conv" ``` in the TransformerEncoder class, just after pos_conv is created then i create the following config: ```yaml # package _group_ hydra: run: dir: . job_logging: disable_existing_loggers: false common: fp16: true log_format: json log_interval: 10 fp16_no_flatten_grads: true checkpoint: save_interval_updates: 20 keep_interval_updates: 1 no_epoch_checkpoints: true no_save: false Reviewed By: myleott Differential Revision: D25152032 fbshipit-source-id: c73ff95146ecc2a04660c67bcad02b637c5c5098 --- fairseq/criterions/ctc.py | 2 +- fairseq/criterions/model_criterion.py | 9 +- fairseq/dataclass/utils.py | 20 +- fairseq/distributed_utils.py | 11 +- fairseq/models/__init__.py | 14 +- fairseq/modules/same_pad.py | 11 +- fairseq/optim/composite.py | 183 ++++++++++++++++++ fairseq/optim/fairseq_optimizer.py | 17 +- fairseq/optim/fp16_optimizer.py | 22 ++- .../lr_scheduler/fairseq_lr_scheduler.py | 2 +- fairseq/optim/lr_scheduler/pass_through.py | 39 ++++ fairseq/optim/nag.py | 2 +- fairseq/tasks/fairseq_task.py | 3 + fairseq/trainer.py | 11 +- 14 files changed, 311 insertions(+), 35 deletions(-) create mode 100644 fairseq/optim/composite.py create mode 100644 fairseq/optim/lr_scheduler/pass_through.py diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 0e4e3577d2..8cb1331825 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -8,7 +8,7 @@ from argparse import Namespace from dataclasses import dataclass, field from omegaconf import II -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py index c4f2c0b354..8e366a5d85 100644 --- a/fairseq/criterions/model_criterion.py +++ b/fairseq/criterions/model_criterion.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field import logging +from dataclasses import dataclass, field from typing import Dict, List from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass -from fairseq.distributed_utils import get_data_parallel_world_size logger = logging.getLogger(__name__) @@ -80,6 +79,7 @@ def forward(self, model, sample, reduce=True): "ntokens": sample_size, "nsentences": sample["id"].numel(), "sample_size": sample_size, + "_world_size": 1, } for lk in self.log_keys: @@ -113,9 +113,12 @@ def reduce_metrics(logging_outputs) -> None: "ntokens", "nsentences", "sample_size", + "_world_size", } - world_size = get_data_parallel_world_size() + world_size = utils.item( + sum(log.get("_world_size", 0) for log in logging_outputs) + ) for k in logging_outputs[0]: if k not in builtin_keys: diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index beae592d1a..e25838400c 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -4,19 +4,19 @@ # LICENSE file in the root directory of this source tree. import ast -import os +import inspect import logging +import os import re from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum -import inspect from typing import Any, Dict, List, Tuple, Type from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig -from hydra.experimental import compose, initialize from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize from omegaconf import DictConfig, OmegaConf, open_dict logger = logging.getLogger(__name__) @@ -218,7 +218,9 @@ def get_default(f): isinstance(val, str) and not val.startswith("${") # not interpolation and field_type != str - and (not inspect.isclass(field_type) or not issubclass(field_type, Enum)) # not choices enum + and ( + not inspect.isclass(field_type) or not issubclass(field_type, Enum) + ) # not choices enum ): # upgrade old models that stored complex parameters as string val = ast.literal_eval(val) @@ -438,9 +440,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): - dc_instance = DictConfig(dc) - dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"] - with open_dict(dc_instance): - cfg = OmegaConf.merge(dc_instance, cfg) - OmegaConf.set_struct(cfg, True) - return cfg + merged_cfg = OmegaConf.merge(dc, cfg) + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index fa70607fbc..8f98ac88f9 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -161,8 +161,9 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert cfg.distributed_world_size <= torch.cuda.device_count(), \ - f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" + assert ( + cfg.distributed_world_size <= torch.cuda.device_count() + ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" port = random.randint(10000, 20000) cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) @@ -376,8 +377,10 @@ def get_world_size(group): assert group[0] == "tpu" my_group = _find_my_group(group[1]) return len(my_group) - else: + elif torch.distributed.is_initialized(): return dist.get_world_size(group=group) + else: + return 1 def get_global_group(): @@ -416,6 +419,7 @@ def get_data_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu + return mpu.get_data_parallel_group() else: return get_global_group() @@ -435,6 +439,7 @@ def get_model_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu + return mpu.get_model_parallel_group() else: return None diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 600ca27c6a..135530d5c0 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -63,7 +63,11 @@ def build_model(cfg: FairseqDataclass, task): cfg = cfg[model_type] else: raise Exception( - "Could not infer model type from directory. Please add _name field to indicate model type" + "Could not infer model type from directory. Please add _name field to indicate model type. " + "Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type ) if model_type in ARCH_MODEL_REGISTRY: @@ -81,7 +85,13 @@ def build_model(cfg: FairseqDataclass, task): else: cfg = merge_with_parent(dc(), cfg) - assert model is not None, f"Could not infer model type from {cfg}" + assert model is not None, ( + f"Could not infer model type from {cfg}. " + f"Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type + ) return model.build_model(cfg, task) diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py index b46f94d635..4c04990ea6 100644 --- a/fairseq/modules/same_pad.py +++ b/fairseq/modules/same_pad.py @@ -8,11 +8,14 @@ class SamePad(nn.Module): - def __init__(self, kernel_size): + def __init__(self, kernel_size, causal=False): super().__init__() - self.remove = kernel_size % 2 == 0 + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): - if self.remove: - x = x[:, :, :-1] + if self.remove > 0: + x = x[:, :, : -self.remove] return x diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py new file mode 100644 index 0000000000..51e6999368 --- /dev/null +++ b/fairseq/optim/composite.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Any, List, Optional + +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer +from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler +from omegaconf import II, open_dict + + +logger = logging.getLogger(__name__) + + +@dataclass +class OptimizerAndSchedulerConfig(FairseqDataclass): + optimizer: Any = None + lr_scheduler: Optional[Any] = None + lr: List[float] = II("optimization.lr") + + +@dataclass +class CompositeOptimizerConfig(FairseqDataclass): + groups: Dict[str, OptimizerAndSchedulerConfig] = field( + default_factory=lambda: {}, + metadata={ + "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " + "Configures a different optimizer and (optionally) lr scheduler for each parameter group" + }, + ) + + +@register_optimizer("composite", dataclass=CompositeOptimizerConfig) +class FairseqCompositeOptimizer(FairseqOptimizer): + + optimizers: Dict[str, FairseqOptimizer] = {} + lr_schedulers: Dict[str, FairseqLRScheduler] = {} + lr_scheduler: FairseqLRScheduler = None + _optimizer: torch.optim.Optimizer + + def __init__(self, cfg: CompositeOptimizerConfig, params): + super().__init__(cfg) + + assert ( + len(params) > 1 + ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" + + groupped_params = defaultdict(list) + for p in params: + group = getattr(p, "param_group", "default") + groupped_params[group].append(p) + + assert groupped_params.keys() == cfg.groups.keys(), ( + f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! " + "Try setting 'param_group' on your parameters in the model." + ) + + for group, group_params in groupped_params.items(): + group_cfg = cfg.groups[group] + with open_dict(group_cfg): + group_cfg.optimizer.lr = group_cfg.lr + group_cfg.lr_scheduler.lr = group_cfg.lr + self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) + if group_cfg.lr_scheduler is not None: + self.lr_schedulers[group] = build_lr_scheduler( + group_cfg.lr_scheduler, self.optimizers[group] + ) + + if len(self.lr_schedulers) > 0: + assert len(self.lr_schedulers) == len(self.optimizers), ( + f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " + f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" + ) + self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) + + self._optimizer = CompositeOptimizer(self.optimizers) + + @property + def supports_groups(self): + return True + + @property + def param_groups(self): + for opt in self.optimizers.values(): + for group in opt.param_groups: + yield group + + def get_lr(self): + """Return the current learning rate.""" + k = ( + "default" + if "default" in self.optimizers + else next(iter(self.optimizers.keys())) + ) + return self.optimizers[k].param_groups[0]["lr"] + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.optimizers.items()} + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + if k not in self.optimizers: + # skip extra keys like "loss_scale" added by fp16 optimizer + continue + + overrides = ( + optimizer_overrides[k] + if isinstance(optimizer_overrides, dict) and k in optimizer_overrides + else None + ) + self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) + + +class CompositeOptimizer(torch.optim.Optimizer): + def __init__(self, optimizers: Dict[str, FairseqOptimizer]): + self.optimizers = optimizers + + @property + def supports_memory_efficient_fp16(self): + return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) + + @property + def supports_flat_params(self): + return all(o.supports_flat_params for o in self.optimizers.values()) + + def step(self, closure=None, groups=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for k, opt in self.optimizers.items(): + if groups is None or k in groups: + opt.step() + + return loss + + def zero_grad(self): + for opt in self.optimizers.values(): + opt.zero_grad() + + +class CompositeLRScheduler(FairseqLRScheduler): + def __init__(self, lr_schedulers): + super().__init__(None, None) + + self.lr_schedulers = lr_schedulers + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.lr_schedulers.items()} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + self.lr_schedulers[k].load_state_dict(state) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step_begin_epoch(epoch) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index f9864533b6..41c859355c 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -109,14 +109,21 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) - def step(self, closure=None, scale=1.0): + def step(self, closure=None, scale=1.0, groups=None): """Performs a single optimization step.""" if self.supports_step_with_scale: self.optimizer.step(closure, scale=scale) + if self.supports_groups: + self.optimizer.step(closure, scale=scale, groups=groups) + else: + self.optimizer.step(closure, scale=scale) else: if scale != 1.0: self.multiply_grads(1.0 / scale) - self.optimizer.step(closure) + if self.supports_groups: + self.optimizer.step(closure, groups=groups) + else: + self.optimizer.step(closure) def zero_grad(self): """Clears the gradients of all optimized parameters.""" @@ -136,6 +143,12 @@ def supports_step_with_scale(self): return self.optimizer.supports_step_with_scale return False + @property + def supports_groups(self): + if hasattr(self.optimizer, "supports_groups"): + return self.optimizer.supports_groups + return False + @property def supports_flat_params(self): """ diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 4457023527..a0da4948c8 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -65,6 +65,8 @@ def build_fp32_params(cls, args, params, flatten=True): for p in params: p32 = torch.nn.Parameter(p.data.float()) p32.grad = torch.zeros_like(p32.data) + if hasattr(p, "param_group"): + p32.param_group = p.param_group fp32_params.append(p32) return fp32_params @@ -198,15 +200,15 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): return grad_norm - def step(self, closure=None): + def step(self, closure=None, groups=None): """Performs a single optimization step.""" self._sync_fp16_grads_to_fp32() if getattr(self, "supports_step_with_scale", False): - self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) else: self._unscale_grads() - self.fp32_optimizer.step(closure) + self.fp32_optimizer.step(closure, groups=groups) if self.scaler is not None: self.scaler.update() @@ -303,6 +305,10 @@ def optimizer(self): def optimizer(self, optimizer): self.fp32_optimizer.optimizer = optimizer + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + @property def optimizer_config(self): return self.fp32_optimizer.optimizer_config @@ -416,14 +422,14 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): return grad_norm - def step(self, closure=None): + def step(self, closure=None, groups=None): """Performs a single optimization step.""" if getattr(self, "supports_step_with_scale", False): # NOTE(msb) optimizer divides by scale factor - self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) else: self._unscale_grads() - self.wrapped_optimizer.step(closure) + self.wrapped_optimizer.step(closure, groups=groups) if self.scaler is not None: self.scaler.update() @@ -514,6 +520,10 @@ def optimizer(self, optimizer): def optimizer_config(self): return self.wrapped_optimizer.optimizer_config + @property + def lr_scheduler(self): + return getattr(self.wrapped_optimizer, "lr_scheduler", None) + def get_lr(self): return self.wrapped_optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index dd75dc5e30..6c12fa56b8 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -12,7 +12,7 @@ class FairseqLRScheduler(object): def __init__(self, cfg, optimizer): super().__init__() - if not isinstance(optimizer, FairseqOptimizer): + if optimizer is not None and not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") self.cfg = cfg self.optimizer = optimizer diff --git a/fairseq/optim/lr_scheduler/pass_through.py b/fairseq/optim/lr_scheduler/pass_through.py new file mode 100644 index 0000000000..2f93db328c --- /dev/null +++ b/fairseq/optim/lr_scheduler/pass_through.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PassThroughScheduleConfig(FairseqDataclass): + pass + + +@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) +class PassThroughScheduleSchedule(FairseqLRScheduler): + """Delegate lr scheduling to the optimizer.""" + + def __init__(self, cfg: PassThroughScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + assert ( + hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None + ), "Pass-through schedule can only be used with optimizers with their own schedulers" + + def state_dict(self): + return self.optimizer.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.lr_scheduler.load_state_dict(state_dict) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + return self.optimizer.lr_scheduler.step_begin_epoch(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.lr_scheduler.step_update(num_updates) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index c612d812c9..4f652fe6d3 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -75,7 +75,7 @@ def step(self, closure=None): momentum = group["momentum"] lr = group["lr"] lr_old = group.get("lr_old", lr) - lr_correct = lr / lr_old + lr_correct = lr / lr_old if lr_old > 0 else lr for p in group["params"]: if p.grad is None: diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index d34f09d1d7..c9b7477ae7 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -438,6 +438,9 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output + def optimizer_step(self, optimizer, model, update_num): + optimizer.step() + def build_dataset_for_inference( self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs ) -> torch.utils.data.Dataset: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 94684f051b..cfeb63237b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -637,7 +637,9 @@ def maybe_no_sync(): with torch.autograd.profiler.record_function("optimizer"): # take an optimization step - self.optimizer.step() + self.task.optimizer_step( + self.optimizer, model=self.model, update_num=self.get_num_updates() + ) except FloatingPointError: # re-run the forward and backward pass with hooks attached to print @@ -827,7 +829,12 @@ def lr_step(self, epoch, val_loss=None): def lr_step_update(self): """Update the learning rate after each update.""" new_lr = self.lr_scheduler.step_update(self.get_num_updates()) - metrics.log_scalar("lr", new_lr, weight=0, priority=300) + if isinstance(new_lr, dict): + for k, v in new_lr.items(): + metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) + new_lr = new_lr.get("default", next(iter(new_lr.values()))) + else: + metrics.log_scalar("lr", new_lr, weight=0, priority=300) return new_lr def get_lr(self): From 4df4d0af8d706952013f8edf7da811937b8384c8 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 5 Dec 2020 07:36:28 -0800 Subject: [PATCH 344/707] Add missing `--optimizer` option to tutorial docs (fixes #2830) (#1485) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1485 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25342182 Pulled By: myleott fbshipit-source-id: 7eb2a4b2b7377da31d4f538053cc196437532db0 --- docs/getting_started.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index d227b95544..5d1d2d6979 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -90,7 +90,7 @@ well for the IWSLT 2014 dataset: > mkdir -p checkpoints/fconv > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \ - --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ + --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ --arch fconv_iwslt_de_en --save-dir checkpoints/fconv By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the From 72a25a4e52402b6f53aa98cfb739c075c0d6f7ee Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 5 Dec 2020 07:36:28 -0800 Subject: [PATCH 345/707] Rename optimization.min_lr -> optimization.stop_min_lr (#1486) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1486 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25342181 Pulled By: myleott fbshipit-source-id: 7d1cfb26334fff26d688648724ab073e5fb956f5 --- docs/getting_started.rst | 3 ++- examples/cross_lingual_language_model/README.md | 2 +- examples/language_model/README.adaptive_inputs.md | 2 +- examples/latent_depth/README.md | 2 +- examples/mbart/README.md | 2 +- examples/multilingual/README.md | 4 ++-- examples/multilingual/finetune_multilingual_model.sh | 2 +- examples/multilingual/train_multilingual_model.sh | 2 +- examples/nonautoregressive_translation/README.md | 2 +- examples/nonautoregressive_translation/scripts.md | 12 ++++++------ examples/pay_less_attention_paper/README.md | 10 +++++----- examples/quant_noise/README.md | 4 ++-- examples/simultaneous_translation/README.md | 6 +++--- examples/translation/README.md | 2 +- examples/translation_moe/README.md | 2 +- examples/wav2vec/README.md | 4 ++-- fairseq/checkpoint_utils.py | 8 ++++++-- fairseq/dataclass/configs.py | 2 +- fairseq_cli/train.py | 10 +++++++++- tests/test_binaries.py | 2 +- 20 files changed, 48 insertions(+), 35 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 5d1d2d6979..745ad7763c 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -182,9 +182,10 @@ sure to update ``--master_addr`` to the IP address of the first node: --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ - --lr 0.0005 --min-lr 1e-09 \ + --lr 0.0005 \ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 3584 \ + --max-epoch 70 \ --fp16 On SLURM clusters, fairseq will automatically detect the number of nodes and diff --git a/examples/cross_lingual_language_model/README.md b/examples/cross_lingual_language_model/README.md index a78f86d8da..f4c76cfed5 100644 --- a/examples/cross_lingual_language_model/README.md +++ b/examples/cross_lingual_language_model/README.md @@ -61,7 +61,7 @@ fairseq-train \ --max-update 2400000 --save-interval 1 --no-epoch-checkpoints \ --arch xlm_base \ --optimizer adam --lr-scheduler reduce_lr_on_plateau \ ---lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \ +--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \ --dropout 0.1 \ --criterion legacy_masked_lm_loss \ --max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \ diff --git a/examples/language_model/README.adaptive_inputs.md b/examples/language_model/README.adaptive_inputs.md index 6873467115..2ab3733018 100644 --- a/examples/language_model/README.adaptive_inputs.md +++ b/examples/language_model/README.adaptive_inputs.md @@ -20,7 +20,7 @@ fairseq-train --task language_modeling \ --save-dir checkpoints/transformer_wikitext-103 \ --arch transformer_lm_wiki103 \ --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ - --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ + --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \ --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d ``` diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md index a0ec55a3f6..e70e16405c 100644 --- a/examples/latent_depth/README.md +++ b/examples/latent_depth/README.md @@ -25,7 +25,7 @@ fairseq-train ${databin_dir} \ --share-decoder-input-output-embed \ --dropout 0.3 --attention-dropout 0.3 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ + --lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ --max-tokens 4096 --update-freq 1 \ --lr 0.0015 \ --clip-norm 1.0 \ diff --git a/examples/mbart/README.md b/examples/mbart/README.md index fa520a6825..8a3e22d425 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -73,7 +73,7 @@ fairseq-train path_2_data \ --source-lang en_XX --target-lang ro_RO \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ + --lr-scheduler polynomial_decay --lr 3e-05 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md index 3559c244e2..35eca89804 100644 --- a/examples/multilingual/README.md +++ b/examples/multilingual/README.md @@ -41,7 +41,7 @@ fairseq-train $path_2_data \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ @@ -69,7 +69,7 @@ fairseq-train $path_2_data \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/finetune_multilingual_model.sh b/examples/multilingual/finetune_multilingual_model.sh index cfa9a86113..ffcf1fc722 100644 --- a/examples/multilingual/finetune_multilingual_model.sh +++ b/examples/multilingual/finetune_multilingual_model.sh @@ -20,7 +20,7 @@ fairseq-train "$path_2_data" \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/train_multilingual_model.sh b/examples/multilingual/train_multilingual_model.sh index 09014c8217..c41730dfcd 100644 --- a/examples/multilingual/train_multilingual_model.sh +++ b/examples/multilingual/train_multilingual_model.sh @@ -16,7 +16,7 @@ fairseq-train "$path_2_data" \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md index dfc592f0a0..7b2d42a91d 100644 --- a/examples/nonautoregressive_translation/README.md +++ b/examples/nonautoregressive_translation/README.md @@ -44,7 +44,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ diff --git a/examples/nonautoregressive_translation/scripts.md b/examples/nonautoregressive_translation/scripts.md index 63b945c1d3..a3a33e6e02 100644 --- a/examples/nonautoregressive_translation/scripts.md +++ b/examples/nonautoregressive_translation/scripts.md @@ -14,7 +14,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -43,7 +43,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -76,7 +76,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -109,7 +109,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -136,7 +136,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -165,7 +165,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index 3fb93b23d1..537ca5f25b 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -110,7 +110,7 @@ mkdir -p $SAVE CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \ --clip-norm 0 --optimizer adam --lr 0.0005 \ --source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \ - --log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \ + --log-interval 100 --stop-min-lr '1e-09' --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --lr-scheduler inverse_sqrt \ --ddp-backend=no_c10d \ @@ -137,10 +137,10 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --max-update 30000 --share-all-embeddings --optimizer adam \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 20000 \ --arch lightconv_wmt_en_de_big --save-dir $SAVE \ --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ @@ -162,10 +162,10 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --max-update 30000 --share-all-embeddings --optimizer adam \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 70000 \ --arch lightconv_wmt_en_fr_big --save-dir $SAVE \ --dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \ diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 057ea620ab..9fe492d0cf 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -212,7 +212,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \ --sample-break-mode none --update-freq 3 \ --warmup-init-lr 1e-07 --warmup-updates 16000 \ - --weight-decay 0 --seed 1 --min-lr 1e-09 \ + --weight-decay 0 --seed 1 --stop-min-lr 1e-09 \ --quant-noise-pq 0.05 --quant-noise-pq-block-size 8 ``` @@ -269,7 +269,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --ddp-backend no_c10d \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ --fp16 --keep-last-epochs -1 \ - --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --min-lr 1e-09 \ + --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --stop-min-lr 1e-09 \ --max-tokens 2944 --tokens-per-sample 2944\ --momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \ --sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \ diff --git a/examples/simultaneous_translation/README.md b/examples/simultaneous_translation/README.md index e27b65280e..bbc6dacdda 100644 --- a/examples/simultaneous_translation/README.md +++ b/examples/simultaneous_translation/README.md @@ -23,7 +23,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 @@ -44,7 +44,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 @@ -65,7 +65,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 diff --git a/examples/translation/README.md b/examples/translation/README.md index 3eb8e01310..7b1fcc8de2 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -268,7 +268,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ --arch multilingual_transformer_iwslt_de_en \ --share-decoders --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ --warmup-updates 4000 --warmup-init-lr '1e-07' \ --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ --dropout 0.3 --weight-decay 0.0001 \ diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md index ef7abdb44b..3cc3fb46dc 100644 --- a/examples/translation_moe/README.md +++ b/examples/translation_moe/README.md @@ -24,7 +24,7 @@ fairseq-train --ddp-backend='no_c10d' \ --arch transformer_wmt_en_de --share-all-embeddings \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ - --lr 0.0007 --min-lr 1e-09 \ + --lr 0.0007 \ --dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \ --max-tokens 3584 ``` diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 52dce362ab..bf501ab9af 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -186,7 +186,7 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ ---arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ +--arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ @@ -244,7 +244,7 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ ---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 \ --optimizer adam --max-lr 1e-05 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 235c660a5e..f03875da50 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -458,7 +458,6 @@ def _upgrade_state_dict(state): "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } - # old model checkpoints may not have separate source/target positions # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # default to translation task @@ -474,15 +473,20 @@ def _upgrade_state_dict(state): state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1 ) - + # --remove-bpe ==> --postprocess if hasattr(state["args"], "remove_bpe"): state["args"].post_process = state["args"].remove_bpe + # --min-lr ==> --stop-min-lr + if hasattr(state["args"], "min_lr"): + state["args"].stop_min_lr = state["args"].min_lr + del state["args"].min_lr state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: with open_dict(state["cfg"]): if state["cfg"].task is not None: + # old model checkpoints may not have separate source/target positions if hasattr(state["cfg"].task, "max_positions") and not hasattr( state["cfg"].task, "max_source_positions" ): diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3ff177d969..3992e3c2d5 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -465,7 +465,7 @@ class OptimizationConfig(FairseqDataclass): " (note: this may be interpreted differently depending on --lr-scheduler)" }, ) - min_lr: float = field( + stop_min_lr: float = field( default=-1.0, metadata={"help": "stop training when the learning rate reaches this minimum"}, ) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 7739759693..82c30321eb 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -125,7 +125,15 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while epoch_itr.next_epoch_idx <= max_epoch: + if lr <= cfg.optimization.stop_min_lr: + logger.info( + f"stopping training because current learning rate ({lr}) is smaller " + "than or equal to minimum learning rate " + f"(--stop-min-lr={cfg.optimization.stop_min_lr})" + ) + break + # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: diff --git a/tests/test_binaries.py b/tests/test_binaries.py index cad6f1eba4..58f86484f7 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1454,7 +1454,7 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): "0.5", "--lr", "0.0001", - "--min-lr", + "--stop-min-lr", "1e-09", # dropout, attention args "--dropout", From 4817a9142f49793ec2eedbd71fe5bd872e58e7b5 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 5 Dec 2020 07:36:28 -0800 Subject: [PATCH 346/707] Cleanup CosineLRScheduler and change defaults (#1487) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1487 Here's the code for CosineLRScheduler that I used as a reference: https://github.com/pytorch/fairseq/blob/577e4fa78a295fd7cd3ee7e9fd4b936ca800ebea/fairseq/optim/lr_scheduler/cosine_lr_schedul In the reference: - `warmup_init_lr` defaults to `args.lr[0]` - `warmup_end_lr` defaults to `args.max_lr` - `min_lr` defaults to `args.lr[0]` (note that there's also a `args.min_lr` option defined in the global fairseq config, but this is unused by the cosine scheduler) - `max_lr` is a required option This diff removes `max_lr` and replaces it with `lr[0]` to be more consistent with other LR schedulers. We then add an explicit `min_lr` option to the Config. Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25342180 Pulled By: myleott fbshipit-source-id: 61281666e68839da8efc4714c2ce8c49dc4c8e6e --- .../language_model/README.adaptive_inputs.md | 4 +-- examples/pay_less_attention_paper/README.md | 4 +-- examples/quant_noise/README.md | 4 +-- examples/truncated_bptt/README.md | 2 +- examples/wav2vec/README.md | 6 ++-- .../optim/lr_scheduler/cosine_lr_scheduler.py | 32 +++++++++---------- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/language_model/README.adaptive_inputs.md b/examples/language_model/README.adaptive_inputs.md index 2ab3733018..98043c5377 100644 --- a/examples/language_model/README.adaptive_inputs.md +++ b/examples/language_model/README.adaptive_inputs.md @@ -19,8 +19,8 @@ fairseq-train --task language_modeling \ data-bin/wikitext-103 \ --save-dir checkpoints/transformer_wikitext-103 \ --arch transformer_lm_wiki103 \ - --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ - --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ + --max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ + --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \ --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \ --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d ``` diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index 537ca5f25b..d5b19af6cc 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -140,7 +140,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 20000 \ --arch lightconv_wmt_en_de_big --save-dir $SAVE \ --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ @@ -165,7 +165,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 70000 \ --arch lightconv_wmt_en_fr_big --save-dir $SAVE \ --dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \ diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 9fe492d0cf..7fe301f732 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -208,7 +208,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --ddp-backend no_c10d \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \ --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ - --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 1.0 --t-mult 2.0 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 1.0 --t-mult 2.0 \ --max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \ --sample-break-mode none --update-freq 3 \ --warmup-init-lr 1e-07 --warmup-updates 16000 \ @@ -269,7 +269,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --ddp-backend no_c10d \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ --fp16 --keep-last-epochs -1 \ - --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --stop-min-lr 1e-09 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 0.05 --stop-min-lr 1e-09 \ --max-tokens 2944 --tokens-per-sample 2944\ --momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \ --sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \ diff --git a/examples/truncated_bptt/README.md b/examples/truncated_bptt/README.md index f5c6447f1c..86518c9d5e 100644 --- a/examples/truncated_bptt/README.md +++ b/examples/truncated_bptt/README.md @@ -37,7 +37,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \ --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \ --optimizer adam --clip-norm 0.25 \ - --lr-scheduler cosine --warmup-updates 0 --lr 0.0 --max-lr 0.00025 \ + --lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \ --log-format json --log-interval 25 \ --fp16 ``` diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index bf501ab9af..4089edf42b 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -186,7 +186,7 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ ---arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ +--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ @@ -244,8 +244,8 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ ---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 \ ---optimizer adam --max-lr 1e-05 --lr-scheduler cosine \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ +--optimizer adam --lr 1e-05 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index d73c7cc7ed..38b57fe54c 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -26,9 +26,11 @@ class CosineLRScheduleConfig(FairseqDataclass): "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) - max_lr: float = field( - default=1.0, metadata={"help": "max learning rate, must be more than cfg.lr"} + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) t_mult: float = field( default=1.0, metadata={"help": "factor to grow the length of each period"} ) @@ -38,7 +40,7 @@ class CosineLRScheduleConfig(FairseqDataclass): lr_shrink: float = field( default=0.1, metadata={"help": "shrink factor for annealing"} ) - lr: List[float] = II("optimization.lr") + # This is not required, but is for convenience in inferring lr_period_updates max_update: int = II("optimization.max_update") @@ -50,7 +52,7 @@ class CosineLRSchedule(FairseqLRScheduler): We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (``--warmup-init-lr``) until the configured - max learning rate (``--max-lr``). + max learning rate (``--lr``). During warmup:: @@ -59,7 +61,7 @@ class CosineLRSchedule(FairseqLRScheduler): After warmup:: - lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) + lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i)) where ``t_curr`` is current percentage of updates within the current period range and ``t_i`` is the current period range, which is scaled by ``t_mul`` @@ -74,23 +76,21 @@ def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" ) - warmup_end_lr = cfg.max_lr - lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr - if cfg.warmup_init_lr < 0: - cfg.warmup_init_lr = lr + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + assert ( + self.max_lr > cfg.min_lr + ), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})" - # default min_lr=-1 -> cosine anneale to lr=0.0 - # otherwise pick min_lr from config - self.min_lr = cfg.min_lr if cfg.min_lr > 0.0 else 0.0 - self.max_lr = lr - assert self.max_lr > self.min_lr, "max_lr must be more than lr" + warmup_end_lr = self.max_lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = cfg.min_lr self.t_mult = cfg.t_mult self.period = cfg.lr_period_updates if self.period <= 0: assert ( - cfg.max_update >= 0 + cfg.max_update > 0 ), "Either --max_update or --lr-period-updates must be set" self.period = cfg.max_update - cfg.warmup_updates @@ -136,7 +136,7 @@ def step_update(self, num_updates): t_curr = curr_updates - (self.period * i) lr_shrink = self.lr_shrink ** i - min_lr = self.min_lr * lr_shrink + min_lr = self.cfg.min_lr * lr_shrink max_lr = self.max_lr * lr_shrink self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( From feb5f07ff4220b7908ff12c6692685784c7c9a71 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 8 Dec 2020 15:47:53 -0800 Subject: [PATCH 347/707] fix wav2vec scripts (#1494) Summary: fixes #2942 + docs + migration of old models Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1494 Reviewed By: myleott Differential Revision: D25404601 Pulled By: alexeib fbshipit-source-id: 092f145602522f8e7ea3eaa709bfe602a4d29d8b --- examples/wav2vec/README.md | 12 ++++++------ examples/wav2vec/vq-wav2vec_featurize.py | 11 ++++------- examples/wav2vec/wav2vec_featurize.py | 8 +++----- fairseq/checkpoint_utils.py | 9 +++++++++ fairseq/dataclass/utils.py | 5 +++++ 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 4089edf42b..05df59f214 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -160,11 +160,11 @@ Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl #### Example usage: ```python import torch -from fairseq.models.wav2vec import Wav2VecModel +import fairseq cp = torch.load('/path/to/wav2vec.pt') -model = Wav2VecModel.build_model(cp['args'], task=None) -model.load_state_dict(cp['model']) +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] model.eval() wav_input_16khz = torch.randn(1,10000) @@ -217,11 +217,11 @@ Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download] #### Example usage: ```python import torch -from fairseq.models.wav2vec import Wav2VecModel +import fairseq cp = torch.load('/path/to/vq-wav2vec.pt') -model = Wav2VecModel.build_model(cp['args'], task=None) -model.load_state_dict(cp['model']) +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] model.eval() wav_input_16khz = torch.randn(1,10000) diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py index baabc1d365..1adb52de1c 100644 --- a/examples/wav2vec/vq-wav2vec_featurize.py +++ b/examples/wav2vec/vq-wav2vec_featurize.py @@ -16,8 +16,7 @@ import soundfile as sf import torch -import tqdm -from fairseq.models.wav2vec.wav2vec import Wav2VecModel +import fairseq from torch import nn from torch.utils.data import DataLoader @@ -211,13 +210,11 @@ def load_data(self, fnames): return loader def load_model(self): - cp = torch.load(self.checkpoint, map_location=lambda x, _: x) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint]) + model = model[0] - model = Wav2VecModel.build_model(cp["args"], None) + self.quantize_location = getattr(cfg.model, "vq", "encoder") - self.quantize_location = getattr(cp["args"], "vq", "encoder") - - model.load_state_dict(cp["model"]) model.eval().float() model.cuda() diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py index 9283930587..b806316e5a 100644 --- a/examples/wav2vec/wav2vec_featurize.py +++ b/examples/wav2vec/wav2vec_featurize.py @@ -18,7 +18,7 @@ import soundfile as sf import torch import tqdm -from fairseq.models.wav2vec.wav2vec import Wav2VecModel +import fairseq from torch import nn @@ -35,10 +35,8 @@ class PretrainedWav2VecModel(nn.Module): def __init__(self, fname): super().__init__() - checkpoint = torch.load(fname) - self.args = checkpoint["args"] - model = Wav2VecModel.build_model(self.args, None) - model.load_state_dict(checkpoint["model"]) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) + model = model[0] model.eval() self.model = model diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index f03875da50..6209c71aef 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -480,6 +480,15 @@ def _upgrade_state_dict(state): if hasattr(state["args"], "min_lr"): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr + # binary_cross_entropy => wav2vec criterion + if hasattr(state["args"], "criterion") and state["args"].criterion == "binary_cross_entropy": + state["args"].criterion = "wav2vec" + # speech_pretraining => audio pretraining + if hasattr(state["args"], "task") and state["args"].task == "speech_pretraining": + state["args"].task = "audio_pretraining" + # audio_cpc => wav2vec + if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": + state["args"].arch = "wav2vec" state["cfg"] = convert_namespace_to_omegaconf(state["args"]) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index e25838400c..9d52d45942 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -237,6 +237,11 @@ def get_default(f): t_args = v.type.__args__ if len(t_args) == 1: val = list(map(t_args[0], val)) + elif val is not None and (field_type is int or field_type is bool or field_type is float): + try: + val = field_type(val) + except: + pass # ignore errors here, they are often from interpolation args if val is None: overrides.append("{}.{}=null".format(sub_node, k)) From 606b3b8c8d7e15dad66b177cde66a04621349e6c Mon Sep 17 00:00:00 2001 From: Peng-Jen Chen Date: Thu, 10 Dec 2020 11:57:14 -0800 Subject: [PATCH 348/707] Release WMT20 MT/LM models Summary: Add README to expose wmt20 model paths to download and torch.hub examples. Reviewed By: ngoyal2707 Differential Revision: D25456298 fbshipit-source-id: 8c78bffb3f539963cbf61e508a56e421929925f0 --- fairseq/models/transformer.py | 13 +++++++++++++ fairseq/models/transformer_lm.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 70920ed779..9655578e52 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -71,6 +71,13 @@ def moses_fastbpe(path): 'bpe': 'fastbpe', } + def spm(path): + return { + 'path': path, + 'bpe': 'sentencepiece', + 'tokenizer': 'space', + } + return { 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', @@ -83,6 +90,12 @@ def moses_fastbpe(path): 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), + 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), + 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), + 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), + 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), + 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), + 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), } # fmt: on diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 35bfa6eb6f..d86b68b508 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -171,6 +171,8 @@ class TransformerLanguageModel(FairseqLanguageModel): def hub_models(cls): def moses_fastbpe(path): return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} + def spm(path): + return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} return { "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", @@ -184,6 +186,18 @@ def moses_fastbpe(path): "transformer_lm.wmt19.ru": moses_fastbpe( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" ), + "transformer_lm.wmt20.en": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.en.tar.gz" + ), + "transformer_lm.wmt20.ta": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.ta.tar.gz" + ), + "transformer_lm.wmt20.iu.news": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.news.tar.gz" + ), + "transformer_lm.wmt20.iu.nh": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.nh.tar.gz" + ), } def __init__(self, decoder): From 43e4db8cd77dab63836a9d6dbb239d27ee2641d7 Mon Sep 17 00:00:00 2001 From: Peng-Jen Chen Date: Fri, 11 Dec 2020 06:22:57 -0800 Subject: [PATCH 349/707] Add default dataclass to space_tokenizer and nltk_tokenizer Summary: The default dataclass of `space_tokenizer` is `None`, this somehow breaks when we load a model with space tokenizer from `torch.hub` (full message: P154020599): ``` ... hydra.errors.MissingConfigException: Could not load tokenizer/space. Available options: moses ``` Adding default dataclass to `FairseqDataclass` to solve the problem. Note: another fix might be setting the default dataclass to `FairseqDataclass` here: https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/deeplearning/projects/fairseq-py/fairseq/registry.py?commit=1ae2e67d969ca090b41a6fa33ca9aaf360a26d3f&lines=63 Reviewed By: myleott Differential Revision: D25466594 fbshipit-source-id: aa4ec23731081e266ce641ea3db179169233842c --- fairseq/data/encoders/nltk_tokenizer.py | 3 ++- fairseq/data/encoders/space_tokenizer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index ee164710a0..0ab92377b3 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("nltk") +@register_tokenizer("nltk", dataclass=FairseqDataclass) class NLTKTokenizer(object): def __init__(self, *unused): try: diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 7c7f644d5c..925ad41b7c 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -6,9 +6,10 @@ import re from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("space") +@register_tokenizer("space", dataclass=FairseqDataclass) class SpaceTokenizer(object): def __init__(self, *unused): self.space_tok = re.compile(r"\s+") From 5430df004ff4a5928e0295f3c8ac0b29132bd6a8 Mon Sep 17 00:00:00 2001 From: Peng-Jen Chen Date: Fri, 11 Dec 2020 08:37:38 -0800 Subject: [PATCH 350/707] Add WMT20 page to fairseq-py/example Summary: Add WMT20 page to fairseq-py/example to release WMT20 models Reviewed By: myleott, ngoyal2707 Differential Revision: D25495578 fbshipit-source-id: 088a457e379b5227cf45a5db1901073499a2e4c1 --- examples/wmt20/README.md | 70 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 examples/wmt20/README.md diff --git a/examples/wmt20/README.md b/examples/wmt20/README.md new file mode 100644 index 0000000000..7bfec0ee55 --- /dev/null +++ b/examples/wmt20/README.md @@ -0,0 +1,70 @@ +# WMT 20 + +This page provides pointers to the models of Facebook-FAIR's WMT'20 news translation task submission [(Chen et al., 2020)](https://arxiv.org/abs/2011.08298). + +## Single best MT models (after finetuning on part of WMT20 news dev set) + +Model | Description | Download +---|---|--- +`transformer.wmt20.ta-en` | Ta->En | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz) +`transformer.wmt20.en-ta` | En->Ta | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz) +`transformer.wmt20.iu-en.news` | Iu->En (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz) +`transformer.wmt20.en-iu.news` | En->Iu (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz) +`transformer.wmt20.iu-en.nh` | Iu->En (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz) +`transformer.wmt20.en-iu.nh` | En->Iu (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz) + +## Language models +`transformer_lm.wmt20.en` | En Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en.tar.gz) +`transformer_lm.wmt20.ta` | Ta Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta.tar.gz) +`transformer_lm.wmt20.iu.news` | Iu Language Model (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.news.tar.gz) +`transformer_lm.wmt20.iu.nh` | Iu Language Model (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.nh.tar.gz) + +## Example usage (torch.hub) + +#### Translation + +```python +import torch + +# English to Tamil translation +en2ta = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-ta') +en2ta.translate("Machine learning is great!") # 'இயந்திரக் கற்றல் அருமை!' + +# Tamil to English translation +ta2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.ta-en') +ta2en.translate("இயந்திரக் கற்றல் அருமை!") # 'Machine learning is great!' + +# English to Inuktitut translation +en2iu = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-iu.news') +en2iu.translate("machine learning is great!") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!' + +# Inuktitut to English translation +iu2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.iu-en.news') +iu2en.translate("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!") # 'Machine learning excellence!' +``` + +#### Language Modeling + +```python +# Sample from the English LM +en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.en') +en_lm.sample("Machine learning is") # 'Machine learning is a type of artificial intelligence that uses machine learning to learn from data and make predictions.' + +# Sample from the Tamil LM +ta_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.ta') +ta_lm.sample("இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின்") # 'இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின் ஒரு பகுதியாகும்.' + +# Sample from the Inuktitut LM +iu_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.iu.news') +iu_lm.sample("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ, ᐊᒻᒪᓗ ᓯᓚᐅᑉ ᐊᓯᙳᖅᐸᓪᓕᐊᓂᖓᓄᑦ ᖃᓄᐃᓕᐅᕈᑎᒃᓴᑦ, ᐃᓚᖃᖅᖢᑎᒃ ᐅᑯᓂᖓ:' +``` + +## Citation +```bibtex +@inproceedings{chen2020facebook + title={Facebook AI's WMT20 News Translation Task Submission}, + author={Peng-Jen Chen and Ann Lee and Changhan Wang and Naman Goyal and Angela Fan and Mary Williamson and Jiatao Gu}, + booktitle={Proc. of WMT}, + year={2020}, +} +``` From 3a2c0a2558aaf363c8ecc6967c9ae63f23a58502 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 11 Dec 2020 09:55:24 -0800 Subject: [PATCH 351/707] Back out "Improve performance of distributed_utils.broadcast_object" Summary: Original commit changeset: 521102fae75a Reviewed By: ngoyal2707 Differential Revision: D25494613 fbshipit-source-id: 64aead87ee84f4294fd37f2de12689b909e11ff1 --- fairseq/distributed_utils.py | 126 ++++++----------------------------- 1 file changed, 21 insertions(+), 105 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 8f98ac88f9..dd93cda35c 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -14,7 +14,6 @@ import warnings from argparse import Namespace from collections import OrderedDict -from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional import torch @@ -643,127 +642,44 @@ def get_from_stack(key): return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) -def broadcast_tensors( - tensors: Optional[List[torch.Tensor]], - src_rank: int, - group: object, - dist_device: Optional[torch.device] = None, -) -> List[torch.Tensor]: - """ - Broadcasts a list of tensors without other (non-src) ranks needing to know - the dtypes/shapes of the tensors. - """ - if dist_device is None: - if torch.distributed.get_backend(group) == "nccl": - dist_device = torch.device("cuda") - else: - dist_device = torch.device("cpu") - - # share metadata first to simplify transfer - is_src_rank = (get_rank(group) == src_rank) - if is_src_rank: - metadata = [ - {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors - ] - metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) - else: - metadata = _broadcast_object_slow(None, src_rank, group, dist_device) - - out_tensors = [] - for i, meta in enumerate(metadata): - if is_src_rank: - tensor = tensors[i] - broadcast(tensors[i].to(dist_device), src=src_rank, group=group) - else: - tensor = torch.zeros( - [meta["size"].numel()], dtype=meta["dtype"], device=dist_device - ) - broadcast(tensor, src=src_rank, group=group) - tensor = tensor.view(meta["size"]).to(meta["device"]) - out_tensors.append(tensor) - return out_tensors - - +# From fairscale/optim/utils.py def broadcast_object( obj: Any, src_rank: int, group: object, dist_device: Optional[torch.device] = None, + dist_length_dtype: Optional[torch.dtype] = torch.long, + dist_dtype: Optional[torch.dtype] = torch.uint8, ) -> Any: - """Broadcast an arbitrary Python object to other workers.""" + """ + Either broadcast from master to the fleet (default), + or use the src setting as the original rank. + """ if dist_device is None: if torch.distributed.get_backend(group) == "nccl": dist_device = torch.device("cuda") else: dist_device = torch.device("cpu") - if get_rank(group) == src_rank: - # split the tensors from the non-tensors so we can broadcast them - # directly, avoiding unnecessary serialization/deserialization - tensors = [] - obj = _split_tensors_from_obj(obj, tensors) - obj = _broadcast_object_slow(obj, src_rank, group, dist_device) - tensors = broadcast_tensors(tensors, src_rank, group, dist_device) - else: - obj = _broadcast_object_slow(None, src_rank, group, dist_device) - tensors = broadcast_tensors(None, src_rank, group, dist_device) - return _put_tensors_in_obj(obj, tensors) - - -def _broadcast_object_slow( - obj: Any, src_rank: int, group: object, dist_device: torch.device, -) -> Any: if get_rank(group) == src_rank: # Emit data buffer = io.BytesIO() torch.save(obj, buffer) - buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) - length = torch.LongTensor([len(buffer)]).to(dist_device) - broadcast(length, src=src_rank, group=group) - broadcast(buffer, src=src_rank, group=group) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.tensor( + [len(data)], dtype=dist_length_dtype, device=dist_device + ) + broadcast(length_tensor, src=src_rank, group=group) + data_send_tensor = torch.tensor(data, dtype=dist_dtype, device=dist_device) + broadcast(data_send_tensor, src=src_rank, group=group) else: # Fetch from the source - length = torch.LongTensor([0]).to(dist_device) - broadcast(length, src=src_rank, group=group) - buffer = torch.ByteTensor(int(length.item())).to(dist_device) - broadcast(buffer, src=src_rank, group=group) - buffer = io.BytesIO(buffer.cpu().numpy()) + length_tensor = torch.tensor([0], dtype=dist_length_dtype, device=dist_device) + broadcast(length_tensor, src=src_rank, group=group) + data_recv_tensor = torch.zeros( + [int(length_tensor.item())], dtype=dist_dtype, device=dist_device + ) + broadcast(data_recv_tensor, src=src_rank, group=group) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) obj = torch.load(buffer, map_location="cpu") return obj - - -@dataclass(frozen=True) -class _TensorPlaceholder: - index: int - - -def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: - if torch.is_tensor(obj): - placeholder = _TensorPlaceholder(index=len(tensors)) - tensors.append(obj) - return placeholder - elif isinstance(obj, dict): - return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} - elif isinstance(obj, list): - return [_split_tensors_from_obj(v, tensors) for v in obj] - elif isinstance(obj, tuple): - return tuple(_split_tensors_from_obj(v, tensors) for v in obj) - elif isinstance(obj, set): - return {_split_tensors_from_obj(v, tensors) for v in obj} - else: - return obj - - -def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: - if isinstance(obj, _TensorPlaceholder): - return tensors[obj.index] - elif isinstance(obj, dict): - return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} - elif isinstance(obj, list): - return [_put_tensors_in_obj(v, tensors) for v in obj] - elif isinstance(obj, tuple): - return tuple(_put_tensors_in_obj(v, tensors) for v in obj) - elif isinstance(obj, set): - return {_put_tensors_in_obj(v, tensors) for v in obj} - else: - return obj From e8b195ac069600203da3e7d60ba29d0975dd0afd Mon Sep 17 00:00:00 2001 From: louismartin Date: Fri, 11 Dec 2020 10:18:23 -0800 Subject: [PATCH 352/707] Use deepcopy for copying cfg #3011 (#3022) Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3011. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3022 Reviewed By: myleott Differential Revision: D25495116 Pulled By: louismartin fbshipit-source-id: bcd3bc04b92f083882dfbc9110b14bb2ac7c8ce0 --- fairseq/hub_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index b716884c78..775d1f7aeb 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -157,7 +157,7 @@ def generate( )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.cfg.generation) + gen_args = copy.deepcopy(self.cfg.generation) with open_dict(gen_args): gen_args.beam = beam for k, v in kwargs.items(): From 13dbdf279887012ecf1bf955a114b952c8d00927 Mon Sep 17 00:00:00 2001 From: Robert Verkuil Date: Fri, 11 Dec 2020 11:44:25 -0800 Subject: [PATCH 353/707] Add S3 PathHandler to Fairseq (#1441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## Description This PathHandler supports saving and loading from Amazon S3. For internal FB use. Intended for the use case of infrequent loading / saving (e.g. normal checkpoint reading and writing during training or when running an analysis script.) The PathHandler assumes the full contents of files to be saved / loaded can be stored temporarily in memory. Further optimizations could support streaming for handling large files. ## TODO: ~~- [ ] Switch to pid-based S3 client cache, to avoid multithreading problems.~~ - [x] Complete the test cases in test/file_io.py - [x] Unit tests complete - [x] Integration tested with the following command, to check for ability to save, resume from saved checkpoint, and re-download a file if local copy is stale. ```bash pyscrun fairseq-py/fairseq_cli/train.py \ /checkpoint/bioseq_nonsecure/rverkuil/aws/data \ --restore-file checkpoint_last.pt \ --dataset-impl fasta \ --task masked_lm --criterion masked_lm \ --arch roberta --dropout 0.0 --attention-dropout 0.0 \ --encoder-layers 2 \ --optimizer adam --lr 0.0001 \ --max-sentences 24 \ --log-format simple --log-interval 1 \ --fp16 \ --max-epoch 20 \ --max-update 20 \ --save-dir s3://fairusersglobal// ``` # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1441 Reviewed By: myleott Differential Revision: D25370388 Pulled By: robert-verkuil fbshipit-source-id: ffefbd7345f8d8ae72513f1cead94469a17c4459 --- fairseq/file_io.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index d667256922..d74a48591a 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -5,14 +5,29 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import os import shutil from typing import List, Optional +logger = logging.getLogger(__file__) + + try: from fvcore.common.file_io import PathManager as FVCorePathManager + try: + # [FB only - for now] AWS PathHandler for PathManager + from .fb_pathhandlers import S3PathHandler + + FVCorePathManager.register_handler(S3PathHandler()) + except KeyError: + logging.warning("S3PathHandler already registered.") + except ImportError: + logging.debug( + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module.") + except ImportError: FVCorePathManager = None From fc7a787c01cca883330fc16d22b4693fb8277df8 Mon Sep 17 00:00:00 2001 From: Peng-Jen Chen Date: Fri, 11 Dec 2020 12:17:28 -0800 Subject: [PATCH 354/707] Update WMT20 README.md (#3027) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Fix the format # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3027 Reviewed By: ngoyal2707 Differential Revision: D25500155 Pulled By: pipibjc fbshipit-source-id: bf6298b0a4a1942e6e0e8aa632e0af7fd5c516a0 --- examples/wmt20/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/wmt20/README.md b/examples/wmt20/README.md index 7bfec0ee55..b4f2874652 100644 --- a/examples/wmt20/README.md +++ b/examples/wmt20/README.md @@ -14,6 +14,8 @@ Model | Description | Download `transformer.wmt20.en-iu.nh` | En->Iu (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz) ## Language models +Model | Description | Download +---|---|--- `transformer_lm.wmt20.en` | En Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en.tar.gz) `transformer_lm.wmt20.ta` | Ta Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta.tar.gz) `transformer_lm.wmt20.iu.news` | Iu Language Model (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.news.tar.gz) From d4788a9f7082c9b4a1bbe81c6e2c898e7e16ceb9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 11 Dec 2020 15:20:10 -0800 Subject: [PATCH 355/707] wandb QOL: infer run name from environment or save_dir (#1500) Summary: As documented, in the , the name of a run can be controlled using either `wandb.init(name)` or the environment_variable `WANDB_NAME`. we set ```python wandb_run_name=os.environ.get("WANDB_NAME", cfg.checkpoint.save_dir) ``` to preserve the environment variable functionality documented [here](https://docs.wandb.com/library/environment-variables), and chose a sane default if this is not specified. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1500 Test Plan: ```bash WANDB_NAME=dummy2 CUDA_VISIBLE_DEVICES=0 python train.py --task dummy_lm --arch \ transformer_lm_gpt2_small --tokens-per-sample 512 --max-sentences 2 \ --lr 0.0001 --log-format simple --log-interval 1 --optimizer adam \ --fp16 --no-save --disable-validation --max-update 10 \ --restore-file x.pt --wandb-project dummy-lm ``` Reviewed By: myleott Differential Revision: D25497735 Pulled By: sshleifer fbshipit-source-id: fcd4e2a3263444e5759fae98641963fd3b9f6914 --- fairseq/logging/progress_bar.py | 15 ++++++++------- fairseq_cli/train.py | 14 ++++++++++++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 2b3873794e..07ee26f4fc 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -34,6 +34,7 @@ def progress_bar( tensorboard_logdir: Optional[str] = None, default_log_format: str = "tqdm", wandb_project: Optional[str] = None, + wandb_run_name: Optional[str] = None, ): if log_format is None: log_format = default_log_format @@ -62,7 +63,7 @@ def progress_bar( bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) if wandb_project: - bar = WandBProgressBarWrapper(bar, wandb_project) + bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) return bar @@ -370,15 +371,15 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): class WandBProgressBarWrapper(BaseProgressBar): """Log to Weights & Biases.""" - def __init__(self, wrapped_bar, wandb_project): + def __init__(self, wrapped_bar, wandb_project, run_name=None): self.wrapped_bar = wrapped_bar if wandb is None: - logger.warning('wandb not found, pip install wandb') + logger.warning("wandb not found, pip install wandb") return # reinit=False to ensure if wandb.init() is called multiple times # within one process it still references the same run - wandb.init(project=wandb_project, reinit=False) + wandb.init(project=wandb_project, reinit=False, name=run_name) def __iter__(self): return iter(self.wrapped_bar) @@ -397,11 +398,11 @@ def _log_to_wandb(self, stats, tag=None, step=None): if wandb is None: return if step is None: - step = stats['num_updates'] + step = stats["num_updates"] - prefix = '' if tag is None else tag + '/' + prefix = "" if tag is None else tag + "/" - for key in stats.keys() - {'num_updates'}: + for key in stats.keys() - {"num_updates"}: if isinstance(stats[key], AverageMeter): wandb.log({prefix + key: stats[key].val}, step=step) elif isinstance(stats[key], Number): diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 82c30321eb..11baf5a59b 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -211,7 +211,12 @@ def train( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) @@ -352,7 +357,12 @@ def validate( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) From 39e722ceabff11db00d9dd66998039236b40ae50 Mon Sep 17 00:00:00 2001 From: Davide Caroselli Date: Fri, 11 Dec 2020 18:54:40 -0800 Subject: [PATCH 356/707] Fix #3017: restore support for user modules in zip/jar files (#3018) Summary: ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3017 My original implementation of the function utils.import_user_module supported external user modules wrapped in zip/jar files (python natively support modules in zip/jar files). The new piece of code which checks for file existence breaks this functionality. This trivial fix can solve the problem! Pull Request resolved: https://github.com/pytorch/fairseq/pull/3018 Reviewed By: alexeib Differential Revision: D25485611 Pulled By: myleott fbshipit-source-id: 21667cd87e9e7a99095f8ad21d7b3bfdb547a993 --- fairseq/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/utils.py b/fairseq/utils.py index 4046f6696c..1b2e060650 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -437,7 +437,7 @@ def import_user_module(args): module_path = getattr(args, "user_dir", None) if module_path is not None: module_path = os.path.abspath(args.user_dir) - if not os.path.exists(module_path): + if not os.path.exists(module_path) and not os.path.isfile(os.path.dirname(module_path)): fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path From 032a404d389307a0e8f7dd2a0d501c78afa78f39 Mon Sep 17 00:00:00 2001 From: Hiroyuki Deguchi Date: Fri, 11 Dec 2020 18:59:50 -0800 Subject: [PATCH 357/707] Add "soft" argument of "--print-alignment" (#2985) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: If the argument is set to "soft", print probability for each source token, like this: A-0 0.365083,0.328207,0.306710 0.442428,0.340282,0.217290 0.378712,0.367315,0.253973 0.321335,0.425601,0.253064 Each source token is separated from each other by a comma (,) and each target token is separated from each other by a space ( ). This option is based on the Marian NMT's option. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2985 Reviewed By: alexeib Differential Revision: D25344394 Pulled By: myleott fbshipit-source-id: 659eb8f7af1ccdafacaaa91ce5ddf5d71cb3e775 --- fairseq/dataclass/configs.py | 9 ++++++--- fairseq/dataclass/constants.py | 1 + fairseq/sequence_generator.py | 9 +++++++-- fairseq/tasks/fairseq_task.py | 4 +++- fairseq/utils.py | 17 +++++++++++++++++ fairseq_cli/generate.py | 15 ++++++++++++++- 6 files changed, 48 insertions(+), 7 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3992e3c2d5..1a89560072 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -17,6 +17,7 @@ GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, ZERO_SHARDING_CHOICES, ) @@ -737,10 +738,12 @@ class GenerationConfig(FairseqDataclass): default=-1.0, metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, ) - print_alignment: bool = field( - default=False, + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( + default=None, metadata={ - "help": "if set, uses attention feedback to compute and print alignment to source tokens" + "help": "if set, uses attention feedback to compute and print alignment to source tokens " + "(valid options are: hard, soft, otherwise treated as hard alignment)", + "argparse_const": "hard", }, ) print_step: bool = field( diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 858f77a863..46881786a8 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -44,3 +44,4 @@ def ChoiceEnum(choices: List[str]): ) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) +PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index afc1500090..bd46f9e5b9 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -903,7 +903,7 @@ def reorder_incremental_state( class SequenceGeneratorWithAlignment(SequenceGenerator): - def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): + def __init__(self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs): """Generates translations of a given source sentence. Produces alignments following "Jointly Learning to Align and @@ -917,6 +917,11 @@ def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) self.left_pad_target = left_pad_target + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + @torch.no_grad() def generate(self, models, sample, **kwargs): finalized = super()._generate(sample, **kwargs) @@ -945,7 +950,7 @@ def generate(self, models, sample, **kwargs): # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): - alignment = utils.extract_hard_alignment( + alignment = self.extract_alignment( attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos ) finalized[i // beam_size][i % beam_size]["alignment"] = alignment diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index c9b7477ae7..b99c511990 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -376,12 +376,14 @@ def build_generator( else: search_strategy = search.BeamSearch(self.target_dictionary) + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs['print_alignment'] = args.print_alignment else: seq_gen_cls = SequenceGenerator - extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + return seq_gen_cls( models, self.target_dictionary, diff --git a/fairseq/utils.py b/fairseq/utils.py index 1b2e060650..a20c83384c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -631,6 +631,23 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): return alignment +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ( + ((tgt_sent != pad)).nonzero(as_tuple=False) + ) + src_valid = ( + ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [ + ["{:.6f}".format(p) for p in src_probs.tolist()] + for src_probs in attn_valid + ] + return alignment + + def new_arange(x, *size): """ Return a Tensor of `size` filled with a range function on the device of x. diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 6be8150cda..4aeb4a56fa 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -299,7 +299,7 @@ def decode_fn(x): file=output_file, ) - if cfg.generation.print_alignment: + if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, @@ -312,6 +312,19 @@ def decode_fn(x): ), file=output_file, ) + if cfg.generation.print_alignment == "soft": + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [ + ",".join(src_probs) + for src_probs in alignment + ] + ), + ), + file=output_file, + ) if cfg.generation.print_step: print( From f3d5045a71ae463bd3f05254d7c4216801a04bc2 Mon Sep 17 00:00:00 2001 From: Raphael Scheible Date: Fri, 11 Dec 2020 19:09:41 -0800 Subject: [PATCH 358/707] add German RoBERTa model (GottBERT) (#2992) Summary: # Before submitting - There is no related issue for this pull request. - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - We did not see any necessity for tests. ## What does this PR do? Add German RoBERTa model (GottBERT) Pull Request resolved: https://github.com/pytorch/fairseq/pull/2992 Reviewed By: alexeib Differential Revision: D25494927 Pulled By: myleott fbshipit-source-id: b6790124d7c3c8dc387c141706cd8a527cc950ab --- README.md | 1 + examples/gottbert/README.md | 64 ++++++++++++++++++++++++ examples/roberta/README.md | 1 + fairseq/data/encoders/hf_byte_bpe.py | 8 ++- fairseq/hub_utils.py | 2 + fairseq/models/roberta/__init__.py | 1 + fairseq/models/roberta/model_gottbert.py | 49 ++++++++++++++++++ 7 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 examples/gottbert/README.md create mode 100644 fairseq/models/roberta/model_gottbert.py diff --git a/README.md b/README.md index 3ae332b350..9cc5b7a559 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* December 2020: [GottBERT model and code released](examples/gottbert/README.md) * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) diff --git a/examples/gottbert/README.md b/examples/gottbert/README.md new file mode 100644 index 0000000000..1d58feb279 --- /dev/null +++ b/examples/gottbert/README.md @@ -0,0 +1,64 @@ +# GottBERT: a pure German language model + +## Introduction + +[GottBERT](http://arxiv.org/abs/2012.02110) is a pretrained language model trained on 145GB of German text based on RoBERTa. + +## Example usage + +### fairseq +##### Load GottBERT from torch.hub (PyTorch >= 1.1): +```python +import torch +gottbert = torch.hub.load('pytorch/fairseq', 'gottbert-base') +gottbert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load GottBERT (for PyTorch 1.0 or custom models): +```python +# Download gottbert model +wget https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz +tar -xzvf gottbert.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import GottbertModel +gottbert = GottbertModel.from_pretrained('/path/to/gottbert') +gottbert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Filling masks: +```python +masked_line = 'Gott ist ! :)' +gottbert.fill_mask(masked_line, topk=3) +# [('Gott ist gut ! :)', 0.3642110526561737, ' gut'), +# ('Gott ist überall ! :)', 0.06009674072265625, ' überall'), +# ('Gott ist großartig ! :)', 0.0370681993663311, ' großartig')] +``` + +##### Extract features from GottBERT + +```python +# Extract the last layer's features +line = "Der erste Schluck aus dem Becher der Naturwissenschaft macht atheistisch , aber auf dem Grunde des Bechers wartet Gott !" +tokens = gottbert.encode(line) +last_layer_features = gottbert.extract_features(tokens) +assert last_layer_features.size() == torch.Size([1, 27, 768]) + +# Extract all layer's features (layer 0 is the embedding layer) +all_layers = gottbert.extract_features(tokens, return_all_hiddens=True) +assert len(all_layers) == 13 +assert torch.all(all_layers[-1] == last_layer_features) +``` +## Citation +If you use our work, please cite: + +```bibtex +@misc{scheible2020gottbert, + title={GottBERT: a pure German Language Model}, + author={Raphael Scheible and Fabian Thomczyk and Patric Tippmann and Victor Jaravine and Martin Boeker}, + year={2020}, + eprint={2012.02110}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/examples/roberta/README.md b/examples/roberta/README.md index ca86131eea..58091b2c7d 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l ### What's New: +- December 2020: German model (GottBERT) is available: [GottBERT](https://github.com/pytorch/fairseq/tree/master/examples/gottbert). - January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto). - November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/master/examples/camembert). - November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/master/examples/xlmr). diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 92d2c3922c..c508578d41 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -7,6 +7,7 @@ from fairseq.data.encoders import register_bpe from fairseq.dataclass import FairseqDataclass +from fairseq import file_utils @dataclass @@ -28,9 +29,12 @@ def __init__(self, cfg): "Please install huggingface/tokenizers with: " "pip install tokenizers" ) + bpe_vocab = file_utils.cached_path(cfg.bpe_vocab) + bpe_merges = file_utils.cached_path(cfg.bpe_merges) + self.bpe = ByteLevelBPETokenizer( - cfg.bpe_vocab, - cfg.bpe_merges, + bpe_vocab, + bpe_merges, add_prefix_space=cfg.bpe_add_prefix_space, ) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 775d1f7aeb..7de2e2b0d4 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -60,6 +60,8 @@ def from_pretrained( "code": "bpe_codes", "bpecodes": "bpe_codes", "sentencepiece.bpe.model": "sentencepiece_model", + "merges.txt": "bpe_merges", + "vocab.json": "bpe_vocab", }.items(): path = os.path.join(model_path, file) if os.path.exists(path): diff --git a/fairseq/models/roberta/__init__.py b/fairseq/models/roberta/__init__.py index 56579e5915..cf16914fbc 100644 --- a/fairseq/models/roberta/__init__.py +++ b/fairseq/models/roberta/__init__.py @@ -6,4 +6,5 @@ from .hub_interface import * # noqa from .model import * # noqa from .model_camembert import * # noqa +from .model_gottbert import * # noqa from .model_xlmr import * # noqa diff --git a/fairseq/models/roberta/model_gottbert.py b/fairseq/models/roberta/model_gottbert.py new file mode 100644 index 0000000000..2e8c66354a --- /dev/null +++ b/fairseq/models/roberta/model_gottbert.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +GottBERT: a pure German Language Model +""" + +from fairseq.models import register_model + +from .hub_interface import RobertaHubInterface +from .model import RobertaModel + + +@register_model('gottbert') +class GottbertModel(RobertaModel): + + @classmethod + def hub_models(cls): + return { + 'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', + } + + @classmethod + def from_pretrained(cls, + model_name_or_path, + checkpoint_file='model.pt', + data_name_or_path='.', + bpe='hf_byte_bpe', + bpe_vocab='vocab.json', + bpe_merges='merges.txt', + bpe_add_prefix_space=False, + **kwargs + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + bpe_vocab=bpe_vocab, + bpe_merges=bpe_merges, + bpe_add_prefix_space=bpe_add_prefix_space, + **kwargs, + ) + return RobertaHubInterface(x['args'], x['task'], x['models'][0]) From 881e9f8920bc3d9c3b526f6e52405f7059c926b4 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 12 Dec 2020 07:18:45 -0800 Subject: [PATCH 359/707] Fix bug in FP16 training (#1503) Summary: Fix a critical bug in FP16 training (introduced on Dec 4: https://github.com/pytorch/fairseq/commit/ba4f54267af5c3f67f1b76a6e804b6ab593d1d39) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1503 Reviewed By: donhusa, daniellepintz Differential Revision: D25514055 Pulled By: myleott fbshipit-source-id: 38ebb1f41f365702ce7706846085c7c7cc24a98e --- fairseq/optim/fairseq_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 41c859355c..a1c1d219a0 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -112,7 +112,6 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): def step(self, closure=None, scale=1.0, groups=None): """Performs a single optimization step.""" if self.supports_step_with_scale: - self.optimizer.step(closure, scale=scale) if self.supports_groups: self.optimizer.step(closure, scale=scale, groups=groups) else: From ac11107ed41cb06a758af850373c239309d1c961 Mon Sep 17 00:00:00 2001 From: Colin Clement Date: Sat, 12 Dec 2020 08:01:56 -0800 Subject: [PATCH 360/707] Azure ML Logging to view training/validation progress in AzureML workspace (#2999) Summary: # Before submitting - [no] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [yes] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [yes, the CLI flag has a help string] Did you make sure to update the docs? - [no, but the code was successfully tested in training] Did you write any new necessary tests? ## What does this PR do? Adds a CLI flag `--azureml-logging` to `fairseq-train` which allows fairseq to log to the default Azure ML context to improve training on Azure Machine Learning services. If `azureml-core` is not installed, it fails with a logging message like the `wandb` logging. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Always! Pull Request resolved: https://github.com/pytorch/fairseq/pull/2999 Reviewed By: alexeib Differential Revision: D25494986 Pulled By: myleott fbshipit-source-id: fadd7569aeb72b5b6f9db0508cbec3a138c332d3 --- fairseq/dataclass/configs.py | 6 ++++ fairseq/logging/progress_bar.py | 54 +++++++++++++++++++++++++++++++++ fairseq_cli/train.py | 3 ++ 3 files changed, 63 insertions(+) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 1a89560072..7d27dc0d4b 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -109,6 +109,12 @@ class CommonConfig(FairseqDataclass): "help": "Weights and Biases project name to use for logging" }, ) + azureml_logging: Optional[bool] = field( + default=False, + metadata={ + "help": "Log scalars to AzureML context" + }, + ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} ) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 07ee26f4fc..e2a1711121 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -35,6 +35,7 @@ def progress_bar( default_log_format: str = "tqdm", wandb_project: Optional[str] = None, wandb_run_name: Optional[str] = None, + azureml_logging: Optional[bool] = False, ): if log_format is None: log_format = default_log_format @@ -65,6 +66,9 @@ def progress_bar( if wandb_project: bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) + if azureml_logging: + bar = AzureMLProgressBarWrapper(bar) + return bar @@ -407,3 +411,53 @@ def _log_to_wandb(self, stats, tag=None, step=None): wandb.log({prefix + key: stats[key].val}, step=step) elif isinstance(stats[key], Number): wandb.log({prefix + key: stats[key]}, step=step) + + +try: + from azureml.core import Run +except ImportError: + Run = None + + +class AzureMLProgressBarWrapper(BaseProgressBar): + """Log to Azure ML""" + + def __init__(self, wrapped_bar): + self.wrapped_bar = wrapped_bar + if Run is None: + logger.warning("azureml.core not found, pip install azureml-core") + return + self.run = Run.get_context() + + def __exit__(self, *exc): + if Run is not None: + self.run.complete() + return False + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to AzureML""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def _log_to_azureml(self, stats, tag=None, step=None): + if Run is None: + return + if step is None: + step = stats['num_updates'] + + prefix = '' if tag is None else tag + '/' + + for key in stats.keys() - {'num_updates'}: + name = prefix + key + if isinstance(stats[key], AverageMeter): + self.run.log_row(name=name, **{'step': step, key: stats[key].val}) + elif isinstance(stats[key], Number): + self.run.log_row(name=name, **{'step': step, key: stats[key]}) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 11baf5a59b..81ecc80140 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -218,6 +218,9 @@ def train( wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), + azureml_logging=( + cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False + ), ) trainer.begin_epoch(epoch_itr.epoch) From f49bb2c4d165d7c134e0dadc70c6fcda4ccd5e26 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 14 Dec 2020 10:45:36 -0800 Subject: [PATCH 361/707] Improve performance of distributed_utils.broadcast_object Summary: Recommit D25291594 (https://github.com/pytorch/fairseq/commit/bb039fa2063dca1b388d6be2f64052b07fb556a2) Original commit changeset: 64aead87ee84 Reviewed By: arendu Differential Revision: D25518027 fbshipit-source-id: 5da303f835eb4b598728c7c7ef7c07538ea3b9b4 --- fairseq/distributed_utils.py | 126 +++++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 21 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index dd93cda35c..8f98ac88f9 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -14,6 +14,7 @@ import warnings from argparse import Namespace from collections import OrderedDict +from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional import torch @@ -642,44 +643,127 @@ def get_from_stack(key): return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) -# From fairscale/optim/utils.py +def broadcast_tensors( + tensors: Optional[List[torch.Tensor]], + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + Broadcasts a list of tensors without other (non-src) ranks needing to know + the dtypes/shapes of the tensors. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + # share metadata first to simplify transfer + is_src_rank = (get_rank(group) == src_rank) + if is_src_rank: + metadata = [ + {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors + ] + metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + else: + metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + out_tensors = [] + for i, meta in enumerate(metadata): + if is_src_rank: + tensor = tensors[i] + broadcast(tensors[i].to(dist_device), src=src_rank, group=group) + else: + tensor = torch.zeros( + [meta["size"].numel()], dtype=meta["dtype"], device=dist_device + ) + broadcast(tensor, src=src_rank, group=group) + tensor = tensor.view(meta["size"]).to(meta["device"]) + out_tensors.append(tensor) + return out_tensors + + def broadcast_object( obj: Any, src_rank: int, group: object, dist_device: Optional[torch.device] = None, - dist_length_dtype: Optional[torch.dtype] = torch.long, - dist_dtype: Optional[torch.dtype] = torch.uint8, ) -> Any: - """ - Either broadcast from master to the fleet (default), - or use the src setting as the original rank. - """ + """Broadcast an arbitrary Python object to other workers.""" if dist_device is None: if torch.distributed.get_backend(group) == "nccl": dist_device = torch.device("cuda") else: dist_device = torch.device("cpu") + if get_rank(group) == src_rank: + # split the tensors from the non-tensors so we can broadcast them + # directly, avoiding unnecessary serialization/deserialization + tensors = [] + obj = _split_tensors_from_obj(obj, tensors) + obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + else: + obj = _broadcast_object_slow(None, src_rank, group, dist_device) + tensors = broadcast_tensors(None, src_rank, group, dist_device) + return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, src_rank: int, group: object, dist_device: torch.device, +) -> Any: if get_rank(group) == src_rank: # Emit data buffer = io.BytesIO() torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor( - [len(data)], dtype=dist_length_dtype, device=dist_device - ) - broadcast(length_tensor, src=src_rank, group=group) - data_send_tensor = torch.tensor(data, dtype=dist_dtype, device=dist_device) - broadcast(data_send_tensor, src=src_rank, group=group) + buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) + length = torch.LongTensor([len(buffer)]).to(dist_device) + broadcast(length, src=src_rank, group=group) + broadcast(buffer, src=src_rank, group=group) else: # Fetch from the source - length_tensor = torch.tensor([0], dtype=dist_length_dtype, device=dist_device) - broadcast(length_tensor, src=src_rank, group=group) - data_recv_tensor = torch.zeros( - [int(length_tensor.item())], dtype=dist_dtype, device=dist_device - ) - broadcast(data_recv_tensor, src=src_rank, group=group) - buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + length = torch.LongTensor([0]).to(dist_device) + broadcast(length, src=src_rank, group=group) + buffer = torch.ByteTensor(int(length.item())).to(dist_device) + broadcast(buffer, src=src_rank, group=group) + buffer = io.BytesIO(buffer.cpu().numpy()) obj = torch.load(buffer, map_location="cpu") return obj + + +@dataclass(frozen=True) +class _TensorPlaceholder: + index: int + + +def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if torch.is_tensor(obj): + placeholder = _TensorPlaceholder(index=len(tensors)) + tensors.append(obj) + return placeholder + elif isinstance(obj, dict): + return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors_from_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_split_tensors_from_obj(v, tensors) for v in obj} + else: + return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if isinstance(obj, _TensorPlaceholder): + return tensors[obj.index] + elif isinstance(obj, dict): + return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_put_tensors_in_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_put_tensors_in_obj(v, tensors) for v in obj} + else: + return obj From d740093bac8a6115b68ab5de9a6b63099a07497b Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 14 Dec 2020 11:34:21 -0800 Subject: [PATCH 362/707] Porting adaptive span to fairseq (#1428) Summary: ## What does this PR do? 1. We add an enwiki8 character level LM task sweep for transformer XL, which lands at 1.05 matches the performance (1.06): https://github.com/kimiyoung/transformer-xl/tree/master/pytorch Eval with ``` PYTHONPATH=. python fairseq_cli/eval_lm.py /private/home/daju/data/enwik8/eos-data-bin/ --path /checkpoint/daju/2020-11-19/enwiki8.transformer_xl.fp16.transformer_xl.adam.cl0.25.cosine.lr0.00025.s2.ngpu4/checkpoint_best.pt --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 80 --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' ``` 2. Impalements adaptive span in fairseq code. It reproduces the enwiki8 result at 1.03 comparing to 1.02 (for the 12 L model) reported in https://github.com/facebookresearch/adaptive-span, which is a consistent improvement over the transformer XL baseline listed above with a smaller model. You can evaluate the example run with: ``` PYTHONPATH=. python fairseq_cli/eval_lm.py /private/home/daju/data/enwik8/eos-data-bin/ --path /checkpoint/daju/2020-11-20/enwiki8.adaptivespan.headwise.adaptive_span.adagrad_with_grad_clip.ag_cl0.03.fixed.wu32000.lr0.07.s2.loss5e-07.ngpu4/checkpoint_best.pt --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test ``` Paper: https://arxiv.org/abs/1905.07799 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1428 Reviewed By: myleott Differential Revision: D25495754 Pulled By: dexterju fbshipit-source-id: 15a875a5f82d506a4964dea934a374132ce39f8b --- README.md | 3 +- examples/adaptive_span/README.md | 90 ++++++ examples/adaptive_span/__init__.py | 19 ++ .../adaptive_span/adagrad_with_grad_clip.py | 128 +++++++++ .../adaptive_span/adaptive_span_attention.py | 160 +++++++++++ examples/adaptive_span/adaptive_span_loss.py | 106 +++++++ examples/adaptive_span/adaptive_span_model.py | 263 ++++++++++++++++++ .../adaptive_span_model_wrapper.py | 145 ++++++++++ .../adaptive_span/truncated_bptt_lm_task.py | 1 + fairseq/optim/adagrad.py | 2 +- 10 files changed, 915 insertions(+), 2 deletions(-) create mode 100644 examples/adaptive_span/README.md create mode 100644 examples/adaptive_span/__init__.py create mode 100644 examples/adaptive_span/adagrad_with_grad_clip.py create mode 100644 examples/adaptive_span/adaptive_span_attention.py create mode 100644 examples/adaptive_span/adaptive_span_loss.py create mode 100644 examples/adaptive_span/adaptive_span_model.py create mode 100644 examples/adaptive_span/adaptive_span_model_wrapper.py create mode 120000 examples/adaptive_span/truncated_bptt_lm_task.py diff --git a/README.md b/README.md index 9cc5b7a559..cc1c76ec36 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ We provide reference implementations of various sequence modeling papers: + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](examples/truncated_bptt/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) diff --git a/examples/adaptive_span/README.md b/examples/adaptive_span/README.md new file mode 100644 index 0000000000..913a873386 --- /dev/null +++ b/examples/adaptive_span/README.md @@ -0,0 +1,90 @@ +# Adaptive Span + +Adaptive Span is a novel self-attention mechanism that can learn its optimal +attention span. This allows us to extend significantly the maximum context size +used in Transformer, while maintaining control over their memory footprint +and computational time. It uses the Truncated BPTT technique for training, +as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md). + +Adaptive Span was introduced by paper: +[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799), +which achieved state-of-the-art language modeling results at the time of publication. + +We manage to reproduce their result in fairseq and keep most of the +[original implementation](https://github.com/facebookresearch/adaptive-span) untouched. +You can refer to the their sweep file as well if any combination of hyperparameter is not clear. + +##### 0. Setup + +First you need to process the Enwik8 dataset, we use the pre-tokenized dataset +from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh). +You can download the dataset, and then run: +```bash +fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \ + --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \ + --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20 +``` + +##### 1. Train a Adaptive Span model on Enwik8 + +We will train a 12-layer Adaptive Span model following the [hyperparameters +used in the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). + +The following command assumes 4 GPUs, so that the total batch size is 64 +sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/adaptive_span \ + --data ~/data/enwik8/data-bin/ \ + --fp16 --fp16-no-flatten-grads --max-update 600000 \ + --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \ + --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \ + --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \ + --validate-interval-updates 1000 \ + --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \ + --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \ + --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07 +``` +This should land around 1.05 on validation, 1.03 on test. You can lower the +--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc +improvement to the transformerXL baseline here. +If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients +and simulate training on 4 GPUs. +You can also reproduce the transformerXL result on enwik8 using this code base. +It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh). +You can try by +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/truncated_bptt \ + ~/data/enwik8/data-bin/ \ + --task truncated_bptt_lm --fp16 --max-update 400000 \ + --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \ + --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \ + --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \ + --lr-scheduler cosine --warmup-updates 0 \ + --lr 0.0 --lr 0.00025 --batch-size 15 \ + --update-freq 1 --seed 2 --log-format json --log-interval 25 \ + --fp16 +``` + +##### 2. Evaluate +For Adaptive Span: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/adaptive_span \ + --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test +``` +For Transformer-XL evaluation: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \ + --tokens-per-sample 80 \ + --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \ + --gen-subset valid +``` + +*Note:* During training the model saw 512 tokens of context +(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation +settings from [the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). diff --git a/examples/adaptive_span/__init__.py b/examples/adaptive_span/__init__.py new file mode 100644 index 0000000000..e0a142a769 --- /dev/null +++ b/examples/adaptive_span/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +# automatically import any Python files in the current directory +cur_dir = os.path.dirname(__file__) +for file in os.listdir(cur_dir): + path = os.path.join(cur_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + mod_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module(__name__ + "." + mod_name) diff --git a/examples/adaptive_span/adagrad_with_grad_clip.py b/examples/adaptive_span/adagrad_with_grad_clip.py new file mode 100644 index 0000000000..585ce184ab --- /dev/null +++ b/examples/adaptive_span/adagrad_with_grad_clip.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.optim import Adagrad + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adagrad_with_grad_clip") +class FairseqAdagradWithGradClip(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = AdagradWithGradClip(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D', + help='internal grad clip') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, + "grad_clip": self.args.adagrad_clip, + } + + @property + def supports_flat_params(self): + return False + + +def _clip_grad(clr, grad, group_grad_clip): + if group_grad_clip > 0: + norm = grad.norm(2).item() + if norm > group_grad_clip: + clr *= group_grad_clip / (norm + 1e-10) + return clr + + +class AdagradWithGradClip(Adagrad): + """Adagrad algorithm with custom gradient clipping""" + + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + grad_clip=0, + ): + Adagrad.__init__( + self, + params, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + ) + self.defaults["grad_clip"] = grad_clip + self.param_groups[0].setdefault("grad_clip", grad_clip) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + state = self.state[p] + + state["step"] += 1 + + if group["weight_decay"] != 0: + if p.grad.data.is_sparse: + raise RuntimeError( + "weight_decay option is " + "not compatible with sparse " + "gradients" + ) + grad = grad.add(group["weight_decay"], p.data) + + clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"]) + + # clip + clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"]) + + if grad.is_sparse: + # the update is non-linear so indices must be unique + grad = grad.coalesce() + grad_indices = grad._indices() + grad_values = grad._values() + size = grad.size() + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + state["sum"].add_(make_sparse(grad_values.pow(2))) + std = state["sum"]._sparse_mask(grad) + std_values = std._values().sqrt_().add_(1e-10) + p.data.add_(-clr, make_sparse(grad_values / std_values)) + else: + state["sum"].addcmul_(1, grad, grad) + std = state["sum"].sqrt().add_(1e-10) + p.data.addcdiv_(-clr, grad, std) + + return loss diff --git a/examples/adaptive_span/adaptive_span_attention.py b/examples/adaptive_span/adaptive_span_attention.py new file mode 100644 index 0000000000..07f757bb8e --- /dev/null +++ b/examples/adaptive_span/adaptive_span_attention.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AdaptiveMask(nn.Module): + """Soft masking function for adaptive size. + It masks out the last K values of an input. The masking value + goes from 1 to 0 gradually, so K can be learned with + back-propagation. + Args: + max_size: maximum size (i.e. input dimension) + ramp_size: size of the ramp going from 0 to 1 + init_val: initial size proportion not to be masked out + shape: learn multiple sizes independent of each other + """ + + def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)): + nn.Module.__init__(self) + self._max_size = max_size + self._ramp_size = ramp_size + self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) + mask_template = torch.linspace(1 - max_size, 0, steps=max_size) + self.register_buffer("mask_template", mask_template) + + def forward(self, x): + mask = self.mask_template.float() + self.current_val.float() * self._max_size + mask = mask / self._ramp_size + 1 + mask = mask.clamp(0, 1) + if x.size(-1) < self._max_size: + # the input could have been trimmed beforehand to save computation + mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1)) + x = (x * mask).type_as(x) + return x + + def get_current_max_size(self, include_ramp=True): + current_size = math.ceil(self.current_val.max().item() * self._max_size) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def get_current_avg_size(self, include_ramp=True): + current_size = math.ceil( + self.current_val.float().mean().item() * self._max_size + ) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def clamp_param(self): + """this need to be called after each update""" + self.current_val.data.clamp_(0, 1) + + +class AdaptiveSpan(nn.Module): + """Adaptive attention span for Transformerself. + This module learns an attention span length from data for each + self-attention head. + Args: + attn_span: maximum attention span + adapt_span_loss: loss coefficient for the span length + adapt_span_ramp: length of the masking ramp + adapt_span_init: initial size ratio + adapt_span_cache: adapt cache size to reduce memory usage + """ + + def __init__( + self, + attn_span, + adapt_span_ramp, + adapt_span_init, + n_head, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + self._max_span = attn_span + self._n_head = n_head + self._adapt_span_layer = adapt_span_layer + if self._adapt_span_layer: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + ) + else: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + shape=(n_head, 1, 1), + ) + + def forward(self, attn, normalize=True): + """mask attention with the right span""" + # batch and head dimensions are merged together, so separate them first + self.clamp_param() + if self._adapt_span_layer: + attn = self._mask(attn) + else: + B = attn.size(0) # batch size + M = attn.size(1) # block size + attn = attn.reshape(B // self._n_head, self._n_head, M, -1) + attn = self._mask(attn) + attn = attn.view(B, M, -1) + return attn + + def get_trim_len(self): + """how much of memory can be trimmed to reduce computation""" + L = self._max_span + trim_len = min(L - 1, L - self._mask.get_current_max_size()) + # too fine granularity might be bad for the memory management + trim_len = math.floor(trim_len / 64) * 64 + return trim_len + + def trim_memory(self, query, key, value, key_pe): + """trim out unnecessary memory beforehand to reduce computation""" + trim_len = self.get_trim_len() + cache_size = key.size(1) - query.size(1) + trim_len_cache = trim_len - (self._max_span - cache_size) + if trim_len_cache > 0: + key = key[:, trim_len_cache:, :] + value = value[:, trim_len_cache:, :] + elif trim_len_cache < 0: + # cache is too short! this happens when validation resumes + # after a lot of updates. + key = F.pad(key, [0, 0, -trim_len_cache, 0]) + value = F.pad(value, [0, 0, -trim_len_cache, 0]) + if trim_len > 0: + if key_pe is not None: + key_pe = key_pe[:, :, trim_len:] + return key, value, key_pe + + def get_cache_size(self): + """determine how long the cache should be""" + trim_len = self.get_trim_len() + # give a buffer of 64 steps since a span might increase + # in future updates + return min(self._max_span, self._max_span - trim_len + 64) + + def get_loss(self): + """a loss term for regularizing the span length""" + return self._max_span * self._mask.current_val.float().mean() + + def get_current_max_span(self): + return self._mask.get_current_max_size() + + def get_current_avg_span(self): + return self._mask.get_current_avg_size() + + def clamp_param(self): + self._mask.clamp_param() diff --git a/examples/adaptive_span/adaptive_span_loss.py b/examples/adaptive_span/adaptive_span_loss.py new file mode 100644 index 0000000000..056245807e --- /dev/null +++ b/examples/adaptive_span/adaptive_span_loss.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class AdaptiveSpanCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig) +class AdaptiveSpanCriterion(CrossEntropyCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task, sentence_avg) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss here is summed, different from the adaptive span code + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, aux_loss, avg_span, max_span = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + loss /= sample_size + total_loss = loss + aux_loss + sample_size = 1 + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "total_loss": total_loss.data, + "avg_span": avg_span * sample_size, + "max_span": max_span * sample_size, + } + return total_loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + loss, _ = super().compute_loss(model, net_output, sample, reduce) + aux_loss = model.get_aux_loss() + avg_span = model.get_current_avg_span() + max_span = model.get_current_max_span() + return loss, aux_loss, avg_span, max_span + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs) + avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs) + max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3) + metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3) + # total loss contains the L1 norm on adaptive-span + metrics.log_scalar( + "total_loss", + total_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/examples/adaptive_span/adaptive_span_model.py b/examples/adaptive_span/adaptive_span_model.py new file mode 100644 index 0000000000..d96c95b85d --- /dev/null +++ b/examples/adaptive_span/adaptive_span_model.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.modules.layer_norm import LayerNorm + +from .adaptive_span_attention import AdaptiveSpan + +# Size notations: +# B = batch_size, H = d_model, M = block_size, L = attn_span + + +def _skew(X, pad_value): + """shift every row 1 step to right""" + # X = B x M x L + B, M, L = X.size() + X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1) + X = X.view(B, -1) # B x ML+MM+M + X = X[:, :-M] # B x ML+MM + X = X.view(B, M, M + L) # B x M x L+M + return X + + +def _unskew(X): + """reverse _skew operation""" + # X = B x M x L+M + B, M, L = X.size() + L -= M + X = X.view(B, -1) # B x ML+MM + X = F.pad(X, (0, M)) # B x ML+MM+M + X = X.view(B, M, M + L + 1) # B x M x L+M+1 + X = X[:, :, :L] # B x M x L + return X + + +class SeqAttention(nn.Module): + """Sequential self-attention layer. + Each token will attend to its previous fixed number of steps. + Note that attention doesn't include the current step itself. + """ + + def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs): + nn.Module.__init__(self) + self.dropout = nn.Dropout(dropout) + self.d_model = d_model # size of a single head + self.attn_span = attn_span + self.adaptive_span = AdaptiveSpan( + attn_span=attn_span, + n_head=n_head, + adapt_span_layer=adapt_span_layer, + **kargs + ) + + def forward(self, query, key, value, key_pe): + # query size = B x M x H + # key, value sizes = B x (M+L) x H + + key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe) + + # compute attention from context + # B x M (dest) x (M+L) (src) + attn_cont = torch.matmul(query, key.transpose(-1, -2)) + attn_cont = _unskew(attn_cont) # B x M x L + + # compute the effect of position embedding + attn_pos = torch.matmul(query, key_pe) # B x M x L_pos + attn = attn_cont + attn_pos + + attn = attn / math.sqrt(self.d_model) # B x M X L_pos + + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + + # trim attention lengths according to the learned span + attn = self.adaptive_span(attn) + + attn = self.dropout(attn) # B x M X L_pos + + attn_cont = _skew(attn, 0) # B x M X (L+M) + out = torch.matmul(attn_cont, value) # B x M x H + return out + + def get_cache_size(self): + return self.adaptive_span.get_cache_size() + + +class MultiHeadSeqAttention(nn.Module): + def __init__(self, d_model, n_head, **kargs): + nn.Module.__init__(self) + assert d_model % n_head == 0 + self.n_head = n_head + self.head_dim = d_model // n_head + self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs) + self.proj_query = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_query.weight) + self.proj_out = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_out.weight) + self.proj_val = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_val.weight) + self.proj_key = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_key.weight) + + def head_reshape(self, x): + K = self.n_head + D = self.head_dim + x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D + x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D + x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D + return x + + def forward(self, query, key, value, key_pe): + B = query.size(0) + K = self.n_head + D = self.head_dim + M = query.size(1) + + query = self.proj_query(query) + query = self.head_reshape(query) + value = self.proj_val(value) + value = self.head_reshape(value) + key = self.proj_key(key) + key = self.head_reshape(key) + + out = self.attn(query, key, value, key_pe) # B_K x M x D + out = out.view(B, K, M, D) # B x K x M x D + out = out.transpose(1, 2).contiguous() # B x M x K x D + out = out.view(B, M, -1) # B x M x K_D + out = self.proj_out(out) + return out + + +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, d_inner, dropout, **kargs): + nn.Module.__init__(self) + self.fc1 = nn.Linear(d_model, d_inner) + self.fc2 = nn.Linear(d_inner, d_model) + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + self.dropout = nn.Dropout(dropout) + + def forward(self, h): + h1 = F.relu(self.fc1(h)) + h1 = self.dropout(h1) + h2 = self.fc2(h1) + return h2 + + +class TransformerSeqLayer(nn.Module): + def __init__(self, d_model, **kargs): + nn.Module.__init__(self) + self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs) + self.norm1 = LayerNorm(d_model) + self.ff = FeedForwardLayer(d_model=d_model, **kargs) + self.norm2 = LayerNorm(d_model) + + def forward(self, h, h_cache, key_pe): + # h = B x M x H + # h_cache = B x L x H + h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H + attn_out = self.attn(h, h_all, h_all, key_pe) + h = self.norm1(h + attn_out) # B x M x H + if self.ff is not None: + ff_out = self.ff(h) + out = self.norm2(h + ff_out) # B x M x H + else: + out = h + return out + + def get_cache_size(self): + return self.attn.attn.get_cache_size() + + +class TransformerSeq(nn.Module): + def __init__( + self, + vocab_size, + d_model, + n_head, + n_layer, + attn_span, + emb_dropout, + aux_loss_scaler, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + # token embeddings + self.in_emb = nn.Embedding(vocab_size, d_model) + nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5) + self.out_emb = nn.Linear(d_model, vocab_size) + self.aux_loss_scaler = aux_loss_scaler + if emb_dropout > 0: + self.emb_dropout = nn.Dropout(emb_dropout) + else: + self.emb_dropout = None + # position embeddings + self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span)) + + self.layers = nn.ModuleList() + self.layers.extend( + TransformerSeqLayer( + d_model=d_model, + n_head=n_head, + attn_span=attn_span, + adapt_span_layer=adapt_span_layer, + **kargs + ) + for _ in range(n_layer) + ) + + def forward(self, x, h_cache, target=None): + # x size = B x M + block_size = x.size(1) + h = self.in_emb(x) # B x M x H + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + h_cache_next = [] + for l, layer in enumerate(self.layers): + cache_size = layer.attn.attn.get_cache_size() + if cache_size > block_size: + h_cache_next_l = torch.cat( + [h_cache[l][:, -cache_size + block_size :, :], h], dim=1 + ).detach() + else: + h_cache_next_l = h[:, -cache_size:, :].detach() + h_cache_next.append(h_cache_next_l) + h = layer(h, h_cache[l], self.key_pe) # B x M x H + + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h) + dummy_loss = None + + return out, h_cache_next, dummy_loss + + def get_aux_loss(self): + loss = 0.0 + for layer in self.layers: + loss += layer.attn.attn.adaptive_span.get_loss() + return self.aux_loss_scaler * loss + + def get_current_max_span(self): + max_span = 0.0 + for layer in self.layers: + max_span = max( + max_span, layer.attn.attn.adaptive_span.get_current_max_span() + ) + return max_span + + def get_current_avg_span(self): + avg_span = 0.0 + for layer in self.layers: + avg_span += layer.attn.attn.adaptive_span.get_current_avg_span() + return avg_span / len(self.layers) diff --git a/examples/adaptive_span/adaptive_span_model_wrapper.py b/examples/adaptive_span/adaptive_span_model_wrapper.py new file mode 100644 index 0000000000..5b147fe11f --- /dev/null +++ b/examples/adaptive_span/adaptive_span_model_wrapper.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, +) +from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel + + +logger = logging.getLogger(__name__) + + +@dataclass +class AdaptiveSpanSmallConfig(FairseqDataclass): + # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh + vocab_size: int = 50 + d_model: int = 256 + n_head: int = 4 + d_inner: int = 1024 + n_layer: int = 8 + attn_span: int = 1024 + dropout: float = 0.0 + emb_dropout: float = 0.0 + adapt_span_ramp: int = 32 + adapt_span_init: float = 0.0 + aux_loss_scaler: float = 0.000002 + adapt_span_layer: bool = False + + +@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig) +class AdaptiveSpanTransformer(FairseqLanguageModel): + @classmethod + def build_model(cls, cfg: AdaptiveSpanSmallConfig, task): + return cls(AdaptiveSpanDecoder(cfg, task)) + + def get_aux_loss(self): + return self.decoder.get_aux_loss() + + def get_current_max_span(self): + return self.decoder.get_current_max_span() + + def get_current_avg_span(self): + return self.decoder.get_current_avg_span() + + +class AdaptiveSpanDecoder(FairseqIncrementalDecoder): + def __init__(self, cfg, task): + + super().__init__(task.target_dictionary) + + self.config = cfg + config = AdaptiveSpanSmallConfig( + vocab_size=len(task.target_dictionary), + d_model=cfg.d_model, + n_head=cfg.n_head, + d_inner=cfg.d_inner, + n_layer=cfg.n_layer, + attn_span=cfg.attn_span, + dropout=cfg.dropout, + emb_dropout=cfg.emb_dropout, + adapt_span_ramp=cfg.adapt_span_ramp, + adapt_span_init=cfg.adapt_span_init, + aux_loss_scaler=cfg.aux_loss_scaler, + adapt_span_layer=cfg.adapt_span_layer, + ) + logger.info(config) + self.model = AdaptiveSpanTransformerModel(**config.__dict__) + + self._mems = None + + def forward( + self, + src_tokens, + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + bsz = src_tokens.size(0) + if incremental_state is not None: # used during inference + mems = self.get_incremental_state("mems") + src_tokens = src_tokens[:, -1:] # only keep the most recent token + else: + mems = self._mems + + if mems is None: + # first time init + mems = self.init_hid_cache(bsz) + output = self.model(x=src_tokens, h_cache=mems,) + if incremental_state is not None: + self.set_incremental_state(incremental_state, "mems", output[1]) + else: + self._mems = output[1] + return (output[0],) + + def max_positions(self): + return self.config.attn_span + + def init_hid_cache(self, batch_sz): + hid = [] + for layer in self.model.layers: + param = next(self.model.parameters()) + h = torch.zeros( + batch_sz, + layer.get_cache_size(), + self.config.d_model, + dtype=param.dtype, + device=param.device, + ) + hid.append(h) + return hid + + def get_aux_loss(self): + return self.model.get_aux_loss() + + def get_current_max_span(self): + return self.model.get_current_max_span() + + def get_current_avg_span(self): + return self.model.get_current_avg_span() + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + new_order: torch.Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + raise NotImplementedError("This is required for generation/beam search") + # mems = self.get_incremental_state(incremental_state, "mems") + # if mems is not None: + # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] + # self.set_incremental_state(incremental_state, "mems", new_mems) diff --git a/examples/adaptive_span/truncated_bptt_lm_task.py b/examples/adaptive_span/truncated_bptt_lm_task.py new file mode 120000 index 0000000000..a92da3a298 --- /dev/null +++ b/examples/adaptive_span/truncated_bptt_lm_task.py @@ -0,0 +1 @@ +../truncated_bptt/truncated_bptt_lm_task.py \ No newline at end of file diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index a79b6c39da..4f539541c1 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -37,4 +37,4 @@ def optimizer_config(self): @property def supports_flat_params(self): - return True + return False From 5a3e51d21187415a66632f65e92c71de61367a93 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 15 Dec 2020 17:19:53 -0800 Subject: [PATCH 363/707] Infer TPU flag automatically and deprecate prepare_for_tpu (#1514) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1514 Reviewed By: ngoyal2707 Differential Revision: D25568540 Pulled By: myleott fbshipit-source-id: 93cf2dfd55323bc396b893eb6d658f1c6283a88b --- .../modules/multihead_linear_attention.py | 4 ---- .../model_parallel/modules/multihead_attention.py | 9 +++------ fairseq/models/fairseq_model.py | 15 --------------- fairseq/modules/multihead_attention.py | 10 ++++------ fairseq/modules/transformer_sentence_encoder.py | 7 ++----- fairseq/tasks/fairseq_task.py | 4 ---- 6 files changed, 9 insertions(+), 40 deletions(-) diff --git a/examples/linformer/linformer_src/modules/multihead_linear_attention.py b/examples/linformer/linformer_src/modules/multihead_linear_attention.py index ba2c36b1ef..6be1007279 100644 --- a/examples/linformer/linformer_src/modules/multihead_linear_attention.py +++ b/examples/linformer/linformer_src/modules/multihead_linear_attention.py @@ -111,14 +111,10 @@ def __init__( self.compress_v.weight.requires_grad = False self.onnx_trace = False - self.tpu = False def prepare_for_onnx_export_(self): self.onnx_trace = True - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index 4164bf9131..8eb9d09dad 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -93,11 +93,6 @@ def __init__( embed_dim, embed_dim, bias=bias, input_is_parallel=True ) - self.tpu = False - - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def forward( self, query, @@ -123,6 +118,8 @@ def forward( assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] + is_tpu = query.device.type == "xla" + if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: @@ -250,7 +247,7 @@ def forward( attn_weights = attn_weights.view( bsz, self.num_heads_partition, tgt_len, src_len ) - if not self.tpu: + if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 926d952f77..244cbc0c66 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -223,21 +223,6 @@ def apply_prepare_for_onnx_export_(module): self.apply(apply_prepare_for_onnx_export_) - def prepare_for_tpu_(self, **kwargs): - """Optionally modify model for use on TPUs.""" - seen = set() - - def apply_prepare_for_tpu_(module): - if ( - module != self - and hasattr(module, "prepare_for_tpu_") - and module not in seen - ): - seen.add(module) - module.prepare_for_tpu_(**kwargs) - - self.apply(apply_prepare_for_tpu_) - @classmethod def from_pretrained( cls, diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 99f95deb5f..6ab86245d2 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -87,14 +87,10 @@ def __init__( self.reset_parameters() self.onnx_trace = False - self.tpu = False def prepare_for_onnx_export_(self): self.onnx_trace = True - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with @@ -148,13 +144,15 @@ def forward( if need_head_weights: need_weights = True + is_tpu = query.device.type == "xla" + tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if ( not self.onnx_trace - and not self.tpu # don't use PyTorch version on TPUs + and not is_tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation @@ -337,7 +335,7 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - if not self.tpu: + if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 208488f562..7a5dcbdde3 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -113,7 +113,6 @@ def __init__( self.apply_bert_init = apply_bert_init self.learned_pos_embedding = learned_pos_embedding self.traceable = traceable - self.tpu = False # whether we're on TPU self.embed_tokens = self.build_embedding( self.vocab_size, self.embedding_dim, self.padding_idx @@ -220,9 +219,6 @@ def build_transformer_sentence_encoder_layer( qn_block_size=qn_block_size, ) - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def forward( self, tokens: torch.Tensor, @@ -231,10 +227,11 @@ def forward( positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + is_tpu = tokens.device.type == "xla" # compute padding mask. This is needed for multi-head attention padding_mask = tokens.eq(self.padding_idx) - if not self.traceable and not self.tpu and not padding_mask.any(): + if not self.traceable and not is_tpu and not padding_mask.any(): padding_mask = None if token_embeddings is not None: diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index b99c511990..24116bfd52 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -280,8 +280,6 @@ def build_model(self, cfg: FairseqDataclass): from fairseq import models, quantization_utils model = models.build_model(cfg, self) - if getattr(cfg, "tpu", False): - model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, cfg) return model @@ -567,8 +565,6 @@ def build_model(self, args: Namespace): from fairseq import models, quantization_utils model = models.build_model(args, self) - if getattr(args, "tpu", False): - model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, args) return model From 8d7ee5bf813f1fc0a0685e0cb1ef58fcc63e7855 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 15 Dec 2020 17:46:37 -0800 Subject: [PATCH 364/707] Fix hydra with Python 3.8 (#1511) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1511 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25570468 Pulled By: myleott fbshipit-source-id: 98efc6983479e163e6cf0a7ef33decaa1bc429f1 --- fairseq/dataclass/utils.py | 3 ++- setup.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 9d52d45942..45e7ed9170 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -228,8 +228,9 @@ def get_default(f): if isinstance(val, tuple): val = list(val) + v_type = getattr(v.type, "__origin__", None) if ( - getattr(v.type, "__origin__", None) is List + (v_type is List or v_type is list) # skip interpolation and not (isinstance(val, str) and val.startswith("${")) ): diff --git a/setup.py b/setup.py index 0a4be4b0dd..1954298034 100644 --- a/setup.py +++ b/setup.py @@ -168,6 +168,8 @@ def do_setup(package_data): "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], long_description=readme, @@ -182,7 +184,8 @@ def do_setup(package_data): "cffi", "cython", 'dataclasses; python_version<"3.7"', - "hydra-core", + "hydra-core<1.1", + "omegaconf<2.1", 'numpy<1.20.0; python_version<"3.7"', 'numpy; python_version>="3.7"', "regex", From 409032596bd80240f7fbc833b5d37485dee85b0e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 15 Dec 2020 17:46:37 -0800 Subject: [PATCH 365/707] Fix loading of very old checkpoints (#1512) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1512 See https://github.com/pytorch/fairseq/issues/3032 for context Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25570470 Pulled By: myleott fbshipit-source-id: 9227b1ca36cd81ff72acdb5e03fd574e3e8769be --- fairseq/checkpoint_utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 6209c71aef..f178617b5a 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -451,6 +451,12 @@ def _upgrade_state_dict(state): # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 + # old model checkpoints may not have separate source/target positions + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" + ): + state["args"].max_source_positions = state["args"].max_positions + state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { @@ -489,22 +495,16 @@ def _upgrade_state_dict(state): # audio_cpc => wav2vec if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": state["args"].arch = "wav2vec" + # convert legacy float learning rate to List[float] + if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): + state["args"].lr = [state["args"].lr] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: with open_dict(state["cfg"]): - if state["cfg"].task is not None: - # old model checkpoints may not have separate source/target positions - if hasattr(state["cfg"].task, "max_positions") and not hasattr( - state["cfg"].task, "max_source_positions" - ): - state["cfg"].task.max_source_positions = state[ - "cfg" - ].task.max_positions - state["cfg"].task.max_target_positions = state[ - "cfg" - ].task.max_positions + # any upgrades for Hydra-based configs + pass return state From c8a0659be5cdc15caa102d5bbf72b872567c4859 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 16 Dec 2020 19:06:45 -0800 Subject: [PATCH 366/707] Stronger --checkpoint-activations test (#1505) Summary: - captures and inspects train and valid logs using unittest's `assert_logs_equal` - asserts that `--checkpoint-activations` does not change `train_loss` or `valid_loss`. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1505 Reviewed By: myleott Differential Revision: D25544991 Pulled By: sshleifer fbshipit-source-id: 2762095ab4e7c819a803b3556f5774db8c6b6f39 --- tests/test_binaries.py | 126 ++++++++++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 38 deletions(-) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 58f86484f7..2b57aa66da 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -5,13 +5,14 @@ import contextlib import logging +import json import os import random import sys import tempfile import unittest from io import StringIO - +from typing import List, Dict import torch from fairseq import options from fairseq_cli import eval_lm, train, validate @@ -249,29 +250,6 @@ def test_transformer(self): ) generate_main(data_dir) - def test_transformer_with_activation_checkpointing(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model( - data_dir, - "transformer_iwslt_de_en", - [ - "--encoder-layers", - "2", - "--decoder-layers", - "2", - "--encoder-embed-dim", - "8", - "--decoder-embed-dim", - "8", - "--checkpoint-activations", - ], - run_validation=True, - ) - generate_main(data_dir) - def test_multilingual_transformer(self): # test with all combinations of encoder/decoder lang tokens encoder_langtok_flags = [ @@ -326,7 +304,9 @@ def test_multilingual_transformer(self): + dec_ltok_flag, ) - @unittest.skipIf(sys.platform.lower() == "darwin", "skip latent depth test on MacOS") + @unittest.skipIf( + sys.platform.lower() == "darwin", "skip latent depth test on MacOS" + ) def test_multilingual_translation_latent_depth(self): # test with latent depth in encoder, decoder, or both encoder_latent_layer = [[], ["--encoder-latent-layer"]] @@ -465,9 +445,7 @@ def test_translation_multi_simple_epoch_no_vepoch(self): "test_translation_multi_simple_epoch_dict" ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data( - data_dir, extra_flags=[] - ) + preprocess_translation_data(data_dir, extra_flags=[]) train_translation_model( data_dir, arch="transformer", @@ -517,9 +495,7 @@ def test_translation_multi_simple_epoch_dicts(self): "test_translation_multi_simple_epoch_dict" ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data( - data_dir, extra_flags=[] - ) + preprocess_translation_data(data_dir, extra_flags=[]) train_translation_model( data_dir, arch="transformer", @@ -619,11 +595,17 @@ def test_transformer_pointer_generator(self): "0", ], run_validation=True, - extra_valid_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], + extra_valid_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], ) generate_main( data_dir, - extra_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], + extra_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], ) def test_lightconv(self): @@ -953,7 +935,7 @@ def test_transformer_layerdrop(self): data_dir, [ "--model-overrides", - "{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}" + "{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}", ], ) @@ -1080,7 +1062,9 @@ def test_transformer_lm(self): def test_transformer_lm_with_adaptive_softmax(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_transformer_lm_with_adaptive_softmax") as data_dir: + with tempfile.TemporaryDirectory( + "test_transformer_lm_with_adaptive_softmax" + ) as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) train_language_model( @@ -1199,7 +1183,8 @@ def test_transformer_xl_bptt_lm(self): train_language_model( data_dir=data_dir, arch="transformer_xl", - extra_flags=task_flags + [ + extra_flags=task_flags + + [ "--n-layer", "2", ], @@ -1537,6 +1522,65 @@ def test_optimizers(self): generate_main(data_dir) +def read_last_log_entry( + logs: List[logging.LogRecord], logger_name: str +) -> Dict[str, float]: + for x in reversed(logs): + if x.name == logger_name: + return json.loads(x.message) + raise ValueError(f"No entries from {logger_name} found in captured logs") + + +class TestActivationCheckpointing(unittest.TestCase): + def test_activation_checkpointing_does_not_change_metrics(self): + """--checkpoint-activations should not change loss""" + base_flags = [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--restore-file", + "x.pt", + "--log-format", + "json", + "--log-interval", + "1", + "--max-update", + "2", + ] + + def _train(extra_flags): + with self.assertLogs() as logs: + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + base_flags + extra_flags, + run_validation=True, + extra_valid_flags=["--log-format", "json"], + ) + return logs.records + + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + ckpt_logs = _train(["--checkpoint-activations"]) + baseline_logs = _train([]) + assert len(baseline_logs) == len(ckpt_logs) + + baseline_train_stats = read_last_log_entry(baseline_logs, "train") + ckpt_train_stats = read_last_log_entry(ckpt_logs, "train") + assert baseline_train_stats["train_loss"] == ckpt_train_stats["train_loss"] + + baseline_valid_stats = read_last_log_entry(baseline_logs, "valid") + ckpt_valid_stats = read_last_log_entry(ckpt_logs, "valid") + assert baseline_valid_stats["valid_loss"] == ckpt_valid_stats["valid_loss"] + + def create_dummy_roberta_head_data( data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False ): @@ -1653,7 +1697,12 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): def train_language_model( - data_dir, arch, extra_flags=None, run_validation=False, extra_valid_flags=None, task="language_modeling" + data_dir, + arch, + extra_flags=None, + run_validation=False, + extra_valid_flags=None, + task="language_modeling", ): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( @@ -1723,7 +1772,8 @@ def eval_lm_main(data_dir, extra_flags=None): "--no-progress-bar", "--num-workers", "0", - ] + (extra_flags or []), + ] + + (extra_flags or []), ) eval_lm.main(eval_lm_args) From edc321e767ca4e980463d7af7f3d5eb751f60962 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 18 Dec 2020 07:39:39 -0800 Subject: [PATCH 367/707] Support atomic saves for checkpoints (#1520) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1520 Test Plan: Imported from OSS Reviewed By: stephenroller Differential Revision: D25632782 Pulled By: myleott fbshipit-source-id: bdbe2aed6254d0b023b33f8027dfbd939f1fd271 --- .github/workflows/build.yml | 4 ++++ fairseq/checkpoint_utils.py | 11 +++++++++-- fairseq/file_io.py | 23 +++++++++++++++++++++-- tests/test_iopath.py | 29 +++++++++++++++++++++++++++++ tests/test_reproducibility.py | 2 +- 5 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 tests/test_iopath.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a2d44dd57f..29e5254d33 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,6 +37,10 @@ jobs: python setup.py build_ext --inplace python -m pip install --editable . + - name: Install optional test requirements + run: | + python -m pip install fairscale iopath transformers + - name: Lint with flake8 run: | pip install flake8 diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index f178617b5a..34fef13387 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -408,8 +408,15 @@ def save_state( # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) - with PathManager.open(filename, "wb") as f: - torch_persistent_save(state_dict, f) + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + torch_persistent_save(state_dict, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + torch_persistent_save(state_dict, f) def _upgrade_state_dict(state): diff --git a/fairseq/file_io.py b/fairseq/file_io.py index d74a48591a..7d6c28dccd 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -26,7 +26,8 @@ logging.warning("S3PathHandler already registered.") except ImportError: logging.debug( - "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module.") + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." + ) except ImportError: FVCorePathManager = None @@ -112,7 +113,7 @@ def rm(path: str) -> None: @staticmethod def chmod(path: str, mode: int) -> None: - if "manifold" not in path: + if not PathManager.path_requires_pathmanager(path): os.chmod(path, mode) @staticmethod @@ -129,3 +130,21 @@ def copy_from_local( local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs ) return shutil.copyfile(local_path, dst_path) + + @staticmethod + def path_requires_pathmanager(path: str) -> bool: + """Do we require PathManager to access given path?""" + if FVCorePathManager: + for p in FVCorePathManager._path_handlers.keys(): + if path.startswith(p): + return True + return False + + @staticmethod + def supports_rename(path: str) -> bool: + # PathManager doesn't yet support renames + return not PathManager.path_requires_pathmanager(path) + + @staticmethod + def rename(src: str, dst: str): + os.rename(src, dst) diff --git a/tests/test_iopath.py b/tests/test_iopath.py new file mode 100644 index 0000000000..908261a661 --- /dev/null +++ b/tests/test_iopath.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest import mock + + +class TestIOPath(unittest.TestCase): + + def test_no_iopath(self): + from .test_reproducibility import TestReproducibility + + with mock.patch.dict("sys.modules", {"iopath": None}): + # reuse reproducibility tests, which are e2e tests that should cover + # most checkpoint related functionality + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + def test_no_supports_rename(self): + from .test_reproducibility import TestReproducibility + + with mock.patch("fairseq.file_io.PathManager.supports_rename") as mock_fn: + mock_fn.return_value = False + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 517e23c39e..405d545593 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -26,7 +26,7 @@ def _test_reproducibility( ): def get_last_log_stats_containing_string(log_records, search_string): for log_record in logs.records[::-1]: - if search_string in log_record.msg: + if isinstance(log_record.msg, str) and search_string in log_record.msg: return json.loads(log_record.msg) if extra_flags is None: From a041e1ae9cd5d69af993f5da6561223ad407f5da Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 18 Dec 2020 07:40:57 -0800 Subject: [PATCH 368/707] Fix parameter (#3045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? `src_lengths` is not a required parameter in `TransformerEncoder`. It is a dummy variable. Maybe more changes should be done to fix this issue in Class such as `Transformer`, `FairseqEncoderDecoderModel`, `BARTModel` etc. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3045 Reviewed By: ngoyal2707 Differential Revision: D25632992 Pulled By: myleott fbshipit-source-id: 762d595144b611e1a6c236248d7001049afed0ab --- fairseq/models/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 9655578e52..fa4c29855b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -402,7 +402,7 @@ def forward_embedding( def forward( self, src_tokens, - src_lengths, + src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): @@ -418,7 +418,7 @@ def forward( default `None` will recompute embeddings Returns: - namedtuple: + dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of From 3a597d11731c7b7949072856aa51dbf4581963b0 Mon Sep 17 00:00:00 2001 From: Yiding Tian Date: Fri, 18 Dec 2020 07:41:03 -0800 Subject: [PATCH 369/707] Removing an unwanted bracket character in logging message. (#3028) Summary: # Before submitting - [no] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [yes] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [no need] Did you make sure to update the docs? - [no need] Did you write any new necessary tests? ## What does this PR do? Just fixing a small typo of logging one additional bracket before starting training. ## PR review Anyone in the community is free to review the PR once the tests have passed. > If we didn't discuss your PR in Github issues there's a high chance it will not be merged. It's really a small change, no need to discuss. ## Did you have fun? Small change although, do have fun reading the code. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3028 Reviewed By: ngoyal2707 Differential Revision: D25632978 Pulled By: myleott fbshipit-source-id: 62c85a9727af523d4082678a12a71b78f3ea84c0 --- fairseq_cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 81ecc80140..245263402e 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -76,7 +76,7 @@ def main(cfg: DictConfig) -> None: logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) - logger.info("criterion: {})".format(criterion.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), From 828960f5dace4787ad81aeadca60043c907adc67 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 18 Dec 2020 07:41:49 -0800 Subject: [PATCH 370/707] fix inconsistency in wav2vec documentation (#3039) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3039 Reviewed By: ngoyal2707 Differential Revision: D25632983 Pulled By: myleott fbshipit-source-id: 32a60da9d41e600d2b047171d548f954196c0560 --- examples/wav2vec/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 05df59f214..a0c95e9c34 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -162,8 +162,8 @@ Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl import torch import fairseq -cp = torch.load('/path/to/wav2vec.pt') -model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +cp_path = '/path/to/wav2vec.pt' +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) model = model[0] model.eval() From 9693504a8a75bafd7bddd6caa47cc5aed6821a2b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 18 Dec 2020 07:42:44 -0800 Subject: [PATCH 371/707] Bugfix: Early exist in creating dictionaries due to unexpected f.tell() behavior (#2956) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? This PR fixes an issue with `f.tell()` when creating dictionaries. Before this bugfix, the dictionary generation had silently nondeterministic behavior which worsens with multiple workers. Please the comment in the commit for more details. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/2956 Reviewed By: ngoyal2707 Differential Revision: D25633020 Pulled By: myleott fbshipit-source-id: 08a4ae4a8d6e03f72484baafe012212a99003ada --- fairseq/binarizer.py | 8 +++++++- fairseq/data/dictionary.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 0255c084b5..c736c8754d 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -46,7 +46,13 @@ def replaced_consumer(word, idx): # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: - if end > 0 and f.tell() > end: + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if end > 0 and f.tell() > end and f.tell() < end + 2**32: break if already_numberized: id_strings = line.strip().split() diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index efb5f1542c..127d023f4c 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -335,7 +335,13 @@ def _add_file_to_dictionary_single_worker( for word in tokenize(line): counter.update([word]) counter.update([eos_word]) - if f.tell() > end: + # f.tell() returns only an opaque number which can + # return to the position in the file via f.seek() + # and does not necessarily represent a byte position + # in the file. However, f.tell() is faithful to the + # byte position _most of the time_. Thus we can just + # check against the file size to prevent early exit. + if f.tell() > end and f.tell() < size: break line = f.readline() return counter From 36c63c826d2292c9df56065b5816c02eefc87713 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 18 Dec 2020 11:44:05 -0800 Subject: [PATCH 372/707] Refactor eval_lm to support library usage (#1513) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1513 Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25570467 Pulled By: myleott fbshipit-source-id: 062f748e287797f4f01c605e0b544ef3e698851f --- .../truncated_bptt/transformer_xl_model.py | 9 +- .../truncated_bptt/truncated_bptt_lm_task.py | 36 +- fairseq/checkpoint_utils.py | 6 +- fairseq/data/data_utils.py | 8 +- fairseq/dataclass/configs.py | 8 +- fairseq/tasks/language_modeling.py | 37 ++ fairseq_cli/eval_lm.py | 344 ++++++++++-------- tests/test_binaries.py | 1 + 8 files changed, 294 insertions(+), 155 deletions(-) diff --git a/examples/truncated_bptt/transformer_xl_model.py b/examples/truncated_bptt/transformer_xl_model.py index 7466c951ab..83b248479e 100644 --- a/examples/truncated_bptt/transformer_xl_model.py +++ b/examples/truncated_bptt/transformer_xl_model.py @@ -49,8 +49,13 @@ def build_model(cls, cfg: TransformerXLConfig, task): class TransformerXLDecoder(FairseqIncrementalDecoder): def __init__(self, cfg, task): - from transformers.configuration_transfo_xl import TransfoXLConfig - from transformers.modeling_transfo_xl import TransfoXLLMHeadModel + try: + from transformers.models.transfo_xl import ( + TransfoXLConfig, TransfoXLLMHeadModel + ) + except ImportError: + from transformers.configuration_transfo_xl import TransfoXLConfig + from transformers.modeling_transfo_xl import TransfoXLLMHeadModel super().__init__(task.target_dictionary) self.cfg = cfg diff --git a/examples/truncated_bptt/truncated_bptt_lm_task.py b/examples/truncated_bptt/truncated_bptt_lm_task.py index 5f81ec4b84..34c4f03955 100644 --- a/examples/truncated_bptt/truncated_bptt_lm_task.py +++ b/examples/truncated_bptt/truncated_bptt_lm_task.py @@ -10,7 +10,12 @@ import torch from fairseq import distributed_utils as dist_utils, utils -from fairseq.data import Dictionary, TokenBlockDataset, data_utils, iterators +from fairseq.data import ( + Dictionary, + TokenBlockDataset, + data_utils, + iterators, +) from fairseq.dataclass import FairseqDataclass from fairseq.tasks import FairseqTask, register_task from omegaconf import II @@ -182,6 +187,35 @@ def inference_step( models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token ) + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + context_window: int = 0, + ): + if context_window > 0: + raise NotImplementedError( + "Transformer-XL doesn't need --context-window, try " + "--model-overrides '{\"mem_len\":42}' instead " + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + @property def source_dictionary(self): return self.dictionary diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 34fef13387..f3c2d8aa16 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -11,7 +11,7 @@ import re import traceback from collections import OrderedDict -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig @@ -241,7 +241,7 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): def load_model_ensemble( filenames, - arg_overrides=None, + arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", @@ -273,7 +273,7 @@ def load_model_ensemble( def load_model_ensemble_and_task( filenames, - arg_overrides=None, + arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 81f457365a..1efe352dd2 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -346,8 +346,12 @@ def post_process(sentence: str, symbol: str): sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() - elif symbol is not None and symbol != "none": - sentence = (sentence + " ").replace(symbol, "").rstrip() + elif symbol in {"subword_nmt", "@@ "}: + sentence = (sentence + " ").replace("@@ ", "").rstrip() + elif symbol == "none": + pass + elif symbol is not None: + raise NotImplementedError(f"Unknown post_process option: {symbol}") return sentence diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 7d27dc0d4b..584e9201a3 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -831,9 +831,11 @@ class CommonEvalConfig(FairseqDataclass): post_process: Optional[str] = field( default=None, metadata={ - "help": "post-process text by removing pre-processing such as BPE, letter segmentation, etc " - "(valid options are: sentencepiece, wordpiece, letter, _EOW, none, otherwise treated as BPE symbol)", - "argparse_const": "@@ ", + "help": ( + "post-process text by removing BPE, letter segmentation, etc. " + "Valid options can be found in fairseq.data.utils.post_process." + ), + "argparse_const": "subword_nmt", "argparse_alias": "--remove-bpe", }, ) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index e0bf1f9b2b..b68c4ad4d1 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -15,6 +15,7 @@ AppendTokenDataset, Dictionary, IdDataset, + LMContextWindowDataset, MonolingualDataset, NestedDictionaryDataset, NumelDataset, @@ -312,6 +313,42 @@ def inference_step( models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token ) + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + assert self.args.tokens_per_sample > context_window + # reduce tokens per sample by the required context window size + tokens_per_sample = self.args.tokens_per_sample - context_window + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + @property def source_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index d962a8145b..a872245881 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -11,14 +11,16 @@ import logging import math import os +import sys from argparse import Namespace +from typing import Iterable, List, Optional import torch +import fairseq from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils -from fairseq.data import LMContextWindowDataset from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar -from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.logging.meters import StopwatchMeter from fairseq.sequence_scorer import SequenceScorer from omegaconf import DictConfig @@ -27,144 +29,74 @@ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) logger = logging.getLogger("fairseq_cli.eval_lm") -class WordStat(object): - def __init__(self, word, is_bpe): - self.word = word - self.is_bpe = is_bpe - self.log_prob = 0 - self.next_word_prob = 0 - self.count = 0 - self.missing_next_words = 0 - - def add(self, log_prob, next_word_prob): - """increments counters for the sum of log probs of current word and next - word (given context ending at current word). Since the next word might be at the end of the example, - or it might be not counted because it is not an ending subword unit, - also keeps track of how many of those we have seen""" - if next_word_prob is not None: - self.next_word_prob += next_word_prob - else: - self.missing_next_words += 1 - self.log_prob += log_prob - self.count += 1 - - def __str__(self): - return "{}\t{}\t{}\t{}\t{}\t{}".format( - self.word, - self.count, - self.log_prob, - self.is_bpe, - self.next_word_prob, - self.count - self.missing_next_words, - ) - - -def main(cfg: DictConfig, **unused_kwargs): - if isinstance(cfg, Namespace): - cfg = convert_namespace_to_omegaconf(cfg) - - utils.import_user_module(cfg.common) - - use_fp16 = cfg.common.fp16 - use_cuda = torch.cuda.is_available() and not cfg.common.cpu - - if use_cuda: - torch.cuda.set_device(cfg.distributed_training.device_id) - - logger.info(cfg) - - # Load ensemble - logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - - # reduce tokens per sample by the required context window size - cfg.task.tokens_per_sample -= cfg.eval_lm.context_window - - # Initialize the task using the current *cfg* - task = tasks.setup_task(cfg.task) - - # Initialize the model (but not the task) using the checkpoint's *cfg* - models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [cfg.common_eval.path], - arg_overrides=eval(cfg.common_eval.model_overrides), - suffix=cfg.checkpoint.checkpoint_suffix, - strict=(cfg.checkpoint.checkpoint_shard_count == 1), - num_shards=cfg.checkpoint.checkpoint_shard_count, - task=task, - ) - - # Load dataset splits - gen_subset = cfg.dataset.gen_subset - task.load_dataset(gen_subset) - dataset = task.dataset(gen_subset) - if cfg.eval_lm.context_window > 0: - dataset = LMContextWindowDataset( - dataset=dataset, - tokens_per_sample=cfg.task.tokens_per_sample, - context_window=cfg.eval_lm.context_window, - pad_idx=task.source_dictionary.pad(), - ) - logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) - - # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) - for model in models: - if use_fp16: - model.half() - if use_cuda and not cfg.distributed_training.pipeline_model_parallel: - model.cuda() - model.prepare_for_inference_(cfg) - - assert len(models) > 0 - - logger.info( - "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) - ) - - itr = task.get_batch_iterator( - dataset=dataset, - max_tokens=cfg.dataset.max_tokens or 36000, - max_sentences=cfg.dataset.batch_size, - max_positions=utils.resolve_max_positions( - *[model.max_positions() for model in models] - ), - ignore_invalid_inputs=True, - num_shards=max( - cfg.dataset.num_shards, - cfg.distributed_training.distributed_world_size, - ), - shard_id=max( - cfg.dataset.shard_id, - cfg.distributed_training.distributed_rank, - ), - num_workers=cfg.dataset.num_workers, - data_buffer_size=cfg.dataset.data_buffer_size, - ).next_epoch_itr(shuffle=False) - progress = progress_bar.progress_bar( - itr, - log_format=cfg.common.log_format, - log_interval=cfg.common.log_interval, - default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), - ) +def eval_lm( + models: List[fairseq.models.FairseqModel], + source_dictionary: fairseq.data.Dictionary, + batch_iterator: Iterable, + post_process: Optional[str] = None, + output_word_probs: bool = False, + output_word_stats: bool = False, + target_dictionary: Optional[fairseq.data.Dictionary] = None, + softmax_batch: int = False, + remove_bos_token: bool = False, + device: Optional[torch.device] = None, +): + """ + Args: + models (List[~fairseq.models.FairseqModel]): list of models to + evaluate. Models are essentially `nn.Module` instances, but + must be compatible with fairseq's `SequenceScorer`. + source_dictionary (~fairseq.data.Dictionary): dictionary for + applying any relevant post processing or outputing word + probs/stats. + batch_iterator (Iterable): yield batches of data + post_process (Optional[str]): post-process text by removing BPE, + letter segmentation, etc. Valid options can be found in + fairseq.data.utils.post_process, although not all options + are implemented here. + output_word_probs (Optional[bool]): output words and their + predicted log probabilities + output_word_stats (Optional[bool]): output word statistics such + as word count and average probability + target_dictionary (Optional[~fairseq.data.Dictionary]): output + dictionary (defaults to *source_dictionary*) + softmax_batch (Optional[bool]): if BxT is more than this, will + batch the softmax over vocab to this amount of tokens, in + order to fit into GPU memory + remove_bos_token (Optional[bool]): if True, confirm that the + first token is the beginning-of-sentence symbol (according + to the relevant dictionary) and remove it from the output + device (Optional[torch.device]): device to use for evaluation + (defaults to device of first model parameter) + """ + if target_dictionary is None: + target_dictionary = source_dictionary + if device is None: + device = next(models[0].parameters()).device gen_timer = StopwatchMeter() - scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) + scorer = SequenceScorer(target_dictionary, softmax_batch) score_sum = 0.0 count = 0 - if cfg.common_eval.post_process is not None: - if cfg.common_eval.post_process == "sentencepiece": - raise NotImplementedError - else: - bpe_cont = cfg.common_eval.post_process.rstrip() + if post_process is not None: + if post_process in {"subword_nmt", "@@ "}: + bpe_cont = post_process.rstrip() bpe_toks = { i - for i in range(len(task.source_dictionary)) - if task.source_dictionary[i].endswith(bpe_cont) + for i in range(len(source_dictionary)) + if source_dictionary[i].endswith(bpe_cont) } + else: + raise NotImplementedError( + "--post-process={post_process} is not implemented" + ) bpe_len = len(bpe_cont) else: bpe_toks = None @@ -172,13 +104,11 @@ def main(cfg: DictConfig, **unused_kwargs): word_stats = dict() - wps_meter = TimeMeter() - - for sample in progress: + for sample in batch_iterator: if "net_input" not in sample: continue - sample = utils.move_to_cuda(sample) if use_cuda else sample + sample = utils.move_to_cuda(sample, device=device) gen_timer.start() hypos = scorer.generate(models, sample) @@ -192,8 +122,8 @@ def main(cfg: DictConfig, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if getattr(cfg.task, "add_bos_token", False): - assert hypo["tokens"][0].item() == task.target_dictionary.bos() + if remove_bos_token: + assert hypo["tokens"][0].item() == target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -209,19 +139,19 @@ def main(cfg: DictConfig, **unused_kwargs): if inf_scores.any(): logger.info( "skipping tokens with inf scores:", - task.target_dictionary.string(tokens[inf_scores.nonzero()]), + target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks - if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: + if output_word_probs or output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() - w += task.source_dictionary[w_ind] + w += source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True @@ -241,7 +171,7 @@ def main(cfg: DictConfig, **unused_kwargs): ) is_bpe = False w = "" - if cfg.eval_lm.output_word_probs: + if output_word_probs: logger.info( str(int(sample_id)) + " " @@ -252,24 +182,150 @@ def main(cfg: DictConfig, **unused_kwargs): ) ) - wps_meter.update(sample["ntokens"]) - progress.log({"wps": round(wps_meter.avg)}) - avg_nll_loss = -score_sum / count / math.log(2) if count > 0 else 0 # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 ) ) + + if output_word_stats: + for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): + logger.info(ws) + + return { + "loss": avg_nll_loss, + "perplexity": 2 ** avg_nll_loss, + } + + +class WordStat(object): + def __init__(self, word, is_bpe): + self.word = word + self.is_bpe = is_bpe + self.log_prob = 0 + self.next_word_prob = 0 + self.count = 0 + self.missing_next_words = 0 + + def add(self, log_prob, next_word_prob): + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" + if next_word_prob is not None: + self.next_word_prob += next_word_prob + else: + self.missing_next_words += 1 + self.log_prob += log_prob + self.count += 1 + + def __str__(self): + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, + ) + + +def main(cfg: DictConfig, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + logger.info(cfg) + + # Initialize the task using the current *cfg* + task = tasks.setup_task(cfg.task) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=eval(cfg.common_eval.model_overrides), + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + task=task, + ) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + # Optimize ensemble for generation and set the source and dest dicts on the model + # (required by scorer) + for model in models: + if use_fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + assert len(models) > 0 + + logger.info( + "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) + ) + + # Load dataset splits + task.load_dataset(cfg.dataset.gen_subset) + dataset = task.dataset(cfg.dataset.gen_subset) + logger.info( + "{} {} {} examples".format(cfg.task.data, cfg.dataset.gen_subset, len(dataset)) + ) + + itr = task.eval_lm_dataloader( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens or 36000, + batch_size=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + context_window=cfg.eval_lm.context_window, + ) + + itr = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + results = eval_lm( + models=models, + source_dictionary=task.source_dictionary, + batch_iterator=itr, + post_process=cfg.common_eval.post_process, + output_word_probs=cfg.eval_lm.output_word_probs, + output_word_stats=cfg.eval_lm.output_word_stats, + target_dictionary=task.target_dictionary, + softmax_batch=cfg.eval_lm.softmax_batch, + remove_bos_token=getattr(cfg.task, "add_bos_token", False), + ) + logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( - avg_nll_loss, 2 ** avg_nll_loss + results["loss"], results["perplexity"] ) ) - if cfg.eval_lm.output_word_stats: - for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): - logger.info(ws) + return results def cli_main(): diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 2b57aa66da..4e605bd0b1 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1048,6 +1048,7 @@ def test_transformer_lm(self): run_validation=True, ) eval_lm_main(data_dir) + eval_lm_main(data_dir, extra_flags=["--context-window", "25"]) generate_main( data_dir, [ From 4af87e3c150191f12e6e023c65ec97e38859b7f9 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Tue, 22 Dec 2020 18:00:04 -0800 Subject: [PATCH 373/707] Fix validate_and_save logic so validation is done for training terminated by stop_time_hours Summary: Also improves stopping criteria logging The final validation step wasn't being done when training was terminated by stop_time_hours. This was more of an issue for toy test cases (ex: wanting training to terminate in just a few min in order to test subsequent pipeline logic) since that could result in no validation loss ever being produced - which can break the rest of our pipeline Reviewed By: chtran Differential Revision: D25630252 fbshipit-source-id: 32cadbf977b0c9775830c1e8eb7f26bac12fe9ae --- fairseq_cli/train.py | 45 +++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 245263402e..adf07729c5 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -16,7 +16,6 @@ import numpy as np import torch - from fairseq import ( checkpoint_utils, distributed_utils, @@ -29,8 +28,8 @@ from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer -from omegaconf import DictConfig from fairseq.trainer import Trainer +from omegaconf import DictConfig logging.basicConfig( @@ -219,7 +218,9 @@ def train( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), azureml_logging=( - cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False + cfg.common.azureml_logging + if distributed_utils.is_master(cfg.distributed_training) + else False ), ) @@ -273,9 +274,32 @@ def validate_and_save( ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf + + # Stopping conditions (and an additional one based on validation loss later + # on) + should_stop = False + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 @@ -286,7 +310,7 @@ def validate_and_save( do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 @@ -299,16 +323,7 @@ def validate_and_save( if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) - # Stopping conditions - should_stop = ( - should_stop_early(cfg, valid_losses[0]) - or num_updates >= max_update - or ( - cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) - > cfg.optimization.stop_time_hours - ) - ) + should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: From 16a5fca05d3bcb5c1116ca987fa41c86599dfdf3 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 23 Dec 2020 11:15:53 -0800 Subject: [PATCH 374/707] fairseq checkpoint improvements Reviewed By: myleott Differential Revision: D25677238 fbshipit-source-id: b43075034c953491211f19a5464148de4758df83 --- fairseq/checkpoint_utils.py | 24 +++++++++++++++++------- fairseq/trainer.py | 19 +++++++++---------- fairseq_cli/train.py | 1 - 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index f3c2d8aa16..36a28f35dc 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -47,9 +47,6 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if not trainer.is_data_parallel_master: return - def is_better(a, b): - return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - write_timer = meters.StopwatchMeter() write_timer.start() @@ -57,6 +54,11 @@ def is_better(a, b): end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( @@ -91,11 +93,13 @@ def is_better(a, b): if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: - PathManager.copy(checkpoints[0], cp, overwrite=True) + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( - "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( checkpoints[0], epoch, updates, val_loss, write_timer.sum ) ) @@ -494,10 +498,16 @@ def _upgrade_state_dict(state): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr # binary_cross_entropy => wav2vec criterion - if hasattr(state["args"], "criterion") and state["args"].criterion == "binary_cross_entropy": + if ( + hasattr(state["args"], "criterion") + and state["args"].criterion == "binary_cross_entropy" + ): state["args"].criterion = "wav2vec" # speech_pretraining => audio pretraining - if hasattr(state["args"], "task") and state["args"].task == "speech_pretraining": + if ( + hasattr(state["args"], "task") + and state["args"].task == "speech_pretraining" + ): state["args"].task = "audio_pretraining" # audio_cpc => wav2vec if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": diff --git a/fairseq/trainer.py b/fairseq/trainer.py index cfeb63237b..8f42743ac3 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -263,10 +263,6 @@ def consolidate_optimizer(self): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint - logger.info( - f"Preparing to save checkpoint to {filename} after " - f"{self.get_num_updates()} updates" - ) extra_state["metrics"] = metrics.state_dict() extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( @@ -297,6 +293,7 @@ def load_checkpoint( """ extra_state, self._optim_history, last_optim_state = None, [], None + logger.info(f"Preparing to load checkpoint {filename}") bexists = PathManager.isfile(filename) if bexists: load_on_all_ranks = ( @@ -377,11 +374,6 @@ def load_checkpoint( if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] - logger.info( - "loaded checkpoint {} (epoch {} @ {} updates)".format( - filename, epoch, self.get_num_updates() - ) - ) if "previous_training_time" in extra_state: self._previous_training_time = extra_state["previous_training_time"] @@ -396,8 +388,15 @@ def load_checkpoint( for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() + + logger.info( + "Loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + else: - logger.info("no existing checkpoint found {}".format(filename)) + logger.info("No existing checkpoint found {}".format(filename)) return extra_state diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index adf07729c5..165ed86b58 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -327,7 +327,6 @@ def validate_and_save( # Save checkpoint if do_save or should_stop: - logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) From 996ae207075db47f4061519c7dc39a86ab6d9535 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 23 Dec 2020 16:41:59 -0800 Subject: [PATCH 375/707] Add --heartbeat-timeout (#1527) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1527 Test Plan: Train roberta large with FP32 so step time is moderate: No timeout: ``` CUDA_VISIBLE_DEVICES=0,1 python train.py --task dummy_masked_lm --arch roberta_large --criterion masked_lm --max-sentences 8 --optimizer adam --lr 0.0001 --log-format simple --log-interval 1 --update-freq 16 2020-12-23 13:45:07 | INFO | train_inner | epoch 001: 1 / 391 loss=15.957, ppl=63597.8, wps=0, ups=0, wpb=131072, bsz=256, num_updates=1, lr=0.0001, gnorm=12.446, train_wall=19, wall=22 2020-12-23 13:45:07 | INFO | root | Reducer buckets have been rebuilt in this iteration. 2020-12-23 13:45:20 | INFO | train_inner | epoch 001: 2 / 391 loss=14.75, ppl=27553.9, wps=10635.6, ups=0.08, wpb=131072, bsz=256, num_updates=2, lr=0.0001, gnorm=8.173, train_wall=12, wall=35 2020-12-23 13:45:32 | INFO | train_inner | epoch 001: 3 / 391 loss=13.894, ppl=15223, wps=10653, ups=0.08, wpb=131072, bsz=256, num_updates=3, lr=0.0001, gnorm=5.141, train_wall=12, wall=47 ``` Timeout of 1 second (fails): ``` CUDA_VISIBLE_DEVICES=0,1 python train.py --task dummy_masked_lm --arch roberta_large --criterion masked_lm --max-sentences 8 --optimizer adam --lr 0.0001 --log-format simple --log-interval 1 --update-freq 16 --heartbeat-timeout 1 2020-12-23 13:50:11 | ERROR | fairseq.models.distributed_fairseq_model | Killing job for not making progress in 1 seconds. Set --heartbeat-timeout=-1 to disable this timeout. 2020-12-23 13:50:11 | ERROR | fairseq.models.distributed_fairseq_model | Killing job for not making progress in 1 seconds. Set --heartbeat-timeout=-1 to disable this timeout. ``` Timeout of 3 seconds (doesn't fail): ``` CUDA_VISIBLE_DEVICES=0,1 python train.py --task dummy_masked_lm --arch roberta_large --criterion masked_lm --max-sentences 8 --optimizer adam --lr 0.0001 --log-format simple --log-interval 1 --update-freq 16 --heartbeat-timeout 3 2020-12-23 13:55:25 | INFO | train_inner | epoch 001: 1 / 391 loss=15.957, ppl=63597.8, wps=0, ups=0, wpb=131072, bsz=256, num_updates=1, lr=0.0001, gnorm=12.446, train_wall=19, wall=21 2020-12-23 13:55:25 | INFO | root | Reducer buckets have been rebuilt in this iteration. 2020-12-23 13:55:37 | INFO | train_inner | epoch 001: 2 / 391 loss=14.75, ppl=27553.9, wps=10682, ups=0.08, wpb=131072, bsz=256, num_updates=2, lr=0.0001, gnorm=8.173, train_wall=12, wall=33 2020-12-23 13:55:50 | INFO | train_inner | epoch 001: 3 / 391 loss=13.894, ppl=15223, wps=10654.5, ups=0.08, wpb=131072, bsz=256, num_updates=3, lr=0.0001, gnorm=5.141, train_wall=12, wall=46 ``` Reviewed By: joshim5, ngoyal2707 Differential Revision: D25696904 Pulled By: myleott fbshipit-source-id: b2dbb7ddd8ce3ea83491f479314a0b3caa09b4b7 --- fairseq/dataclass/configs.py | 7 ++++ fairseq/models/distributed_fairseq_model.py | 45 ++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 584e9201a3..caf4a7a2b8 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -252,6 +252,13 @@ class DistributedTrainingConfig(FairseqDataclass): default=False, metadata={"help": "[deprecated] this is now defined per Criterion"}, ) + heartbeat_timeout: int = field( + default=-1, + metadata={ + "help": "kill the job if no progress is made in N seconds; " + "set to -1 to disable" + } + ) broadcast_buffers: bool = field( default=False, metadata={ diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index b78a0125e3..909b3757b2 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -4,6 +4,10 @@ # LICENSE file in the root directory of this source tree. import inspect +import logging +import os +import signal +import threading import torch import torch.nn as nn @@ -12,6 +16,9 @@ from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel +logger = logging.getLogger(__name__) + + _GOSSIP_DISABLED = False try: import gossip @@ -97,12 +104,41 @@ def DistributedFairseqModel(args, model, process_group): else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + heartbeat_timeout = getattr(args, "heartbeat_timeout", -1) + class _DistributedFairseqModel(ddp_class): - """Extend DistributedDataParallel to check for missing - attributes in the wrapped module.""" + """ + Extend DistributedDataParallel to check for missing attributes in the + wrapped module and to add a timeout to kill the job if no progress is + made (--heartbeat-timeout). + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._heartbeat_timeout = heartbeat_timeout + if self._heartbeat_timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + else: + self._heartbeat = None + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self._heartbeat_timeout) + if not success: + logger.error(( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self._heartbeat_timeout))) + os.kill(parent_pid, signal.SIGKILL) + return def __getattr__(self, name): wrapped_module = super().__getattr__("module") @@ -110,6 +146,11 @@ def __getattr__(self, name): return getattr(wrapped_module, name) return super().__getattr__(name) + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return super().forward(*args, **kwargs) + return _DistributedFairseqModel(**init_kwargs) From b8ea8a9b72c82192da07e3377adf4ebbde16716d Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 23 Dec 2020 18:34:59 -0800 Subject: [PATCH 376/707] Fix --context-window and add test (#1526) Summary: This was broken in the recent refactoring: https://github.com/pytorch/fairseq/commit/36c63c826d2292c9df56065b5816c02eefc87713 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1526 Reviewed By: sshleifer Differential Revision: D25697706 Pulled By: myleott fbshipit-source-id: 4d9a735c0071a0d71a4ae46e1c3fc3aba572117b --- fairseq/data/lm_context_window_dataset.py | 21 ++++++++-- fairseq/data/monolingual_dataset.py | 8 ++-- fairseq/tasks/language_modeling.py | 5 +-- fairseq_cli/eval_lm.py | 4 ++ tests/test_lm_context_window.py | 51 +++++++++++++++++++++++ 5 files changed, 77 insertions(+), 12 deletions(-) create mode 100644 tests/test_lm_context_window.py diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py index 29ad887b7d..39512797bc 100644 --- a/fairseq/data/lm_context_window_dataset.py +++ b/fairseq/data/lm_context_window_dataset.py @@ -11,10 +11,23 @@ class LMContextWindowDataset(FairseqDataset): - """Wraps a MonolingualDataset and provides more context for evaluation.""" - - def __init__(self, dataset, tokens_per_sample, context_window, pad_idx): - assert isinstance(dataset, MonolingualDataset) + """ + Wraps a MonolingualDataset and provides more context for evaluation. + + Each item in the new dataset will have a maximum size of + ``tokens_per_sample + context_window``. + + Args: + dataset: dataset to wrap + tokens_per_sample (int): the max number of tokens in each dataset item + context_window (int): the number of accumulated tokens to add to each + dataset item + pad_idx (int): padding symbol + """ + + def __init__( + self, dataset, tokens_per_sample: int, context_window: int, pad_idx: int + ): assert context_window > 0 self.dataset = dataset self.tokens_per_sample = tokens_per_sample diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index ec73f1fda8..bf7aa86f6c 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -70,16 +70,16 @@ def __init__( dataset, sizes, src_vocab, - tgt_vocab, - add_eos_for_other_targets, - shuffle, + tgt_vocab=None, + add_eos_for_other_targets=False, + shuffle=False, targets=None, add_bos_token=False, ): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = src_vocab - self.tgt_vocab = tgt_vocab + self.tgt_vocab = tgt_vocab or src_vocab self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index b68c4ad4d1..4a44d967b3 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -328,12 +328,9 @@ def eval_lm_dataloader( context_window: int = 0, ): if context_window > 0: - assert self.args.tokens_per_sample > context_window - # reduce tokens per sample by the required context window size - tokens_per_sample = self.args.tokens_per_sample - context_window dataset = LMContextWindowDataset( dataset=dataset, - tokens_per_sample=tokens_per_sample, + tokens_per_sample=self.args.tokens_per_sample, context_window=context_window, pad_idx=self.source_dictionary.pad(), ) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index a872245881..f27e0258d0 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -239,6 +239,10 @@ def main(cfg: DictConfig, **unused_kwargs): logger.info(cfg) + if cfg.eval_lm.context_window > 0: + # reduce tokens per sample by the required context window size + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + # Initialize the task using the current *cfg* task = tasks.setup_task(cfg.task) diff --git a/tests/test_lm_context_window.py b/tests/test_lm_context_window.py new file mode 100644 index 0000000000..7415e86abd --- /dev/null +++ b/tests/test_lm_context_window.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from fairseq.data import MonolingualDataset +from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig +from tests import utils as test_utils + + +class TestLMContextWindow(unittest.TestCase): + + def test_eval_dataloader(self): + dictionary = test_utils.dummy_dictionary(10) + assert len(dictionary) == 14 # 4 extra special symbols + assert dictionary.pad() == 1 + + dataset = test_utils.TestDataset([ + torch.tensor([4, 5, 6, 7], dtype=torch.long), + torch.tensor([8, 9, 10, 11], dtype=torch.long), + torch.tensor([12, 13], dtype=torch.long), + ]) + dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) + + config = LanguageModelingConfig(tokens_per_sample=4) + task = LanguageModelingTask(config, dictionary) + + eval_dataloader = task.eval_lm_dataloader( + dataset=dataset, + batch_size=1, + context_window=2, + ) + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] + assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] + assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] + assert batch["target"][0].tolist() == [1, 1, 12, 13] + + +if __name__ == "__main__": + unittest.main() From 982ec329769cd189ba3735eecc9687d072bcdb72 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 28 Dec 2020 15:46:51 -0800 Subject: [PATCH 377/707] logger: format big numbers with commas for readability (#1525) Summary: Before: ``` 2020-12-23 11:46:16 | INFO | fairseq_cli.eval_lm | num. model params: 353781760 2020-12-23 11:46:21 | INFO | fairseq.data.data_utils | loaded 89663978 examples from: /private/home/sshleifer/data-bin/new_hybrid_data/train ``` After: ``` 2020-12-23 11:46:16 | INFO | fairseq_cli.eval_lm | num. model params: 353,781,760 2020-12-23 11:46:21 | INFO | fairseq.data.data_utils | loaded 89,663,978 examples from: /private/home/sshleifer/data-bin/new_hybrid_data/train ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1525 Test Plan: Run `fairseq-eval-lm` or `fairseq-train` and look at logs. For example, ``` export dd2=/private/home/sshleifer/data-bin/new_hybrid_data export m=/private/home/myleott/models/public_models/LM/roberta_lm.me_fp16.bm_none.tps1024.transformer_lm_gpt2_small.share.adam.b2_0.98.eps1e-08.cl0.0.lr0.003.wu3000.dr0.1.atdr0.1.wd0.01.ms2.uf4.mu100000.s1.ngpu64/model.pt fairseq-eval-lm $dd2 \ --path $m \ --sample-break-mode complete --gen-subset train \ --tokens-per-sample 3072 --max-tokens 3072 --context-window 2560 --softmax-batch 1024 --fp16 ``` Reviewed By: myleott Differential Revision: D25693004 Pulled By: sshleifer fbshipit-source-id: bfeb93fc6607cca2cb7a6e820f51e174d02d1f62 --- fairseq/data/data_utils.py | 2 +- fairseq/tasks/fairseq_task.py | 4 ++-- fairseq_cli/eval_lm.py | 12 ++++++++---- fairseq_cli/generate.py | 2 +- fairseq_cli/train.py | 2 +- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 1efe352dd2..cac11ed1f9 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -99,7 +99,7 @@ def load_indexed_dataset( ) if dataset is None: break - logger.info("loaded {} examples from: {}".format(len(dataset), path_k)) + logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k)) datasets.append(dataset) if not combine: break diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 24116bfd52..eb5e6a7694 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -154,7 +154,7 @@ def filter_indices_by_size( ) logger.warning( ( - "{} samples have invalid sizes and will be skipped, " + "{:,} samples have invalid sizes and will be skipped, " "max_positions={}, first few sample ids={}" ).format(len(ignored), max_positions, ignored[:10]) ) @@ -378,7 +378,7 @@ def build_generator( if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment - extra_gen_cls_kwargs['print_alignment'] = args.print_alignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment else: seq_gen_cls = SequenceGenerator diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index f27e0258d0..4501cac67e 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -182,9 +182,11 @@ def eval_lm( ) ) - avg_nll_loss = -score_sum / count / math.log(2) if count > 0 else 0 # convert to base 2 + avg_nll_loss = ( + -score_sum / count / math.log(2) if count > 0 else 0 + ) # convert to base 2 logger.info( - "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( + "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 ) ) @@ -274,14 +276,16 @@ def main(cfg: DictConfig, **unused_kwargs): assert len(models) > 0 logger.info( - "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) + "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) ) # Load dataset splits task.load_dataset(cfg.dataset.gen_subset) dataset = task.dataset(cfg.dataset.gen_subset) logger.info( - "{} {} {} examples".format(cfg.task.data, cfg.dataset.gen_subset, len(dataset)) + "{} {} {:,} examples".format( + cfg.task.data, cfg.dataset.gen_subset, len(dataset) + ) ) itr = task.eval_lm_dataloader( diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 4aeb4a56fa..ff8369b539 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -370,7 +370,7 @@ def decode_fn(x): logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( - "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 165ed86b58..6069cf48fe 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -77,7 +77,7 @@ def main(cfg: DictConfig) -> None: logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( - "num. model params: {} (num. trained: {})".format( + "num. model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), ) From 4e3895be1ccb59e36de85441cd049294cbad2d15 Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov Date: Mon, 28 Dec 2020 21:03:59 -0800 Subject: [PATCH 378/707] batch_by_size refactoring: 100x speedup and optimization of memory footprint Summary: Refactoring batch_by_size. You may be required to rebuild Cython components with: `python setup.py build_ext --inplace`. Reviewed By: myleott Differential Revision: D25705733 fbshipit-source-id: a263505276e3d820a9e44b93354ee5ace70d7fc5 --- fairseq/data/data_utils.py | 35 +++-- fairseq/data/data_utils_fast.pyx | 144 ++++++++++++------ fairseq/data/fairseq_dataset.py | 14 ++ fairseq/data/language_pair_dataset.py | 8 + .../multilingual/sampled_multi_dataset.py | 5 + tests/test_data_utils.py | 136 +++++++++++++++++ 6 files changed, 287 insertions(+), 55 deletions(-) create mode 100644 tests/test_data_utils.py diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index cac11ed1f9..9a0580977d 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -80,8 +80,8 @@ def load_indexed_dataset( combine 'data-bin/train', 'data-bin/train1', ... and return a single ConcatDataset instance. """ - from fairseq.data.concat_dataset import ConcatDataset import fairseq.data.indexed_dataset as indexed_dataset + from fairseq.data.concat_dataset import ConcatDataset datasets = [] for k in itertools.count(): @@ -276,6 +276,7 @@ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_siz def batch_by_size( indices, num_tokens_fn, + num_tokens_vec=None, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, @@ -289,6 +290,8 @@ def batch_by_size( indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index + num_tokens_vec (List[int], optional): precomputed vector of the number + of tokens for each index in indices (to enable faster batch generation) max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each @@ -301,7 +304,8 @@ def batch_by_size( """ try: from fairseq.data.data_utils_fast import ( - batch_by_size_fast, + batch_by_size_fn, + batch_by_size_vec, batch_fixed_shapes_fast, ) except ImportError: @@ -317,14 +321,27 @@ def batch_by_size( if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) + if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray): + num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1) + if fixed_shapes is None: - return batch_by_size_fast( - indices, - num_tokens_fn, - max_tokens, - max_sentences, - bsz_mult, - ) + if num_tokens_vec is None: + return batch_by_size_fn( + indices, + num_tokens_fn, + max_tokens, + max_sentences, + bsz_mult, + ) + else: + return batch_by_size_vec( + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ) + else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort( diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index 38b4aa67dd..d197d3f00e 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -10,63 +10,115 @@ cimport cython cimport numpy as np from libc.stdint cimport int32_t, int64_t +from libcpp cimport bool as bool_t ctypedef int64_t DTYPE_t - -cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences): - if num_sentences == 0: - return 0 - if max_sentences > 0 and num_sentences == max_sentences: - return 1 - if max_tokens > 0 and num_tokens > max_tokens: - return 1 - return 0 - - @cython.cdivision(True) -cpdef list batch_by_size_fast( +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_vec( + np.ndarray[int64_t, ndim=1] indices, + np.ndarray[int64_t, ndim=1] num_tokens_vec, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( + f"Sentences lengths should not exceed max_tokens={max_tokens}" + ) + + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + np.zeros(indices_len, dtype=np.int32) + cdef int32_t[:] batches_ends_view = batches_ends + cdef int64_t[:] num_tokens_view = num_tokens_vec + + cdef int32_t pos = 0 + cdef int32_t new_batch_end = 0 + + cdef int64_t new_batch_max_tokens = 0 + cdef int32_t new_batch_sentences = 0 + cdef int64_t new_batch_num_tokens = 0 + + cdef bool_t overflow = False + cdef bool_t size_matches_with_bsz_mult = False + + cdef int32_t batches_count = 0 + cdef int32_t batch_start = 0 + cdef int64_t tail_max_tokens = 0 + cdef int64_t batch_max_tokens = 0 + + for pos in range(indices_len): + # At every pos we keep stats about the last complete batch [batch_start:batch_end), + # and tail [batch_end:pos]. + # 1) Every time when (batch + tail) forms a valid batch + # (according to max_tokens, max_sentences and bsz_mult) we append tail to batch. + # 2) When (batch+tail) violates max_tokens or max_sentences constraints + # we finalize running batch, and tail becomes a new batch. + # 3) There is a corner case when tail also violates constraints. + # In that situation [batch_end:pos-1] (tail without the current pos) + # gets added to the finalized batches, while [pos:pos] becomes a new tail. + # + # Important: For the sake of performance try to avoid using function calls within this loop. + + tail_max_tokens = tail_max_tokens \ + if tail_max_tokens > num_tokens_view[pos] \ + else num_tokens_view[pos] + new_batch_end = pos + 1 + new_batch_max_tokens = batch_max_tokens \ + if batch_max_tokens > tail_max_tokens \ + else tail_max_tokens + new_batch_sentences = new_batch_end - batch_start + new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + + overflow = (new_batch_sentences > max_sentences > 0 or + new_batch_num_tokens > max_tokens > 0) + size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + new_batch_sentences % bsz_mult == 0) + + if overflow: + tail_num_tokens = tail_max_tokens * \ + (new_batch_end - batches_ends_view[batches_count]) + tail_overflow = tail_num_tokens > max_tokens > 0 + # In case of a tail overflow finalize two batches + if tail_overflow: + batches_count += 1 + batches_ends_view[batches_count] = pos + tail_max_tokens = num_tokens_view[pos] + batch_start = batches_ends_view[batches_count] + batches_count += 1 + new_batch_max_tokens = tail_max_tokens + + if overflow or size_matches_with_bsz_mult: + batches_ends_view[batches_count] = new_batch_end + batch_max_tokens = new_batch_max_tokens + tail_max_tokens = 0 + if batches_ends_view[batches_count] != indices_len: + batches_count += 1 + # Memory and time-efficient split + return np.split(indices, batches_ends[:batches_count]) + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_fn( np.ndarray[DTYPE_t, ndim=1] indices, num_tokens_fn, int64_t max_tokens, int64_t max_sentences, int32_t bsz_mult, ): - cdef int64_t sample_len = 0 - cdef list sample_lens = [] - cdef list batch = [] - cdef list batches = [] - cdef int64_t mod_len - cdef int64_t i - cdef int64_t idx - cdef int64_t num_tokens + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + dtype=np.int64) cdef DTYPE_t[:] indices_view = indices - - for i in range(len(indices_view)): - idx = indices_view[i] - num_tokens = num_tokens_fn(idx) - sample_lens.append(num_tokens) - sample_len = max(sample_len, num_tokens) - - assert max_tokens <= 0 or sample_len <= max_tokens, ( - "sentence at index {} of size {} exceeds max_tokens " - "limit of {}!".format(idx, sample_len, max_tokens) - ) - num_tokens = (len(batch) + 1) * sample_len - - if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences): - mod_len = max( - bsz_mult * (len(batch) // bsz_mult), - len(batch) % bsz_mult, - ) - batches.append(batch[:mod_len]) - batch = batch[mod_len:] - sample_lens = sample_lens[mod_len:] - sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 - batch.append(idx) - if len(batch) > 0: - batches.append(batch) - return batches + cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + cdef int64_t pos + for pos in range(indices_len): + num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + max_sentences, bsz_mult,) cdef _find_valid_shape( diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index ed08c1ba20..23e6992dba 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -3,10 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import numpy as np import torch.utils.data from fairseq.data import data_utils +logger = logging.getLogger(__name__) + class EpochListening: """Mixin for receiving updates whenever the epoch increments.""" @@ -54,6 +57,11 @@ def num_tokens(self, index): enforce ``--max-tokens`` during batching.""" raise NotImplementedError + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" @@ -129,9 +137,15 @@ def adjust_bsz(bsz, num_tokens): ] ) + try: + num_tokens_vec = self.num_tokens_vec(indices).astype('int64') + except NotImplementedError: + num_tokens_vec = None + return data_utils.batch_by_size( indices, num_tokens_fn=self.num_tokens, + num_tokens_vec=num_tokens_vec, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 62e7109b33..8858cec84e 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -408,6 +408,14 @@ def num_tokens(self, index): self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + sizes = self.src_sizes[indices] + if self.tgt_sizes is not None: + sizes = np.maximum(sizes, self.tgt_sizes[indices]) + return sizes + def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 599f3a862b..f74ec18141 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -238,6 +238,11 @@ def __getitem__(self, index): def num_tokens(self, index): return self.sizes[index].max() + def num_tokens_vec(self, indices): + sizes_vec = self.sizes[np.array(indices)] + # max across all dimensions but first one + return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape)))) + def size(self, index): return self.sizes[index] diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py new file mode 100644 index 0000000000..2acfc8dc18 --- /dev/null +++ b/tests/test_data_utils.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np +from fairseq.data.data_utils_fast import batch_by_size_fn +from fairseq.data.data_utils_fast import batch_by_size_vec + + +class TestBatchBySize(unittest.TestCase): + @classmethod + def batch_by_size_baseline( + cls, + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ): + """Simple, reliable and slow implementation of batch by size """ + batches = [] + start = 0 + while start < len(indices): + for end in range(start + 1, len(indices) + 1): + max_val = max(num_tokens_vec[pos] for pos in range(start, end)) + sent_count = end - start + num_tokens = max_val * sent_count + overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0 + terminate = overflow or end == len(indices) + if overflow: + sent_count -= 1 + if terminate: + if sent_count > bsz_mult: + sent_count = sent_count - sent_count % bsz_mult + batches.append(indices[start : start + sent_count]) + start = start + sent_count + break + return batches + + @classmethod + def _get_error_message( + cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results + ): + return f"""Reference batch_by_size implementation should produce + same output as the baseline method. + Params: + max_sentences={max_sentences}, + max_tokens={max_tokens}, + bsz_mult={bsz_mult}, + num_tokens_vec={num_tokens_vec}, + expected_batches={validation}, + returned_batches={results}""" + + def _compare_results( + self, + indices_len, + batch_by_size_impl, + max_sentences, + max_tokens, + bsz_mult, + num_tokens_vec, + ): + indices = np.array(list(range(indices_len))) + validation = self.batch_by_size_baseline( + indices, + num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + bsz_mult=bsz_mult, + ) + results = batch_by_size_impl( + indices, + num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + bsz_mult=bsz_mult, + ) + error_msg = self._get_error_message( + max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results + ) + self.assertEqual(len(validation), len(results), error_msg) + for first, second in zip(validation, results): + self.assertTrue(np.array_equal(first, second), error_msg) + + def _run_compare_with_baseline_sweep(self, batch_by_size_impl): + """Compare reference batch_by_size implementation with batch_by_size_baseline + across a dense grid of hyperparam values""" + MAX_MAX_TOKENS = 10 + NUM_TOKENS_VECS_COUNT = 5 + for indices_len in [10, 11]: # try odd and even len of indices + for max_sentences in range(0, indices_len + 2): + for max_tokens in range(0, MAX_MAX_TOKENS): + for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2): + for _ in range(NUM_TOKENS_VECS_COUNT): + num_tokens_vec = np.random.randint( + 0, max_tokens + 1, size=indices_len + ) + self._compare_results( + indices_len, + batch_by_size_impl, + max_sentences, + max_tokens, + bsz_mult, + num_tokens_vec, + ) + + +class TestBatchBySizeVec(TestBatchBySize): + def test_compare_with_baseline(self): + self._run_compare_with_baseline_sweep(batch_by_size_vec) + + +class TestBatchBySizeFn(TestBatchBySize): + def test_compare_with_baseline(self): + def batch_by_size_fn_wrapper( + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ): + def num_tokens_fn(idx): + return num_tokens_vec[idx] + + return batch_by_size_fn( + indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult + ) + + self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper) + + +if __name__ == "__main__": + unittest.main() From e2e80c6f2dca01dd8c04b3e5b0b356abf4b429cf Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Tue, 29 Dec 2020 15:53:03 -0800 Subject: [PATCH 379/707] Rename "Arguments:" to "Args:" (#3060) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow—and now PyTorch—codebases: inconsistent use of `Args:` and `Arguments:` in its docstrings. It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per: - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md) - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md) - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst) Therefore, only `Args:` is valid. This PR replaces them throughout the codebase. PS: For related PRs, see pytorch/pytorch/pull/49736 Pull Request resolved: https://github.com/pytorch/fairseq/pull/3060 Reviewed By: ngoyal2707 Differential Revision: D25692815 Pulled By: myleott fbshipit-source-id: 461543ad3e2acd1fc475da799495841be73250bd --- fairseq/optim/adafactor.py | 4 ++-- fairseq/optim/adam.py | 4 ++-- fairseq/optim/adamax.py | 4 ++-- fairseq/optim/composite.py | 2 +- fairseq/optim/fused_adam.py | 4 ++-- fairseq/optim/nag.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 91745ce10e..c969b9fbc0 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -76,7 +76,7 @@ class Adafactor(torch.optim.Optimizer): schedule you should set `scale_parameter=False` and `relative_step=False`. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): external learning rate (default: None) @@ -168,7 +168,7 @@ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 1a4f213707..f73804718a 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -103,7 +103,7 @@ class Adam(torch.optim.Optimizer): It has been proposed in `Adam: A Method for Stochastic Optimization`_. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) @@ -146,7 +146,7 @@ def supports_flat_params(self): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py index 577a688166..98ff8ad7ad 100644 --- a/fairseq/optim/adamax.py +++ b/fairseq/optim/adamax.py @@ -53,7 +53,7 @@ class Adamax(torch.optim.Optimizer): Compared to the version in PyTorch, this version implements a fix for weight decay. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 2e-3) @@ -107,7 +107,7 @@ def supports_flat_params(self): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index 51e6999368..1a581bc010 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -134,7 +134,7 @@ def supports_flat_params(self): def step(self, closure=None, groups=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index 1780f9c0bb..e2b8e1bcd1 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -47,7 +47,7 @@ class FusedAdamV1(torch.optim.Optimizer): Compared to the original version in Apex, the fairseq version casts grads and params to FP32 internally to support ``--memory-efficient-fp16``. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) @@ -113,7 +113,7 @@ def supports_step_with_scale(self): def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. grads (list of tensors, optional): weight gradient to use for the diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 4f652fe6d3..c30a6c0fb1 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -62,7 +62,7 @@ def supports_flat_params(self): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ From 48a607527a8c2435c63795cca9a05348ef6e1f9d Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 29 Dec 2020 15:53:06 -0800 Subject: [PATCH 380/707] Reorganize self.emb_layer_norm (#3057) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Reorganize `self.emb_layer_norm` in order to keep right order while `print(model)` ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3057 Reviewed By: ngoyal2707 Differential Revision: D25692819 Pulled By: myleott fbshipit-source-id: d371955152e7fcb9e356351311f194a2418ca4b5 --- fairseq/modules/transformer_sentence_encoder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 7a5dcbdde3..6e9c32f467 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -145,6 +145,11 @@ def __init__( else None ) + if encoder_normalize_before: + self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) + else: + self.emb_layer_norm = None + if self.layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.layerdrop) else: @@ -167,11 +172,6 @@ def __init__( ] ) - if encoder_normalize_before: - self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) - else: - self.emb_layer_norm = None - # Apply initialization of model params after building the model if self.apply_bert_init: self.apply(init_bert_params) From bff7f85206f6f64b9455035893d44d66b98e33b0 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 30 Dec 2020 12:57:02 -0800 Subject: [PATCH 381/707] fastseq ngram blocking (#1509) Summary: Command: ```bash fairseq-generate \ ~myleott/data/data-bin/wmt16_en_de_bpe32k/ \ --path /checkpoint/myleott/s3/models/wmt16.en-de.joined-dict.transformer/model.pt \ --beam 4 --remove-bpe --lenpen 0.6 --batch-size 256 --no-repeat-ngram-size 3 \ --gen-subset test --fp16 ``` master/devfair: 297.8s (10.08 sentences/s, 286.47 tokens/s) branch/devfair: 31.9s (94.27 sentences/s, 2678.66 tokens/s) master/v100: 227.4s (13.21 sentences/s, 375.24 tokens/s) branch/v100: 13.1s (228.68 sentences/s, 6497.99 tokens/s) (all BLEU4=29.17) ### ToDo: - tests ### Future Work - test other fastseq proposed improvements. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1509 Reviewed By: myleott Differential Revision: D25587857 Pulled By: sshleifer fbshipit-source-id: d42af5c50e3f94c90e878f92da5ce5ef3fc8b988 --- fairseq/clib/cuda/ngram_repeat_block_cuda.cpp | 47 ++++++ .../cuda/ngram_repeat_block_cuda_kernel.cu | 76 +++++++++ fairseq/ngram_repeat_block.py | 150 ++++++++++++++++++ fairseq/sequence_generator.py | 77 ++------- setup.py | 10 +- tests/test_sequence_generator.py | 128 +++++++++++++-- 6 files changed, 411 insertions(+), 77 deletions(-) create mode 100644 fairseq/clib/cuda/ngram_repeat_block_cuda.cpp create mode 100644 fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu create mode 100644 fairseq/ngram_repeat_block.py diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp new file mode 100644 index 0000000000..4199cd6ea8 --- /dev/null +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +#include +#include + +/* +CPP Binding for CUDA OP +*/ + +// CUDA forward declarations +torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// Input check and call to CUDA OP +// Backward method not required +torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size) { + CHECK_INPUT(tokens); + CHECK_INPUT(lprobs); + assert(bsz > 0); + assert(step >= 0); + assert(beam_size > 0); + assert(no_repeat_ngram_size > 0); + + return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, + no_repeat_ngram_size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); +} diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu new file mode 100644 index 0000000000..b458b0916a --- /dev/null +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -0,0 +1,76 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for blocking repeated n-grams. +*/ + +#include +#include +#include +#include +#include + +// Ban repeated ngrams of length = 'no_repeat_ngram_size' +__global__ void banRepeatedTokens(long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, int vocab_size, + int no_repeat_ngram_size) { + auto row = blockIdx.x; + auto col = threadIdx.x; + auto start = row * (max_predict_len) + col; + // Each thread compares ngram starting from + // thread index with final ngram starting from + // step - no_repeat_ngram_size +2 + auto check_start_pos = blockDim.x; + auto lprob_start = row * vocab_size; + bool is_banned = true; + extern __shared__ long tokens_shm[]; + tokens_shm[col] = tokens[start]; + if (col == blockDim.x - 1) { + for (int i=1; i(); + auto lprob_ptr = lprobs.data_ptr(); + int blocks = bsz * beam_size; + int shared_mem_size = (step + 1) * sizeof(long); + + // Launching N blocks where N is number of samples in a batch (beams*bsz) + // Launching T threads where T is number of previous ngrams in a sample + // Allocating shared mem per block for fastser access of input tokens since + // each token will be accessed N times to compare with current Ngram where + // N is Ngram size. + banRepeatedTokens<<>>( + token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); + return lprobs; +} diff --git a/fairseq/ngram_repeat_block.py b/fairseq/ngram_repeat_block.py new file mode 100644 index 0000000000..856c9e64f7 --- /dev/null +++ b/fairseq/ngram_repeat_block.py @@ -0,0 +1,150 @@ +# Originally from Microsoft Corporation. +# Licensed under the MIT License. + +""" Wrapper for ngram_repeat_block cuda extension """ +import torch +from torch import nn + +import math +from typing import Dict, List, Optional +import warnings + +try: + from fairseq import ngram_repeat_block_cuda + + EXTENSION_BUILT = True +except ImportError: + EXTENSION_BUILT = False + + +def is_cuda_extension_usable() -> bool: + """Check whether ngram_repeat_block_cuda is built properly""" + if not EXTENSION_BUILT: + return False + bsz = 2 + tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") + lprobs = torch.rand((8, 12), device="cuda") + try: + outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) + outputs = outputs + 4 # This line breaks if the extension is built incorrectly. + return True + except RuntimeError: + warnings.warn( + "NGramRepeatBlock extension must be rebuilt." + 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' + ) + return False + + +class NGramRepeatBlock(nn.Module): + """ Wrapper class for calling ngram_repeat_block cuda extension """ + + def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): + super().__init__() + self.use_extension = is_cuda_extension_usable() if use_extension else False + self.no_repeat_ngram_size = no_repeat_ngram_size + + def reset_parameters(self): + pass + + @torch.jit.unused + def call_cuda_extension( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + return ngram_repeat_block_cuda.forward( + tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size + ) + + def forward( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + msg = f"expected {bsz *beam_size} got" + assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" + assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" + if self.use_extension: + return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) + + else: + return self._no_repeat_ngram( + tokens, + lprobs, + bsz, + beam_size, + step, + ) + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): + """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" + gen_ngrams: List[Dict[str, List[int]]] = [ + torch.jit.annotate(Dict[str, List[int]], {}) + for bbsz_idx in range(bsz * beam_size) + ] + cpu_tokens = tokens.cpu() + for bbsz_idx in range(bsz * beam_size): + gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() + for ngram in self.transpose_list( + [gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] + ): + key = ",".join([str(x) for x in ngram[:-1]]) + gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( + key, torch.jit.annotate(List[int], []) + ) + [ngram[-1]] + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [ + self.calculate_banned_tokens( + tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx + ) + for bbsz_idx in range(bsz * beam_size) + ] + else: + banned_tokens = [ + torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) + ] + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][ + torch.tensor(banned_tokens[bbsz_idx]).long() + ] = torch.tensor(-math.inf).to(lprobs) + return lprobs + + @staticmethod + def calculate_banned_tokens( + tokens, + step: int, + gen_ngrams: List[Dict[str, List[int]]], + no_repeat_ngram_size: int, + bbsz_idx: int, + ): + tokens_list: List[int] = tokens[ + bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 + ].tolist() + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = ",".join([str(x) for x in tokens_list]) + return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) + + @staticmethod + def transpose_list(l: List[List[int]]): + # GeneratorExp aren't supported in TS so ignoring the lint + min_len = min([len(x) for x in l]) # noqa + l2 = [[row[i] for row in l] for i in range(min_len)] + return l2 diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index bd46f9e5b9..b0249888ce 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -12,6 +12,7 @@ from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder from torch import Tensor +from fairseq.ngram_repeat_block import NGramRepeatBlock class SequenceGenerator(nn.Module): @@ -84,7 +85,10 @@ def __init__( self.unk_penalty = unk_penalty self.temperature = temperature self.match_source_len = match_source_len + self.no_repeat_ngram_size = no_repeat_ngram_size + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + assert temperature > 0, "--temperature must be greater than 0" self.search = ( @@ -278,7 +282,12 @@ def _generate( cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes - bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens).to(src_tokens.device) + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) reorder_state: Optional[Tensor] = None @@ -365,7 +374,7 @@ def _generate( self.search.set_src_lengths(src_lengths) if self.no_repeat_ngram_size > 0: - lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step) + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( @@ -709,62 +718,6 @@ def is_finished( return True return False - def calculate_banned_tokens( - self, - tokens, - step: int, - gen_ngrams: List[Dict[str, List[int]]], - no_repeat_ngram_size: int, - bbsz_idx: int, - ): - tokens_list: List[int] = tokens[ - bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 - ].tolist() - # before decoding the next token, prevent decoding of ngrams that have already appeared - ngram_index = ",".join([str(x) for x in tokens_list]) - return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) - - def transpose_list(self, l: List[List[int]]): - # GeneratorExp aren't supported in TS so ignoring the lint - min_len = min([len(x) for x in l]) # noqa - l2 = [[row[i] for row in l] for i in range(min_len)] - return l2 - - def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): - # for each beam and batch sentence, generate a list of previous ngrams - gen_ngrams: List[Dict[str, List[int]]] = [ - torch.jit.annotate(Dict[str, List[int]], {}) - for bbsz_idx in range(bsz * beam_size) - ] - cpu_tokens = tokens.cpu() - for bbsz_idx in range(bsz * beam_size): - gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() - for ngram in self.transpose_list( - [gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] - ): - key = ",".join([str(x) for x in ngram[:-1]]) - gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( - key, torch.jit.annotate(List[int], []) - ) + [ngram[-1]] - - if step + 2 - self.no_repeat_ngram_size >= 0: - # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - banned_tokens = [ - self.calculate_banned_tokens( - tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx - ) - for bbsz_idx in range(bsz * beam_size) - ] - else: - banned_tokens = [ - torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) - ] - for bbsz_idx in range(bsz * beam_size): - lprobs[bbsz_idx][ - torch.tensor(banned_tokens[bbsz_idx]).long() - ] = torch.tensor(-math.inf).to(lprobs) - return lprobs - class EnsembleModel(nn.Module): """A wrapper around an ensemble of models.""" @@ -867,7 +820,9 @@ def forward_decoder( return avg_probs, avg_attn @torch.jit.export - def reorder_encoder_out(self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order): + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): """ Reorder encoder output according to *new_order*. @@ -903,7 +858,9 @@ def reorder_incremental_state( class SequenceGeneratorWithAlignment(SequenceGenerator): - def __init__(self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): """Generates translations of a given source sentence. Produces alignments following "Jointly Learning to Align and diff --git a/setup.py b/setup.py index 1954298034..08fe0dcccc 100644 --- a/setup.py +++ b/setup.py @@ -109,7 +109,6 @@ def include_dirs(self, dirs): ) ] ) - if "CUDA_HOME" in os.environ: extensions.extend( [ @@ -119,7 +118,14 @@ def include_dirs(self, dirs): "fairseq/clib/libnat_cuda/edit_dist.cu", "fairseq/clib/libnat_cuda/binding.cpp", ], - ) + ), + cpp_extension.CppExtension( + "fairseq.ngram_repeat_block_cuda", + sources=[ + "fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", + "fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", + ], + ), ] ) cmdclass["build_ext"] = cpp_extension.BuildExtension diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index c890b655ff..afbdfb6c2c 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -6,6 +6,9 @@ import argparse import tempfile import unittest +import math +import numpy as np + import tests.utils as test_utils import torch @@ -13,6 +16,7 @@ from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.sequence_generator import EnsembleModel, SequenceGenerator +from fairseq.ngram_repeat_block import NGramRepeatBlock from fairseq.tasks.fairseq_task import LegacyFairseqTask @@ -41,7 +45,7 @@ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): dummy_dict = Dictionary() # add dummy symbol to satisfy vocab size for id, _ in enumerate(range(vocab_size)): - dummy_dict.add_symbol("{}".format(id), 1000) + dummy_dict.add_symbol("{}".format(id), n=1000) return dummy_dict @@ -107,30 +111,27 @@ def _test_save_and_load(self, scripted_module): torch.jit.load(f.name) -class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase): - @unittest.skipIf( - torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" - ) +JIT_MSG = "Targeting OSS scriptability for the 1.6 release" + + +@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG) +class TestJitSequenceGenerator(TestJitSequenceGeneratorBase): def test_export_transformer(self): model = self.transformer_model torch.jit.script(model) - @unittest.skipIf( - torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" - ) def test_ensemble_sequence_generator(self): model = self.transformer_model generator = SequenceGenerator( - [model], self.task.tgt_dict, beam_size=2, no_repeat_ngram_size=2 + [model], + self.task.tgt_dict, + beam_size=2, + no_repeat_ngram_size=2, + max_len_b=10, ) scripted_model = torch.jit.script(generator) self._test_save_and_load(scripted_model) - -class TestJitEnsemble(TestJitSequenceGeneratorBase): - @unittest.skipIf( - torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" - ) def test_export_ensemble_model(self): model = self.transformer_model ensemble_models = EnsembleModel([model]) @@ -185,7 +186,7 @@ def assertTensorEqual(self, t1, t2): self.assertEqual(t1.ne(t2).long().sum(), 0) -class TestSequeneceGenerator(TestSequenceGeneratorBase): +class TestSequenceGenerator(TestSequenceGeneratorBase): def setUp(self): ( self.tgt_dict, @@ -326,6 +327,103 @@ def test_generation_with_additional_input(self): self.assertHypoScore(hypos[0][0], [0.9, 1.0]) +@unittest.skipUnless(torch.cuda.is_available(), "") +class TestRepeatNgramBlocking(TestSequenceGeneratorBase): + @classmethod + def setUpClass(cls): + ( + cls.tgt_dict, + cls.w1, + cls.w2, + src_tokens, + src_lengths, + cls.model, + ) = test_utils.sequence_generator_setup() + return cls + + def test_finds_repetitive_tokens(self): + bsz, vocab_size, beam_size, step = 2, 4, 1, 3 + generated_tok = torch.tensor( + [[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda" + ) + lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda") + desired_result = lprobs.new_tensor( + [[0.0, 0.0, -math.inf, 0.0], [0.0, 0.0, 0.0, -math.inf]] + ) + + cuda_ext_result, baseline_result = self._compare_cuda_ext_to_default_implem( + bsz, beam_size, generated_tok, lprobs, step, 2 + ) + self.assertTensorEqual(cuda_ext_result, desired_result) + self.assertTensorEqual(baseline_result, desired_result) + + @unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG) + def test_jit_no_extension(self): + bsz, vocab_size, beam_size, step = 2, 4, 1, 3 + generated_tok = torch.tensor( + [[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda" + ) + lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda") + blocker = NGramRepeatBlock(2, use_extension=False) + base_result = blocker(generated_tok, lprobs.clone(), bsz, beam_size, step) + scripted_blocker = torch.jit.script(blocker) + jit_result = scripted_blocker( + generated_tok, lprobs.clone(), bsz, beam_size, step + ) + self.assertTensorEqual(base_result, jit_result) + + def test_ngram_blocking_same_as_default_implem(self): + """Test that cuda extension returns same things as default impl in many settings.""" + vocab_size = 4 + step = 6 + for _ in range(2): + block_param = np.random.choice([1, 2, 3, 4]) + batch_size = np.random.randint(1, 8) + beam_size = np.random.choice([1, 2, 4, 8]) + lprobs = torch.zeros((beam_size * batch_size, vocab_size), device="cuda") + + generated_tok = torch.tensor( + np.random.randint( + 0, vocab_size, size=(batch_size * beam_size, step + 1) + ), + device="cuda", + dtype=torch.long, + ) + self._compare_cuda_ext_to_default_implem( + batch_size, + beam_size, + generated_tok, + lprobs, + step, + block_param, + ) + + def _compare_cuda_ext_to_default_implem( + self, bsz, beam_size, generated_tok, lprobs, step, block_param + ): + """Assert that cuda extension and default implem return the same thing.""" + blocker = NGramRepeatBlock(block_param) + assert blocker.use_extension, "Extension not compiled" + cuda_ext_result = blocker( + generated_tok, + lprobs.clone(), + bsz, + beam_size, + step, + ) + blocker.use_extension = False + baseline_result = blocker( + generated_tok, + lprobs.clone(), + bsz, + beam_size, + step, + ) + self.assertTensorEqual(cuda_ext_result, baseline_result) + blocker.use_extension = True + return cuda_ext_result, baseline_result + + class TestDiverseBeamSearch(TestSequenceGeneratorBase): def setUp(self): # construct dummy dictionary From 01fcec5fc3dcc695c59cf3fdf7f178c174edcf0d Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 30 Dec 2020 20:15:58 -0800 Subject: [PATCH 382/707] Fix incorrect local cache for checkpoint_last.pt when training is restarted on the same host Reviewed By: myleott Differential Revision: D25719057 fbshipit-source-id: 2bf7dd93b6d223804da0326a0ed347e5e353f1f0 --- fairseq/checkpoint_utils.py | 37 ++++++++++++++++++++++++++++++++++--- fairseq/trainer.py | 11 +++++++---- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 36a28f35dc..79c811424a 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -226,9 +226,40 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): return extra_state, epoch_itr -def load_checkpoint_to_cpu(path, arg_overrides=None): - """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" - with open(PathManager.get_local_path(path), "rb") as f: +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): + """Loads a checkpoint to CPU (with upgrading for backward compatibility). + + If doing single-GPU training or if the checkpoint is only being loaded by at + most one process on each node (current default behavior is for only rank 0 + to read the checkpoint from disk), load_on_all_ranks should be False to + avoid errors from torch.distributed not having been initialized or + torch.distributed.barrier() hanging. + + If all processes on each node may be loading the checkpoint + simultaneously, load_on_all_ranks should be set to True to avoid I/O + conflicts. + + There's currently no support for > 1 but < all processes loading the + checkpoint on each node. + """ + local_path = PathManager.get_local_path(path) + # The locally cached file returned by get_local_path() may be stale for + # remote files that are periodically updated/overwritten (ex: + # checkpoint_last.pt) - so we remove the local copy, sync across processes + # (if needed), and then download a fresh copy. + if local_path != path and PathManager.path_requires_pathmanager(path): + try: + os.remove(local_path) + except FileNotFoundError: + # With potentially multiple processes removing the same file, the + # file being missing is benign (missing_ok isn't available until + # Python 3.8). + pass + if load_on_all_ranks: + torch.distributed.barrier() + local_path = PathManager.get_local_path(path) + + with open(local_path, "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 8f42743ac3..a6c1013635 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -294,6 +294,7 @@ def load_checkpoint( extra_state, self._optim_history, last_optim_state = None, [], None logger.info(f"Preparing to load checkpoint {filename}") + is_distributed = self.data_parallel_world_size > 1 bexists = PathManager.isfile(filename) if bexists: load_on_all_ranks = ( @@ -304,7 +305,9 @@ def load_checkpoint( ) if load_on_all_ranks or self.data_parallel_rank == 0: - state = checkpoint_utils.load_checkpoint_to_cpu(filename) + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, load_on_all_ranks=load_on_all_ranks + ) last_optim_state = state.get("last_optimizer_state", None) # If doing zero_sharding, do not broadcast global optimizer @@ -314,14 +317,14 @@ def load_checkpoint( not load_on_all_ranks and self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state - and self.data_parallel_world_size > 1 + and is_distributed ): state["last_optimizer_state"] = "SHARDED" else: last_optim_state = None state = None - if self.data_parallel_world_size > 1 and not load_on_all_ranks: + if is_distributed and not load_on_all_ranks: state = distributed_utils.broadcast_object( state, src_rank=0, @@ -364,7 +367,7 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if not load_on_all_ranks and self.data_parallel_world_size > 1: + if not load_on_all_ranks and is_distributed: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) From 336942734c85791a90baa373c212d27e7c722662 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 2 Jan 2021 10:21:42 -0800 Subject: [PATCH 383/707] Better support for WandB (#1530) Summary: Logs full config Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1530 Reviewed By: sshleifer Differential Revision: D25723894 Pulled By: myleott fbshipit-source-id: bb4b168c774bef498e336bbb3ba92eda4b08df3b --- README.md | 4 +-- fairseq/logging/progress_bar.py | 28 ++++++++++++++++++--- fairseq/models/distributed_fairseq_model.py | 2 -- fairseq_cli/train.py | 16 +++++++++++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index cc1c76ec36..c22abba8c0 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,8 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ -# to install the latest stable release (0.10.0) -# pip install fairseq==0.10.0 +# to install the latest stable release (0.10.1) +# pip install fairseq==0.10.1 ``` * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index e2a1711121..dc061a1821 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -145,6 +145,10 @@ def print(self, stats, tag=None, step=None): """Print end-of-epoch stats.""" raise NotImplementedError + def update_config(self, config): + """Log latest configuration.""" + pass + def _str_commas(self, stats): return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) @@ -303,9 +307,12 @@ def print(self, stats, tag=None, step=None): try: _tensorboard_writers = {} - from tensorboardX import SummaryWriter + from torch.utils.tensorboard import SummaryWriter except ImportError: - SummaryWriter = None + try: + from tensorboardX import SummaryWriter + except ImportError: + SummaryWriter = None def _close_writers(): @@ -325,7 +332,7 @@ def __init__(self, wrapped_bar, tensorboard_logdir): if SummaryWriter is None: logger.warning( - "tensorboard not found, please install with: pip install tensorboardX" + "tensorboard not found, please install with: pip install tensorboard" ) def _writer(self, key): @@ -350,6 +357,11 @@ def print(self, stats, tag=None, step=None): self._log_to_tensorboard(stats, tag, step) self.wrapped_bar.print(stats, tag=tag, step=step) + def update_config(self, config): + """Log latest configuration.""" + # TODO add hparams to Tensorboard + self.wrapped_bar.update_config(config) + def _log_to_tensorboard(self, stats, tag=None, step=None): writer = self._writer(tag or "") if writer is None: @@ -398,6 +410,12 @@ def print(self, stats, tag=None, step=None): self._log_to_wandb(stats, tag, step) self.wrapped_bar.print(stats, tag=tag, step=step) + def update_config(self, config): + """Log latest configuration.""" + if wandb is not None: + wandb.config.update(config) + self.wrapped_bar.update_config(config) + def _log_to_wandb(self, stats, tag=None, step=None): if wandb is None: return @@ -447,6 +465,10 @@ def print(self, stats, tag=None, step=None): self._log_to_azureml(stats, tag, step) self.wrapped_bar.print(stats, tag=tag, step=step) + def update_config(self, config): + """Log latest configuration.""" + self.wrapped_bar.update_config(config) + def _log_to_azureml(self, stats, tag=None, step=None): if Run is None: return diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 909b3757b2..ffa3c37b19 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -60,8 +60,6 @@ def DistributedFairseqModel(args, model, process_group): process_group=process_group, ) # Maintain backward compatibility - if "check_reduction" in inspect.getargspec(ddp_class)[0]: - init_kwargs["check_reduction"] = True if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: init_kwargs["find_unused_parameters"] = args.find_unused_parameters elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d": diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 6069cf48fe..1156222642 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -29,7 +29,7 @@ from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf logging.basicConfig( @@ -223,6 +223,7 @@ def train( else False ), ) + progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) @@ -264,6 +265,19 @@ def train( return valid_losses, should_stop +def _flatten_config(cfg: DictConfig): + config = OmegaConf.to_container(cfg) + # remove any legacy Namespaces and replace with a single "args" + namespace = None + for k, v in list(config.items()): + if isinstance(v, argparse.Namespace): + namespace = v + del config[k] + if namespace is not None: + config["args"] = vars(namespace) + return config + + def validate_and_save( cfg: DictConfig, trainer: Trainer, From 7e5e45b483ec5896830231ebbd3ed26472bcff47 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 5 Jan 2021 11:27:00 -0800 Subject: [PATCH 384/707] Add omegaconf dependency to hubconf.py (fixes #3093) (#3102) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3102 Reviewed By: alexeib Differential Revision: D25784689 Pulled By: myleott fbshipit-source-id: 66b0755e7f8abca2a522266bd58881b789bae581 --- hubconf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hubconf.py b/hubconf.py index ce7d76cfe1..4eb6998504 100644 --- a/hubconf.py +++ b/hubconf.py @@ -17,6 +17,7 @@ "dataclasses", "hydra", "numpy", + "omegaconf", "regex", "requests", "torch", From 540fb42c523e98f066989c3e3b61a18caaca24f5 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 5 Jan 2021 12:13:07 -0800 Subject: [PATCH 385/707] Move dep checks before fairseq imports in hubconf.py (fixes #3093) (#3104) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3104 Reviewed By: alexeib Differential Revision: D25786013 Pulled By: myleott fbshipit-source-id: 894b104f275573ce824d7f2318d043516f0e0c5c --- hubconf.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/hubconf.py b/hubconf.py index 4eb6998504..5949e274ed 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,16 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""isort:skip_file""" import functools import importlib -from fairseq.hub_utils import ( # noqa; noqa - BPEHubInterface as bpe, - TokenizerHubInterface as tokenizer, -) -from fairseq.models import MODEL_REGISTRY # noqa - dependencies = [ "dataclasses", @@ -40,6 +35,14 @@ raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) +# only do fairseq imports after checking for dependencies +from fairseq.hub_utils import ( # noqa; noqa + BPEHubInterface as bpe, + TokenizerHubInterface as tokenizer, +) +from fairseq.models import MODEL_REGISTRY # noqa + + # torch.hub doesn't build Cython components, so if they are not found then try # to build them here try: From 70b3f529659099b1bfca4099b11fc6d5577f7b0e Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 5 Jan 2021 14:33:29 -0800 Subject: [PATCH 386/707] Migrate wav2letter to flashlight (#2876) Summary: With this PR we start using flashlight bindings instead of wav2letter. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? Pull Request resolved: https://github.com/pytorch/fairseq/pull/2876 Reviewed By: myleott Differential Revision: D25785525 Pulled By: alexeib fbshipit-source-id: 245b3cebffedfd7db26c002ae3d26a1fe66c7156 --- examples/speech_recognition/README.md | 43 +++----------- .../speech_recognition/criterions/ASG_loss.py | 2 +- .../speech_recognition/criterions/__init__.py | 4 +- examples/speech_recognition/data/replabels.py | 4 +- examples/speech_recognition/infer.py | 2 +- examples/speech_recognition/w2l_decoder.py | 57 +++++++++---------- examples/wav2vec/README.md | 4 +- examples/wav2vec/libri_labels.py | 2 +- examples/wav2vec/vq-wav2vec_featurize.py | 2 +- examples/wav2vec/wav2vec_featurize.py | 12 ++-- 10 files changed, 53 insertions(+), 79 deletions(-) diff --git a/examples/speech_recognition/README.md b/examples/speech_recognition/README.md index 19f7cc563e..8207fe257c 100644 --- a/examples/speech_recognition/README.md +++ b/examples/speech_recognition/README.md @@ -32,41 +32,16 @@ sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.w ``` `Sum/Avg` row from first table of the report has WER -## Using wav2letter components -[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes: +## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components +[flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes: * AutoSegmentationCriterion (ASG) -* wav2letter-style Conv/GLU model -* wav2letter's beam search decoder +* flashlight-style Conv/GLU model +* flashlight's beam search decoder -To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required). +To use these, follow the instructions on [this page](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) to install python bindings. -To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps: -``` -# additional prerequisites - use equivalents for your distro -sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev -# install KenLM from source -git clone https://github.com/kpu/kenlm.git -cd kenlm -mkdir -p build && cd build -cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON -make -j16 -cd .. -export KENLM_ROOT_DIR=$(pwd) -cd .. -# install wav2letter python bindings -git clone https://github.com/facebookresearch/wav2letter.git -cd wav2letter/bindings/python -# make sure your python environment is active at this point -pip install torch packaging -pip install -e . -# try some examples to verify installation succeeded -python ./examples/criterion_example.py -python ./examples/decoder_example.py ../../src/decoder/test -python ./examples/feature_example.py ../../src/feature/test/data -``` - -## Training librispeech data (wav2letter style, Conv/GLU + ASG loss) +## Training librispeech data (flashlight style, Conv/GLU + ASG loss) Training command: ``` python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition @@ -74,13 +49,13 @@ python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 10 Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`. -## Inference for librispeech (wav2letter decoder, n-gram LM) +## Inference for librispeech (flashlight decoder, n-gram LM) Inference command: ``` python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition ``` -`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels): +`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels): ``` doorbell D O 1 R B E L 1 ▁ ``` @@ -99,7 +74,7 @@ doorbell ▁DO OR BE L L ``` Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`). -## Inference for librispeech (wav2letter decoder, viterbi only) +## Inference for librispeech (flashlight decoder, viterbi only) Inference command: ``` python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py index 7493654afc..41f50bbd70 100644 --- a/examples/speech_recognition/criterions/ASG_loss.py +++ b/examples/speech_recognition/criterions/ASG_loss.py @@ -46,7 +46,7 @@ def __init__( linseg_updates, hide_linseg_messages, ): - from wav2letter.criterion import ASGLoss, CriterionScaleMode + from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode super().__init__(task) self.tgt_dict = task.target_dictionary diff --git a/examples/speech_recognition/criterions/__init__.py b/examples/speech_recognition/criterions/__init__.py index 88af9f340f..a667b1c918 100644 --- a/examples/speech_recognition/criterions/__init__.py +++ b/examples/speech_recognition/criterions/__init__.py @@ -2,10 +2,10 @@ import os -# ASG loss requires wav2letter +# ASG loss requires flashlight bindings files_to_skip = set() try: - import wav2letter + import flashlight.lib.sequence.criterion except ImportError: files_to_skip.add("ASG_loss.py") diff --git a/examples/speech_recognition/data/replabels.py b/examples/speech_recognition/data/replabels.py index d76bda7aef..441f1bd432 100644 --- a/examples/speech_recognition/data/replabels.py +++ b/examples/speech_recognition/data/replabels.py @@ -6,13 +6,13 @@ # LICENSE file in the root directory of this source tree. """ -Replabel transforms for use with wav2letter's ASG criterion. +Replabel transforms for use with flashlight's ASG criterion. """ def replabel_symbol(i): """ - Replabel symbols used in wav2letter, currently just "1", "2", ... + Replabel symbols used in flashlight, currently just "1", "2", ... This prevents training with numeral tokens, so this might change in the future """ return str(i) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index ddd3fd6340..5a582c54af 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -277,7 +277,7 @@ def build_generator(args): return W2lFairseqLMDecoder(args, task.target_dictionary) else: print( - "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" + "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" ) # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index e2870df6a7..1fb20757d0 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. """ -Wav2letter decoders. +Flashlight decoders. """ import gc @@ -25,11 +25,11 @@ try: - from wav2letter.common import create_word_dict, load_words - from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes - from wav2letter.decoder import ( + from flashlight.lib.text.dictionary import create_word_dict, load_words + from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes + from flashlight.lib.text.decoder import ( CriterionType, - DecoderOptions, + LexiconDecoderOptions, KenLM, LM, LMState, @@ -39,7 +39,7 @@ ) except: warnings.warn( - "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings" + "flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python" ) LM = object LMState = object @@ -156,19 +156,19 @@ def __init__(self, args, tgt_dict): self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) - self.decoder_opts = DecoderOptions( - args.beam, - int(getattr(args, "beam_size_token", len(tgt_dict))), - args.beam_threshold, - args.lm_weight, - args.word_score, - args.unk_weight, - args.sil_weight, - 0, - False, - self.criterion_type, + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, ) + if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() @@ -368,17 +368,16 @@ def __init__(self, args, tgt_dict): self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) - self.decoder_opts = DecoderOptions( - args.beam, - int(getattr(args, "beam_size_token", len(tgt_dict))), - args.beam_threshold, - args.lm_weight, - args.word_score, - args.unk_weight, - args.sil_weight, - 0, - False, - self.criterion_type, + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, ) if self.lexicon: @@ -411,7 +410,7 @@ def __init__(self, args, tgt_dict): self.unit_lm, ) else: - from wav2letter.decoder import LexiconFreeDecoder + from flashlight.lib.text.decoder import LexiconFreeDecoder self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index a0c95e9c34..b3300a8ed8 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -123,12 +123,12 @@ You can specify the right config via the `--config-name` parameter. Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) `distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k -Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). +Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. ### Evaluating a CTC model: -Evaluating a CTC model with a language model requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings) to be installed. +Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). Be sure to upper-case the language model vocab after downloading it. diff --git a/examples/wav2vec/libri_labels.py b/examples/wav2vec/libri_labels.py index 3fa1ec4c8b..694a202604 100644 --- a/examples/wav2vec/libri_labels.py +++ b/examples/wav2vec/libri_labels.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Helper script to pre-compute embeddings for a wav2letter++ dataset +Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset """ import argparse diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py index 1adb52de1c..627072ee17 100644 --- a/examples/wav2vec/vq-wav2vec_featurize.py +++ b/examples/wav2vec/vq-wav2vec_featurize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Helper script to pre-compute embeddings for a wav2letter++ dataset +Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset """ import argparse diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py index b806316e5a..588268b708 100644 --- a/examples/wav2vec/wav2vec_featurize.py +++ b/examples/wav2vec/wav2vec_featurize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Helper script to pre-compute embeddings for a wav2letter++ dataset +Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset """ import argparse @@ -52,7 +52,7 @@ def forward(self, x): class EmbeddingWriterConfig(argparse.ArgumentParser): def __init__(self): - super().__init__("Pre-compute embeddings for wav2letter++ datasets") + super().__init__("Pre-compute embeddings for flashlight datasets") kwargs = {"action": "store", "type": str, "required": True} @@ -67,7 +67,7 @@ def __init__(self): self.add_argument( "--no-copy-labels", action="store_true", - help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.", + help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.", ) self.add_argument( "--use-feat", @@ -93,7 +93,7 @@ def __call__(self, x): class H5Writer: - """ Write features as hdf5 file in wav2letter++ compatible format """ + """ Write features as hdf5 file in flashlight compatible format """ def __init__(self, fname): self.fname = fname @@ -109,11 +109,11 @@ def write(self, data): class EmbeddingDatasetWriter(object): - """Given a model and a wav2letter++ dataset, pre-compute and store embeddings + """Given a model and a flashlight dataset, pre-compute and store embeddings Args: input_root, str : - Path to the wav2letter++ dataset + Path to the flashlight dataset output_root, str : Desired output directory. Will be created if non-existent split, str : From 53b474f8ac071da5aad94d255aa698278a492923 Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 5 Jan 2021 17:41:46 -0800 Subject: [PATCH 387/707] minor changes/fixes (#1536) Summary: - some minor guardrails - ability to suppress crashes (useful for sweepers) - hydra_train main method returns the best validation value (useful for sweepers) - add more typing checks to work with python 3.9 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1536 Reviewed By: myleott Differential Revision: D25768210 Pulled By: alexeib fbshipit-source-id: 3df421efb261eb61331a9af1da11b8ef34bfd8f9 --- fairseq/criterions/ctc.py | 2 +- fairseq/criterions/model_criterion.py | 2 +- fairseq/dataclass/configs.py | 7 ++++ fairseq/dataclass/utils.py | 13 +++---- .../sinusoidal_positional_embedding.py | 2 +- fairseq/tasks/__init__.py | 2 +- fairseq_cli/hydra_train.py | 35 ++++++++++++++----- 7 files changed, 45 insertions(+), 18 deletions(-) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 8cb1331825..543e796da3 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -66,7 +66,7 @@ class CtcCriterionConfig(FairseqDataclass): class CtcCriterion(FairseqCriterion): def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): super().__init__(task) - self.blank_idx = task.target_dictionary.index(task.blank_symbol) + self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() self.post_process = cfg.post_process diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py index 8e366a5d85..30350f13b1 100644 --- a/fairseq/criterions/model_criterion.py +++ b/fairseq/criterions/model_criterion.py @@ -83,7 +83,7 @@ def forward(self, model, sample, reduce=True): } for lk in self.log_keys: - if lk in net_output: + if lk in net_output and net_output[lk] is not None: logging_output[lk] = float(net_output[lk]) if len(scaled_losses) > 1: diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index caf4a7a2b8..2968d2ab0f 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -186,6 +186,13 @@ class CommonConfig(FairseqDataclass): "help": "when using Hydra, reset the logging at the beginning of training" }, ) + suppress_crashes: bool = field( + default=False, + metadata={ + "help": "suppress crashes when training with the hydra_train entry point so that the " + "main method can return a value (useful for sweeps)" + }, + ) @dataclass diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 45e7ed9170..4dc978409e 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -11,7 +11,7 @@ from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig @@ -43,7 +43,7 @@ def interpret_dc_type(field_type): return str typestring = str(field_type) - if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): + if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring) or typestring.startswith("typing.Optional"): return field_type.__args__[0] return field_type @@ -230,14 +230,15 @@ def get_default(f): v_type = getattr(v.type, "__origin__", None) if ( - (v_type is List or v_type is list) + (v_type is List or v_type is list or v_type is Optional) # skip interpolation and not (isinstance(val, str) and val.startswith("${")) ): # if type is int but val is float, then we will crash later - try to convert here - t_args = v.type.__args__ - if len(t_args) == 1: - val = list(map(t_args[0], val)) + if hasattr(v.type, '__args__'): + t_args = v.type.__args__ + if len(t_args) == 1: + val = list(map(t_args[0], val)) elif val is not None and (field_type is int or field_type is bool or field_type is float): try: val = field_type(val) diff --git a/fairseq/modules/sinusoidal_positional_embedding.py b/fairseq/modules/sinusoidal_positional_embedding.py index 857830faf7..4793ecfb52 100644 --- a/fairseq/modules/sinusoidal_positional_embedding.py +++ b/fairseq/modules/sinusoidal_positional_embedding.py @@ -21,7 +21,7 @@ class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim - self.padding_idx = padding_idx + self.padding_idx = padding_idx if padding_idx is not None else 0 self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx ) diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 0e55d093b1..95b4a9647f 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -39,7 +39,7 @@ def setup_task(cfg: FairseqDataclass, **kwargs): cfg = merge_with_parent(dc(), cfg) task = TASK_REGISTRY[task_name] - assert task is not None, f"Could not infer task type from {cfg}" + assert task is not None, f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" return task.setup_task(cfg, **kwargs) diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index b092ce14ee..cf48337462 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -10,7 +10,7 @@ from fairseq.dataclass.initialize import hydra_init from fairseq_cli.train import main as pre_main -from fairseq import distributed_utils +from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig import hydra @@ -22,7 +22,7 @@ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") -def hydra_main(cfg: FairseqConfig) -> None: +def hydra_main(cfg: FairseqConfig) -> float: cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) @@ -30,12 +30,31 @@ def hydra_main(cfg: FairseqConfig) -> None: if cfg.common.reset_logging: reset_logging() # Hydra hijacks logging, fix that - if cfg.common.profile: - with torch.cuda.profiler.profile(): - with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(cfg, pre_main) - else: - distributed_utils.call_main(cfg, pre_main) + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, pre_main) + else: + distributed_utils.call_main(cfg, pre_main) + except BaseException as e: + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! " + str(e)) + + # get best val and return - useful for sweepers + try: + best_val = metrics.get_smoothed_value( + "valid", cfg.checkpoint.best_checkpoint_metric + ) + except: + best_val = None + + if best_val is None: + best_val = float("inf") + + return best_val def reset_logging(): From 4daa41bedcc8e3ac7fd35dd1ec27088ecc71ebec Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 5 Jan 2021 17:41:46 -0800 Subject: [PATCH 388/707] migrate label smoothed cross entropy (#1537) Summary: migrate label smoothed cross entropy to hydra Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1537 Reviewed By: myleott Differential Revision: D25768702 Pulled By: alexeib fbshipit-source-id: a90fa40802ae67ad81a8f5b0735d5316d41fea2d --- .../label_smoothed_cross_entropy.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 2dc7f7a47d..cb47a1582f 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -4,10 +4,30 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field import torch from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + report_accuracy: bool = field( + default=False, + metadata={"help": "report accuracy metric"}, + ) + ignore_prefix_size: int = field( + default=0, + metadata={"help": "Ignore first N tokens"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): @@ -30,7 +50,9 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T return loss, nll_loss -@register_criterion("label_smoothed_cross_entropy") +@register_criterion( + "label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig +) class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): def __init__( self, @@ -46,18 +68,6 @@ def __init__( self.ignore_prefix_size = ignore_prefix_size self.report_accuracy = report_accuracy - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', - help='epsilon for label smoothing, 0 means no label smoothing') - parser.add_argument('--report-accuracy', action='store_true', - help='report accuracy metric') - parser.add_argument('--ignore-prefix-size', default=0, type=int, - help='Ignore first N tokens') - # fmt: on - def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. From cb7398299b4a8c51eb195fc0c054068470e8a1c9 Mon Sep 17 00:00:00 2001 From: Arthur Guo Date: Tue, 5 Jan 2021 19:08:27 -0800 Subject: [PATCH 389/707] Add typing and JIT support for TransfomerDecoder and dependencies Reviewed By: zhengwy888 Differential Revision: D25451106 fbshipit-source-id: f9aa7f1ca6120f00c4938d4d72219471035211ca --- fairseq/modules/conv_tbc.py | 6 +++++- fairseq/modules/linearized_convolution.py | 23 ++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py index 2dc46c4b9b..79b2b2ad57 100644 --- a/fairseq/modules/conv_tbc.py +++ b/fairseq/modules/conv_tbc.py @@ -5,6 +5,7 @@ import torch from torch.nn.modules.utils import _single +from torch import Tensor class ConvTBC(torch.nn.Module): @@ -26,11 +27,14 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0): ) self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) - def forward(self, input): + def conv_tbc(self, input: Tensor): return torch.conv_tbc( input.contiguous(), self.weight, self.bias, self.padding[0] ) + def forward(self, input: Tensor): + return self.conv_tbc(input) + def __repr__(self): s = ( "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py index b36cea91fa..f7e156cb0c 100644 --- a/fairseq/modules/linearized_convolution.py +++ b/fairseq/modules/linearized_convolution.py @@ -10,6 +10,8 @@ from .conv_tbc import ConvTBC +from typing import Dict, Optional +from torch import Tensor @with_incremental_state class LinearizedConvolution(ConvTBC): @@ -38,8 +40,8 @@ def upgrade_state_dict_named(self, state_dict, name): if prefix + "_linearized_weight" in state_dict: del state_dict[prefix + "_linearized_weight"] - @torch.jit.ignore - def forward(self, input, incremental_state=None): + @torch.jit.export + def forward(self, input, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None): """ Args: incremental_state: Used to buffer signal; if not None, then input is @@ -50,7 +52,7 @@ def forward(self, input, incremental_state=None): Batch x Time x Channel during inference """ if incremental_state is None: - output = super().forward(input) + output = self.conv_tbc(input) if self.kernel_size[0] > 1 and self.padding[0] > 0: # remove future timesteps added by padding output = output[: -self.padding[0], :, :] @@ -77,29 +79,32 @@ def forward(self, input, incremental_state=None): output = F.linear(input.view(bsz, -1), weight, self.bias) return output.view(bsz, 1, -1) - def reorder_incremental_state(self, incremental_state, new_order): + @torch.jit.unused + def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: input_buffer = input_buffer.index_select(0, new_order) self._set_input_buffer(incremental_state, input_buffer) - def _get_input_buffer(self, incremental_state): + @torch.jit.unused + def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): return utils.get_incremental_state(self, incremental_state, "input_buffer") - def _set_input_buffer(self, incremental_state, new_buffer): + @torch.jit.unused + def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer): return utils.set_incremental_state( self, incremental_state, "input_buffer", new_buffer ) + @torch.jit.unused def _get_linearized_weight(self): if self._linearized_weight is None: kw = self.kernel_size[0] weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() assert weight.size() == (self.out_channels, kw, self.in_channels) - self._linearized_weight = torch.nn.Parameter( - weight.view(self.out_channels, -1) - ) + return weight.view(self.out_channels, -1) return self._linearized_weight + @torch.jit.unused def _clear_linearized_weight(self, *args): self._linearized_weight = None From 89a4d2bc70fd680c4768803d20707ef65df89b0f Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Tue, 5 Jan 2021 20:54:21 -0800 Subject: [PATCH 390/707] S2T bug fix for issue #3095 Summary: S2T bug fix for issue [#3095](https://github.com/pytorch/fairseq/issues/3095): - Revert the dropped S2T de-tokenization in `fairseq-generate` - Update S2T Transformer encoder output to dict type (following the updates on the text Transformer) Reviewed By: jmp84 Differential Revision: D25788341 fbshipit-source-id: f226fb9d5e001bbc7dd245293819e0a36d5a88e7 --- .../models/speech_to_text/s2t_transformer.py | 70 +++++++------------ fairseq_cli/generate.py | 4 +- 2 files changed, 28 insertions(+), 46 deletions(-) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index fc2f14fea6..afd43f1ec7 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -4,7 +4,6 @@ import math from typing import Dict, List, Optional, Tuple -import torch import torch.nn as nn from fairseq import checkpoint_utils, utils from fairseq.data.data_utils import lengths_to_padding_mask @@ -14,7 +13,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.modules import ( FairseqDropout, @@ -308,70 +306,54 @@ def forward(self, src_tokens, src_lengths): for layer in self.transformer_layers: x = layer(x, encoder_padding_mask) - if not encoder_padding_mask.any(): - encoder_padding_mask = None - if self.layer_norm is not None: x = self.layer_norm(x) - return EncoderOut( - encoder_out=x, - encoder_padding_mask=encoder_padding_mask, - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=None, - ) - - @torch.jit.export - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): - """ - Since encoder_padding_mask and encoder_embedding are both of type - Optional[Tensor] in EncoderOut, they need to be copied as local - variables for Torchscript Optional refinement - """ - - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = ( - encoder_out.encoder_out - if encoder_out.encoder_out is None - else encoder_out.encoder_out.index_select(1, new_order) + [] if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] ) new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(0, new_order) + [] if len(encoder_out["encoder_padding_mask"]) == 0 + else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] ) new_encoder_embedding = ( - encoder_embedding - if encoder_embedding is None - else encoder_embedding.index_select(0, new_order) + [] if len(encoder_out["encoder_embedding"]) == 0 + else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]] ) - encoder_states = encoder_out.encoder_states - if encoder_states is not None: + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) - return EncoderOut( - encoder_out=new_encoder_out, # T x B x C - encoder_padding_mask=new_encoder_padding_mask, # B x T - encoder_embedding=new_encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=None, - src_lengths=None, - ) + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } class TransformerDecoderScriptable(TransformerDecoder): def extract_features( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index ff8369b539..0a523680f0 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -170,8 +170,8 @@ def _main(cfg: DictConfig, output_file): ) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(cfg.tokenizer) - bpe = encoders.build_bpe(cfg.bpe) + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: From d1d487395e8206c39640f07f3be5b2bce33edee6 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Tue, 5 Jan 2021 22:08:51 -0800 Subject: [PATCH 391/707] bug fix for scorer config Summary: bug fix for scorer config - additional scorer arguments (e.g. WER punctuation removal) from cli are not passed into `build_scorer` properly Reviewed By: jmp84 Differential Revision: D25797236 fbshipit-source-id: 909e272931ca0fd1cad2d7d8a2ca7e79bd7dcd43 --- fairseq/scoring/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 9163be87e7..2372727883 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -37,10 +37,9 @@ def result_string(self) -> str: def build_scorer(choice, tgt_dict): - if isinstance(choice, DictConfig): - choice = choice._name + _choice = choice._name if isinstance(choice, DictConfig) else choice - if choice == "bleu": + if _choice == "bleu": from fairseq.scoring import bleu return bleu.Scorer( From 2e649232d4455826f7110f7c321b1d5e193b891e Mon Sep 17 00:00:00 2001 From: Rui Hou Date: Wed, 6 Jan 2021 21:41:58 -0800 Subject: [PATCH 392/707] Make StreamingEpochBatchIterator work with batch size > 1 Summary: Previously, StreamingEpochBatchIterator only support batch_size = 1 for each GPU. This diff makes it possible to collate samples into a mini-batch such that batch_size can be greater than 1. Reviewed By: spencerp Differential Revision: D25295572 fbshipit-source-id: 398a4247701aa1ec71c1d16bba87430decf61ee6 --- fairseq/data/iterators.py | 69 ++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 0f55026ef8..00bf41375c 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -144,19 +144,46 @@ def first_batch(self): class StreamingEpochBatchIterator(EpochBatchIterating): + """A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`. + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + max_sentences: batch size + collate_fn (callable): merges a list of samples to form a mini-batch + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + """ + def __init__( self, dataset, + max_sentences=1, + collate_fn=None, epoch=1, - num_shards=1, - shard_id=0, + num_workers=0, + buffer_size=0, + timeout=0, ): assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset + self.max_sentences = max_sentences + self.collate_fn = collate_fn self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.num_workers = num_workers + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + self._current_epoch_iterator = None - self.num_shards = num_shards - self.shard_id = shard_id @property def next_epoch_idx(self): @@ -170,13 +197,7 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): self.epoch = self.next_epoch_idx if hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(self.epoch) - self._current_epoch_iterator = CountingIterator( - iterable=ShardedIterator( - iterable=self.dataset, - num_shards=self.num_shards, - shard_id=self.shard_id, - ), - ) + self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle) return self._current_epoch_iterator def end_of_epoch(self) -> bool: @@ -196,6 +217,30 @@ def state_dict(self): def load_state_dict(self, state_dict): self.epoch = state_dict["epoch"] + def _get_iterator_for_epoch(self, epoch, shuffle, offset=0): + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + worker_init_fn = getattr(self.dataset, "worker_init_fn", None) + itr = torch.utils.data.DataLoader( + self.dataset, + batch_size=self.max_sentences, + collate_fn=self.collate_fn, + num_workers=self.num_workers, + timeout=self.timeout, + worker_init_fn=worker_init_fn, + ) + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + return itr + class EpochBatchIterator(EpochBatchIterating): """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. @@ -442,7 +487,7 @@ def shuffle_batches(batches, seed): if self.buffer_size > 0: itr = BufferedIterator(self.buffer_size, itr) - # Wrap with CoutingIterator + # Wrap with CountingIterator itr = CountingIterator(itr, start=offset) return itr From a9f5741f58c05c581686b73465d7e3f9df5528f3 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Fri, 8 Jan 2021 15:25:59 -0800 Subject: [PATCH 393/707] update S2T examples and small fixes for S2T Summary: - Update S2T examples: documentation (rendered version: https://github.com/fairinternal/fairseq-py/tree/v2_fairseq_s2t/examples/speech_to_text), bug fixes and pre-trained models - Revert `--share-decoder-input-output-embed`'s default value to `False` (for s2t_transformer) Reviewed By: yuntang Differential Revision: D25821616 fbshipit-source-id: 3dba2eb5566bff39305d0056daf1b9f5adf1a926 --- examples/speech_recognition/README.md | 4 +- examples/speech_to_text/README.md | 259 ++---------------- examples/speech_to_text/data_utils.py | 118 ++++---- .../speech_to_text/docs/covost_example.md | 93 +++++++ .../docs/librispeech_example.md | 61 +++++ examples/speech_to_text/docs/mustc_example.md | 155 +++++++++++ examples/speech_to_text/prep_covost_data.py | 96 +++---- .../speech_to_text/prep_librispeech_data.py | 37 ++- examples/speech_to_text/prep_mustc_data.py | 73 ++--- .../audio/feature_transforms/global_cmvn.py | 4 + .../models/speech_to_text/s2t_transformer.py | 4 +- fairseq_cli/generate.py | 1 - 12 files changed, 495 insertions(+), 410 deletions(-) create mode 100644 examples/speech_to_text/docs/covost_example.md create mode 100644 examples/speech_to_text/docs/librispeech_example.md create mode 100644 examples/speech_to_text/docs/mustc_example.md diff --git a/examples/speech_recognition/README.md b/examples/speech_recognition/README.md index 8207fe257c..17030bf0fd 100644 --- a/examples/speech_recognition/README.md +++ b/examples/speech_recognition/README.md @@ -1,3 +1,5 @@ +### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned. + # Speech Recognition `examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660). @@ -39,7 +41,7 @@ sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.w * flashlight-style Conv/GLU model * flashlight's beam search decoder -To use these, follow the instructions on [this page](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) to install python bindings. +To use these, follow the instructions on [this page](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) to install python bindings. ## Training librispeech data (flashlight style, Conv/GLU + ASG loss) Training command: diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 4030af0144..0bd8bfdac9 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -1,8 +1,8 @@ # Speech-to-Text (S2T) Modeling -[https://arxiv.org/abs/2010.05171](https://arxiv.org/abs/2010.05171) +[https://www.aclweb.org/anthology/2020.aacl-demo.6](https://www.aclweb.org/anthology/2020.aacl-demo.6.pdf) -Examples for speech recognition (ASR) and speech-to-text translation (ST) with fairseq. +Speech recognition (ASR) and speech-to-text translation (ST) with fairseq. ## Data Preparation S2T modeling data consists of source speech features, target text and other optional information @@ -21,239 +21,31 @@ temperature-based resampling, etc. ## Model Training & Evaluation Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation. -It requires arguments `--task speech_to_text` and `--arch `. +It requires arguments `--task speech_to_text` and `--arch `. -## Example 1: Speech Recognition (ASR) on LibriSpeech +## Examples +- [Speech Recognition (ASR) on LibriSpeech](docs/librispeech_example.md) -#### Data preparation -Download and preprocess [LibriSpeech](https://www.danielpovey.com/files/2015_icassp_librispeech.pdf) data with -```bash -python examples/speech_to_text/prep_librispeech_data.py --output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000 -``` -where `LS_ROOT` is the root path for downloaded data as well as generated manifest and feature files. - -#### Training -```bash -fairseq-train ${LS_ROOT} --train-subset train --valid-subset dev --save-dir ${SAVE_DIR} --num-workers 4 \ - --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy --max-update 300000 \ - --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \ - --clip-norm 10.0 --seed 1 --update-freq 8 -``` -where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example. -You may switch to `s2t_transformer_m` (71M) or `s2t_transformer_l` (268M) for better performance. We set -`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. - -#### Inference & Evaluation -Average the last 10 checkpoints and evaluate on the 4 splits -(`dev-clean`, `dev-other`, `test-clean` and `test-other`): -```bash -CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py --inputs ${SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}" -for SUBSET in dev-clean dev-other test-clean test-other; do - fairseq-generate ${LS_ROOT} --gen-subset ${SUBSET} --task speech_to_text \ - --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring wer -done -``` +- [Speech-to-Text Translation (ST) on MuST-C](docs/mustc_example.md) -#### Result +- [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) -| --arch | Params | dev-clean | dev-other | test-clean | test-other | -|---|---|---|---|---|---| -| s2t_transformer_s | 30M | 4.1 | 9.3 | 4.4 | 9.2 | -| s2t_transformer_sp | 35M | 3.9 | 9.3 | 4.3 | 8.8 | -| s2t_transformer_m | 71M | 3.5 | 8.1 | 3.7 | 8.1 | -| s2t_transformer_mp | 84M | 3.3 | 7.8 | 3.7 | 8.2 | -| s2t_transformer_l | 268M | 3.3 | 7.7 | 3.5 | 7.8 | -| s2t_transformer_lp | 318M | 3.1 | 7.5 | 3.4 | 7.6 | - - -## Example 2: Speech Translation (ST) on MuST-C - -#### Data Preparation -[Download](https://ict.fbk.eu/must-c) and unpack [MuST-C](https://www.aclweb.org/anthology/N19-1202) data -to a path `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with -```bash -# Generate TSV manifests, features, vocabulary and configuration for each language -python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} --task asr \ - --vocab-type unigram --vocab-size 5000 -python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} --task st \ - --vocab-type unigram --vocab-size 8000 - -# Add vocabulary and configuration for joint data (based on the manifests and features generated above) -python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} --task asr --joint \ - --vocab-type unigram --vocab-size 10000 -python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} --task st --joint \ - --vocab-type unigram --vocab-size 10000 -``` -The generated files will be available under `${MUSTC_ROOT}/en-${TARGET_LANG_ID}` (per-language data) and -`MUSTC_ROOT` (joint data). - -#### ASR -###### Training -ASR data from En-De as example: -```bash -fairseq-train ${MUSTC_ROOT}/en-de --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \ - --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ - --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \ - --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 -``` -Using joint data from all directions: -```bash -fairseq-train ${MUSTC_ROOT} \ - --train-subset train_de_asr,train_nl_asr,train_es_asr,train_fr_asr,train_it_asr,train_pt_asr,train_ro_asr,train_ru_asr \ - --valid-subset dev_de_asr,dev_nl_asr,dev_es_asr,dev_fr_asr,dev_it_asr,dev_pt_asr,dev_ro_asr,dev_ru_asr \ - --save-dir ${JOINT_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --task speech_to_text --arch s2t_transformer_s \ - --criterion label_smoothed_cross_entropy --report-accuracy --max-update 100000 --optimizer adam --lr 1e-3 \ - --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 -``` -where `ASR_SAVE_DIR` (`JOINT_ASR_SAVE_DIR`) is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs -with 1 GPU. You may want to update it accordingly when using more than 1 GPU. - -###### Inference & Evaluation -```bash -CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" -fairseq-generate ${MUSTC_ROOT}/en-de --gen-subset tst-COMMON_asr --task speech_to_text \ - --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ - --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct - -# For models trained on joint data -python scripts/average_checkpoints.py --inputs ${JOINT_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" -for LANG in de nl es fr it pt ro ru; do - fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_${LANG}_asr --task speech_to_text \ - --path ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ - --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct -done -``` -###### Result -| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | -|---|---|---|---|---|---|---|---|---|---|---| -| Single | s2t_transformer_s | 31M | 18.2 | 17.6 | 17.7 | 17.2 | 17.9 | 19.1 | 18.1 | 17.7 | -| Joint | s2t_transformer_m | 76M | 16.8 | 16.7 | 16.9 | 16.9 | 17.0 | 17.4 | 17.0 | 16.9 | - -#### ST -###### Training -En-De as example: -```bash -fairseq-train ${MUSTC_ROOT}/en-de --train-subset train_st --valid-subset dev_st --save-dir ${ST_SAVE_DIR} \ - --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ - --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \ - --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ - --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} -``` -Example for multilingual models: -```bash -fairseq-train ${MUSTC_ROOT} \ - --train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \ - --valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \ - --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --task speech_to_text \ - --arch s2t_transformer_s --criterion label_smoothed_cross_entropy --report-accuracy --ignore-prefix-size 1 \ - --max-update 100000 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 \ - --seed 1 --update-freq 8 --load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} -``` -where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR -for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set -`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. -For multilingual models, we prepend target language ID token as target BOS, which should be excluded from -the training loss via `--ignore-prefix-size 1`. +## Updates +- 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data + preparation scripts. We also add pre-trained models to the examples and revise the instructions. + Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied + on-the-fly (defined in the config YAML). -###### Inference & Evaluation -Average the last 10 checkpoints and evaluate on the `tst-COMMON` split: -```bash -CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" -fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \ - --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu - -# For multilingual models -python scripts/average_checkpoints.py --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" -for LANG in de nl es fr it pt ro ru; do - fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_${LANG}_st --task speech_to_text --prefix-size 1 \ - --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu -done -``` -For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. - -###### Result -| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | -|---|---|---|---|---|---|---|---|---|---|---| -| Bilingual | s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 | -| Multilingual | s2t_transformer_m | 76M | 24.5 | 28.6 | 28.2 | 34.9 | 24.6 | 31.1 | 23.8 | 16.0 | - - -## Example 3: ST on CoVoST -We replicate the experiments in -[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310). - -#### Data Preparation -Download and preprocess [CoVoST (version 2)](https://arxiv.org/abs/2007.10310) data with -```bash -# En ASR -python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} --vocab-type char --src-lang en -# ST -python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} --vocab-type char \ - --src-lang fr --tgt-lang en -``` -where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files. - -#### ASR -###### Training -```bash -fairseq-train ${COVOST_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \ - --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ - --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \ - --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 -``` -where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. -You may want to update it accordingly when using more than 1 GPU. - -###### Inference & Evaluation -```bash -CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" -fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \ - --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ - --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct -``` -###### Result -| --arch | Params | En | -|---|---|---| -| s2t_transformer_s | 31M | 25.6 | - -#### ST -###### Training -```bash -fairseq-train ${COVOST_ROOT} --train-subset train_st_fr_en --valid-subset dev_st_fr_en --save-dir ${ST_SAVE_DIR} \ - --num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \ - --report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \ - --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ - --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} -``` -where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better -performance: `--load-pretrained-encoder-from `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. -You may want to update it accordingly when using more than 1 GPU. - -###### Inference & Evaluation -Average the last 10 checkpoints and evaluate on test split: -```bash -CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt -python scripts/average_checkpoints.py --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ - --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" -fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \ - --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu -``` - -###### Result -| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | -|---|---|---|---|---|---|---|---|---|---| -| s2t_transformer_s | 31M | 26.3 | 17.1 | 23.0 | 18.8 | 16.3 | 21.8 | 13.1 | 13.2 | +## What's Next +- We are migrating the old fairseq [ASR example](../speech_recognition) into this S2T framework and + merging the features from both sides. +- The following papers also base their experiments on fairseq S2T. We are adding more examples for replication. + - [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474) + - [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124) + - [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490) + - [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320) + - [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515) ## Citation Please cite as: @@ -272,12 +64,3 @@ Please cite as: year = {2019}, } ``` - -## More Paper Code -The following papers also base their experiments on fairseq S2T. We are adding more examples for replication. - -- [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474) -- [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124) -- [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490) -- [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320) -- [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 083d7316cd..0d7c034419 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -5,19 +5,16 @@ # LICENSE file in the root directory of this source tree. import csv -import os -import os.path as op +from pathlib import Path import zipfile from functools import reduce -from glob import glob from multiprocessing import cpu_count -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd import sentencepiece as sp from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank -from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN from tqdm import tqdm @@ -28,12 +25,13 @@ def gen_vocab( - input_path: str, output_path_prefix: str, model_type="bpe", vocab_size=1000, + input_path: Path, output_path_prefix: Path, model_type="bpe", + vocab_size=1000, special_symbols: Optional[List[str]] = None ): # Train SentencePiece Model arguments = [ - f"--input={input_path}", - f"--model_prefix={output_path_prefix}", + f"--input={input_path.as_posix()}", + f"--model_prefix={output_path_prefix.as_posix()}", f"--model_type={model_type}", f"--vocab_size={vocab_size}", "--character_coverage=1.0", @@ -43,10 +41,13 @@ def gen_vocab( f"--eos_id={EOS_TOKEN_ID}", f"--pad_id={PAD_TOKEN_ID}", ] + if special_symbols is not None: + _special_symbols = ",".join(special_symbols) + arguments.append(f"--user_defined_symbols={_special_symbols}") sp.SentencePieceTrainer.Train(" ".join(arguments)) # Export fairseq dictionary spm = sp.SentencePieceProcessor() - spm.Load(output_path_prefix + ".model") + spm.Load(output_path_prefix.as_posix() + ".model") vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} assert ( vocab.get(UNK_TOKEN_ID) == UNK_TOKEN @@ -59,20 +60,19 @@ def gen_vocab( for i, s in vocab.items() if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} } - with open(output_path_prefix + ".txt", "w") as f_out: + with open(output_path_prefix.as_posix() + ".txt", "w") as f_out: for _, s in sorted(vocab.items(), key=lambda x: x[0]): f_out.write(f"{s} 1\n") def extract_fbank_features( waveform, - sample_rate, - output_path=None, - n_mel_bins=80, - apply_utterance_cmvn=True, - overwrite=False, + sample_rate: int, + output_path: Optional[Path] = None, + n_mel_bins: int = 80, + overwrite: bool = False, ): - if output_path is not None and op.exists(output_path) and not overwrite: + if output_path is not None and output_path.is_file() and not overwrite: return _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers @@ -83,42 +83,36 @@ def extract_fbank_features( features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) if features is None: raise ImportError( - "Please install pyKaldi or torchaudio to enable " - "online filterbank feature extraction" + "Please install pyKaldi or torchaudio to enable fbank feature extraction" ) - if apply_utterance_cmvn: - cmvn = UtteranceCMVN(norm_means=True, norm_vars=True) - features = cmvn(features) if output_path is not None: - np.save(output_path, features) + np.save(output_path.as_posix(), features) else: return features -def create_zip(data_root, zip_path): - cwd = os.path.abspath(os.curdir) - os.chdir(data_root) +def create_zip(data_root: Path, zip_path: Path): + paths = list(data_root.glob("*.npy")) with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f: - for filename in tqdm(glob("*.npy")): - f.write(filename) - os.chdir(cwd) + for path in tqdm(paths): + f.write(path, arcname=path.name) def is_npy_data(data: bytes) -> bool: return data[0] == 147 and data[1] == 78 -def get_zip_manifest(zip_root, zip_filename): - zip_path = op.join(zip_root, zip_filename) - with zipfile.ZipFile(zip_path, mode="r") as f: +def get_zip_manifest(zip_path: Path, zip_root: Optional[Path] = None): + _zip_path = zip_path if zip_root is None else Path.joinpath(zip_root, zip_path) + with zipfile.ZipFile(_zip_path, mode="r") as f: info = f.infolist() manifest = {} for i in tqdm(info): - utt_id = op.splitext(i.filename)[0] + utt_id = Path(i.filename).stem offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size - manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}" - with open(zip_path, "rb") as f: + manifest[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}" + with open(_zip_path, "rb") as f: f.seek(offset) data = f.read(file_size) assert len(data) > 1 and is_npy_data(data) @@ -126,16 +120,16 @@ def get_zip_manifest(zip_root, zip_filename): def gen_config_yaml( - data_root, - spm_filename, - yaml_filename="config.yaml", - specaugment_policy="lb", - prepend_tgt_lang_tag=False, - sampling_alpha=1.0, + manifest_root: Path, + spm_filename: str, + yaml_filename: str = "config.yaml", + specaugment_policy: str = "lb", + prepend_tgt_lang_tag: bool = False, + sampling_alpha: float = 1.0, + audio_root: str = "" ): - data_root = op.abspath(data_root) - writer = S2TDataConfigWriter(op.join(data_root, yaml_filename)) - writer.set_audio_root(op.abspath(data_root)) + manifest_root = manifest_root.absolute() + writer = S2TDataConfigWriter(manifest_root / yaml_filename) writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) writer.set_input_channels(1) writer.set_input_feat_per_channel(80) @@ -145,24 +139,29 @@ def gen_config_yaml( "sm": writer.set_specaugment_sm_policy, "ss": writer.set_specaugment_ss_policy, } - assert specaugment_policy in specaugment_setters - specaugment_setters[specaugment_policy]() + specaugment_setter = specaugment_setters.get(specaugment_policy, None) + if specaugment_setter is not None: + specaugment_setter() writer.set_bpe_tokenizer( { "bpe": "sentencepiece", - "sentencepiece_model": op.join(data_root, spm_filename), + "sentencepiece_model": (manifest_root / spm_filename).as_posix(), } ) if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) writer.set_sampling_alpha(sampling_alpha) - writer.set_feature_transforms("_train", ["specaugment"]) + writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"]) + writer.set_feature_transforms("*", ["utterance_cmvn"]) + if len(audio_root) > 0: + writer.set_audio_root(audio_root) writer.flush() -def load_df_from_tsv(path: str): +def load_df_from_tsv(path: Union[str, Path]): + _path = path if isinstance(path, str) else path.as_posix() return pd.read_csv( - path, + _path, sep="\t", header=0, encoding="utf-8", @@ -172,9 +171,10 @@ def load_df_from_tsv(path: str): ) -def save_df_to_tsv(dataframe, path): +def save_df_to_tsv(dataframe, path: Union[str, Path]): + _path = path if isinstance(path, str) else path.as_posix() dataframe.to_csv( - path, + _path, sep="\t", header=True, index=False, @@ -211,11 +211,11 @@ class S2TDataConfigWriter(object): DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 DEFAULT_INPUT_CHANNELS = 1 - def __init__(self, yaml_path): + def __init__(self, yaml_path: Path): try: import yaml except ImportError: - print("Please install PyYAML to load YAML files for S2T data config") + print("Please install PyYAML for S2T data config YAML files") self.yaml = yaml self.yaml_path = yaml_path self.config = {} @@ -227,7 +227,7 @@ def flush(self): def set_audio_root(self, audio_root=""): self.config["audio_root"] = audio_root - def set_vocab_filename(self, vocab_filename="dict.txt"): + def set_vocab_filename(self, vocab_filename: str = "dict.txt"): self.config["vocab_filename"] = vocab_filename def set_specaugment( @@ -288,22 +288,22 @@ def set_specaugment_ss_policy(self): time_mask_p=0.2, ) - def set_input_channels(self, input_channels=1): + def set_input_channels(self, input_channels: int = 1): self.config["input_channels"] = input_channels - def set_input_feat_per_channel(self, input_feat_per_channel=80): + def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): self.config["input_feat_per_channel"] = input_feat_per_channel def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer - def set_feature_transforms(self, split, transforms: List[str]): + def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} self.config["transforms"][split] = transforms - def set_prepend_tgt_lang_tag(self, flag=True): + def set_prepend_tgt_lang_tag(self, flag: bool = True): self.config["prepend_tgt_lang_tag"] = flag - def set_sampling_alpha(self, sampling_alpha=1.0): + def set_sampling_alpha(self, sampling_alpha: float = 1.0): self.config["sampling_alpha"] = sampling_alpha diff --git a/examples/speech_to_text/docs/covost_example.md b/examples/speech_to_text/docs/covost_example.md new file mode 100644 index 0000000000..a4ce8a10e4 --- /dev/null +++ b/examples/speech_to_text/docs/covost_example.md @@ -0,0 +1,93 @@ +[[Back]](..) + +# S2T Example: ST on CoVoST +We replicate the experiments in +[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310). + +## Data Preparation +[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path +`${COVOST_ROOT}/${SOURCE_LANG_ID}`, then preprocess it with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# En ASR +python examples/speech_to_text/prep_covost_data.py \ + --data-root ${COVOST_ROOT} --vocab-type char --src-lang en +# ST +python examples/speech_to_text/prep_covost_data.py \ + --data-root ${COVOST_ROOT} --vocab-type char \ + --src-lang fr --tgt-lang en +``` +The generated files (manifest, features, vocabulary and data configuration) will be added to +`${COVOST_ROOT}/${SOURCE_LANG_ID}`. + +Download our vocabulary files if you want to use our pre-trained models: +- ASR: [En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_vocab_char.zip) +- ST: [Fr-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_vocab_char.zip), [De-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_vocab_char.zip), [Es-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_vocab_char.zip), [Ca-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_vocab_char.zip), [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip), [En-Ca](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_vocab_char.zip), [En-Fa](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_vocab_char.zip), [En-Et](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_vocab_char.zip) + +## ASR +#### Training +We train an En ASR model for encoder pre-training of all ST models: +```bash +fairseq-train ${COVOST_ROOT}/en \ + --config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 60000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +#### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${COVOST_ROOT}/en \ + --config-yaml config_asr_en.yaml --gen-subset test_asr_en --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct +``` +#### Results +| --arch | Params | En | Model | +|---|---|---|---| +| s2t_transformer_s | 31M | 25.6 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_transformer_s.pt) | + +## ST +#### Training +Fr-En as example: +```bash +fairseq-train ${COVOST_ROOT}/fr \ + --config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 60000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better +performance: `--load-pretrained-encoder-from `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. +You may want to update it accordingly when using more than 1 GPU. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on test split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${COVOST_ROOT}/fr \ + --config-yaml config_st_fr_en.yaml --gen-subset test_st_fr_en --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu +``` + +#### Results +| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model | +|---|---|---|---|---|---|---|---|---|---|---| +| s2t_transformer_s | 31M | [26.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.0](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [18.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [13.0](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [13.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) | + +[[Back]](..) diff --git a/examples/speech_to_text/docs/librispeech_example.md b/examples/speech_to_text/docs/librispeech_example.md new file mode 100644 index 0000000000..21b754ee11 --- /dev/null +++ b/examples/speech_to_text/docs/librispeech_example.md @@ -0,0 +1,61 @@ +[[Back]](..) + +# S2T Example: Speech Recognition (ASR) on LibriSpeech +[LibriSpeech](https://www.danielpovey.com/files/2015_icassp_librispeech.pdf) is a de-facto standard English ASR +benchmark. We provide competitive +vanilla [Transformer](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) baselines. + +## Data preparation +Download and preprocess LibriSpeech data with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +python examples/speech_to_text/prep_librispeech_data.py \ + --output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000 +``` +where `LS_ROOT` is the root path for downloaded data as well as generated files (manifest, features, vocabulary and +data configuration). + +[Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_vocab_unigram10000.zip) our vocabulary files +if you want to use our pre-trained models. + +## Training +```bash +fairseq-train ${LS_ROOT} --save-dir ${SAVE_DIR} \ + --config-yaml config.yaml --train-subset train --valid-subset dev \ + --num-workers 4 --max-tokens 40000 --max-update 300000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --share-decoder-input-output-embed \ + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \ + --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example. +For better performance, you may switch to `s2t_transformer_m` (71M, with `--lr 1e-3`) or `s2t_transformer_l` +(268M, with `--lr 5e-4`). We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly +when using more than 1 GPU. + +## Inference & Evaluation +Average the last 10 checkpoints and evaluate on the 4 splits +(`dev-clean`, `dev-other`, `test-clean` and `test-other`): +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \ + --num-epoch-checkpoints 10 \ + --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}" +for SUBSET in dev-clean dev-other test-clean test-other; do + fairseq-generate ${LS_ROOT} --config-yaml config.yaml --gen-subset ${SUBSET} \ + --task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring wer +done +``` + +## Results + +| --arch | Params | dev-clean | dev-other | test-clean | test-other | Model | +|---|---|---|---|---|---|---| +| s2t_transformer_s | 30M | 3.8 | 8.9 | 4.4 | 9.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_s.pt) | +| s2t_transformer_m | 71M | 3.2 | 8.0 | 3.4 | 7.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_m.pt) | +| s2t_transformer_l | 268M | 3.0 | 7.5 | 3.2 | 7.5 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_l.pt) | + +[[Back]](..) diff --git a/examples/speech_to_text/docs/mustc_example.md b/examples/speech_to_text/docs/mustc_example.md new file mode 100644 index 0000000000..7628dc77ef --- /dev/null +++ b/examples/speech_to_text/docs/mustc_example.md @@ -0,0 +1,155 @@ +[[Back]](..) + +# S2T Example: Speech Translation (ST) on MuST-C + +[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with +8-language translations on English TED talks. We match the state-of-the-art performance in +[ESPNet-ST](https://arxiv.org/pdf/2004.10234.pdf) with a simpler model training pipeline. + +## Data Preparation +[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path +`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary +# and configuration for each language +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task asr \ + --vocab-type unigram --vocab-size 5000 +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task st \ + --vocab-type unigram --vocab-size 8000 + +# Add vocabulary and configuration for joint data +# (based on the manifests and features generated above) +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task asr --joint \ + --vocab-type unigram --vocab-size 10000 +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task st --joint \ + --vocab-type unigram --vocab-size 10000 +``` +The generated files (manifest, features, vocabulary and data configuration) will be added to +`${MUSTC_ROOT}/en-${TARGET_LANG_ID}` (per-language data) and `MUSTC_ROOT` (joint data). + +Download our vocabulary files if you want to use our pre-trained models: +- ASR: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_vocab_unigram5000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_vocab_unigram5000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_vocab_unigram5000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_vocab_unigram5000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_vocab_unigram5000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_vocab_unigram5000.zip), [Joint](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_vocab_unigram10000.zip) +- ST: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_vocab_unigram8000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_vocab_unigram8000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_vocab_unigram8000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_vocab_unigram8000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_vocab_unigram8000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_vocab_unigram8000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_vocab_unigram8000.zip), [Multilingual](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_vocab_unigram10000.zip) + +## ASR +#### Training +En-De as example: +```bash +fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +For joint model (using ASR data from all 8 directions): +```bash +fairseq-train ${MUSTC_ROOT} \ + --config-yaml config_asr.yaml \ + --train-subset train_de_asr,train_nl_asr,train_es_asr,train_fr_asr,train_it_asr,train_pt_asr,train_ro_asr,train_ru_asr \ + --valid-subset dev_de_asr,dev_nl_asr,dev_es_asr,dev_fr_asr,dev_it_asr,dev_pt_asr,dev_ro_asr,dev_ru_asr \ + --save-dir ${JOINT_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +where `ASR_SAVE_DIR` (`JOINT_ASR_SAVE_DIR`) is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs +with 1 GPU. You may want to update it accordingly when using more than 1 GPU. + +#### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${MUSTC_ROOT}/en-de \ + --config-yaml config_asr.yaml --gen-subset tst-COMMON_asr --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct + +# For models trained on joint data +python scripts/average_checkpoints.py \ + --inputs ${JOINT_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" +for LANG in de nl es fr it pt ro ru; do + fairseq-generate ${MUSTC_ROOT} \ + --config-yaml config_asr.yaml --gen-subset tst-COMMON_${LANG}_asr --task speech_to_text \ + --path ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct +done +``` +#### Results +| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model | +|---|---|---|---|---|---|---|---|---|---|---|---| +| Single | s2t_transformer_s | 31M | [18.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_transformer_s.pt) | [17.6](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_transformer_s.pt) | [17.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_transformer_s.pt) | [17.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_transformer_s.pt) | [19.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_transformer_s.pt) | [18.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_transformer_s.pt) | (<-Download) | +| Joint | s2t_transformer_m | 76M | 16.8 | 16.7 | 16.9 | 16.9 | 17.0 | 17.4 | 17.0 | 16.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt) | + +## ST +#### Training +En-De as example: +```bash +fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +For multilingual model (all 8 directions): +```bash +fairseq-train ${MUSTC_ROOT} \ + --config-yaml config_st.yaml \ + --train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \ + --valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \ + --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --ignore-prefix-size 1 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} +``` +where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR +for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the `tst-COMMON` split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +fairseq-generate ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --gen-subset tst-COMMON_st --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu + +# For multilingual models +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" +for LANG in de nl es fr it pt ro ru; do + fairseq-generate ${MUSTC_ROOT} \ + --config-yaml config_st.yaml --gen-subset tst-COMMON_${LANG}_st --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu +done +``` +For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. + +#### Results +| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model | +|---|---|---|---|---|---|---|---|---|---|---|---| +| Bilingual | s2t_transformer_s | 31M | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_transformer_s.pt) | [27.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_transformer_s.pt) | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_transformer_s.pt) | [32.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_transformer_s.pt) | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_transformer_s.pt) | [28.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_transformer_s.pt) | [21.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_transformer_s.pt) | [15.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_transformer_s.pt) | (<-Download) | +| Multilingual | s2t_transformer_m | 76M | 24.5 | 28.6 | 28.2 | 34.9 | 24.6 | 31.1 | 23.8 | 16.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_transformer_m.pt) | + +[[Back]](..) diff --git a/examples/speech_to_text/prep_covost_data.py b/examples/speech_to_text/prep_covost_data.py index e8a028b446..af1d3fc6b8 100644 --- a/examples/speech_to_text/prep_covost_data.py +++ b/examples/speech_to_text/prep_covost_data.py @@ -5,10 +5,8 @@ # LICENSE file in the root directory of this source tree. import argparse -import csv import logging -import os -import os.path as op +from pathlib import Path import shutil from tempfile import NamedTemporaryFile from typing import Optional, Tuple @@ -22,6 +20,7 @@ gen_config_yaml, gen_vocab, get_zip_manifest, + load_df_from_tsv, save_df_to_tsv, ) from torch import Tensor @@ -49,10 +48,6 @@ class CoVoST(Dataset): found at root path. (default: ``False``). """ - CV_URL_TEMPLATE = ( - "https://voice-prod-bundler-ee1969a6ce8178826482b88" - "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz" - ) COVOST_URL_TEMPLATE = ( "https://dl.fbaipublicfiles.com/covost/" "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" @@ -61,8 +56,6 @@ class CoVoST(Dataset): VERSIONS = {2} SPLITS = ["train", "dev", "test"] - CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"} - XX_EN_LANGUAGES = { 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"], 2: [ @@ -117,7 +110,6 @@ def __init__( source_language: str, target_language: Optional[str] = None, version: int = 2, - download: bool = False, ) -> None: assert version in self.VERSIONS and split in self.SPLITS assert source_language is not None @@ -134,30 +126,22 @@ def __init__( # to Common Voice train split. target_language = "de" if source_language == "en" else "en" - self.root = os.path.join(root, "raw") - os.makedirs(self.root, exist_ok=True) + self.root: Path = Path(root) - cv_url = self.CV_URL_TEMPLATE.format( - ver=self.CV_VERSION_ID[version], lang=source_language - ) - cv_archive = os.path.join(self.root, os.path.basename(cv_url)) - if download: - if not os.path.isfile(cv_archive): - download_url(cv_url, self.root, hash_value=None) - extract_archive(cv_archive) + cv_tsv_path = self.root / "validated.tsv" + assert cv_tsv_path.is_file() covost_url = self.COVOST_URL_TEMPLATE.format( src_lang=source_language, tgt_lang=target_language ) - covost_archive = os.path.join(self.root, os.path.basename(covost_url)) - if download: - if not os.path.isfile(covost_archive): - download_url(covost_url, self.root, hash_value=None) - extract_archive(covost_archive) + covost_archive = self.root / Path(covost_url).name + if not covost_archive.is_file(): + download_url(covost_url, self.root.as_posix(), hash_value=None) + extract_archive(covost_archive.as_posix()) - cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv")) - covost_tsv = self.load_from_tsv( - os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", "")) + cv_tsv = load_df_from_tsv(cv_tsv_path) + covost_tsv = load_df_from_tsv( + self.root / Path(covost_url).name.replace(".tar.gz", "") ) df = pd.merge( left=cv_tsv[["path", "sentence", "client_id"]], @@ -169,20 +153,16 @@ def __init__( df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")] else: df = df[df["split"] == split] - self.data = df.to_dict(orient="index").items() - self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])] - - @classmethod - def load_from_tsv(cls, path: str): - return pd.read_csv( - path, - sep="\t", - header=0, - encoding="utf-8", - escapechar="\\", - quoting=csv.QUOTE_NONE, - na_filter=False, - ) + data = df.to_dict(orient="index").items() + data = [v for k, v in sorted(data, key=lambda x: x[0])] + self.data = [] + for e in data: + try: + path = self.root / "clips" / e["path"] + _ = torchaudio.info(path.as_posix()) + self.data.append(e) + except RuntimeError: + pass def __getitem__( self, n: int @@ -197,7 +177,7 @@ def __getitem__( sample_id)`` """ data = self.data[n] - path = os.path.join(self.root, "clips", data["path"]) + path = self.root / "clips" / data["path"] waveform, sample_rate = torchaudio.load(path) sentence = data["sentence"] translation = None if self.no_translation else data["translation"] @@ -210,26 +190,26 @@ def __len__(self) -> int: def process(args): - root = op.join(args.data_root, args.src_lang) - os.makedirs(root, exist_ok=True) + root = Path(args.data_root).absolute() / args.src_lang + if not root.is_dir(): + raise NotADirectoryError(f"{root} does not exist") # Extract features - feature_root = op.join(root, "fbank80") - os.makedirs(feature_root, exist_ok=True) + feature_root = root / "fbank80" + feature_root.mkdir(exist_ok=True) for split in CoVoST.SPLITS: print(f"Fetching split {split}...") - dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True) + dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) print("Extracting log mel filter bank features...") for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): extract_fbank_features( - waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy") + waveform, sample_rate, feature_root / f"{utt_id}.npy" ) # Pack features into ZIP - zip_filename = "fbank80.zip" - zip_path = op.join(root, zip_filename) + zip_path = root / "fbank80.zip" print("ZIPing features...") create_zip(feature_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}") + zip_manifest = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] @@ -251,7 +231,7 @@ def process(args): train_text.extend(manifest["tgt_text"]) df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) - save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv")) + save_df_to_tsv(df, root / f"{split}_{task}.tsv") # Generate vocab vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}" @@ -259,7 +239,10 @@ def process(args): for t in train_text: f.write(t + "\n") gen_vocab( - f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size + Path(f.name), + root / spm_filename_prefix, + args.vocab_type, + args.vocab_size ) # Generate config YAML gen_config_yaml( @@ -274,7 +257,10 @@ def process(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--data-root", "-d", required=True, type=str, + help="data root with sub-folders for each language /" + ) parser.add_argument( "--vocab-type", default="unigram", diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 95fcec8fe3..6a6f55ded4 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -6,8 +6,7 @@ import argparse import logging -import os -import os.path as op +from pathlib import Path import shutil from tempfile import NamedTemporaryFile @@ -40,34 +39,34 @@ def process(args): - os.makedirs(args.output_root, exist_ok=True) + out_root = Path(args.output_root).absolute() + out_root.mkdir(exist_ok=True) # Extract features - feature_root = op.join(args.output_root, "fbank80") - os.makedirs(feature_root, exist_ok=True) + feature_root = out_root / "fbank80" + feature_root.mkdir(exist_ok=True) for split in SPLITS: print(f"Fetching split {split}...") - dataset = LIBRISPEECH(args.output_root, url=split, download=True) + dataset = LIBRISPEECH(out_root.as_posix(), url=split, download=True) print("Extracting log mel filter bank features...") - for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset): - sample_id = f"{spk_id}-{chapter_id}-{utt_id}" + for wav, sample_rate, _, spk_id, chapter_no, utt_no in tqdm(dataset): + sample_id = f"{spk_id}-{chapter_no}-{utt_no}" extract_fbank_features( - wav, sample_rate, op.join(feature_root, f"{sample_id}.npy") + wav, sample_rate, feature_root / f"{sample_id}.npy" ) # Pack features into ZIP - zip_filename = "fbank80.zip" - zip_path = op.join(args.output_root, zip_filename) + zip_path = out_root / "fbank80.zip" print("ZIPing features...") create_zip(feature_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(args.output_root, zip_filename) + zip_manifest = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] for split in SPLITS: manifest = {c: [] for c in MANIFEST_COLUMNS} - dataset = LIBRISPEECH(args.output_root, url=split) - for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset): - sample_id = f"{spk_id}-{chapter_id}-{utt_id}" + dataset = LIBRISPEECH(out_root.as_posix(), url=split) + for wav, sample_rate, utt, spk_id, chapter_no, utt_no in tqdm(dataset): + sample_id = f"{spk_id}-{chapter_no}-{utt_no}" manifest["id"].append(sample_id) manifest["audio"].append(zip_manifest[sample_id]) duration_ms = int(wav.size(1) / sample_rate * 1000) @@ -75,7 +74,7 @@ def process(args): manifest["tgt_text"].append(utt) manifest["speaker"].append(spk_id) save_df_to_tsv( - pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv") + pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv" ) if split.startswith("train"): train_text.extend(manifest["tgt_text"]) @@ -86,14 +85,14 @@ def process(args): for t in train_text: f.write(t + "\n") gen_vocab( - f.name, - op.join(args.output_root, spm_filename_prefix), + Path(f.name), + out_root / spm_filename_prefix, args.vocab_type, args.vocab_size, ) # Generate config YAML gen_config_yaml( - args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld" + out_root, spm_filename_prefix + ".model", specaugment_policy="ld" ) # Clean up shutil.rmtree(feature_root) diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 59a42803f9..520968401c 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -7,7 +7,7 @@ import argparse import logging import os -import os.path as op +from pathlib import Path import shutil from itertools import groupby from tempfile import NamedTemporaryFile @@ -48,19 +48,19 @@ class MUSTC(Dataset): def __init__(self, root: str, lang: str, split: str) -> None: assert split in self.SPLITS and lang in self.LANGUAGES - _root = op.join(root, f"en-{lang}", "data", split) - wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt") - assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root) + _root = Path(root) / f"en-{lang}" / "data" / split + wav_root, txt_root = _root / "wav", _root / "txt" + assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() # Load audio segments try: import yaml except ImportError: - print("Please install PyYAML to load YAML files for " "the MuST-C dataset") - with open(op.join(txt_root, f"{split}.yaml")) as f: + print("Please install PyYAML to load the MuST-C YAML files") + with open(txt_root / f"{split}.yaml") as f: segments = yaml.load(f, Loader=yaml.BaseLoader) # Load source and target utterances for _lang in ["en", lang]: - with open(op.join(txt_root, f"{split}.{_lang}")) as f: + with open(txt_root / f"{split}.{_lang}") as f: utterances = [r.strip() for r in f] assert len(segments) == len(utterances) for i, u in enumerate(utterances): @@ -68,16 +68,16 @@ def __init__(self, root: str, lang: str, split: str) -> None: # Gather info self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): - wav_path = op.join(wav_root, wav_filename) - sample_rate = torchaudio.info(wav_path)[0].rate + wav_path = wav_root / wav_filename + sample_rate = torchaudio.info(wav_path.as_posix())[0].rate seg_group = sorted(_seg_group, key=lambda x: x["offset"]) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) n_frames = int(float(segment["duration"]) * sample_rate) - _id = f"{op.splitext(wav_filename)[0]}_{i}" + _id = f"{wav_path.stem}_{i}" self.data.append( ( - wav_path, + wav_path.as_posix(), offset, n_frames, sample_rate, @@ -98,29 +98,29 @@ def __len__(self) -> int: def process(args): + root = Path(args.data_root).absolute() for lang in MUSTC.LANGUAGES: - cur_root = op.join(args.data_root, f"en-{lang}") - if not op.isdir(cur_root): - print(f"{cur_root} does not exist. Skipped.") + cur_root = root / f"en-{lang}" + if not cur_root.is_dir(): + print(f"{cur_root.as_posix()} does not exist. Skipped.") continue # Extract features - feature_root = op.join(cur_root, "fbank80") - os.makedirs(feature_root, exist_ok=True) + feature_root = cur_root / "fbank80" + feature_root.mkdir(exist_ok=True) for split in MUSTC.SPLITS: print(f"Fetching split {split}...") - dataset = MUSTC(args.data_root, lang, split) + dataset = MUSTC(root.as_posix(), lang, split) print("Extracting log mel filter bank features...") for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): extract_fbank_features( - waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy") + waveform, sample_rate, feature_root / f"{utt_id}.npy" ) # Pack features into ZIP - zip_filename = "fbank80.zip" - zip_path = op.join(cur_root, zip_filename) + zip_path = cur_root / "fbank80.zip" print("ZIPing features...") create_zip(feature_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}") + zip_manifest = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] @@ -139,7 +139,7 @@ def process(args): train_text.extend(manifest["tgt_text"]) df = pd.DataFrame.from_dict(manifest) df = filter_manifest_df(df, is_train_split=is_train_split) - save_df_to_tsv(df, op.join(cur_root, f"{split}_{args.task}.tsv")) + save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") # Generate vocab v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" @@ -147,8 +147,8 @@ def process(args): for t in train_text: f.write(t + "\n") gen_vocab( - f.name, - op.join(cur_root, spm_filename_prefix), + Path(f.name), + cur_root / spm_filename_prefix, args.vocab_type, args.vocab_size, ) @@ -164,39 +164,42 @@ def process(args): def process_joint(args): - assert all( - op.isdir(op.join(args.data_root, f"en-{lang}")) for lang in MUSTC.LANGUAGES - ), "do not have downloaded data available for all 8 languages" - cur_root = args.data_root + cur_root = Path(args.data_root) + assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \ + "do not have downloaded data available for all 8 languages" # Generate vocab vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" with NamedTemporaryFile(mode="w") as f: for lang in MUSTC.LANGUAGES: - tsv_path = op.join(cur_root, f"en-{lang}", f"train_{args.task}.tsv") + tsv_path = cur_root / f"en-{lang}" / f"train_{args.task}.tsv" df = load_df_from_tsv(tsv_path) for t in df["tgt_text"]: f.write(t + "\n") + special_symbols = None + if args.task == 'st': + special_symbols = [f'' for lang in MUSTC.LANGUAGES] gen_vocab( - f.name, - op.join(cur_root, spm_filename_prefix), + Path(f.name), + cur_root / spm_filename_prefix, args.vocab_type, args.vocab_size, + special_symbols=special_symbols ) # Generate config YAML gen_config_yaml( cur_root, spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", - specaugment_policy="lb", + specaugment_policy="ld", prepend_tgt_lang_tag=(args.task == "st"), ) # Make symbolic links to manifests for lang in MUSTC.LANGUAGES: for split in MUSTC.SPLITS: - src_path = op.join(cur_root, f"en-{lang}", f"{split}_{args.task}.tsv") - desc_path = op.join(cur_root, f"{split}_{lang}_{args.task}.tsv") - if not op.islink(desc_path): + src_path = cur_root / f"en-{lang}" / f"{split}_{args.task}.tsv" + desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" + if not desc_path.is_symlink(): os.symlink(src_path, desc_path) diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py index d512fed300..e457ff176f 100644 --- a/fairseq/data/audio/feature_transforms/global_cmvn.py +++ b/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -16,9 +16,13 @@ def from_config_dict(cls, config=None): return GlobalCMVN(_config.get("stats_npz_path")) def __init__(self, stats_npz_path): + self.stats_npz_path = stats_npz_path stats = np.load(stats_npz_path) self.mean, self.std = stats["mean"], stats["std"] + def __repr__(self): + return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")' + def __call__(self, x): x = np.subtract(x, self.mean) x = np.divide(x, self.std) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index afd43f1ec7..1f556107a2 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -251,7 +251,7 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): """ The forward method inherited from the base class has a **kwargs argument in its input, which is not supported in torchscript. This - method overrites the forward method definition without **kwargs. + method overwrites the forward method definition without **kwargs. """ encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) decoder_out = self.decoder( @@ -397,7 +397,7 @@ def base_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", True + args, "share_decoder_input_output_embed", False ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 0a523680f0..7bd582b256 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -18,7 +18,6 @@ import numpy as np import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils -from fairseq.data import encoders from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter From f32de63e69aceb966b84c7c515a016ec96439125 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 11 Jan 2021 12:30:55 -0800 Subject: [PATCH 394/707] Fix IWSLT'14 link (fixes #2984) (#3113) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3113 Reviewed By: pritamdamania87 Differential Revision: D25836423 Pulled By: myleott fbshipit-source-id: 0fe9cafcfd0f3edab2db1025d2fcc8dbb8af570a --- examples/translation/prepare-iwslt14.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/translation/prepare-iwslt14.sh b/examples/translation/prepare-iwslt14.sh index 0bf0dc2a2e..2fb6643fbc 100644 --- a/examples/translation/prepare-iwslt14.sh +++ b/examples/translation/prepare-iwslt14.sh @@ -15,7 +15,7 @@ CLEAN=$SCRIPTS/training/clean-corpus-n.perl BPEROOT=subword-nmt/subword_nmt BPE_TOKENS=10000 -URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" +URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz" GZ=de-en.tgz if [ ! -d "$SCRIPTS" ]; then From 60d2da7055ad696ef037c98dacc931f79d4ce117 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Mon, 11 Jan 2021 14:01:45 -0800 Subject: [PATCH 395/707] cast scale window to int for memory efficient fp16 Summary: In some of my runs I found loss scale to never increase {F358780650} This was because scale window when using 48 GPUs by default will not be a whole integer. Lets cast it like the default fp16 optimizer. Reviewed By: myleott Differential Revision: D25872832 fbshipit-source-id: 7ab5c01c555dca07bda72cae3633bb4b17709a77 --- fairseq/optim/fp16_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index a0da4948c8..e0b069f172 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -480,7 +480,7 @@ def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): cfg.distributed_training.distributed_world_size / cfg.common.model_parallel_size ) - scale_window = ( + scale_window = int( 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) else: From 6f6f704d10fc00b89fb4a01e0a6857624132573e Mon Sep 17 00:00:00 2001 From: Yuqing Tang Date: Tue, 12 Jan 2021 11:26:11 -0800 Subject: [PATCH 396/707] added data scripts and model download links (#1554) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting ## What does this PR do? added data scripts and model download links for mBART 50 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1554 Reviewed By: pipibjc Differential Revision: D25871769 Pulled By: tangyuq fbshipit-source-id: 2c1b2f8c70e9cc4a083c87f76758fc472d42eaf0 --- examples/multilingual/ML50_langs.txt | 52 + examples/multilingual/README.md | 34 + examples/multilingual/data_scripts/README.md | 24 + .../multilingual/data_scripts/binarize.py | 200 ++++ .../data_scripts/check_iswlt_test_data.py | 67 ++ .../data_scripts/check_self_overlaps.py | 103 ++ .../data_scripts/check_valid_test_overlaps.py | 124 +++ .../multilingual/data_scripts/dedup_all.py | 52 + .../data_scripts/download_ML50_v1.sh | 30 + .../data_scripts/download_af_xh.sh | 164 ++++ .../data_scripts/download_flores_data.sh | 246 +++++ .../data_scripts/download_iitb.sh | 35 + .../download_iwslt_and_extract.sh | 225 +++++ .../data_scripts/download_lotus.sh | 46 + .../data_scripts/download_ted_and_extract.py | 338 +++++++ .../data_scripts/download_wat19_my.sh | 36 + .../data_scripts/download_wmt19_and_before.py | 899 ++++++++++++++++++ .../data_scripts/download_wmt20.sh | 547 +++++++++++ .../data_scripts/preprocess_ML50_v1.sh | 27 + .../remove_valid_test_in_train.py | 290 ++++++ .../multilingual/data_scripts/requirement.txt | 2 + .../multilingual/data_scripts/utils/dedup.py | 41 + .../utils/fasttext_multi_filter.py | 63 ++ .../data_scripts/utils/strip_sgm.sh | 1 + .../finetune_multilingual_model.sh | 5 + .../multilingual/multilingual_fairseq_gen.sh | 5 + .../multilingual/train_multilingual_model.sh | 5 + 27 files changed, 3661 insertions(+) create mode 100644 examples/multilingual/ML50_langs.txt create mode 100644 examples/multilingual/data_scripts/README.md create mode 100755 examples/multilingual/data_scripts/binarize.py create mode 100644 examples/multilingual/data_scripts/check_iswlt_test_data.py create mode 100644 examples/multilingual/data_scripts/check_self_overlaps.py create mode 100644 examples/multilingual/data_scripts/check_valid_test_overlaps.py create mode 100644 examples/multilingual/data_scripts/dedup_all.py create mode 100644 examples/multilingual/data_scripts/download_ML50_v1.sh create mode 100644 examples/multilingual/data_scripts/download_af_xh.sh create mode 100644 examples/multilingual/data_scripts/download_flores_data.sh create mode 100644 examples/multilingual/data_scripts/download_iitb.sh create mode 100644 examples/multilingual/data_scripts/download_iwslt_and_extract.sh create mode 100644 examples/multilingual/data_scripts/download_lotus.sh create mode 100644 examples/multilingual/data_scripts/download_ted_and_extract.py create mode 100644 examples/multilingual/data_scripts/download_wat19_my.sh create mode 100644 examples/multilingual/data_scripts/download_wmt19_and_before.py create mode 100644 examples/multilingual/data_scripts/download_wmt20.sh create mode 100644 examples/multilingual/data_scripts/preprocess_ML50_v1.sh create mode 100755 examples/multilingual/data_scripts/remove_valid_test_in_train.py create mode 100644 examples/multilingual/data_scripts/requirement.txt create mode 100644 examples/multilingual/data_scripts/utils/dedup.py create mode 100644 examples/multilingual/data_scripts/utils/fasttext_multi_filter.py create mode 100755 examples/multilingual/data_scripts/utils/strip_sgm.sh diff --git a/examples/multilingual/ML50_langs.txt b/examples/multilingual/ML50_langs.txt new file mode 100644 index 0000000000..558abbc785 --- /dev/null +++ b/examples/multilingual/ML50_langs.txt @@ -0,0 +1,52 @@ +ar_AR +cs_CZ +de_DE +en_XX +es_XX +et_EE +fi_FI +fr_XX +gu_IN +hi_IN +it_IT +ja_XX +kk_KZ +ko_KR +lt_LT +lv_LV +my_MM +ne_NP +nl_XX +ro_RO +ru_RU +si_LK +tr_TR +vi_VN +zh_CN +af_ZA +az_AZ +bn_IN +fa_IR +he_IL +hr_HR +id_ID +ka_GE +km_KH +mk_MK +ml_IN +mn_MN +mr_IN +pl_PL +ps_AF +pt_XX +sv_SE +sw_KE +ta_IN +te_IN +th_TH +tl_XX +uk_UA +ur_PK +xh_ZA +gl_ES +sl_SI \ No newline at end of file diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md index 35eca89804..0076f5e8f0 100644 --- a/examples/multilingual/README.md +++ b/examples/multilingual/README.md @@ -108,7 +108,41 @@ cat {source_lang}_${target_lang}.txt | grep -P "^T" |sort -V |cut -f 2- |$TOK_CM sacrebleu -tok 'none' -s 'none' ${source_lang}_${target_lang}.ref < ${source_lang}_${target_lang}.hyp ``` +# mBART50 models + +* [mMBART 50 pretrained model](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.pretrained.tar.gz). +* [mMBART 50 finetuned many-to-one](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.n1.tar.gz). +* [mMBART 50 finetuned one-to-many](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.1n.tar.gz). +* [mMBART 50 finetuned many-to-many](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.nn.tar.gz). + +Please download and extract from the above tarballs. Each tarball contains +* The fairseq model checkpoint: model.pt +* The list of supported languages: ML50_langs.txt +* Sentence piece model: sentence.bpe.model +* Fairseq dictionary of each language: dict.{lang}.txt (please replace lang with a language specified in ML50_langs.txt) + +To use the trained models, +* use the tool [binarize.py](./data_scripts/binarize.py) to binarize your data using sentence.bpe.model and dict.{lang}.txt, and copy the dictionaries to your data path +* then run the generation command: +```bash +path_2_data= +model=/model.pt +lang_list=/ML50_langs.txt +source_lang= +target_lang= +fairseq-generate $path_2_data \ + --path $model \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang $source_lang \ + --target-lang $target_lang + --sacrebleu --remove-bpe 'sentencepiece'\ + --batch-size 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" +``` ## Citation diff --git a/examples/multilingual/data_scripts/README.md b/examples/multilingual/data_scripts/README.md new file mode 100644 index 0000000000..cc610c0c9e --- /dev/null +++ b/examples/multilingual/data_scripts/README.md @@ -0,0 +1,24 @@ + +# Install dependency +```bash +pip install -r requirement.txt +``` + +# Download the data set +```bash +export WORKDIR_ROOT= + +``` +The downloaded data will be at $WORKDIR_ROOT/ML50 + +# preprocess the data +Install SPM [here](https://github.com/google/sentencepiece) +```bash +export WORKDIR_ROOT= +export SPM_PATH= +``` +* $WORKDIR_ROOT/ML50/raw: extracted raw data +* $WORKDIR_ROOT/ML50/dedup: dedup data +* $WORKDIR_ROOT/ML50/clean: data with valid and test sentences removed from the dedup data + + diff --git a/examples/multilingual/data_scripts/binarize.py b/examples/multilingual/data_scripts/binarize.py new file mode 100755 index 0000000000..ee54c6aabf --- /dev/null +++ b/examples/multilingual/data_scripts/binarize.py @@ -0,0 +1,200 @@ +import shutil +import os, sys +from subprocess import check_call, check_output +import glob +import argparse +import shutil +import pathlib +import itertools + +def call_output(cmd): + print(f"Executing: {cmd}") + ret = check_output(cmd, shell=True) + print(ret) + return ret + +def call(cmd): + print(cmd) + check_call(cmd, shell=True) + + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +SPM_PATH = os.environ.get('SPM_PATH', None) + +if SPM_PATH is None or not SPM_PATH.strip(): + print("Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting...") + sys.exit(-1) + + +SPM_MODEL = f'{WORKDIR_ROOT}/sentence.bpe.model' +SPM_VOCAB = f'{WORKDIR_ROOT}/dict_250k.txt' + +SPM_ENCODE = f'{SPM_PATH}' + +if not os.path.exists(SPM_MODEL): + call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/sentence.bpe.model -O {SPM_MODEL}") + + +if not os.path.exists(SPM_VOCAB): + call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/dict_250k.txt -O {SPM_VOCAB}") + + + +def get_data_size(raw): + cmd = f'wc -l {raw}' + ret = call_output(cmd) + return int(ret.split()[0]) + +def encode_spm(model, direction, prefix='', splits=['train', 'test', 'valid'], pairs_per_shard=None): + src, tgt = direction.split('-') + + for split in splits: + src_raw, tgt_raw = f'{RAW_DIR}/{split}{prefix}.{direction}.{src}', f'{RAW_DIR}/{split}{prefix}.{direction}.{tgt}' + if os.path.exists(src_raw) and os.path.exists(tgt_raw): + cmd = f"""python {SPM_ENCODE} \ + --model {model}\ + --output_format=piece \ + --inputs {src_raw} {tgt_raw} \ + --outputs {BPE_DIR}/{direction}{prefix}/{split}.bpe.{src} {BPE_DIR}/{direction}{prefix}/{split}.bpe.{tgt} """ + print(cmd) + call(cmd) + + +def binarize_( + bpe_dir, + databin_dir, + direction, spm_vocab=SPM_VOCAB, + splits=['train', 'test', 'valid'], +): + src, tgt = direction.split('-') + + try: + shutil.rmtree(f'{databin_dir}', ignore_errors=True) + os.mkdir(f'{databin_dir}') + except OSError as error: + print(error) + cmds = [ + "fairseq-preprocess", + f"--source-lang {src} --target-lang {tgt}", + f"--destdir {databin_dir}/", + f"--workers 8", + ] + if isinstance(spm_vocab, tuple): + src_vocab, tgt_vocab = spm_vocab + cmds.extend( + [ + f"--srcdict {src_vocab}", + f"--tgtdict {tgt_vocab}", + ] + ) + else: + cmds.extend( + [ + f"--joined-dictionary", + f"--srcdict {spm_vocab}", + ] + ) + input_options = [] + if 'train' in splits and glob.glob(f"{bpe_dir}/train.bpe*"): + input_options.append( + f"--trainpref {bpe_dir}/train.bpe", + ) + if 'valid' in splits and glob.glob(f"{bpe_dir}/valid.bpe*"): + input_options.append(f"--validpref {bpe_dir}/valid.bpe") + if 'test' in splits and glob.glob(f"{bpe_dir}/test.bpe*"): + input_options.append(f"--testpref {bpe_dir}/test.bpe") + if len(input_options) > 0: + cmd = " ".join(cmds + input_options) + print(cmd) + call(cmd) + + +def binarize( + databin_dir, + direction, spm_vocab=SPM_VOCAB, prefix='', + splits=['train', 'test', 'valid'], + pairs_per_shard=None, +): + def move_databin_files(from_folder, to_folder): + for bin_file in glob.glob(f"{from_folder}/*.bin") \ + + glob.glob(f"{from_folder}/*.idx") \ + + glob.glob(f"{from_folder}/dict*"): + try: + shutil.move(bin_file, to_folder) + except OSError as error: + print(error) + bpe_databin_dir = f"{BPE_DIR}/{direction}{prefix}_databin" + bpe_dir = f"{BPE_DIR}/{direction}{prefix}" + if pairs_per_shard is None: + binarize_(bpe_dir, bpe_databin_dir, direction, spm_vocab=spm_vocab, splits=splits) + move_databin_files(bpe_databin_dir, databin_dir) + else: + # binarize valid and test which will not be sharded + binarize_( + bpe_dir, bpe_databin_dir, direction, + spm_vocab=spm_vocab, splits=[s for s in splits if s != "train"]) + for shard_bpe_dir in glob.glob(f"{bpe_dir}/shard*"): + path_strs = os.path.split(shard_bpe_dir) + shard_str = path_strs[-1] + shard_folder = f"{bpe_databin_dir}/{shard_str}" + databin_shard_folder = f"{databin_dir}/{shard_str}" + print(f'working from {shard_folder} to {databin_shard_folder}') + os.makedirs(databin_shard_folder, exist_ok=True) + binarize_( + shard_bpe_dir, shard_folder, direction, + spm_vocab=spm_vocab, splits=["train"]) + + for test_data in glob.glob(f"{bpe_databin_dir}/valid.*") + glob.glob(f"{bpe_databin_dir}/test.*"): + filename = os.path.split(test_data)[-1] + try: + os.symlink(test_data, f"{databin_shard_folder}/{filename}") + except OSError as error: + print(error) + move_databin_files(shard_folder, databin_shard_folder) + + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", default=f"{WORKDIR_ROOT}/ML50") + parser.add_argument("--raw-folder", default='raw') + parser.add_argument("--bpe-folder", default='bpe') + parser.add_argument("--databin-folder", default='databin') + + args = parser.parse_args() + + DATA_PATH = args.data_root #'/private/home/yuqtang/public_data/ML50' + RAW_DIR = f'{DATA_PATH}/{args.raw_folder}' + BPE_DIR = f'{DATA_PATH}/{args.bpe_folder}' + DATABIN_DIR = f'{DATA_PATH}/{args.databin_folder}' + os.makedirs(BPE_DIR, exist_ok=True) + + raw_files = itertools.chain( + glob.glob(f'{RAW_DIR}/train*'), + glob.glob(f'{RAW_DIR}/valid*'), + glob.glob(f'{RAW_DIR}/test*'), + ) + + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + + for direction in directions: + prefix = "" + splits = ['train', 'valid', 'test'] + try: + shutil.rmtree(f'{BPE_DIR}/{direction}{prefix}', ignore_errors=True) + os.mkdir(f'{BPE_DIR}/{direction}{prefix}') + os.makedirs(DATABIN_DIR, exist_ok=True) + except OSError as error: + print(error) + spm_model, spm_vocab = SPM_MODEL, SPM_VOCAB + encode_spm(spm_model, direction=direction, splits=splits) + binarize(DATABIN_DIR, direction, spm_vocab=spm_vocab, splits=splits) diff --git a/examples/multilingual/data_scripts/check_iswlt_test_data.py b/examples/multilingual/data_scripts/check_iswlt_test_data.py new file mode 100644 index 0000000000..f8e2eb0f15 --- /dev/null +++ b/examples/multilingual/data_scripts/check_iswlt_test_data.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os, sys +import subprocess +import re +from subprocess import check_call, check_output + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +BLEU_REGEX = re.compile("^BLEU\\S* = (\\S+) ") +def run_eval_bleu(cmd): + output = check_output(cmd, shell=True, stderr=subprocess.STDOUT).decode("utf-8").strip() + print(output) + bleu = -1.0 + for line in output.strip().split('\n'): + m = BLEU_REGEX.search(line) + if m is not None: + bleu = m.groups()[0] + bleu = float(bleu) + break + return bleu + +def check_data_test_bleu(raw_folder, data_lang_pairs): + not_matchings = [] + for sacrebleu_set, src_tgts in data_lang_pairs: + for src_tgt in src_tgts: + print(f'checking test bleus for: {src_tgt} at {sacrebleu_set}') + src, tgt = src_tgt.split('-') + ssrc, stgt = src[:2], tgt[:2] + if os.path.exists(f'{raw_folder}/test.{tgt}-{src}.{src}'): + # reversed direction may have different test set + test_src = f'{raw_folder}/test.{tgt}-{src}.{src}' + else: + test_src = f'{raw_folder}/test.{src}-{tgt}.{src}' + cmd1 = f'cat {test_src} | sacrebleu -t "{sacrebleu_set}" -l {stgt}-{ssrc}; [ $? -eq 0 ] || echo ""' + test_tgt = f'{raw_folder}/test.{src}-{tgt}.{tgt}' + cmd2 = f'cat {test_tgt} | sacrebleu -t "{sacrebleu_set}" -l {ssrc}-{stgt}; [ $? -eq 0 ] || echo ""' + bleu1 = run_eval_bleu(cmd1) + if bleu1 != 100.0: + not_matchings.append(f'{sacrebleu_set}:{src_tgt} source side not matching: {test_src}') + bleu2 = run_eval_bleu(cmd2) + if bleu2 != 100.0: + not_matchings.append(f'{sacrebleu_set}:{src_tgt} target side not matching: {test_tgt}') + return not_matchings + +if __name__ == "__main__": + to_data_path = f'{WORKDIR_ROOT}/iwsltv2' + not_matching = check_data_test_bleu( + f'{to_data_path}/raw', + [ + ('iwslt17', ['en_XX-ar_AR', 'en_XX-ko_KR', 'ar_AR-en_XX', 'ko_KR-en_XX']), + ('iwslt17', ['en_XX-it_IT', 'en_XX-nl_XX', 'it_IT-en_XX', 'nl_XX-en_XX']), + ('iwslt17/tst2015', ['en_XX-vi_VN', "vi_VN-en_XX"]), + ] + ) + if len(not_matching) > 0: + print('the following datasets do not have matching test datasets:\n\t', '\n\t'.join(not_matching)) + diff --git a/examples/multilingual/data_scripts/check_self_overlaps.py b/examples/multilingual/data_scripts/check_self_overlaps.py new file mode 100644 index 0000000000..07b338dcfd --- /dev/null +++ b/examples/multilingual/data_scripts/check_self_overlaps.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import glob +import argparse +from utils.dedup import deup +import sys + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +def get_directions(folder): + raw_files = glob.glob(f'{folder}/train*') + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + return directions + +def diff_list(lhs, rhs): + return set(lhs).difference(set(rhs)) + +def check_diff( + from_src_file, from_tgt_file, + to_src_file, to_tgt_file, +): + seen_in_from = set() + seen_src_in_from = set() + seen_tgt_in_from = set() + from_count = 0 + with open(from_src_file, encoding='utf-8') as fsrc, \ + open(from_tgt_file, encoding='utf-8') as ftgt: + for s, t in zip(fsrc, ftgt): + seen_in_from.add((s, t)) + seen_src_in_from.add(s) + seen_tgt_in_from.add(t) + from_count += 1 + common = 0 + common_src = 0 + common_tgt = 0 + to_count = 0 + seen = set() + + with open(to_src_file, encoding='utf-8') as fsrc, \ + open(to_tgt_file, encoding='utf-8') as ftgt: + for s, t in zip(fsrc, ftgt): + to_count += 1 + if (s, t) not in seen: + if (s, t) in seen_in_from: + common += 1 + if s in seen_src_in_from: + common_src += 1 + seen_src_in_from.remove(s) + if t in seen_tgt_in_from: + common_tgt += 1 + seen_tgt_in_from.remove(t) + seen.add((s, t)) + return common, common_src, common_tgt, from_count, to_count + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--folder", type=str, required=True, + help="the data folder ") + parser.add_argument("--split", type=str, default='test', + help="split (valid, test) to check against training data") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + + if args.directions is None: + directions = set(get_directions(args.folder)) + directions = sorted(directions) + else: + directions = args.directions.split(',') + directions = sorted(set(directions)) + + results = [] + print(f'checking where {args.split} split data are in training') + print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size') + + for direction in directions: + src, tgt = direction.split('-') + from_src_file = f'{args.folder}/{args.split}.{src}-{tgt}.{src}' + from_tgt_file = f'{args.folder}/{args.split}.{src}-{tgt}.{tgt}' + if not os.path.exists(from_src_file): + # some test/valid data might in reverse directinos: + from_src_file = f'{args.folder}/{args.split}.{tgt}-{src}.{src}' + from_tgt_file = f'{args.folder}/{args.split}.{tgt}-{src}.{tgt}' + to_src_file = f'{args.folder}/train.{src}-{tgt}.{src}' + to_tgt_file = f'{args.folder}/train.{src}-{tgt}.{tgt}' + if not os.path.exists(to_src_file) or not os.path.exists(from_src_file): + continue + r = check_diff(from_src_file, from_tgt_file, to_src_file, to_tgt_file) + results.append(r) + print(f'{direction}\t', '\t'.join(map(str, r))) + + +if __name__ == "__main__": + main() diff --git a/examples/multilingual/data_scripts/check_valid_test_overlaps.py b/examples/multilingual/data_scripts/check_valid_test_overlaps.py new file mode 100644 index 0000000000..40fa9aecdf --- /dev/null +++ b/examples/multilingual/data_scripts/check_valid_test_overlaps.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import argparse +import pandas as pd +import sys + + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + + + +def load_sentences(raw_data, split, direction): + src, tgt = direction.split('-') + src_path = f"{raw_data}/{split}.{direction}.{src}" + tgt_path = f"{raw_data}/{split}.{direction}.{tgt}" + if os.path.exists(src_path) and os.path.exists(tgt_path): + return [(src, open(src_path).read().splitlines()), (tgt, open(tgt_path).read().splitlines())] + else: + return [] + +def swap_direction(d): + src, tgt = d.split('-') + return f'{tgt}-{src}' + +def get_all_test_data(raw_data, directions, split='test'): + test_data = [ + x + for dd in directions + for d in [dd, swap_direction(dd)] + for x in load_sentences(raw_data, split, d) + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_data = {} + for lang, d in test_data: + for s in d: + s = s.strip() + lgs = all_test_data.get(s, set()) + lgs.add(lang) + all_test_data[s] = lgs + return all_test_data, test_data + + +def check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train={}): + # src, tgt = direction.split('-') + print(f'check training data for {direction} in {src_path} and {tgt_path}') + size = 0 + overlapped_size_counted_dup = 0 + if not os.path.exists(tgt_path) or not os.path.exists(src_path): + return mess_up_train, size, overlapped_size_counted_dup + + with open(src_path) as f, open(tgt_path) as g: + for src_line, tgt_line in zip(f, g): + s = src_line.strip() + t = tgt_line.strip() + size += 1 + if s in all_test_data: + langs = mess_up_train.get(s, set()) + langs.add(direction) + mess_up_train[s] = langs + overlapped_size_counted_dup += 1 + if t in all_test_data: + langs = mess_up_train.get(t, set()) + langs.add(direction) + mess_up_train[t] = langs + overlapped_size_counted_dup += 1 + print(f'{direction}: size={size}, overlapped={overlapped_size_counted_dup}') + return mess_up_train, size, overlapped_size_counted_dup + +def check_train_all(raw_data, directions, all_test_data): + mess_up_train = {} + data_sizes = {} + # raw_data = '~chau/data-bin/MineBART/multilingual_mined_100M/en_XX/et_EE-en_XX/all.{en_XX, et_EE}' + print(f'checking training data againsts # {len(all_test_data)} sentences') + print(f'example test data: ', [s for i, s in enumerate(all_test_data.keys()) if i < 10]) + for direction in directions: + src, tgt = direction.split('-') + path = f'{raw_data}/en_XX/{direction}/all' + src_path = f'{path}.{src}' + tgt_path = f'{path}.{tgt}' + print(f'checking {src_path} {tgt_path}') + _, size, overlapped_size_counted_dup = check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train) + data_sizes[direction] = (size, overlapped_size_counted_dup) + return mess_up_train, data_sizes + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--folder", type=str, required=True, + help="the data folder ") + parser.add_argument("--test-data", type=str, required=True, + help="the test data folder ") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + directions = args.directions.split(',') + directions = sorted(set(directions)) + + results = [] + # print(f'checking where {args.split} split data are in training') + # print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size') + raw_data = args.folder + all_test_data, test_data = get_all_test_data(args.test_data, directions, split='test') + mess_up_train, data_sizes = check_train_all(raw_data, directions, all_test_data) + print(data_sizes) + + +if __name__ == "__main__": + main() diff --git a/examples/multilingual/data_scripts/dedup_all.py b/examples/multilingual/data_scripts/dedup_all.py new file mode 100644 index 0000000000..ef39c05ee6 --- /dev/null +++ b/examples/multilingual/data_scripts/dedup_all.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + + +import os +import glob +import argparse +from utils.dedup import deup + +import sys +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--from-folder", type=str, required=True, + help="the data folder to be dedup") + parser.add_argument("--to-folder", type=str, required=True, + help="the data folder to save deduped data") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + + if args.directions is None: + raw_files = glob.glob(f'{args.from_folder}/train*') + + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + else: + directions = args.directions.split(',') + directions = sorted(set(directions)) + + for direction in directions: + src, tgt = direction.split('-') + src_file = f'{args.from_folder}/train.{src}-{tgt}.{src}' + tgt_file = f'{args.from_folder}/train.{src}-{tgt}.{tgt}' + src_file_out = f'{args.to_folder}/train.{src}-{tgt}.{src}' + tgt_file_out = f'{args.to_folder}/train.{src}-{tgt}.{tgt}' + assert src_file != src_file_out + assert tgt_file != tgt_file_out + print(f'deduping {src_file}, {tgt_file}') + deup(src_file, tgt_file, src_file_out, tgt_file_out) + + +if __name__ == "__main__": + main() diff --git a/examples/multilingual/data_scripts/download_ML50_v1.sh b/examples/multilingual/data_scripts/download_ML50_v1.sh new file mode 100644 index 0000000000..99fbc75920 --- /dev/null +++ b/examples/multilingual/data_scripts/download_ML50_v1.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +# first run download_wmt20.sh; it will install a few useful tools for other scripts +# TODO: need to print out instructions on downloading a few files which requires manually authentication from the websites +bash ./download_wmt20.sh + +python ./download_wmt19_and_before.py +bash ./download_wat19_my.sh +python ./download_ted_and_extract.py +bash ./download_lotus.sh +bash ./download_iitb.sh +bash ./download_af_xh.sh + + +# IWSLT downloading URLs have changed in between; TODO: fix them: +bash ./download_iwslt_and_extract.sh + +# TODO: globalvoices URLs changed; need to be fixed +bash ./download_flores_data.sh diff --git a/examples/multilingual/data_scripts/download_af_xh.sh b/examples/multilingual/data_scripts/download_af_xh.sh new file mode 100644 index 0000000000..a78fbbbbcc --- /dev/null +++ b/examples/multilingual/data_scripts/download_af_xh.sh @@ -0,0 +1,164 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# set -x -e + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +# put intermediate files +TMP_DIR=$WORKDIR_ROOT/temp/af_xhv2 +# output {train,valid,test} files to dest +DEST=${WORKDIR_ROOT}/ML50/raw + + + +ROOT=${WORKDIR_ROOT} +UTILS=$PWD/utils +TMX2CORPUS="${UTILS}/tmx2corpus" +TMX_TOOL="python ${TMX2CORPUS}/tmx2corpus.py" + +mkdir -p $TMP_DIR +mkdir -p $DEST +mkdir -p $UTILS + +function download_opus(){ + src=$1 + tgt=$2 + subset=$3 + ulr=$4 + + mkdir extract_$subset.$src-$tgt + pushd extract_$subset.$src-$tgt + if [ ! -f "$subset.$src-$tgt.tmx.gz" ]; then + wget $url -O "$subset.$src-$tgt.tmx.gz" + gzip -d "$subset.$src-$tgt.tmx.gz" + f=$subset.$src-$tgt.tmx + $TMX_TOOL $f + mv bitext.$src ../$subset.$src-$tgt.$src + mv bitext.$tgt ../$subset.$src-$tgt.$tgt + fi + popd +} + +function concat_subsets(){ + src=$1 + tgt=$2 + subsets=$3 + src_train=raw_train.$src-$tgt.$src + tgt_train=raw_train.$src-$tgt.$tgt + > $src_train + > $tgt_train + for subset in $subsets; do + cat $subset.$src-$tgt.$src >> $src_train + cat $subset.$src-$tgt.$tgt >> $tgt_train + done +} + + + +function get_seeded_random() +{ + seed="$1" + openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \ + /dev/null +} + +function split_train_valid(){ + src=$1 + tgt=$2 + raw_src_train=raw_train.$src-$tgt.$src + raw_tgt_train=raw_train.$src-$tgt.$tgt + + shuf --random-source=<(get_seeded_random 43) $raw_src_train > shuffled.$src-$tgt.$src + shuf --random-source=<(get_seeded_random 43) $raw_tgt_train > shuffled.$src-$tgt.$tgt + + head -n 1500 shuffled.$src-$tgt.$src > valid.$src-$tgt.$src + head -n 1500 shuffled.$src-$tgt.$tgt > valid.$src-$tgt.$tgt + + tail +1501 shuffled.$src-$tgt.$src > train.$src-$tgt.$src + tail +1501 shuffled.$src-$tgt.$tgt > train.$src-$tgt.$tgt +} + +function copy2dst(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + + + cp valid.$src-$tgt.$src $DEST/valid.$lsrc-$ltgt.$lsrc + cp valid.$src-$tgt.$tgt $DEST/valid.$lsrc-$ltgt.$ltgt + + cp train.$src-$tgt.$src $DEST/train.$lsrc-$ltgt.$lsrc + cp train.$src-$tgt.$tgt $DEST/train.$lsrc-$ltgt.$ltgt +} + + + + +#for xh-en +declare -A xh_en_urls +xh_en_urls=( + [Tatoeba]=https://object.pouta.csc.fi/OPUS-Tatoeba/v20190709/tmx/en-xh.tmx.gz + [wikimedia]=https://object.pouta.csc.fi/OPUS-wikimedia/v20190628/tmx/en-xh.tmx.gz + [memat]=https://object.pouta.csc.fi/OPUS-memat/v1/tmx/en-xh.tmx.gz + [uedin]=https://object.pouta.csc.fi/OPUS-bible-uedin/v1/tmx/en-xh.tmx.gz + [GNOME]=https://object.pouta.csc.fi/OPUS-GNOME/v1/tmx/en-xh.tmx.gz + [XhosaNavy]=https://object.pouta.csc.fi/OPUS-XhosaNavy/v1/tmx/en-xh.tmx.gz + [KDE4]=https://object.pouta.csc.fi/OPUS-KDE4/v2/tmx/en-xh.tmx.gz + [Ubuntu]=https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/tmx/en-xh.tmx.gz +) + +mkdir $TMP_DIR/xh-en +pushd $TMP_DIR/xh-en +for k in "${!xh_en_urls[@]}" +do + name=$k + url=${xh_en_urls[$k]} + echo "$name: $url" + download_opus xh en $name $ulr +done +concat_subsets xh en "${!xh_en_urls[@]}" +split_train_valid xh en +copy2dst xh_ZA en_XX +popd + + +## +#for af-en +declare -A af_en_urls +af_en_urls=( + [Tatoeba]=https://object.pouta.csc.fi/OPUS-Tatoeba/v20190709/tmx/af-en.tmx.gz + [uedin]=https://object.pouta.csc.fi/OPUS-bible-uedin/v1/tmx/af-en.tmx.gz + [GNOME]=https://object.pouta.csc.fi/OPUS-GNOME/v1/tmx/af-en.tmx.gz + [QED]=https://object.pouta.csc.fi/OPUS-QED/v2.0a/tmx/af-en.tmx.gz + [KDE4]=https://object.pouta.csc.fi/OPUS-KDE4/v2/tmx/af-en.tmx.gz + [OpenSubtitles]=https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2018/tmx/af-en.tmx.gz + [SPC]=https://object.pouta.csc.fi/OPUS-SPC/v1/tmx/af-en.tmx.gz + [Ubuntu]=https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/tmx/af-en.tmx.gz +) + +mkdir $TMP_DIR/af-en +pushd $TMP_DIR/af-en +for k in "${!af_en_urls[@]}" +do + name=$k + url=${af_en_urls[$k]} + echo "$name: $url" + download_opus af en $name $ulr +done +concat_subsets af en "${!af_en_urls[@]}" +split_train_valid af en +copy2dst af_ZA en_XX +popd + + diff --git a/examples/multilingual/data_scripts/download_flores_data.sh b/examples/multilingual/data_scripts/download_flores_data.sh new file mode 100644 index 0000000000..e6175ce0c3 --- /dev/null +++ b/examples/multilingual/data_scripts/download_flores_data.sh @@ -0,0 +1,246 @@ +#!/bin/bash + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +set -e +set -o pipefail + +SRC=en +SI_TGT=si +NE_TGT=ne + +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + +ROOT=${WORKDIR_ROOT}/tmp +mkdir -p $ROOT +DATA=$ROOT/data +NE_ROOT=$DATA/all-clean-ne +SI_ROOT=$DATA/all-clean-si + +mkdir -p $DATA $NE_ROOT $SI_ROOT + +SI_OPUS_DATASETS=( + "$SI_ROOT/GNOME.en-si" + "$SI_ROOT/Ubuntu.en-si" + "$SI_ROOT/KDE4.en-si" + "$SI_ROOT/OpenSubtitles.en-si" +) + +SI_OPUS_URLS=( + "https://object.pouta.csc.fi/OPUS-GNOME/v1/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-KDE4/v2/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2018/moses/en-si.txt.zip" +) + +NE_OPUS_DATASETS=( + "$NE_ROOT/GNOME.en-ne" + "$NE_ROOT/Ubuntu.en-ne" + "$NE_ROOT/KDE4.en-ne" +) + +NE_OPUS_URLS=( + "https://object.pouta.csc.fi/OPUS-GNOME/v1/moses/en-ne.txt.zip" + "https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/moses/en-ne.txt.zip" + "https://object.pouta.csc.fi/OPUS-KDE4/v2/moses/en-ne.txt.zip" +) + +REMOVE_FILE_PATHS=() + +# Download data +download_data() { + CORPORA=$1 + URL=$2 + + if [ -f $CORPORA ]; then + echo "$CORPORA already exists, skipping download" + else + echo "Downloading $URL" + wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA + if [ -f $CORPORA ]; then + echo "$URL successfully downloaded." + else + echo "$URL not successfully downloaded." + rm -f $CORPORA + exit -1 + fi + fi +} + +# Example: download_opus_data $LANG_ROOT $TGT +download_opus_data() { + LANG_ROOT=$1 + TGT=$2 + + if [ "$TGT" = "si" ]; then + URLS=("${SI_OPUS_URLS[@]}") + DATASETS=("${SI_OPUS_DATASETS[@]}") + else + URLS=("${NE_OPUS_URLS[@]}") + DATASETS=("${NE_OPUS_DATASETS[@]}") + fi + + # Download and extract data + for ((i=0;i<${#URLS[@]};++i)); do + URL=${URLS[i]} + CORPORA=${DATASETS[i]} + + download_data $CORPORA $URL + unzip -o $CORPORA -d $LANG_ROOT + REMOVE_FILE_PATHS+=( $CORPORA $CORPORA.xml $CORPORA.ids $LANG_ROOT/README $LANG_ROOT/LICENSE ) + done + + cat ${DATASETS[0]}.$SRC ${DATASETS[1]}.$SRC ${DATASETS[2]}.$SRC > $LANG_ROOT/GNOMEKDEUbuntu.$SRC-$TGT.$SRC + cat ${DATASETS[0]}.$TGT ${DATASETS[1]}.$TGT ${DATASETS[2]}.$TGT > $LANG_ROOT/GNOMEKDEUbuntu.$SRC-$TGT.$TGT + + REMOVE_FILE_PATHS+=( ${DATASETS[0]}.$SRC ${DATASETS[1]}.$SRC ${DATASETS[2]}.$SRC ) + REMOVE_FILE_PATHS+=( ${DATASETS[0]}.$TGT ${DATASETS[1]}.$TGT ${DATASETS[2]}.$TGT ) +} + +download_opus_data $SI_ROOT $SI_TGT +cp ${SI_OPUS_DATASETS[3]}.$SRC $SI_ROOT/OpenSubtitles2018.$SRC-$SI_TGT.$SRC +cp ${SI_OPUS_DATASETS[3]}.$SI_TGT $SI_ROOT/OpenSubtitles2018.$SRC-$SI_TGT.$SI_TGT +REMOVE_FILE_PATHS+=( ${SI_OPUS_DATASETS[3]}.$SRC ${SI_OPUS_DATASETS[3]}.$SI_TGT ) + +download_opus_data $NE_ROOT $NE_TGT + + +# Download and extract Global Voices data +GLOBAL_VOICES="$NE_ROOT/globalvoices.2018q4.ne-en" +GLOBAL_VOICES_URL="http://www.casmacat.eu/corpus/global-voices/globalvoices.ne-en.xliff.gz" + +download_data $GLOBAL_VOICES.gz $GLOBAL_VOICES_URL +gunzip -Nf $GLOBAL_VOICES.gz + +sed -ne 's?.*\(.*\).*?\1?p' $GLOBAL_VOICES > $GLOBAL_VOICES.$NE_TGT +sed -ne 's?.*]*>\(.*\).*?\1?p' $GLOBAL_VOICES > $GLOBAL_VOICES.$SRC + +REMOVE_FILE_PATHS+=( $GLOBAL_VOICES ) + +# Download and extract the bible dataset +BIBLE_TOOLS=bible-corpus-tools +XML_BIBLES=XML_Bibles +XML_BIBLES_DUP=XML_Bibles_dup + +if [ ! -e $BIBLE_TOOLS ]; then + echo "Cloning bible-corpus-tools repository..." + git clone https://github.com/christos-c/bible-corpus-tools.git +fi + +mkdir -p $BIBLE_TOOLS/bin $XML_BIBLES $XML_BIBLES_DUP +javac -cp "$BIBLE_TOOLS/lib/*" -d $BIBLE_TOOLS/bin $BIBLE_TOOLS/src/bible/readers/*.java $BIBLE_TOOLS/src/bible/*.java + +download_data bible.tar.gz "https://github.com/christos-c/bible-corpus/archive/v1.2.1.tar.gz" +tar xvzf bible.tar.gz + +cp bible-corpus-1.2.1/bibles/{Greek.xml,English.xml,Nepali.xml} $XML_BIBLES/ +cp bible-corpus-1.2.1/bibles/{Greek.xml,English-WEB.xml,Nepali.xml} $XML_BIBLES_DUP/ + +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateMLBooks $XML_BIBLES +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateMLBooks $XML_BIBLES_DUP +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateVerseAlignedBooks $XML_BIBLES +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateVerseAlignedBooks $XML_BIBLES_DUP + +cat $XML_BIBLES/aligned/*/English.txt > $NE_ROOT/bible.$SRC-$NE_TGT.$SRC +cat $XML_BIBLES/aligned/*/Nepali.txt > $NE_ROOT/bible.$SRC-$NE_TGT.$NE_TGT +cat $XML_BIBLES_DUP/aligned/*/English-WEB.txt > $NE_ROOT/bible_dup.$SRC-$NE_TGT.$SRC +cat $XML_BIBLES_DUP/aligned/*/Nepali.txt > $NE_ROOT/bible_dup.$SRC-$NE_TGT.$NE_TGT +REMOVE_FILE_PATHS+=( bible-corpus-1.2.1 bible.tar.gz $BIBLE_TOOLS $XML_BIBLES $XML_BIBLES_DUP ) + +# Download and extract the Penn Treebank dataset +NE_TAGGED=$ROOT/new_submissions_parallel_corpus_project_Nepal +NE_TAGGED_URL="http://www.cle.org.pk/Downloads/ling_resources/parallelcorpus/NepaliTaggedCorpus.zip" +EN_TAGGED_PATCH_URL="https://dl.fbaipublicfiles.com/fairseq/data/nepali-penn-treebank.en.patch" +NE_TAGGED_PATCH_URL="https://dl.fbaipublicfiles.com/fairseq/data/nepali-penn-treebank.ne.patch" +MOSES=mosesdecoder +MOSES_TOK=$MOSES/scripts/tokenizer +EN_PATCH_REGEX="{s:\\\/:\/:g;s/\*\T\*\-\n+//g;s/\-LCB\-/\{/g;s/\-RCB\-/\}/g; s/\-LSB\-/\[/g; s/\-RSB\-/\]/g;s/\-LRB\-/\(/g; s/\-RRB\-/\)/g; s/\'\'/\"/g; s/\`\`/\"/g; s/\ +\'s\ +/\'s /g; s/\ +\'re\ +/\'re /g; s/\"\ +/\"/g; s/\ +\"/\"/g; s/\ n't([\ \.\"])/n't\1/g; s/\r+(.)/\1/g;}" +NE_PATCH_REGEX="{s:\p{Cf}::g;s:\\\/:\/:g;s/\*\T\*\-\n+//g;s/\-LCB\-/\{/g;s/\-RCB\-/\}/g; s/\-LSB\-/\[/g; s/\-RSB\-/\]/g;s/\-LRB\-/\(/g; s/\-RRB\-/\)/g; s/\'\'/\"/g; s/\`\`/\"/g; s/\ +\'s\ +/\'s /g; s/\ +\'re\ +/\'re /g; s/\"\ +/\"/g; s/\ +\"/\"/g; s/\ n't([\ \.\"])/n't\1/g; s/\r+(.)/\1/g;}" + +download_data $DATA/nepali-penn-treebank.$SRC.patch $EN_TAGGED_PATCH_URL +download_data $DATA/nepali-penn-treebank.$NE_TGT.patch $NE_TAGGED_PATCH_URL +download_data original.zip $NE_TAGGED_URL +unzip -o original.zip -d $ROOT + +cat $NE_TAGGED/00.txt $NE_TAGGED/01.txt $NE_TAGGED/02.txt > $NE_TAGGED/nepali-penn-treebank.$SRC +cat $NE_TAGGED/00ne_revised.txt $NE_TAGGED/01ne_revised.txt $NE_TAGGED/02ne_revised.txt > $NE_TAGGED/nepali-penn-treebank.$NE_TGT + +patch $NE_TAGGED/nepali-penn-treebank.$SRC -i $DATA/nepali-penn-treebank.$SRC.patch -o $NE_TAGGED/nepali-penn-treebank-patched.$SRC +patch $NE_TAGGED/nepali-penn-treebank.$NE_TGT -i $DATA/nepali-penn-treebank.$NE_TGT.patch -o $NE_TAGGED/nepali-penn-treebank-patched.$NE_TGT + +if [ ! -e $MOSES ]; then + echo "Cloning moses repository..." + git clone https://github.com/moses-smt/mosesdecoder.git +fi + +cat $NE_TAGGED/nepali-penn-treebank-patched.$SRC | \ + perl -anpe "$EN_PATCH_REGEX" | \ + $MOSES_TOK/tokenizer.perl -l $SRC | \ + $MOSES_TOK/detokenizer.perl -l $SRC > $NE_ROOT/nepali-penn-treebank.$SRC + +cat $NE_TAGGED/nepali-penn-treebank-patched.$NE_TGT | \ + perl -CIO -anpe "$NE_PATCH_REGEX" | \ + $MOSES_TOK/detokenizer.perl -l $SRC > $NE_ROOT/nepali-penn-treebank.$NE_TGT + + +# Download nepali dictionary data +NE_DICT=$NE_ROOT/dictionaries +download_data $NE_DICT "http://www.seas.upenn.edu/~nlp/resources/TACL-data-release/dictionaries.tar.gz" +tar xvzf $NE_DICT +cp dictionaries/dict.ne $NE_ROOT/dictionary.$NE_TGT-$SRC +REMOVE_FILE_PATHS+=( $NE_DICT dictionaries ) + +REMOVE_FILE_PATHS+=( $MOSES $NE_TAGGED original.zip $DATA/nepali-penn-treebank.$SRC.patch $DATA/nepali-penn-treebank.$NE_TGT.patch ) + + +# Remove the temporary files +for ((i=0;i<${#REMOVE_FILE_PATHS[@]};++i)); do + rm -rf ${REMOVE_FILE_PATHS[i]} +done + +# Copy the training data +si=si_LK +ne=ne_NP +en=en_XX +cat $SI_ROOT/GNOMEKDEUbuntu.en-si.si $SI_ROOT/OpenSubtitles2018.en-si.si > $DESTDIR/train.$si-$en.$si +cat $SI_ROOT/GNOMEKDEUbuntu.en-si.en $SI_ROOT/OpenSubtitles2018.en-si.en > $DESTDIR/train.$si-$en.$en + +cat $NE_ROOT/bible_dup.en-ne.ne $NE_ROOT/bible.en-ne.ne $NE_ROOT/globalvoices.2018q4.ne-en.ne $NE_ROOT/GNOMEKDEUbuntu.en-ne.ne $NE_ROOT/nepali-penn-treebank.ne > $DESTDIR/train.$ne-$en.$ne +cat $NE_ROOT/bible_dup.en-ne.en $NE_ROOT/bible.en-ne.en $NE_ROOT/globalvoices.2018q4.ne-en.en $NE_ROOT/GNOMEKDEUbuntu.en-ne.en $NE_ROOT/nepali-penn-treebank.en > $DESTDIR/train.$ne-$en.$en + + +#Download the test sets +wget https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz +tar -xvzf wikipedia_en_ne_si_test_sets.tgz + +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.ne-en.ne $DESTDIR/valid.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.ne-en.en $DESTDIR/valid.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.si-en.si $DESTDIR/valid.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.si-en.en $DESTDIR/valid.$si-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.ne-en.ne $DESTDIR/devtest.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.ne-en.en $DESTDIR/devtest.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.si-en.si $DESTDIR/devtest.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.si-en.en $DESTDIR/devtest.$si-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.test.ne-en.ne $DESTDIR/test.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.test.ne-en.en $DESTDIR/test.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.test.si-en.si $DESTDIR/test.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.test.si-en.en $DESTDIR/test.$si-$en.$en + +rm -rf wikipedia_en_ne_si_test_sets.tgz wikipedia_en_ne_si_test_sets diff --git a/examples/multilingual/data_scripts/download_iitb.sh b/examples/multilingual/data_scripts/download_iitb.sh new file mode 100644 index 0000000000..a884e20839 --- /dev/null +++ b/examples/multilingual/data_scripts/download_iitb.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +IITB=$WORKDIR_ROOT/IITB +mkdir -p $IITB +pushd $IITB + +wget http://www.cfilt.iitb.ac.in/~moses/iitb_en_hi_parallel/iitb_corpus_download/parallel.tgz +tar -xvzf parallel.tgz + +wget http://www.cfilt.iitb.ac.in/~moses/iitb_en_hi_parallel/iitb_corpus_download/dev_test.tgz +tar -xvzf dev_test.tgz + +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + +cp parallel/IITB.en-hi.en $DESTDIR/train.hi_IN-en_XX.en_XX +cp parallel/IITB.en-hi.hi $DESTDIR/train.hi_IN-en_XX.hi_IN + +cp dev_test/dev.en $DESTDIR/valid.hi_IN-en_XX.en_XX +cp dev_test/dev.hi $DESTDIR/valid.hi_IN-en_XX.hi_IN + +cp dev_test/test.en $DESTDIR/test.hi_IN-en_XX.en_XX +cp dev_test/test.hi $DESTDIR/test.hi_IN-en_XX.hi_IN +popd \ No newline at end of file diff --git a/examples/multilingual/data_scripts/download_iwslt_and_extract.sh b/examples/multilingual/data_scripts/download_iwslt_and_extract.sh new file mode 100644 index 0000000000..ca3591b3db --- /dev/null +++ b/examples/multilingual/data_scripts/download_iwslt_and_extract.sh @@ -0,0 +1,225 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +#echo 'Cloning Moses github repository (for tokenization scripts)...' +#git clone https://github.com/moses-smt/mosesdecoder.git + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + + +data_root=${WORKDIR_ROOT}/iwsltv2 +DESTDIR=${WORKDIR_ROOT}/ML50/raw + + +langs="ar_AR it_IT nl_XX ko_KR vi_VN" +echo "data_root: $data_root" + +download_path=${data_root}/downloads +raw=${DESTDIR} +tmp=${data_root}/tmp +orig=${data_root}/orig + +mkdir -p $download_path $orig $raw $tmp +####################### +download_iwslt(){ + iwslt_key=$1 + src=$2 + tgt=$3 + save_prefix=$4 + pushd ${download_path} + if [[ ! -f ${save_prefix}$src-$tgt.tgz ]]; then + wget https://wit3.fbk.eu/archive/${iwslt_key}/texts/$src/$tgt/$src-$tgt.tgz -O ${save_prefix}$src-$tgt.tgz + [ $? -eq 0 ] && return 0 + fi + popd +} + +extract_iwslt(){ + src=$1 + tgt=$2 + prefix=$3 + pushd $orig + tar zxvf ${download_path}/${prefix}$src-${tgt}.tgz + popd +} + +generate_train(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + for ll in $lsrc $ltgt; do + l=${ll:0:2} + f="$orig/*/train.tags.$src-$tgt.$l" + f_raw=$raw/train.$lsrc-$ltgt.$ll + cat $f \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | sed -e 's///g' \ + | sed -e 's/<\/title>//g' \ + | sed -e 's/<description>//g' \ + | sed -e 's/<\/description>//g' \ + | sed 's/^\s*//g' \ + | sed 's/\s*$//g' \ + > $f_raw + [ $? -eq 0 ] && echo "extracted $f to $f_raw" + done + return 0 +} + +convert_valid_test(){ + src=$1 + tgt=$2 + for l in $src $tgt; do + echo "lang: ${l}" + for o in `ls $orig/*/IWSLT*.TED*.$src-$tgt.$l.xml`; do + fname=${o##*/} + f=$tmp/${fname%.*} + echo "$o => $f" + grep '<seg id' $o \ + | sed -e 's/<seg id="[0-9]*">\s*//g' \ + | sed -e 's/\s*<\/seg>\s*//g' \ + | sed -e "s/\’/\'/g" \ + > $f + echo "" + done + done +} + +generate_subset(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + subset=$3 + prefix=$4 + for ll in $lsrc $ltgt; do + l=${ll:0:2} + f=$tmp/$prefix.${src}-${tgt}.$l + if [[ -f $f ]]; then + cp $f $raw/$subset.${lsrc}-$ltgt.${ll} + fi + done +} +################# + +echo "downloading iwslt training and dev data" +# using multilingual for it, nl +download_iwslt "2017-01-trnmted" DeEnItNlRo DeEnItNlRo +download_iwslt "2017-01-trnted" ar en +download_iwslt "2017-01-trnted" en ar +download_iwslt "2017-01-trnted" ko en +download_iwslt "2017-01-trnted" en ko +download_iwslt "2015-01" vi en +download_iwslt "2015-01" en vi + +echo "donwloading iwslt test data" +download_iwslt "2017-01-mted-test" it en "test." +download_iwslt "2017-01-mted-test" en it "test." +download_iwslt "2017-01-mted-test" nl en "test." +download_iwslt "2017-01-mted-test" en nl "test." + +download_iwslt "2017-01-ted-test" ar en "test." +download_iwslt "2017-01-ted-test" en ar "test." +download_iwslt "2017-01-ted-test" ko en "test." +download_iwslt "2017-01-ted-test" en ko "test." +download_iwslt "2015-01-test" vi en "test." +download_iwslt "2015-01-test" en vi "test." + +echo "extract training data tar balls" +extract_iwslt DeEnItNlRo DeEnItNlRo +extract_iwslt ar en +extract_iwslt en ar +extract_iwslt ko en +extract_iwslt en ko +extract_iwslt vi en +extract_iwslt en vi + + +echo "extracting iwslt test data" +for lang in $langs; do + l=${lang:0:2} + extract_iwslt $l en "test." + extract_iwslt en $l "test." +done + +echo "convert dev and test data" +for lang in $langs; do + s_lang=${lang:0:2} + convert_valid_test $s_lang en + convert_valid_test en $s_lang +done + + + +echo "creating training data into $raw" +for lang in $langs; do + generate_train $lang en_XX + generate_train en_XX $lang +done + +echo "creating iwslt dev data into raw" +generate_subset en_XX vi_VN valid "IWSLT15.TED.tst2013" +generate_subset vi_VN en_XX valid "IWSLT15.TED.tst2013" + +generate_subset en_XX ar_AR valid "IWSLT17.TED.tst2016" +generate_subset ar_AR en_XX valid "IWSLT17.TED.tst2016" +generate_subset en_XX ko_KR valid "IWSLT17.TED.tst2016" +generate_subset ko_KR en_XX valid "IWSLT17.TED.tst2016" + + +generate_subset en_XX it_IT valid "IWSLT17.TED.tst2010" +generate_subset it_IT en_XX valid "IWSLT17.TED.tst2010" +generate_subset en_XX nl_XX valid "IWSLT17.TED.tst2010" +generate_subset nl_XX en_XX valid "IWSLT17.TED.tst2010" + +echo "creating iswslt test data into raw" +generate_subset en_XX vi_VN test "IWSLT15.TED.tst2015" +generate_subset vi_VN en_XX test "IWSLT15.TED.tst2015" + +generate_subset en_XX ar_AR test "IWSLT17.TED.tst2017" +generate_subset ar_AR en_XX test "IWSLT17.TED.tst2017" +generate_subset en_XX ko_KR test "IWSLT17.TED.tst2017" +generate_subset ko_KR en_XX test "IWSLT17.TED.tst2017" + +generate_subset en_XX it_IT test "IWSLT17.TED.tst2017.mltlng" +generate_subset it_IT en_XX test "IWSLT17.TED.tst2017.mltlng" +generate_subset en_XX nl_XX test "IWSLT17.TED.tst2017.mltlng" +generate_subset nl_XX en_XX test "IWSLT17.TED.tst2017.mltlng" + +# normalze iwslt directions into x-en +pushd $raw +for lang in $langs; do + for split in test valid; do + x_en_f1=$split.$lang-en_XX.en_XX + x_en_f2=$split.$lang-en_XX.${lang} + + en_x_f1=$split.en_XX-$lang.en_XX + en_x_f2=$split.en_XX-$lang.${lang} + + if [ -f $en_x_f1 ] && [ ! -f $x_en_f1 ]; then + echo "cp $en_x_f1 $x_en_f1" + cp $en_x_f1 $x_en_f1 + fi + if [ -f $x_en_f2 ] && [ ! -f $x_en_f2 ]; then + echo "cp $en_x_f2 $x_en_f2" + cp $en_x_f2 $x_en_f2 + fi + done +done +popd \ No newline at end of file diff --git a/examples/multilingual/data_scripts/download_lotus.sh b/examples/multilingual/data_scripts/download_lotus.sh new file mode 100644 index 0000000000..c08c701314 --- /dev/null +++ b/examples/multilingual/data_scripts/download_lotus.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +SRCDIR=$WORKDIR_ROOT/indic_languages_corpus +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ +mkdir -p $SRCDIR +mkdir -p $DESTDIR + +cd $SRCDIR +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/indic-multilingual/indic_languages_corpus.tar.gz +tar -xvzf indic_languages_corpus.tar.gz + +SRC_EXTRACT_DIR=$SRCDIR/indic_languages_corpus/bilingual + +cp $SRC_EXTRACT_DIR/ml-en/train.ml $DESTDIR/train.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/train.en $DESTDIR/train.ml_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ml-en/dev.ml $DESTDIR/valid.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/dev.en $DESTDIR/valid.ml_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ml-en/test.ml $DESTDIR/test.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/test.en $DESTDIR/test.ml_IN-en_XX.en_XX + +cp $SRC_EXTRACT_DIR/ur-en/train.ur $DESTDIR/train.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/train.en $DESTDIR/train.ur_PK-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ur-en/dev.ur $DESTDIR/valid.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/dev.en $DESTDIR/valid.ur_PK-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ur-en/test.ur $DESTDIR/test.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/test.en $DESTDIR/test.ur_PK-en_XX.en_XX + +cp $SRC_EXTRACT_DIR/te-en/train.te $DESTDIR/train.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/train.en $DESTDIR/train.te_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/te-en/dev.te $DESTDIR/valid.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/dev.en $DESTDIR/valid.te_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/te-en/test.te $DESTDIR/test.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/test.en $DESTDIR/test.te_IN-en_XX.en_XX diff --git a/examples/multilingual/data_scripts/download_ted_and_extract.py b/examples/multilingual/data_scripts/download_ted_and_extract.py new file mode 100644 index 0000000000..eb756680fa --- /dev/null +++ b/examples/multilingual/data_scripts/download_ted_and_extract.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import os +import csv +from collections import defaultdict +from six.moves import zip +import io +import wget +import sys + +from subprocess import check_call, check_output + +# scripts and data locations +CWD = os.getcwd() +UTILS = f"{CWD}/utils" + +MOSES = f"{UTILS}/mosesdecoder" + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +# please donwload mosesdecoder here: +detok_cmd = f'{MOSES}/scripts/tokenizer/detokenizer.perl' + + +def call(cmd): + print(f"Executing: {cmd}") + check_call(cmd, shell=True) + +class MultiLingualAlignedCorpusReader(object): + """A class to read TED talk dataset + """ + + def __init__(self, corpus_path, delimiter='\t', + target_token=True, bilingual=True, corpus_type='file', + lang_dict={'source': ['fr'], 'target': ['en']}, + eval_lang_dict=None, zero_shot=False, + detok=True, + ): + + self.empty_line_flag = 'NULL' + self.corpus_path = corpus_path + self.delimiter = delimiter + self.bilingual = bilingual + self.lang_dict = lang_dict + self.lang_set = set() + self.target_token = target_token + self.zero_shot = zero_shot + self.eval_lang_dict = eval_lang_dict + self.corpus_type = corpus_type + self.detok = detok + + for list_ in self.lang_dict.values(): + for lang in list_: + self.lang_set.add(lang) + + self.data = dict() + self.data['train'] = self.read_aligned_corpus(split_type='train') + self.data['test'] = self.read_aligned_corpus(split_type='test') + self.data['dev'] = self.read_aligned_corpus(split_type='dev') + + def read_data(self, file_loc_): + data_list = list() + with io.open(file_loc_, 'r', encoding='utf8') as fp: + for line in fp: + try: + text = line.strip() + except IndexError: + text = self.empty_line_flag + data_list.append(text) + return data_list + + def filter_text(self, dict_): + if self.target_token: + field_index = 1 + else: + field_index = 0 + data_dict = defaultdict(list) + list1 = dict_['source'] + list2 = dict_['target'] + for sent1, sent2 in zip(list1, list2): + try: + src_sent = ' '.join(sent1.split()[field_index: ]) + except IndexError: + src_sent = 'NULL' + + if src_sent.find(self.empty_line_flag) != -1 or len(src_sent) == 0: + continue + + elif sent2.find(self.empty_line_flag) != -1 or len(sent2) == 0: + continue + + else: + data_dict['source'].append(sent1) + data_dict['target'].append(sent2) + return data_dict + + def read_file(self, split_type, data_type): + return self.data[split_type][data_type] + + def save_file(self, path_, split_type, data_type, lang): + tok_file = tok_file_name(path_, lang) + with io.open(tok_file, 'w', encoding='utf8') as fp: + for line in self.data[split_type][data_type]: + fp.write(line + '\n') + if self.detok: + de_tok(tok_file, lang) + + def add_target_token(self, list_, lang_id): + new_list = list() + token = '__' + lang_id + '__' + for sent in list_: + new_list.append(token + ' ' + sent) + return new_list + + def read_from_single_file(self, path_, s_lang, t_lang): + data_dict = defaultdict(list) + with io.open(path_, 'r', encoding='utf8') as fp: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + for row in reader: + data_dict['source'].append(row[s_lang]) + data_dict['target'].append(row[t_lang]) + + if self.target_token: + text = self.add_target_token(data_dict['source'], t_lang) + data_dict['source'] = text + + return data_dict['source'], data_dict['target'] + + def read_aligned_corpus(self, split_type='train'): + data_dict = defaultdict(list) + iterable = [] + s_list = [] + t_list = [] + + if self.zero_shot: + if split_type == "train": + iterable = zip(self.lang_dict['source'], self.lang_dict['target']) + else: + iterable = zip(self.eval_lang_dict['source'], self.eval_lang_dict['target']) + + elif self.bilingual: + iterable = itertools.product(self.lang_dict['source'], self.lang_dict['target']) + + for s_lang, t_lang in iterable: + if s_lang == t_lang: + continue + if self.corpus_type == 'file': + split_type_file_path = os.path.join(self.corpus_path, + "all_talks_{}.tsv".format(split_type)) + s_list, t_list = self.read_from_single_file(split_type_file_path, + s_lang=s_lang, + t_lang=t_lang) + data_dict['source'] += s_list + data_dict['target'] += t_list + new_data_dict = self.filter_text(data_dict) + return new_data_dict + + +def read_langs(corpus_path): + split_type_file_path = os.path.join(corpus_path, 'extracted', + "all_talks_dev.tsv") + with io.open(split_type_file_path, 'r', encoding='utf8') as fp: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + header = next(reader) + return [k for k in header.keys() if k != 'talk_name'] + +def extra_english(corpus_path, split): + split_type_file_path = os.path.join(corpus_path, + f"all_talks_{split}.tsv") + output_split_type_file_path = os.path.join(corpus_path, + f"all_talks_{split}.en") + with io.open(split_type_file_path, 'r', encoding='utf8') as fp, io.open(output_split_type_file_path, 'w', encoding='utf8') as fw: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + for row in reader: + line = row['en'] + fw.write(line + '\n') + de_tok(output_split_type_file_path, 'en') + + + +def tok_file_name(filename, lang): + seps = filename.split('.') + seps.insert(-1, 'tok') + tok_file = '.'.join(seps) + return tok_file + +def de_tok(tok_file, lang): + # seps = tok_file.split('.') + # seps.insert(-1, 'detok') + # de_tok_file = '.'.join(seps) + de_tok_file = tok_file.replace('.tok.', '.') + cmd = 'perl {detok_cmd} -l {lang} < {tok_file} > {de_tok_file}'.format( + detok_cmd=detok_cmd, tok_file=tok_file, + de_tok_file=de_tok_file, lang=lang[:2]) + call(cmd) + +def extra_bitex( + ted_data_path, + lsrc_lang, + ltrg_lang, + target_token, + output_data_path, +): + def get_ted_lang(lang): + long_langs = ['pt-br', 'zh-cn', 'zh-tw', 'fr-ca'] + if lang[:5] in long_langs: + return lang[:5] + elif lang[:4] =='calv': + return lang[:5] + elif lang in ['pt_BR', 'zh_CN', 'zh_TW', 'fr_CA']: + return lang.lower().replace('_', '-') + return lang[:2] + src_lang = get_ted_lang(lsrc_lang) + trg_lang = get_ted_lang(ltrg_lang) + train_lang_dict={'source': [src_lang], 'target': [trg_lang]} + eval_lang_dict = {'source': [src_lang], 'target': [trg_lang]} + + obj = MultiLingualAlignedCorpusReader(corpus_path=ted_data_path, + lang_dict=train_lang_dict, + target_token=target_token, + corpus_type='file', + eval_lang_dict=eval_lang_dict, + zero_shot=False, + bilingual=True) + + os.makedirs(output_data_path, exist_ok=True) + lsrc_lang = lsrc_lang.replace('-', '_') + ltrg_lang = ltrg_lang.replace('-', '_') + obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='train', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='train', data_type='target', lang=trg_lang) + + obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='test', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='test', data_type='target', lang=trg_lang) + + obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='dev', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='dev', data_type='target', lang=trg_lang) + + +def bar_custom(current, total, width=80): + print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') + + +def download_and_extract(download_to, extract_to): + url = 'http://phontron.com/data/ted_talks.tar.gz' + filename = f"{download_to}/ted_talks.tar.gz" + if os.path.exists(filename): + print(f'{filename} has already been downloaded so skip') + else: + filename = wget.download(url, filename, bar=bar_custom) + if os.path.exists(f'{extract_to}/all_talks_train.tsv'): + print(f'Already extracted so skip') + else: + extract_cmd = f'tar xzfv "{filename}" -C "{extract_to}"' + call(extract_cmd) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--ted_data_path', type=str, default=WORKDIR_ROOT, required=False) + parser.add_argument( + '--direction-list', + type=str, + # default=None, + #for ML50 + default=( + "bn_IN-en_XX,he_IL-en_XX,fa_IR-en_XX,id_ID-en_XX,sv_SE-en_XX,pt_XX-en_XX,ka_GE-en_XX,ka_GE-en_XX,th_TH-en_XX," + "mr_IN-en_XX,hr_HR-en_XX,uk_UA-en_XX,az_AZ-en_XX,mk_MK-en_XX,gl_ES-en_XX,sl_SI-en_XX,mn_MN-en_XX," + #non-english directions + # "fr_XX-de_DE," # replaced with wmt20 + # "ja_XX-ko_KR,es_XX-pt_XX,ru_RU-sv_SE,hi_IN-bn_IN,id_ID-ar_AR,cs_CZ-pl_PL,ar_AR-tr_TR" + ), + required=False) + parser.add_argument('--target-token', action='store_true', default=False) + parser.add_argument('--extract-all-english', action='store_true', default=False) + + args = parser.parse_args() + + import sys + import json + + # TED Talks data directory + ted_data_path = args.ted_data_path + + download_to = f'{ted_data_path}/downloads' + extract_to = f'{ted_data_path}/extracted' + + #DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + output_path = f'{ted_data_path}/ML50/raw' + os.makedirs(download_to, exist_ok=True) + os.makedirs(extract_to, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + download_and_extract(download_to, extract_to) + + + if args.extract_all_english: + for split in ['train', 'dev', 'test']: + extra_english(ted_data_path, split) + exit(0) + if args.direction_list is not None: + directions = args.direction_list.strip().split(',') + directions = [tuple(d.strip().split('-', 1)) for d in directions if d] + else: + langs = read_langs(ted_data_path) + # directions = [ + # '{}.{}'.format(src, tgt) + # for src in langs + # for tgt in langs + # if src < tgt + # ] + directions = [('en', tgt) for tgt in langs if tgt != 'en'] + print(f'num directions={len(directions)}: {directions}') + + for src_lang, trg_lang in directions: + print('--working on {}-{}'.format(src_lang, trg_lang)) + extra_bitex( + extract_to, + src_lang, + trg_lang, + target_token=args.target_token, + output_data_path=output_path + ) diff --git a/examples/multilingual/data_scripts/download_wat19_my.sh b/examples/multilingual/data_scripts/download_wat19_my.sh new file mode 100644 index 0000000000..c1e2d47287 --- /dev/null +++ b/examples/multilingual/data_scripts/download_wat19_my.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +SRCDIR=$WORKDIR_ROOT/indic_languages_corpus +DESTDIR=$WORKDIR_ROOT/ML50/raw +mkdir -p $SRCDIR +mkdir -p $DESTDIR + +WAT_MY_EN=wat2020.my-en.zip +cd $SRCDIR +# please refer to http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/ for latest URL if the following url expired +#- The data used for WAT2020 are identical to those used in WAT2019. +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/$WAT_MY_EN +unzip $WAT_MY_EN + + +SRC_EXTRACT_DIR=$SRCDIR/wat2020.my-en/alt + +cp $SRC_EXTRACT_DIR/train.alt.en $DESTDIR/train.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/train.alt.my $DESTDIR/train.my_MM-en_XX.my_MM +cp $SRC_EXTRACT_DIR/dev.alt.en $DESTDIR/valid.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/dev.alt.my $DESTDIR/valid.my_MM-en_XX.my_MM +cp $SRC_EXTRACT_DIR/test.alt.en $DESTDIR/test.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/test.alt.my $DESTDIR/test.my_MM-en_XX.my_MM diff --git a/examples/multilingual/data_scripts/download_wmt19_and_before.py b/examples/multilingual/data_scripts/download_wmt19_and_before.py new file mode 100644 index 0000000000..3465731eb3 --- /dev/null +++ b/examples/multilingual/data_scripts/download_wmt19_and_before.py @@ -0,0 +1,899 @@ +from typing import NamedTuple, List +from urllib.parse import urlparse +import os, sys +import subprocess +from subprocess import check_call, check_output +import glob +import wget +import re +import multiprocessing as mp +from functools import partial +import pathlib +from collections import OrderedDict + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +# scripts and data locations +CWD = os.getcwd() +UTILS = f"{CWD}/utils" + +MOSES = f"{UTILS}/mosesdecoder" +SGM_TOOL = f'{MOSES}/scripts/ems/support/input-from-sgm.perl' + +TMX2CORPUS = f"{UTILS}/tmx2corpus" +TMX_TOOL = f'python {TMX2CORPUS}/tmx2corpus.py' + +to_data_path = f'{WORKDIR_ROOT}/wmt' +download_to = f'{to_data_path}/downloads' +manually_downloads = f'{to_data_path}/downloads' +extract_to = f'{to_data_path}/extracted' +#DESTDIR=${WORKDIR_ROOT}/ML50/raw/ +raw_data = f'{WORKDIR_ROOT}/ML50/raw' +#### + +class DLDataset(NamedTuple): + name: str + train_urls: List[str] + valid_urls: List[str] + test_urls: List[str] + train_files_patterns: List[str] = [] + valid_files_patterns: List[str] = [] + test_files_patterns: List[str] = [] + + + +def bar_custom(current, total, width=80): + print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') + +def get_downloaded_file(dl_folder, url): + if isinstance(url, tuple): + url, f = url + else: + url_f = urlparse(url) + # f = os.path.split(url_f.path)[-1] + f = '_'.join(url_f.path.split('/')[1:]) + return url, f"{dl_folder}/{f}" + +def download_parts_and_combine(dl_folder, urls, filename): + parts = [] + for url_record in urls: + url, part_file = get_downloaded_file(dl_folder, url_record) + if os.path.exists(part_file): + print(f'{part_file} has already been downloaded so skip') + else: + part_file = wget.download(url, part_file, bar=bar_custom) + parts.append(part_file) + + def get_combine_cmd(parts): + #default as tar.gz.?? + return f'cat {" ".join(parts)} > {filename}' + + combine_cmd = get_combine_cmd(parts) + call(combine_cmd, debug=True) + return filename + +def download_a_url(dl_folder, url): + url, filename = get_downloaded_file(dl_folder, url) + if os.path.exists(filename): + print(f'{filename} has already been downloaded so skip') + return filename + + print(f'downloading {url} to {filename}') + if isinstance(url, list) or isinstance(url, tuple): + download_parts_and_combine(dl_folder, url, filename) + else: + wget.download(url, filename, bar=bar_custom) + print(f'dowloaded: {filename}') + return filename + +def download_files(dl_folder, urls, completed_urls={}): + for url_record in urls: + url, _ = get_downloaded_file(dl_folder, url_record) + filename = download_a_url(dl_folder, url_record) + completed_urls[str(url)] = filename + return completed_urls + +def check_need_manual_downalod(dl_folder, to_manually_download_urls): + to_be_manually_dowloaded = [] + manually_completed_urls = {} + for url_record, instruction in to_manually_download_urls: + url, filename = get_downloaded_file(dl_folder, url_record) + if not os.path.exists(filename): + print(f'{url} need to be download manually, please download it manually following {instruction}; and copy it to {filename}') + to_be_manually_dowloaded.append((url, filename)) + else: + manually_completed_urls[url] = filename + # if len(to_be_manually_dowloaded) > 0: + # raise ValueError('Missing files that need to be downloaded manually; stop the process now.') + return to_be_manually_dowloaded + +def download_dataset(to_folder, dl_dataset, completed_urls={}): + download_files(to_folder, dl_dataset.train_urls, completed_urls) + download_files(to_folder, dl_dataset.valid_urls, completed_urls) + download_files(to_folder, dl_dataset.test_urls, completed_urls) + print('completed downloading') + return completed_urls + +def call(cmd, debug=False): + if debug: + print(cmd) + check_call(cmd, shell=True) + + +def get_extract_name(file_path): + path = os.path.split(file_path) + return path[-1] + '_extract' #.split('.')[0] + +def extract_file(downloaded_file, extract_folder, get_extract_name=get_extract_name, debug=False): + extract_name = get_extract_name(downloaded_file) + extract_to = f'{extract_folder}/{extract_name}' + os.makedirs(extract_to, exist_ok=True) + if os.path.exists(f'{extract_to}/DONE'): + print(f'{downloaded_file} has already been extracted to {extract_to} so skip') + return extract_to + def get_extract_cmd(filename): + if filename.endswith('.tgz') or filename.endswith('tar.gz'): + return f'tar xzfv {filename} -C {extract_to}' + elif filename.endswith('.gz.tar'): + return f'tar xfv {filename} -C {extract_to}; (cd {extract_to}; gzip -d *.gz; [ $? -eq 0 ] || gzip -d */*.gz)' + elif filename.endswith('.tar'): + return f'tar xfv {filename} -C {extract_to}' + elif filename.endswith('.gz'): + return f'cp {filename} {extract_to}; (cd {extract_to}; gzip -d *.gz)' + elif filename.endswith('.zip'): + return f'unzip {filename} -d {extract_to}' + extract_cmd = get_extract_cmd(downloaded_file) + print(f'extracting {downloaded_file}') + if isinstance(extract_cmd, list): + for c in extract_cmd: + call(c, debug=debug) + else: + call(extract_cmd, debug=debug) + call(f'echo DONE > {extract_to}/DONE') + return extract_to + + +def extract_all_files( + completed_urls, extract_folder, + get_extract_name=get_extract_name, + completed_extraction={}, + debug=False): + extracted_folders = OrderedDict() + for url, downloaded_file in set(completed_urls.items()): + if downloaded_file in completed_extraction: + print(f'{downloaded_file} is already extracted; so skip') + continue + folder = extract_file(downloaded_file, extract_folder, get_extract_name, debug) + extracted_folders[url] = folder + return extracted_folders + + +def my_glob(folder): + for p in [f'{folder}/*', f'{folder}/*/*', f'{folder}/*/*/*']: + for f in glob.glob(p): + yield f + + +def sgm2raw(sgm, debug): + to_file = sgm[0:len(sgm) - len('.sgm')] + if os.path.exists(to_file): + debug and print(f'{sgm} already converted to {to_file}; so skip') + return to_file + cmd = f'{SGM_TOOL} < {sgm} > {to_file}' + call(cmd, debug) + return to_file + +def tmx2raw(tmx, debug): + to_file = tmx[0:len(tmx) - len('.tmx')] + to_folder = os.path.join(*os.path.split(tmx)[:-1]) + if os.path.exists(f'{to_folder}/bitext.en'): + debug and print(f'{tmx} already extracted to {to_file}; so skip') + return to_file + cmd = f'(cd {to_folder}; {TMX_TOOL} {tmx})' + call(cmd, debug) + return to_file + +CZENG16_REGEX = re.compile(r'.*?data.plaintext-format/0[0-9]train$') +WMT19_WIKITITLES_REGEX = re.compile(r'.*?wikititles-v1.(\w\w)-en.tsv.gz') +TSV_REGEX = re.compile(r'.*?(\w\w)-(\w\w).tsv$') + + + +def cut_wikitles(wiki_file, debug): + # different languages have different file names: + if wiki_file.endswith('wiki/fi-en/titles.fi-en'): + to_file1 = f'{wiki_file}.fi' + to_file2 = f'{wiki_file}.en' + BACKSLASH = '\\' + cmd1 = f"cat {wiki_file} | sed 's/|||/{BACKSLASH}t/g' |cut -f1 |awk '{{$1=$1}};1' > {to_file1}" + cmd2 = f"cat {wiki_file} | sed 's/|||/{BACKSLASH}t/g' |cut -f2 |awk '{{$1=$1}};1' > {to_file2}" +# elif WMT19_WIKITITLES_REGEX.match(wiki_file): +# src = WMT19_WIKITITLES_REGEX.match(wiki_file).groups()[0] +# to_file1 = f'{wiki_file}.{src}' +# to_file2 = f'{wiki_file}.en' +# cmd1 = f"cat {wiki_file} | cut -f1 |awk '{{$1=$1}};1' > {to_file1}" +# cmd2 = f"cat {wiki_file} | cut -f2 |awk '{{$1=$1}};1' > {to_file2}" + else: + return None + if os.path.exists(to_file1) and os.path.exists(to_file2): + debug and print(f'{wiki_file} already processed to {to_file1} and {to_file2}; so skip') + return wiki_file + + call(cmd1, debug=debug) + call(cmd2, debug=debug) + return wiki_file + +def cut_tsv(file, debug): + m = TSV_REGEX.match(file) + if m is None: + raise ValueError(f'{file} is not matching tsv pattern') + src = m.groups()[0] + tgt = m.groups()[1] + + to_file1 = f'{file}.{src}' + to_file2 = f'{file}.{tgt}' + cmd1 = f"cat {file} | cut -f1 |awk '{{$1=$1}};1' > {to_file1}" + cmd2 = f"cat {file} | cut -f2 |awk '{{$1=$1}};1' > {to_file2}" + if os.path.exists(to_file1) and os.path.exists(to_file2): + debug and print(f'{file} already processed to {to_file1} and {to_file2}; so skip') + return file + + call(cmd1, debug=debug) + call(cmd2, debug=debug) + return file + + +def convert_file_if_needed(file, debug): + if file.endswith('.sgm'): + return sgm2raw(file, debug) + elif file.endswith('.tmx'): + return tmx2raw(file, debug) + elif file.endswith('wiki/fi-en/titles.fi-en'): + return cut_wikitles(file, debug) +# elif WMT19_WIKITITLES_REGEX.match(file): +# return cut_wikitles(file, debug) + elif file.endswith('.tsv'): + return cut_tsv(file, debug) + elif CZENG16_REGEX.match(file): + return convert2czeng17(file, debug) + else: + return file + + +def convert_files_if_needed(extracted_foldrs, my_glob=my_glob, debug=False): + return { + url: list(sorted(set(convert_file_if_needed(f, debug)) for f in sorted(set(my_glob(folder))))) + for url, folder in extracted_foldrs.items() + } + +def match_patt(file_path, file_pattern, src, tgt, lang): + return file_pattern.format(src=src, tgt=tgt, lang=lang) in file_path + +def match_patts(file_path, file_patterns, src, tgt, lang): + for file_pattern in file_patterns: + params = { k: v for k, v in [('src', src), ('tgt', tgt), ('lang', lang)] if k in file_pattern} + matching = file_pattern.format(**params) + + if isinstance(file_pattern, tuple): + pattern, directions = file_pattern + if f'{src}-{tgt}' in directions and matching in file_path: + return True + else: + if matching in file_path: + return True + return False + +def extracted_glob(extracted_folder, file_patterns, src, tgt, lang): + def get_matching_pattern(file_pattern): + params = { + k: v + for k, v in [('src', src), ('tgt', tgt), ('lang', lang)] + if '{' + k + '}' in file_pattern + } + file_pattern = re.sub(r'{src:(.*?)}', r'\1' if lang == src else '', file_pattern) + file_pattern = re.sub(r'{tgt:(.*?)}', r'\1' if lang == tgt else '', file_pattern) + file_pattern = file_pattern.format(**params) + return file_pattern + for file_pattern in file_patterns: + if isinstance(file_pattern, tuple): + file_pattern, lang_pairs = file_pattern + if f'{src}-{tgt}' not in lang_pairs: + continue +# print('working on pattern: ', file_pattern, lang_pairs ) + matching_pattern = get_matching_pattern(file_pattern) + if matching_pattern is None: + continue + glob_patterns = f'{extracted_folder}/{matching_pattern}' +# print('glob_patterns: ', glob_patterns) + for f in glob.glob(glob_patterns): + yield f + +# for debug usage +def all_extracted_files(split, src, tgt, extracted_folders, split_urls): + def get_url(url): + if isinstance(url, tuple): + url, downloaded_file = url + return url + return [ + f + for url in split_urls + for f in my_glob(extracted_folders[str(get_url(url))]) + ] + +def concat_files(split, src, tgt, extracted_folders, split_urls, path_patterns, to_folder, debug=False): +# if debug: +# print('extracted files to be filtered by patterns: ', +# '\n\t'.join(sorted(all_extracted_files(split, src, tgt, extracted_folders, split_urls)))) + for lang in [src, tgt]: + to_file = f'{to_folder}/{split}.{src}-{tgt}.{lang}' + s_src, s_tgt, s_lang = src.split('_')[0], tgt.split('_')[0], lang.split('_')[0] + files = [] + for url in split_urls: + if isinstance(url, tuple): + url, downloaded_file = url + if str(url) not in extracted_folders: + print(f'warning: {url} not in extracted files') + for extracted_file in set( + extracted_glob( + extracted_folders[str(url)], path_patterns, + s_src, s_tgt, s_lang)): + files.append(extracted_file) + if len(files) == 0: + print('warning: ', f'No files found for split {to_file}') + continue + files = sorted(set(files)) + print(f'concating {len(files)} files into {to_file}') + cmd = ['cat'] + [f'"{f}"' for f in files] + [f'>{to_file}'] + cmd = " ".join(cmd) + call(cmd, debug=debug) + +UTILS = os.path.join(pathlib.Path(__file__).parent, 'utils') +LID_MODEL = f'{download_to}/lid.176.bin' +LID_MULTI = f'{UTILS}/fasttext_multi_filter.py' + +def lid_filter(split, src, tgt, from_folder, to_folder, debug=False): + if not os.path.exists(LID_MODEL): + call(f'wget -nc https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -O {LID_MODEL}') + from_prefix = f'{from_folder}/{split}.{src}-{tgt}' + to_prefix = f'{to_folder}/{split}.{src}-{tgt}' + if os.path.exists(f'{from_prefix}.{src}') and os.path.exists(f'{from_prefix}.{tgt}'): + s_src, s_tgt = src.split('_')[0], tgt.split('_')[0] + cmd = ( + f'python {LID_MULTI} --model {LID_MODEL} --inputs {from_prefix}.{src} {from_prefix}.{tgt} ' + f'--langs {s_src} {s_tgt} --outputs {to_prefix}.{src} {to_prefix}.{tgt}' + ) + print(f'filtering {from_prefix}') + call(cmd, debug=debug) + +def concat_into_splits(dl_dataset, src, tgt, extracted_folders, to_folder, debug): + to_folder_tmp = f"{to_folder}_tmp" + os.makedirs(to_folder_tmp, exist_ok=True) + concat_files('train', src, tgt, + extracted_folders, + split_urls=dl_dataset.train_urls, + path_patterns=dl_dataset.train_files_patterns, + to_folder=to_folder_tmp, debug=debug) + lid_filter('train', src, tgt, to_folder_tmp, to_folder, debug) + + concat_files('valid', src, tgt, + extracted_folders, + split_urls=dl_dataset.valid_urls, + path_patterns=dl_dataset.valid_files_patterns, + to_folder=to_folder, debug=debug) + concat_files('test', src, tgt, + extracted_folders, + split_urls=dl_dataset.test_urls, + path_patterns=dl_dataset.test_files_patterns, + to_folder=to_folder, debug=debug) + + +def download_multi(dl_folder, extract_folder, urls, num_processes=8, debug=False): + pool = mp.Pool(processes=num_processes) + download_f = partial(download_a_url, dl_folder) + downloaded_files = pool.imap_unordered(download_f, urls) + pool.close() + pool.join() + +BLEU_REGEX = re.compile("^BLEU\\S* = (\\S+) ") +def run_eval_bleu(cmd): + output = check_output(cmd, shell=True, stderr=subprocess.STDOUT).decode("utf-8").strip() + print(output) + bleu = -1.0 + for line in output.strip().split('\n'): + m = BLEU_REGEX.search(line) + if m is not None: + bleu = m.groups()[0] + bleu = float(bleu) + break + return bleu + +def check_wmt_test_bleu(raw_folder, wmt_lang_pairs): + not_matchings = [] + for wmt, src_tgts in wmt_lang_pairs: + for src_tgt in src_tgts: + print(f'checking test bleus for: {src_tgt} at {wmt}') + src, tgt = src_tgt.split('-') + ssrc, stgt = src[:2], tgt[:2] + if os.path.exists(f'{raw_folder}/test.{tgt}-{src}.{src}'): + # reversed direction may have different test set + test_src = f'{raw_folder}/test.{tgt}-{src}.{src}' + else: + test_src = f'{raw_folder}/test.{src}-{tgt}.{src}' + cmd1 = f'cat {test_src} | sacrebleu -t "{wmt}" -l {stgt}-{ssrc}; [ $? -eq 0 ] || echo ""' + test_tgt = f'{raw_folder}/test.{src}-{tgt}.{tgt}' + cmd2 = f'cat {test_tgt} | sacrebleu -t "{wmt}" -l {ssrc}-{stgt}; [ $? -eq 0 ] || echo ""' + bleu1 = run_eval_bleu(cmd1) + if bleu1 != 100.0: + not_matchings.append(f'{wmt}:{src_tgt} source side not matching: {test_src}') + bleu2 = run_eval_bleu(cmd2) + if bleu2 != 100.0: + not_matchings.append(f'{wmt}:{src_tgt} target side not matching: {test_tgt}') + return not_matchings + +def download_and_extract( + to_folder, lang_pairs, dl_dataset, + to_manually_download_urls, + completed_urls={}, completed_extraction={}, + debug=False): + + dl_folder = f'{to_folder}/downloads' + extract_folder = f'{to_folder}/extracted' + raw_folder = f'{to_folder}/raw' + lid_filtered = f'{to_folder}/lid_filtered' + + os.makedirs(extract_folder, exist_ok=True) + os.makedirs(raw_folder, exist_ok=True) + os.makedirs(lid_filtered, exist_ok=True) + + + to_be_manually_dowloaded = check_need_manual_downalod(dl_folder, to_manually_download_urls) + + completed_urls = download_dataset( + dl_folder, dl_dataset, completed_urls) + if debug: + print('completed urls: ', completed_urls) + + + extracted_folders = extract_all_files( + completed_urls, + extract_folder=extract_folder, + completed_extraction=completed_extraction, + debug=debug) + if debug: + print('download files have been extracted to folders: ', extracted_folders) + + converted_files = convert_files_if_needed(extracted_folders, debug=False) + for src_tgt in lang_pairs: + print(f'working on {dl_dataset.name}: {src_tgt}') + src, tgt = src_tgt.split('-') + concat_into_splits(dl_dataset, + src=src, tgt=tgt, + extracted_folders=extracted_folders, + to_folder=raw_folder, debug=debug) + print('completed data into: ', raw_folder) + +def download_czang16(download_to, username=None): + wgets = [ + f'wget --user={username} --password=czeng -P {download_to} http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar' + for i in range(10)] + cmds = [] + for i, cmd in enumerate(wgets): + filename = f'{download_to}/data-plaintext-format.{i}.tar' + if os.path.exists(filename): + print(f'{filename} has already been downloaded; so skip') + continue + cmds.append(cmd) + if cmds and username is None: + raise ValueError('No czeng username is given; please register at http://ufal.mff.cuni.cz/czeng/czeng16 to obtain username to download') + for cmd in cmds: + call(cmd) + print('done with downloading czeng1.6') + +def download_czeng17_script(download_to, extract_folder, debug=False): + url = 'http://ufal.mff.cuni.cz/czeng/download.php?f=convert_czeng16_to_17.pl.zip' + filename = f'{download_to}/convert_czeng16_to_17.pl.zip' + extract_to = f'{extract_folder}/{get_extract_name(filename)}' + script_path = f'{extract_to}/convert_czeng16_to_17.pl' + + if not os.path.exists(script_path): + wget.download(url, filename, bar=bar_custom) + extract_to = extract_file(f'{download_to}/convert_czeng16_to_17.pl.zip', extract_folder, get_extract_name=get_extract_name, debug=debug) + return script_path + +czeng17_script_path = "" +def convert2czeng17(file, debug): + en_file = f'{file}.en' + cs_file = f'{file}.cs' + + if not os.path.exists(en_file) or not os.path.exists(cs_file): + cs_cmd = f'cat {file} | perl {czeng17_script_path} | cut -f3 > {cs_file}' + en_cmd = f'cat {file} | perl {czeng17_script_path} | cut -f4 > {en_file}' + call(cs_cmd, debug) + call(en_cmd, debug) + else: + print(f'already extracted: {en_file} and {cs_file}') + return file + +def extract_czeng17(extract_folder, debug=False): + url = 'http://ufal.mff.cuni.cz/czeng/download.php?f=convert_czeng16_to_17.pl.zip' + filename = f'{download_to}/convert_czeng16_to_17.pl.zip' + extract_to = f'{extract_folder}/{get_extract_name(filename)}' + script_path = f'{extract_to}/convert_czeng16_to_17.pl' + + if not os.path.exists(script_path): + wget.download(url, filename, bar=bar_custom) + extract_to = extract_file(f'{download_to}/convert_czeng16_to_17.pl.zip', extract_folder, get_extract_name=get_extract_name, debug=debug) + return script_path + +######### +# definitions of wmt data sources +# for es-en +# Punctuation in the official test sets will be encoded with ASCII characters (not complex Unicode characters) as much as possible. You may want to normalize your system's output before submission. You are able able to use a rawer version of the test sets that does not have this normalization. +# script to normalize punctuation: http://www.statmt.org/wmt11/normalize-punctuation.perl +wmt13_es_en = DLDataset( + name='wmt13_es-en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://www.statmt.org/wmt13/training-parallel-un.tgz', + 'http://www.statmt.org/wmt13/training-parallel-nc-v8.tgz', + ], + valid_urls=[ + ('http://www.statmt.org/wmt13/dev.tgz', 'wmt13_dev.tgz') + ], + test_urls=[ + ('http://www.statmt.org/wmt13/test.tgz', 'wmt13_test.tgz') + ], + train_files_patterns=[ + ('*/europarl-v7.{src}-{tgt}.{lang}', ['es-en']), + ('*commoncrawl.{src}-{tgt}.{lang}', ['es-en']), + ('*/news-commentary-v8.{src}-{tgt}.{lang}', ['es-en']), + ('un/*undoc.2000.{src}-{tgt}.{lang}', ['es-en']), + ] , + valid_files_patterns=[ + ('dev/newstest2012.{lang}', ['es-en']) + ], + test_files_patterns=[ + ('test/newstest*.{lang}', ['es-en']) + ], +) + +wmt14_de_fr_en = DLDataset( + name='wmt14_de_fr_en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://www.statmt.org/wmt13/training-parallel-un.tgz', + 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz', + ('http://www.statmt.org/wmt10/training-giga-fren.tar', 'training-giga-fren.gz.tar'), #it is actuall a gz.tar + ], + valid_urls=[ + ('http://www.statmt.org/wmt14/dev.tgz', 'wmt14_dev.tgz'), + ], + test_urls=[ + ('http://www.statmt.org/wmt14/test-full.tgz', 'wmt14_test_full.tgz'), # cleaned test sets + ], + train_files_patterns=[ + ('*/europarl-v7.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('*commoncrawl.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('*/*news-commentary-v9.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('un/undoc.2000.{src}-{tgt}.{lang}', ['fr-en']), + ('*giga-{src}{tgt}*{lang}', ['fr-en']) + ], + valid_files_patterns=[ + ('dev/newstest2013.{lang}', ['fr-en', 'de-en']) + ], + test_files_patterns=[ + ('test-full/newstest*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['en-de', 'de-en', 'fr-en', 'en-fr']), + ], +) + +# pip install git+https://github.com/amake/tmx2corpus.git +wmt16_ro_en = DLDataset( + name='wmt16_ro-en', + train_urls=[ + ('http://data.statmt.org/wmt16/translation-task/training-parallel-ep-v8.tgz', 'wmt16_training-parallel-ep-v8.tgz'), + ('http://opus.nlpl.eu/download.php?f=SETIMES/v2/tmx/en-ro.tmx.gz', 'en-ro.tmx.gz'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt16/translation-task/dev-romanian-updated.tgz', 'wmt16_dev.tgz') + ], + test_urls=[ + ('http://data.statmt.org/wmt16/translation-task/test.tgz', 'wmt16_test.tgz') + ], + train_files_patterns=[ + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['ro-en']), + ('bitext.{lang}', ['ro-en']) #setimes from tmux + ] , + valid_files_patterns=[ + ('dev/newsdev2016*{src}{tgt}*.{lang}', ['ro-en', 'ro-en']) + ], + test_files_patterns=[ + ('test/newstest*{src}{tgt}*.{lang}', ['ro-en', 'en-ro']) + ], +) + +cwmt_wmt_instruction = 'cwmt download instruction at: http://nlp.nju.edu.cn/cwmt-wmt' +wmt17_fi_lv_tr_zh_en_manual_downloads = [ + # fake urls to have unique keys for the data + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASIA2015.zip', 'CASIA2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2011.zip', 'CASICT2011.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2015.zip', 'CASICT2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2015.zip', 'Datum2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2017.zip', 'Datum2017.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/NEU2017.zip', 'NEU2017.zip'), cwmt_wmt_instruction), +] +wmt17_fi_lv_tr_zh_en = DLDataset( + name='wmt17_fi_lv_tr_zh_en', + train_urls=[ + ('http://data.statmt.org/wmt17/translation-task/training-parallel-ep-v8.tgz', 'wmt17_training-parallel-ep-v8.tgz'), + 'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz', + 'http://www.statmt.org/wmt15/wiki-titles.tgz', + ('http://opus.nlpl.eu/download.php?f=SETIMES/v2/tmx/en-tr.tmx.gz', 'en-tr.tmx.gz'), + ('http://data.statmt.org/wmt17/translation-task/rapid2016.tgz', 'wmt17_rapid2016.tgz'), + 'http://data.statmt.org/wmt17/translation-task/leta.v1.tgz', + 'http://data.statmt.org/wmt17/translation-task/dcep.lv-en.v1.tgz', + 'http://data.statmt.org/wmt17/translation-task/books.lv-en.v1.tgz', + (('https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01',), 'UNv1.0.en-zh.tar.gz'), + #manually download files: + ('http://nlp.nju.edu.cn/cwmt-wmt/CASIA2015.zip', 'CASIA2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2011.zip', 'CASICT2011.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2015.zip', 'CASICT2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2015.zip', 'Datum2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2017.zip', 'Datum2017.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/NEU2017.zip', 'NEU2017.zip'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt17/translation-task/dev.tgz', 'wmt17_dev.tgz'), + ], + test_urls=[ + #NEW: Improved translations for zh test sets + ('http://data.statmt.org/wmt17/translation-task/test-update-1.tgz', 'wmt17_test_zh_en.tgz'), + ('http://data.statmt.org/wmt17/translation-task/test.tgz', 'wmt17_test_others.tgz') + ], + train_files_patterns=[ + ('casict*/cas*{src:ch}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('casia*/cas*{src:ch}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('dataum*/Book*{src:cn}{tgt:en}.txt', ['zh-en', 'zh-en']), + ('neu*/NEU*{src:cn}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('*/*UNv1.0.en-zh.{src:zh}{tgt:en}', ['zh-en']), + ('training/*news-commentary-v12.{src}-{tgt}.{lang}', ['zh-en', ]), + + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['fi-en', 'lv-en']), + ('wiki/fi-en/titles.{src}-{tgt}.{lang}', ['fi-en', ]), + ('rapid2016.{tgt}-{src}.{lang}', ['fi-en', 'lv-en']), + ('*/leta.{lang}', ['lv-en']), + ('*/dcep.{lang}', ['lv-en']), + ('*/farewell.{lang}', ['lv-en']), + ('bitext.{lang}', ['tr-en']), + ] , + valid_files_patterns=[ + ('dev/newsdev2017*{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'lv-en', 'tr-en', 'zh-en', + 'en-fi', 'en-lv', 'en-tr', 'en-zh' + ]), + ('dev/newstest2016*{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'tr-en', + 'en-fi', 'en-tr', + ]), + ], + test_files_patterns=[ + ('test/newstest2017-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'lv-en', 'tr-en', + 'en-fi', 'en-lv', 'en-tr', + ]), + ('newstest2017-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'zh-en', + 'en-zh' + ]), + ], +) + +czeng_instruction = 'download instruction at: http://ufal.mff.cuni.cz/czeng/czeng16' +#alternative: use the prepared data but detokenize it? +wmt18_cs_et_en_manual_downloads = [ +#for cs, need to register and download; Register and download CzEng 1.6. +#Better results can be obtained by using a subset of sentences, released under a new version name CzEng 1.7. + # ((f'http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar', + # f'data-plaintext-format.{i}.tar'), czeng_instruction) + # for i in range(10) +] + +wmt18_cs_et_en = DLDataset( + name='wmt18_cs_et_en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-cs.zipporah0-dedup-clean.tgz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-et.zipporah0-dedup-clean.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz', + ('http://data.statmt.org/wmt18/translation-task/rapid2016.tgz', 'wmt18_rapid2016.tgz'), + # (tuple( + # (f'http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar', + # f'data-plaintext-format.{i}.tar') + # for i in range(10) + # ), + # 'czeng16_data_plaintext.gz.tar'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt18/translation-task/dev.tgz', 'wmt18_dev.tgz'), + ], + test_urls=[ + ('http://data.statmt.org/wmt18/translation-task/test.tgz', 'wmt18_test.tgz'), + ], + train_files_patterns=[ + # ('*/*europarl-v7.{src}-{tgt}.{lang}', ['cs-en']), + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['et-en']), + # ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['cs-en', 'et-en']), + ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['et-en']), + # ('*commoncrawl.{src}-{tgt}.{lang}', ['cs-en']), + # ('*/news-commentary-v13.{src}-{tgt}.{lang}', ['cs-en']), + # ('data.plaintext-format/*train.{lang}', ['cs-en']), + ('rapid2016.{tgt}-{src}.{lang}', ['et-en']), + ] , + valid_files_patterns=[ + ('dev/newsdev2018*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['et-en']), + # ('dev/newstest2017*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['cs-en']) + ], + test_files_patterns=[ + ('test/newstest2018-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + # ['cs-en', 'et-en']), + ['et-en']), + ] +) + +ru_en_yandex_instruction = 'Yandex Corpus download instruction at: https://translate.yandex.ru/corpus?lang=en' +wmt19_ru_gu_kk_lt_manual_downloads = [ + (('https://translate.yandex.ru/corpus?lang=en', 'wmt19_1mcorpus.zip'), ru_en_yandex_instruction) +] +wmt19_ru_gu_kk_lt = DLDataset( + name='wmt19_ru_gu_kk_lt', + train_urls=[ + 'http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-lt.bicleaner07.tmx.gz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://data.statmt.org/news-commentary/v14/training/news-commentary-v14-wmt19.en-kk.tsv.gz', + 'http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-ru.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz', + (('https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02',), + 'wmt19_UNv1.0.en-ru.tar.gz'), + 'https://tilde-model.s3-eu-west-1.amazonaws.com/rapid2016.en-lt.tmx.zip', + ('https://translate.yandex.ru/corpus?lang=en', 'wmt19_1mcorpus.zip'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt19/translation-task/dev.tgz', 'wmt19_dev.tgz'), + ], + test_urls=[ + ('http://data.statmt.org/wmt19/translation-task/test.tgz', 'wmt19_test.tgz'), + ], + train_files_patterns=[ + ('*europarl-v9.{src}-{tgt}.tsv.{lang}', ['lt-en']), + #paracrawl + ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['ru-en']), + ('bitext.{lang}', ['lt-en',]), + ('*commoncrawl.{src}-{tgt}.{lang}', ['ru-en',]), + ('*news-commentary-v14-wmt19.{tgt}-{src}.tsv.{lang}', ['kk-en', ]), + ('*news-commentary-v14.{tgt}-{src}.tsv.{lang}', ['ru-en']), + #yandex + ('corpus.{tgt}_{src}.1m.{lang}', ['ru-en']), + ('wikititles_v1_wikititles-v1.{src}-{tgt}.tsv.{lang}', ['ru-en', 'kk-en', 'lt-en', 'gu-en']), + ('*/UNv1.0.{tgt}-{src}.{lang}', ['ru-en']), + #rapid + ('bitext.{lang}', ['lt-en']) + ], + valid_files_patterns=[ + ('dev/newsdev2019*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['gu-en', 'kk-en', 'lt-en']), + ('dev/newstest2018*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['ru-en']), + ], + test_files_patterns=[ + ('sgm/newstest2019-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + ['ru-en', 'gu-en', 'kk-en', 'lt-en', 'en-ru', 'en-gu', 'en-kk', 'en-lt']), + ] +) + + +######### + +if __name__ == "__main__": + # speed up the downloads with multiple processing + dl_folder = f'{to_data_path}/downloads' + extract_folder = f'{to_data_path}/extracted' + + urls = [ + url + for dataset in [wmt13_es_en, wmt14_de_fr_en, wmt16_ro_en, wmt18_cs_et_en, wmt19_ru_gu_kk_lt] + for urls in [dataset.train_urls, dataset.valid_urls, dataset.test_urls] + for url in urls + ] + urls = set(urls) + download_multi(dl_folder, extract_folder, urls, num_processes=8, debug=True) + + # check manually downlaods + to_manually_download_urls = ( + wmt17_fi_lv_tr_zh_en_manual_downloads + wmt18_cs_et_en_manual_downloads + wmt19_ru_gu_kk_lt_manual_downloads + ) + to_be_manually_dowloaded = check_need_manual_downalod(dl_folder, to_manually_download_urls) + if len(to_be_manually_dowloaded) > 0: + print('Missing files that need to be downloaded manually; stop the process now.') + exit(-1) + + completed_urls = {} + completed_extraction = {} + def work_on_wmt(directions, wmt_data): + download_and_extract( + to_data_path, + directions, + wmt_data, + to_manually_download_urls=to_manually_download_urls, + completed_urls=completed_urls, completed_extraction=completed_extraction, debug=True) + + work_on_wmt( + ['es_XX-en_XX'], + wmt13_es_en,) + work_on_wmt( + [ + 'fr_XX-en_XX', 'en_XX-fr_XX', + # 'en_XX-de_DE', 'de_DE-en_XX', + ], + wmt14_de_fr_en,) + work_on_wmt( + ['ro_RO-en_XX', 'en_XX-ro_XX'], + wmt16_ro_en,) + work_on_wmt( + [ + # 'zh_CN-en_XX', + 'lv_LV-en_XX', 'fi_FI-en_XX', 'tr_TR-en_XX', + #in case the reversed directions have different train/valid/test data + # 'en_XX-zh_CN', + 'en_XX-lv_LV', 'en_XX-fi_FI', 'en_XX-tr_TR', + ], + wmt17_fi_lv_tr_zh_en, ) + # czeng17_script_path = download_czeng17_script(download_to, extract_to, debug=False) + # cz_username = None + work_on_wmt( + [ + # 'cs_CZ-en_XX', + 'et_EE-en_XX'], + wmt18_cs_et_en,) + work_on_wmt( + [ + # 'ru_RU-en_XX', 'en_XX-ru_RU', + 'gu_IN-en_XX', 'kk_KZ-en_XX', 'lt_LT-en_XX', + #in case the reversed directions have different train/valid/test data + 'en_XX-gu_IN', 'en_XX-kk_KZ', 'en_XX-lt_LT' + ], + wmt19_ru_gu_kk_lt,) + + not_matching = check_wmt_test_bleu( + f'{to_data_path}/raw', + [ + ('wmt13', ['es_XX-en_XX']), + ('wmt14/full', ['fr_XX-en_XX',]), + ('wmt16', ['ro_RO-en_XX',]), + # ('wmt17/improved', ['zh_CN-en_XX']), + ('wmt17', [ 'lv_LV-en_XX', 'fi_FI-en_XX', 'tr_TR-en_XX']), + ('wmt18', ['cs_CZ-en_XX', 'et_EE-en_XX']), + ('wmt19', ['gu_IN-en_XX', 'kk_KZ-en_XX', 'lt_LT-en_XX']), + #'ru_RU-en_XX', + ] + ) + if len(not_matching) > 0: + print('the following datasets do not have matching test datasets:\n\t', '\n\t'.join(not_matching)) + diff --git a/examples/multilingual/data_scripts/download_wmt20.sh b/examples/multilingual/data_scripts/download_wmt20.sh new file mode 100644 index 0000000000..31cd5c76b7 --- /dev/null +++ b/examples/multilingual/data_scripts/download_wmt20.sh @@ -0,0 +1,547 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + + +set -x -e + +# TODO update the workdir and dest dir name +# put fasttext model +WORKDIR=$WORKDIR_ROOT +# put intermediate files +TMP_DIR=$WORKDIR_ROOT/tmp/tmp_wmt20_lowres_download +# output {train,valid,test} files to dest +DEST=$WORKDIR_ROOT/ML50/raw + +UTILS=$PWD/utils + +# per dataset locations +COMMONCRAWL_DIR=$TMP_DIR/commoncrawl +YANDEX_CORPUS=$WORKDIR_ROOT/wmt20/official/ru/yandex/1mcorpus.zip +# unzipped +CZENG_CORPUS=$WORKDIR_ROOT/wmt20/official/cs/czeng/czeng20-train +CCMT_DIR=$WORKDIR_ROOT/wmt20/official/zh/ccmt/parallel + +download_and_select() { + SUBFOLDER=$1 + URL=$2 + UNCOMPRESS_CMD=$3 + LANG=$4 + INPUT_FILEPATH=$5 + if [[ $# -gt 5 ]]; then + LANG_COL=$6 + EN_COL=$7 + fi + + mkdir -p $SUBFOLDER + cd $SUBFOLDER + wget -nc --content-disposition $URL + $UNCOMPRESS_CMD + + if [[ $# -gt 5 ]]; then + cut -f$LANG_COL $INPUT_FILEPATH > $INPUT_FILEPATH.$LANG + cut -f$EN_COL $INPUT_FILEPATH > $INPUT_FILEPATH.en + fi + cd .. + + ln -sf $SUBFOLDER/$INPUT_FILEPATH.$LANG $SUBFOLDER.$LANG + ln -sf $SUBFOLDER/$INPUT_FILEPATH.en $SUBFOLDER.en +} + +prepare_lid() { + pip install fasttext + + # TODO specify global workdir + MODEL=$WORKDIR/fasttext/lid.176.bin + LID_MULTI=$UTILS/fasttext_multi_filter.py + + if [ ! -f "$MODEL" ]; then + echo "downloading fasttext lid model..." + mkdir -p $WORKDIR/fasttext + wget -nc https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -O $MODEL + fi +} + +prepare_moses() { + pushd $UTILS + echo 'Cloning Moses github repository (for tokenization scripts)...' + git clone https://github.com/moses-smt/mosesdecoder.git + popd +} + +lid_filter() { + # TODO specify global workdir + MODEL=$WORKDIR/fasttext/lid.176.bin + LID_MULTI=$UTILS/fasttext_multi_filter.py + + prepare_lid + + SRC=$1 + SRC_FILE=$2 + SRC_OUTPUT=$3 + TGT=$4 + TGT_FILE=$5 + TGT_OUTPUT=$6 + python $LID_MULTI --model $MODEL --inputs $SRC_FILE $TGT_FILE --langs $SRC $TGT --outputs $SRC_OUTPUT $TGT_OUTPUT +} + +prepare_ja_ted() { + mkdir -p ted + cd ted + + wget -nc https://wit3.fbk.eu/archive/2017-01-trnted//texts/en/ja/en-ja.tgz + tar -zxvf en-ja.tgz + cat en-ja/train.tags.en-ja.en | grep -v -P "^[ ]*\<" | sed 's/^[ \t]*//g' | sed 's/[ \t]*$//g' > en-ja/train.en-ja.en + cat en-ja/train.tags.en-ja.ja | grep -v -P "^[ ]*\<" | sed 's/^[ \t]*//g' | sed 's/[ \t]*$//g' > en-ja/train.en-ja.ja + + cd .. + ln -sf ted/en-ja/train.en-ja.ja ted.ja + ln -sf ted/en-ja/train.en-ja.en ted.en +} + +prepare_ja() { + OUTPUT_DIR=$TMP_DIR/ja + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl/release/2.0/bitext/en-ja.tar.gz" "tar -zxvf en-ja.tar.gz" ja en-ja/en-ja.bicleaner05.txt 4 3 & + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-ja.tsv.gz" "gunzip -f news-commentary-v15.en-ja.tsv.gz" ja news-commentary-v15.en-ja.tsv 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ja-en.tsv.gz" "gunzip -f wikititles-v2.ja-en.tsv.gz" ja wikititles-v2.ja-en.tsv 1 2 & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ja.langid.tsv.gz" "gunzip -f WikiMatrix.v1.en-ja.langid.tsv.gz" ja WikiMatrix.v1.en-ja.langid.tsv 3 2 & + download_and_select subtitle "https://nlp.stanford.edu/projects/jesc/data/split.tar.gz" "tar -zxvf split.tar.gz" ja split/train 2 1 & + download_and_select kftt "http://www.phontron.com/kftt/download/kftt-data-1.0.tar.gz" "tar -zxvf kftt-data-1.0.tar.gz" ja kftt-data-1.0/data/orig/kyoto-train & + + prepare_ja_ted & + + # ted data needs to + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ja" | sort -V | xargs cat > all.ja + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ja all.ja $DEST/train.ja_XX-en_XX.ja_XX en all.en $DEST/train.ja_XX-en_XX.en_XX +} + +prepare_ta() { + OUTPUT_DIR=$TMP_DIR/ta + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ta-en.tsv.gz" "gunzip -f wikititles-v2.ta-en.tsv.gz" ta wikititles-v2.ta-en.tsv 1 2 & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ta.langid.tsv.gz" "gunzip -f WikiMatrix.v1.en-ta.langid.tsv.gz" ta WikiMatrix.v1.en-ta.langid.tsv 3 2 & + download_and_select pmindia "http://data.statmt.org/pmindia/v1/parallel/pmindia.v1.ta-en.tsv" "" ta pmindia.v1.ta-en.tsv 2 1 & + download_and_select tanzil "https://object.pouta.csc.fi/OPUS-Tanzil/v1/moses/en-ta.txt.zip" "unzip en-ta.txt.zip" ta Tanzil.en-ta & + download_and_select pib "http://preon.iiit.ac.in/~jerin/resources/datasets/pib-v0.tar" "tar -xvf pib-v0.tar" ta pib/en-ta/train & + download_and_select mkb "http://preon.iiit.ac.in/~jerin/resources/datasets/mkb-v0.tar" "tar -xvf mkb-v0.tar" ta mkb/en-ta/mkb & + download_and_select ufal "http://ufal.mff.cuni.cz/~ramasamy/parallel/data/v2/en-ta-parallel-v2.tar.gz" "tar -zxvf en-ta-parallel-v2.tar.gz" ta en-ta-parallel-v2/corpus.bcn.train & + + wait + + # need special handling for nlpc + mkdir -p nlpc + cd nlpc + wget -nc https://raw.githubusercontent.com/nlpc-uom/English-Tamil-Parallel-Corpus/master/En-Ta%20Corpus/En-Ta%20English.txt + wget -nc https://github.com/nlpc-uom/English-Tamil-Parallel-Corpus/raw/master/En-Ta%20Corpus/En-Ta%20Tamil.txt + tail -n +4 "En-Ta English.txt" > en-ta.en + tail -n +4 "En-Ta Tamil.txt" > en-ta.ta + cd .. + ln -sf nlpc/en-ta.en nlpc.en + ln -sf nlpc/en-ta.ta nlpc.ta + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ta" | sort -V | xargs cat > all.ta + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ta all.ta $DEST/train.ta_IN-en_XX.ta_IN en all.en $DEST/train.ta_IN-en_XX.en_XX +} + +prepare_iu() { + OUTPUT_DIR=$TMP_DIR/iu + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select nh "https://nrc-digital-repository.canada.ca/eng/view/dataset/?id=c7e34fa7-7629-43c2-bd6d-19b32bf64f60" "tar -zxvf Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0.1.tgz" iu Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0/NunavutHansard > /dev/null & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.iu-en.tsv.gz" "gunzip -f wikititles-v2.iu-en.tsv.gz" iu wikititles-v2.iu-en.tsv 1 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.iu" | sort -V | xargs cat | nh/Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0/scripts/normalize-iu-spelling.pl > all.iu + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + paste all.iu all.en | awk -F $'\t' '$1!=""&&$2!=""' > all.iuen + cut -f1 all.iuen > $DEST/train.iu_CA-en_XX.iu_CA + cut -f2 all.iuen > $DEST/train.iu_CA-en_XX.en_XX +} + +prepare_km() { + OUTPUT_DIR=$TMP_DIR/km + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://data.statmt.org/wmt20/translation-task/ps-km/wmt20-sent.en-km.xz" "unxz wmt20-sent.en-km.zx" km wmt20-sent.en-km 2 1 & + + # km-parallel has multiple sets, concat all of them together + mkdir -p opus + cd opus + wget -nc "http://data.statmt.org/wmt20/translation-task/ps-km/km-parallel.tgz" + tar -zxvf km-parallel.tgz + find ./km-parallel -maxdepth 1 -name "*.km" | sort -V | xargs cat > opus.km + find ./km-parallel -maxdepth 1 -name "*.en" | sort -V | xargs cat > opus.en + cd .. + ln -sf opus/opus.km . + ln -sf opus/opus.en . + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.km" | sort -V | xargs cat > all.km + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter km all.km $DEST/train.km_KH-en_XX.km_KH en all.en $DEST/train.km_KH-en_XX.en_XX +} + +prepare_ps() { + OUTPUT_DIR=$TMP_DIR/ps + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://data.statmt.org/wmt20/translation-task/ps-km/wmt20-sent.en-ps.xz" "unxz wmt20-sent.en-ps.xz" ps wmt20-sent.en-ps 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ps-en.tsv.gz" "gunzip -f wikititles-v2.ps-en.tsv.gz" ps wikititles-v2.ps-en.tsv 1 2 & + # ps-parallel has multiple sets, concat all of them together + mkdir -p opus + cd opus + wget -nc "http://data.statmt.org/wmt20/translation-task/ps-km/ps-parallel.tgz" + tar -zxvf ps-parallel.tgz + find ./ps-parallel -maxdepth 1 -name "*.ps" | sort -V | xargs cat > opus.ps + find ./ps-parallel -maxdepth 1 -name "*.en" | sort -V | xargs cat > opus.en + cd .. + ln -sf opus/opus.ps opus.ps + ln -sf opus/opus.en opus.en + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ps" | sort -V | xargs cat > all.ps + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ps all.ps $DEST/train.ps_AF-en_XX.ps_AF en all.en $DEST/train.ps_AF-en_XX.en_XX +} + +download_commoncrawl() { + mkdir -p $COMMONCRAWL_DIR + cd $COMMONCRAWL_DIR + + wget -nc "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" + tar -zxvf training-parallel-commoncrawl.tgz +} +link_commoncrawl() { + LANG=$1 + ln -sf $COMMONCRAWL_DIR/commoncrawl.$LANG-en.en commoncrawl.en + ln -sf $COMMONCRAWL_DIR/commoncrawl.$LANG-en.$LANG commoncrawl.$LANG +} + +strip_xlf() { + INPUT_FILE=$1 + SRC=$2 + TGT=$3 + grep '<source xml:lang=' $INPUT_FILE | sed 's/^<[^<>]*>//g' | sed 's/<[^<>]*>$//g' > $INPUT_FILE.$SRC + grep '<target xml:lang=' $INPUT_FILE | sed 's/^<[^<>]*>//g' | sed 's/<[^<>]*>$//g' > $INPUT_FILE.$TGT +} + +download_and_process_tilde() { + URL=$1 + UNCOMPRESS_CMD=$2 + FILENAME=$3 + LANG=$4 + PROCESS_CMD=$5 + + mkdir -p tilde + cd tilde + wget -nc $URL + $UNCOMPRESS_CMD + echo "executing cmd" + echo $PROCESS_CMD + $PROCESS_CMD + cd .. + ln -sf tilde/$FILENAME.$LANG tilde.$LANG + ln -sf tilde/$FILENAME.en tilde.en +} + +prepare_cs() { + OUTPUT_DIR=$TMP_DIR/cs + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + #download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.cs-en.tsv.gz" "gunzip europarl-v10.cs-en.tsv.gz" cs europarl-v10.cs-en.tsv 1 2 & + #download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-cs.txt.gz" "gunzip en-cs.txt.gz" cs en-cs.txt 2 1 & + #link_commoncrawl cs + #download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.cs-en.tsv.gz" "gunzip news-commentary-v15.cs-en.tsv.gz" cs news-commentary-v15.cs-en.tsv 1 2 & + #download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.cs-en.tsv.gz" "gunzip wikititles-v2.cs-en.tsv.gz" cs wikititles-v2.cs-en.tsv 1 2 & + #download_and_process_tilde "http://data.statmt.org/wmt20/translation-task/rapid/RAPID_2019.cs-en.xlf.gz" "gunzip RAPID_2019.cs-en.xlf.gz" RAPID_2019.cs-en.xlf cs "strip_xlf RAPID_2019.cs-en.xlf cs en" & + #download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.cs-en.langid.tsv.gz" "gunzip WikiMatrix.v1.cs-en.langid.tsv.gz" cs WikiMatrix.v1.cs-en.langid.tsv 2 3 & + + #wait + + # remove previous results + #rm -f all.?? + #find ./ -maxdepth 1 -name "*.cs" | sort -V | xargs cat > all.cs + #find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + if [ -z $CZENG_CORPUS ] ; + then + echo "Please download CZENG_CORPUS manually and place them at $CZENG_CORPUS. Exitting..." + exit + fi + cat $CZENG_CORPUS | sed '/^$/d' | cut -f5 > all.cs + cat $CZENG_CORPUS | sed '/^$/d' | cut -f6 > all.en + + lid_filter cs all.cs $DEST/train.cs_CZ-en_XX.cs_CZ en all.en $DEST/train.cs_CZ-en_XX.en_XX +} + +prepare_de() { + OUTPUT_DIR=$TMP_DIR/de + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.de-en.tsv.gz" "gunzip europarl-v10.de-en.tsv.gz" de europarl-v10.de-en.tsv 1 2 & + download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-de.txt.gz" "gunzip en-de.txt.gz" de en-de.txt 2 1 & + link_commoncrawl de + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.de-en.tsv.gz" "gunzip news-commentary-v15.de-en.tsv.gz" de news-commentary-v15.de-en.tsv 1 2 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.de-en.tsv.gz" "gunzip wikititles-v2.de-en.tsv.gz" de wikititles-v2.de-en.tsv 1 2 & + download_and_process_tilde "http://data.statmt.org/wmt20/translation-task/rapid/RAPID_2019.de-en.xlf.gz" "gunzip RAPID_2019.de-en.xlf.gz" RAPID_2019.de-en.xlf de "strip_xlf RAPID_2019.de-en.xlf de en" & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.de-en.langid.tsv.gz" "gunzip WikiMatrix.v1.de-en.langid.tsv.gz" de WikiMatrix.v1.de-en.langid.tsv 2 3 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.de" | sort -V | xargs cat > all.de + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter de all.de $DEST/train.de_DE-en_XX.de_DE en all.en $DEST/train.de_DE-en_XX.en_XX +} + +prepare_tmx() { + TMX_FILE=$1 + git clone https://github.com/amake/TMX2Corpus $UTILS/tmx2corpus + pip install tinysegmenter + + python $UTILS/tmx2corpus/tmx2corpus.py $TMX_FILE +} + +prepare_pl() { + OUTPUT_DIR=$TMP_DIR/pl + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + # download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.pl-en.tsv.gz" "gunzip europarl-v10.pl-en.tsv.gz" pl europarl-v10.pl-en.tsv 1 2 & + # download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-pl.txt.gz" "gunzip en-pl.txt.gz" pl en-pl.txt 2 1 & + # download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.pl-en.tsv.gz" "gunzip wikititles-v2.pl-en.tsv.gz" pl wikititles-v2.pl-en.tsv 1 2 & + download_and_select tilde "https://tilde-model.s3-eu-west-1.amazonaws.com/rapid2019.en-pl.tmx.zip" "gunzip rapid2019.en-pl.tmx.zip" bitext pl "prepare_tmx RAPID_2019.UNIQUE.en-pl.tmx" & + # download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-pl.langid.tsv.gz" "gunzip WikiMatrix.v1.en-pl.langid.tsv.gz" pl WikiMatrix.v1.en-pl.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.pl" | sort -V | xargs cat > all.pl + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter pl all.pl $DEST/train.pl_PL-en_XX.pl_PL en all.en $DEST/train.pl_PL-en_XX.en_XX +} + +prepare_uncorpus() { + $URLS=$1 + $FILES=$2 + + mkdir -p uncorpus + cd uncorpus + + for URL in $URLS; do + wget -nc $URL + done + cat $FILES > uncorpus.tar.gz + tar -zxvf uncorpus.tar.gz + + cd .. + ln -sf uncorpus/en-$LANG/UNv1.0.en-$LANG.$LANG uncorpus.$LANG + ln -sf uncorpus/en-$LANG/UNv1.0.en-$LANG.en uncorpus.en +} + +prepare_yandex() { + mkdir -p yandex + cd yandex + unzip $YANDEX_CORPUS ./ + cd .. + ln -s yandex/corpus.en_ru.1m.en yandex.en + ln -s yandex/corpus.en_ru.1m.ru yandex.ru +} + +prepare_ru() { + OUTPUT_DIR=$TMP_DIR/ru + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" "tar -zxvf paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" ru paracrawl-release1.en-ru.zipporah0-dedup-clean & + link_commoncrawl ru + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-ru.tsv.gz" "gunzip news-commentary-v15.en-ru.tsv.gz" ru news-commentary-v15.en-ru.tsv 2 1 & + prepare_yandex & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ru-en.tsv.gz" "gunzip wikititles-v2.ru-en.tsv.gz" ru wikititles-v2.ru-en.tsv 1 2 & + prepare_uncorpus "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" "UNv1.0.en-ru.tar.gz.00 UNv1.0.en-ru.tar.gz.01 UNv1.0.en-ru.tar.gz.02" & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ru.langid.tsv.gz" "gunzip WikiMatrix.v1.en-ru.langid.tsv.gz" ru WikiMatrix.v1.en-ru.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ru" | sort -V | xargs cat > all.ru + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ru all.ru $DEST/train.ru_RU-en_XX.ru_RU en all.en $DEST/train.ru_RU-en_XX.en_XX +} + +prepare_ccmt() { + mkdir -p ccmt + cd ccmt + # assume ccmt data is already unzipped under CCMT_DIR folder + cat $CCMT_DIR/datum2017/Book*_cn.txt | sed 's/ //g' > datum2017.detok.zh + cat $CCMT_DIR/datum2017/Book*_en.txt > datum2017.detok.en + cat $CCMT_DIR/casict2011/casict-A_ch.txt $CCMT_DIR/casict2011/casict-B_ch.txt $CCMT_DIR/casict2015/casict2015_ch.txt $CCMT_DIR/datum2015/datum_ch.txt $CCMT_DIR/neu2017/NEU_cn.txt datum2017.detok.zh > ccmt.zh + cat $CCMT_DIR/casict2011/casict-A_en.txt $CCMT_DIR/casict2011/casict-B_en.txt $CCMT_DIR/casict2015/casict2015_en.txt $CCMT_DIR/datum2015/datum_en.txt $CCMT_DIR/neu2017/NEU_en.txt datum2017.detok.en > ccmt.en + cd .. + ln -sf ccmt/ccmt.zh ccmt.zh + ln -sf ccmt/ccmt.en ccmt.en +} + +prepare_zh() { + OUTPUT_DIR=$TMP_DIR/zh + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-zh.tsv.gz" "gunzip news-commentary-v15.en-zh.tsv.gz" zh news-commentary-v15.en-zh.tsv 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.zh-en.tsv.gz" "gunzip wikititles-v2.zh-en.tsv.gz" zh wikititles-v2.zh-en.tsv 1 2 & + prepare_uncorpus "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" "UNv1.0.en-zh.tar.gz.00 UNv1.0.en-zh.tar.gz.01" & + prepare_ccmt & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-zh.langid.tsv.gz" "gunzip WikiMatrix.v1.en-zh.langid.tsv.gz" zh WikiMatrix.v1.en-zh.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.zh" | sort -V | xargs cat > all.zh + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter zh all.zh $DEST/train.zh_CN-en_XX.zh_CN en all.en $DEST/train.zh_CN-en_XX.en_XX +} + +prepare_tests() { + OUTPUT_DIR=$TMP_DIR + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + wget -nc http://data.statmt.org/wmt20/translation-task/dev.tgz + tar -zxvf dev.tgz + cd dev + + cat newsdev2020-jaen-src.ja.sgm | $UTILS/strip_sgm.sh > newsdev2020-jaen.ja + cat newsdev2020-jaen-ref.en.sgm | $UTILS/strip_sgm.sh > newsdev2020-jaen.en + split newsdev2020-jaen.ja -a 0 -n r/1/2 > $DEST/valid.ja_XX-en_XX.ja_XX + split newsdev2020-jaen.en -a 0 -n r/1/2 > $DEST/valid.ja_XX-en_XX.en_XX + split newsdev2020-jaen.ja -a 0 -n r/2/2 > $DEST/test.ja_XX-en_XX.ja_XX + split newsdev2020-jaen.en -a 0 -n r/2/2 > $DEST/test.ja_XX-en_XX.en_XX + + cat newsdev2020-iuen-src.iu.sgm | strip_sgm.sh > newsdev2020-iuen.iu + cat newsdev2020-iuen-ref.en.sgm | strip_sgm.sh > newsdev2020-iuen.en + split newsdev2020-iuen.iu -a 0 -n r/1/2 > $DEST/valid.iu_CA-en_XX.iu_CA + split newsdev2020-iuen.en -a 0 -n r/1/2 > $DEST/valid.iu_CA-en_XX.en_XX + split newsdev2020-iuen.iu -a 0 -n r/2/2 > $DEST/test.iu_CA-en_XX.iu_CA + split newsdev2020-iuen.en -a 0 -n r/2/2 > $DEST/test.iu_CA-en_XX.en_XX + + cat newsdev2020-taen-src.ta.sgm | strip_sgm.sh > newsdev2020-taen.ta + cat newsdev2020-taen-ref.en.sgm | strip_sgm.sh > newsdev2020-taen.en + split newsdev2020-taen.ta -a 0 -n r/1/2 > $DEST/valid.ta_IN-en_XX.ta_IN + split newsdev2020-taen.en -a 0 -n r/1/2 > $DEST/valid.ta_IN-en_XX.en_XX + split newsdev2020-taen.ta -a 0 -n r/2/2 > $DEST/test.ta_IN-en_XX.ta_IN + split newsdev2020-taen.en -a 0 -n r/2/2 > $DEST/test.ta_IN-en_XX.en_XX + + cp wikipedia.dev.km-en.km $DEST/valid.km_KH-en_XX.km_KH + cp wikipedia.dev.km-en.en $DEST/valid.km_KH-en_XX.en_XX + cp wikipedia.devtest.km-en.km $DEST/test.km_KH-en_XX.km_KH + cp wikipedia.devtest.km-en.en $DEST/test.km_KH-en_XX.en_XX + + cp wikipedia.dev.ps-en.ps $DEST/valid.ps_AF-en_XX.ps_AF + cp wikipedia.dev.ps-en.en $DEST/valid.ps_AF-en_XX.en_XX + cp wikipedia.devtest.ps-en.ps $DEST/test.ps_AF-en_XX.ps_AF + cp wikipedia.devtest.ps-en.en $DEST/test.ps_AF-en_XX.en_XX + + cat newsdev2020-plen-src.pl.sgm | strip_sgm.sh > newsdev2020-plen.pl + cat newsdev2020-plen-ref.en.sgm | strip_sgm.sh > newsdev2020-plen.en + split newsdev2020-plen.pl -a 0 -n r/1/2 > $DEST/valid.pl_PL-en_XX.pl_PL + split newsdev2020-plen.en -a 0 -n r/1/2 > $DEST/valid.pl_PL-en_XX.en_XX + split newsdev2020-plen.pl -a 0 -n r/2/2 > $DEST/test.pl_PL-en_XX.pl_PL + split newsdev2020-plen.en -a 0 -n r/2/2 > $DEST/test.pl_PL-en_XX.en_XX + + cat newstest2018-encs-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-cs_CZ.en_XX + cat newstest2018-encs-ref.cs.sgm | strip_sgm.sh > $DEST/valid.en_XX-cs_CZ.cs_CZ + cat newstest2019-encs-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-cs_CZ.en_XX + cat newstest2019-encs-ref.cs.sgm | strip_sgm.sh > $DEST/test.en_XX-cs_CZ.cs_CZ + + cat newstest2018-deen-src.de.sgm | strip_sgm.sh > $DEST/valid.de_DE-en_XX.de_DE + cat newstest2018-deen-ref.en.sgm | strip_sgm.sh > $DEST/valid.de_DE-en_XX.en_XX + cat newstest2018-ende-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-de_DE.en_XX + cat newstest2018-ende-ref.de.sgm | strip_sgm.sh > $DEST/valid.en_XX-de_DE.de_DE + cat newstest2019-deen-src.de.sgm | strip_sgm.sh > $DEST/test.de_DE-en_XX.de_DE + cat newstest2019-deen-ref.en.sgm | strip_sgm.sh > $DEST/test.de_DE-en_XX.en_XX + cat newstest2019-ende-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-de_DE.en_XX + cat newstest2019-ende-ref.de.sgm | strip_sgm.sh > $DEST/test.en_XX-de_DE.de_DE + + cat newstest2018-ruen-src.ru.sgm | strip_sgm.sh > $DEST/valid.ru_RU-en_XX.ru_RU + cat newstest2018-ruen-ref.en.sgm | strip_sgm.sh > $DEST/valid.ru_RU-en_XX.en_XX + cat newstest2018-enru-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-ru_RU.en_XX + cat newstest2018-enru-ref.ru.sgm | strip_sgm.sh > $DEST/valid.en_XX-ru_RU.ru_RU + cat newstest2019-ruen-src.ru.sgm | strip_sgm.sh > $DEST/test.ru_RU-en_XX.ru_RU + cat newstest2019-ruen-ref.en.sgm | strip_sgm.sh > $DEST/test.ru_RU-en_XX.en_XX + cat newstest2019-enru-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-ru_RU.en_XX + cat newstest2019-enru-ref.ru.sgm | strip_sgm.sh > $DEST/test.en_XX-ru_RU.ru_RU + + cat newstest2018-zhen-src.zh.sgm | strip_sgm.sh > $DEST/valid.zh_CN-en_XX.zh_CN + cat newstest2018-zhen-ref.en.sgm | strip_sgm.sh > $DEST/valid.zh_CN-en_XX.en_XX + cat newstest2018-enzh-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-zh_CN.en_XX + cat newstest2018-enzh-ref.zh.sgm | strip_sgm.sh > $DEST/valid.en_XX-zh_CN.zh_CN + cat newstest2019-zhen-src.zh.sgm | strip_sgm.sh > $DEST/test.zh_CN-en_XX.zh_CN + cat newstest2019-zhen-ref.en.sgm | strip_sgm.sh > $DEST/test.zh_CN-en_XX.en_XX + cat newstest2019-enzh-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-zh_CN.en_XX + cat newstest2019-enzh-ref.zh.sgm | strip_sgm.sh > $DEST/test.en_XX-zh_CN.zh_CN +} + +mkdir -p $DEST + +prepare_lid +prepare_moses +download_commoncrawl + +prepare_ja & +prepare_ta & +prepare_km & +prepare_ps & +prepare_iu & +prepare_cs & +prepare_de & +prepare_pl & +prepare_ru & +prepare_zh & + +# prepare valid/test set +prepare_tests & + +# wait + +# TODO remove intermediate files +# rm -rf $TMP_DIR diff --git a/examples/multilingual/data_scripts/preprocess_ML50_v1.sh b/examples/multilingual/data_scripts/preprocess_ML50_v1.sh new file mode 100644 index 0000000000..4655936149 --- /dev/null +++ b/examples/multilingual/data_scripts/preprocess_ML50_v1.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +if [ -z $SPM_PATH ] ; +then + echo "Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting..." + exit +fi + +ML50=${WORKDIR_ROOT}/ML50 + +mkdir -p $ML50/dedup +mkdir -p $ML50/cleaned_dedup + +python ./dedup_all.py --from-folder $ML50/raw --to-folder $ML50/dedup +python ./remove_valid_test_in_train.py --from-folder $ML50/dedup --to-folder $ML50/clean +python ./binarize.py --raw-folder $ML50/clean \ No newline at end of file diff --git a/examples/multilingual/data_scripts/remove_valid_test_in_train.py b/examples/multilingual/data_scripts/remove_valid_test_in_train.py new file mode 100755 index 0000000000..ef618adef7 --- /dev/null +++ b/examples/multilingual/data_scripts/remove_valid_test_in_train.py @@ -0,0 +1,290 @@ +import os, sys +import glob, itertools +import pandas as pd + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + + + +def load_sentences(raw_data, split, direction): + src, tgt = direction.split('-') + src_path = f"{raw_data}/{split}.{direction}.{src}" + tgt_path = f"{raw_data}/{split}.{direction}.{tgt}" + if os.path.exists(src_path) and os.path.exists(tgt_path): + return [(src, open(src_path).read().splitlines()), (tgt, open(tgt_path).read().splitlines())] + else: + return [] + +def swap_direction(d): + src, tgt = d.split('-') + return f'{tgt}-{src}' + +def get_all_test_data(raw_data, directions, split='test'): + test_data = [ + x + for dd in directions + for d in [dd, swap_direction(dd)] + for x in load_sentences(raw_data, split, d) + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_data = {} + for lang, d in test_data: + for s in d: + s = s.strip() + lgs = all_test_data.get(s, set()) + lgs.add(lang) + all_test_data[s] = lgs + return all_test_data, test_data + +def check_train_sentences(raw_data, direction, all_test_data, mess_up_train={}): + src, tgt = direction.split('-') + tgt_path = f"{raw_data}/train.{direction}.{tgt}" + src_path = f"{raw_data}/train.{direction}.{src}" + print(f'check training data in {raw_data}/train.{direction}') + size = 0 + if not os.path.exists(tgt_path) or not os.path.exists(src_path): + return mess_up_train, size + with open(src_path) as f, open(tgt_path) as g: + for src_line, tgt_line in zip(f, g): + s = src_line.strip() + t = tgt_line.strip() + size += 1 + if s in all_test_data: + langs = mess_up_train.get(s, set()) + langs.add(direction) + mess_up_train[s] = langs + if t in all_test_data: + langs = mess_up_train.get(t, set()) + langs.add(direction) + mess_up_train[t] = langs + return mess_up_train, size + +def check_train_all(raw_data, directions, all_test_data): + mess_up_train = {} + data_sizes = {} + for direction in directions: + _, size = check_train_sentences(raw_data, direction, all_test_data, mess_up_train) + data_sizes[direction] = size + return mess_up_train, data_sizes + +def count_train_in_other_set(mess_up_train): + train_in_others = [(direction, s) for s, directions in mess_up_train.items() for direction in directions] + counts = {} + for direction, s in train_in_others: + counts[direction] = counts.get(direction, 0) + 1 + return counts + +def train_size_if_remove_in_otherset(data_sizes, mess_up_train): + counts_in_other = count_train_in_other_set(mess_up_train) + remain_sizes = [] + for direction, count in counts_in_other.items(): + remain_sizes.append((direction, data_sizes[direction] - count, data_sizes[direction], count, 100 * count / data_sizes[direction] )) + return remain_sizes + + +def remove_messed_up_sentences(raw_data, direction, mess_up_train, mess_up_train_pairs, corrected_langs): + split = 'train' + src_lang, tgt_lang = direction.split('-') + + tgt = f"{raw_data}/{split}.{direction}.{tgt_lang}" + src = f"{raw_data}/{split}.{direction}.{src_lang}" + print(f'working on {direction}: ', src, tgt) + if not os.path.exists(tgt) or not os.path.exists(src) : + return + + corrected_tgt = f"{to_folder}/{split}.{direction}.{tgt_lang}" + corrected_src = f"{to_folder}/{split}.{direction}.{src_lang}" + line_num = 0 + keep_num = 0 + with open(src, encoding='utf8',) as fsrc, \ + open(tgt, encoding='utf8',) as ftgt, \ + open(corrected_src, 'w', encoding='utf8') as fsrc_corrected, \ + open(corrected_tgt, 'w', encoding='utf8') as ftgt_corrected: + for s, t in zip(fsrc, ftgt): + s = s.strip() + t = t.strip() + if t not in mess_up_train \ + and s not in mess_up_train \ + and (s, t) not in mess_up_train_pairs \ + and (t, s) not in mess_up_train_pairs: + corrected_langs.add(direction) + print(s, file=fsrc_corrected) + print(t, file=ftgt_corrected) + keep_num += 1 + line_num += 1 + if line_num % 1000 == 0: + print(f'completed {line_num} lines', end='\r') + return line_num, keep_num + +########## + + +def merge_valid_test_messup(mess_up_train_valid, mess_up_train_test): + merged_mess = [] + for s in set(list(mess_up_train_valid.keys()) + list(mess_up_train_test.keys())): + if not s: + continue + valid = mess_up_train_valid.get(s, set()) + test = mess_up_train_test.get(s, set()) + merged_mess.append((s, valid | test)) + return dict(merged_mess) + + + +######### +def check_train_pairs(raw_data, direction, all_test_data, mess_up_train={}): + src, tgt = direction.split('-') + #a hack; TODO: check the reversed directions + path1 = f"{raw_data}/train.{src}-{tgt}.{src}" + path2 = f"{raw_data}/train.{src}-{tgt}.{tgt}" + if not os.path.exists(path1) or not os.path.exists(path2) : + return + + with open(path1) as f1, open(path2) as f2: + for src_line, tgt_line in zip(f1, f2): + s = src_line.strip() + t = tgt_line.strip() + if (s, t) in all_test_data or (t, s) in all_test_data: + langs = mess_up_train.get( (s, t), set()) + langs.add(src) + langs.add(tgt) + mess_up_train[(s, t)] = langs + + +def load_pairs(raw_data, split, direction): + src, tgt = direction.split('-') + src_f = f"{raw_data}/{split}.{direction}.{src}" + tgt_f = f"{raw_data}/{split}.{direction}.{tgt}" + if tgt != 'en_XX': + src_f, tgt_f = tgt_f, src_f + if os.path.exists(src_f) and os.path.exists(tgt_f): + return list(zip(open(src_f).read().splitlines(), + open(tgt_f).read().splitlines(), + )) + else: + return [] + +# skip_langs = ['cs_CZ', 'en_XX', 'tl_XX', 'tr_TR'] +def get_messed_up_test_pairs(split, directions): + test_pairs = [ + (d, load_pairs(raw_data, split, d)) + for d in directions + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_pairs = {} + for direction, d in test_pairs: + src, tgt = direction.split('-') + for s in d: + langs = all_test_pairs.get(s, set()) + langs.add(src) + langs.add(tgt) + all_test_pairs[s] = langs + mess_up_train_pairs = {} + for direction in directions: + check_train_pairs(raw_data, direction, all_test_pairs, mess_up_train_pairs) + return all_test_pairs, mess_up_train_pairs + + + +if __name__ == "__main__": + ####### + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--from-folder', + required=True, + type=str) + parser.add_argument( + '--to-folder', + required=True, + type=str) + parser.add_argument( + '--directions', + default=None, + type=str) + + + args = parser.parse_args() + raw_data = args.from_folder + to_folder = args.to_folder + os.makedirs(to_folder, exist_ok=True) + + if args.directions: + directions = args.directions.split(',') + else: + raw_files = itertools.chain( + glob.glob(f'{raw_data}/train*'), + glob.glob(f'{raw_data}/valid*'), + glob.glob(f'{raw_data}/test*'), + ) + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + print('working on directions: ', directions) + + ########## + + + + all_test_data, test_data = get_all_test_data(raw_data, directions, 'test') + print('==loaded test data==') + all_valid_data, valid_data = get_all_test_data(raw_data, directions, 'valid') + print('==loaded valid data==') + all_valid_test_data = merge_valid_test_messup(all_test_data, all_valid_data) + mess_up_train, data_sizes = check_train_all(raw_data, directions, all_valid_test_data) + print('training messing up with valid, test data:', len(mess_up_train)) + data_situation = train_size_if_remove_in_otherset(data_sizes, mess_up_train) + df = pd.DataFrame(data_situation, columns=['direction', 'train_size_after_remove', 'orig_size', 'num_to_remove', 'remove_percent']) + df.sort_values('remove_percent', ascending=False) + df.to_csv(f'{raw_data}/clean_summary.tsv', sep='\t') + print(f'projected data clean summary in: {raw_data}/clean_summary.tsv') + + # correct the dataset: + all_test_pairs, mess_up_test_train_pairs = get_messed_up_test_pairs('test', directions) + all_valid_pairs, mess_up_valid_train_pairs = get_messed_up_test_pairs('valid', directions) + + all_messed_pairs = set(mess_up_test_train_pairs.keys()).union(set(mess_up_valid_train_pairs.keys())) + corrected_directions = set() + + real_data_situation = [] + for direction in directions: + org_size, new_size = remove_messed_up_sentences(raw_data, direction, mess_up_train, all_messed_pairs, corrected_directions) + if org_size == 0: + print(f"{direction} has size 0") + continue + real_data_situation.append( + (direction, new_size, org_size, org_size - new_size, (org_size - new_size) / org_size * 100) + ) + print('corrected directions: ', corrected_directions) + df = pd.DataFrame(real_data_situation, columns=['direction', 'train_size_after_remove', 'orig_size', 'num_to_remove', 'remove_percent']) + df.sort_values('remove_percent', ascending=False) + df.to_csv(f'{raw_data}/actual_clean_summary.tsv', sep='\t') + print(f'actual data clean summary (which can be different from the projected one because of duplications) in: {raw_data}/actual_clean_summary.tsv') + + import shutil + for direction in directions: + src_lang, tgt_lang = direction.split('-') + for split in ['train', 'valid', 'test']: + # copying valid, test and uncorrected train + if direction in corrected_directions and split == 'train': + continue + tgt = f"{raw_data}/{split}.{direction}.{tgt_lang}" + src = f"{raw_data}/{split}.{direction}.{src_lang}" + if not (os.path.exists(src) and os.path.exists(tgt)): + continue + corrected_tgt = f"{to_folder}/{split}.{direction}.{tgt_lang}" + corrected_src = f"{to_folder}/{split}.{direction}.{src_lang}" + print(f'copying {src} to {corrected_src}') + shutil.copyfile(src, corrected_src) + print(f'copying {tgt} to {corrected_tgt}') + shutil.copyfile(tgt, corrected_tgt) + + print('completed') \ No newline at end of file diff --git a/examples/multilingual/data_scripts/requirement.txt b/examples/multilingual/data_scripts/requirement.txt new file mode 100644 index 0000000000..e85d7d540e --- /dev/null +++ b/examples/multilingual/data_scripts/requirement.txt @@ -0,0 +1,2 @@ +wget +pandas \ No newline at end of file diff --git a/examples/multilingual/data_scripts/utils/dedup.py b/examples/multilingual/data_scripts/utils/dedup.py new file mode 100644 index 0000000000..d6fed8c695 --- /dev/null +++ b/examples/multilingual/data_scripts/utils/dedup.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse + +def deup(src_file, tgt_file, src_file_out, tgt_file_out): + seen = set() + dup_count = 0 + with open(src_file, encoding='utf-8') as fsrc, \ + open(tgt_file, encoding='utf-8') as ftgt, \ + open(src_file_out, 'w', encoding='utf-8') as fsrc_out, \ + open(tgt_file_out, 'w', encoding='utf-8') as ftgt_out: + for s, t in zip(fsrc, ftgt): + if (s, t) not in seen: + fsrc_out.write(s) + ftgt_out.write(t) + seen.add((s, t)) + else: + dup_count += 1 + print(f'number of duplication: {dup_count}') + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--src-file", type=str, required=True, + help="src file") + parser.add_argument("--tgt-file", type=str, required=True, + help="tgt file") + parser.add_argument("--src-file-out", type=str, required=True, + help="src ouptut file") + parser.add_argument("--tgt-file-out", type=str, required=True, + help="tgt ouput file") + args = parser.parse_args() + deup(args.src_file, args.tgt_file, args.src_file_out, args.tgt_file_out) + + +if __name__ == "__main__": + main() diff --git a/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py b/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py new file mode 100644 index 0000000000..41b38ba5be --- /dev/null +++ b/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +#!/bin/python + +import fasttext +from multiprocessing import Pool +import contextlib +import sys +import argparse +from functools import partial +import io + +model = None +def init(model_path): + global model + model = fasttext.load_model(model_path) + +def pred(lines): + return lines, [model.predict(line.strip())[0][0][9:] for line in lines] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, + help="model to load") + parser.add_argument("--inputs", nargs="+", default=['-'], + help="input files to filter") + parser.add_argument("--langs", nargs="+", required=True, + help="lang ids of each input file") + parser.add_argument("--outputs", nargs="+", default=['-'], + help="path to save lid filtered outputs") + parser.add_argument("--num-workers", type=int, metavar="N", default=10, + help="number of processes in parallel") + args = parser.parse_args() + + assert len(args.inputs) == len(args.langs) and len(args.inputs) == len(args.outputs) + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8", newline="\n", errors="replace")) + if input != "-" else io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors="replace") + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8", newline="\n")) + if output != "-" else sys.stdout + for output in args.outputs + ] + with Pool(args.num_workers, initializer=partial(init, args.model)) as p: + skip_cnt = 0 + for lines, preds in p.imap(pred, list(zip(*inputs)), chunksize=500): + if not all(a == b for a, b in zip(preds, args.langs)): + skip_cnt += 1 + continue + for line, output_h in zip(lines, outputs): + print(line.strip(), file=output_h) + print(f"Skipped {skip_cnt} lines.") + +if __name__ == "__main__": + main() diff --git a/examples/multilingual/data_scripts/utils/strip_sgm.sh b/examples/multilingual/data_scripts/utils/strip_sgm.sh new file mode 100755 index 0000000000..7f4f61d7b1 --- /dev/null +++ b/examples/multilingual/data_scripts/utils/strip_sgm.sh @@ -0,0 +1 @@ +grep "seg id" | sed 's/<seg id="[0-9]\+">//g' | sed 's/<\/seg>//g' diff --git a/examples/multilingual/finetune_multilingual_model.sh b/examples/multilingual/finetune_multilingual_model.sh index ffcf1fc722..25960c5dc8 100644 --- a/examples/multilingual/finetune_multilingual_model.sh +++ b/examples/multilingual/finetune_multilingual_model.sh @@ -1,4 +1,9 @@ #!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. path_2_data=$1 # <path to data> which contains binarized data for each directions lang_list=$2 # <path to a file which contains a list of languages separted by new lines> diff --git a/examples/multilingual/multilingual_fairseq_gen.sh b/examples/multilingual/multilingual_fairseq_gen.sh index 8c2c7703b2..65aa322d7d 100644 --- a/examples/multilingual/multilingual_fairseq_gen.sh +++ b/examples/multilingual/multilingual_fairseq_gen.sh @@ -1,4 +1,9 @@ #!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. lang_pairs="en-fr,en-cs,fr-en,cs-en" path_2_data=$1 # <path to data> diff --git a/examples/multilingual/train_multilingual_model.sh b/examples/multilingual/train_multilingual_model.sh index c41730dfcd..cc050bd3f0 100644 --- a/examples/multilingual/train_multilingual_model.sh +++ b/examples/multilingual/train_multilingual_model.sh @@ -1,4 +1,9 @@ #!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. path_2_data=$1 # <path to data> which contains binarized data for each directions lang_list=$2 # <path to a file which contains a list of languages separted by new lines> From 8c7793b9d9ad272a4bee080839357539743a99d4 Mon Sep 17 00:00:00 2001 From: Yuqing Tang <yuqtang@fb.com> Date: Tue, 12 Jan 2021 21:37:26 -0800 Subject: [PATCH 397/707] Enable translation_multi_simple_epoch to load only two dictionaries for source and target only Summary: In the default settings, the translation_multi_simple_epoch task load a dictionary per language which can result in huge amount of memory consumption if all languages share the same dictionary. Reviewed By: shruti-bh Differential Revision: D25265741 fbshipit-source-id: c5bc3664efd800b120f015b2525c9fba2b1be3c5 --- .../multilingual/multilingual_data_manager.py | 106 +++++++++++++----- .../tasks/translation_multi_simple_epoch.py | 12 +- tests/test_binaries.py | 56 +++++++++ 3 files changed, 143 insertions(+), 31 deletions(-) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 21fb23c047..a2fae5bf52 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -39,6 +39,9 @@ logger = logging.getLogger(__name__) +SRC_DICT_NAME = 'src' +TGT_DICT_NAME = 'tgt' + def _lang_id(dic: Dictionary, lang: str): """Return language ID index.""" @@ -59,6 +62,15 @@ def __init__(self, args, lang_pairs, langs, dicts, sampling_method): self.args = args self.seed = args.seed self.lang_pairs = lang_pairs + self.extra_lang_pairs = ( + list( + {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} + ) + if args.extra_lang_pairs + else [] + ) + self.src_langs = {p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs} + self.tgt_langs = {p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs} self.langs = langs self.dicts = dicts self.lang_dict = self.create_lang_dictionary(self.langs) @@ -98,6 +110,10 @@ def add_args(parser): "note that the ordering determines language token IDs; " "--langs and --lang-dict are two exclusive options", ) + parser.add_argument('--source-dict', default=None, type=str, + help='path to source dictionary; if specified it will override per language dictionary loading') + parser.add_argument('--target-dict', default=None, type=str, + help='path to target dictionary; if specified it will override per language dictionary loading') parser.add_argument( "--lang-tok-style", default=LangTokStyle.multilingual.value, @@ -346,7 +362,28 @@ def check_langs(langs, pairs): ), ) - # load dictionaries + def load_dictionary_and_postproc(path): + d = load_dictionary(path) + augment_dictionary( + dictionary=d, + language_list=language_list, + lang_tok_style=args.lang_tok_style, + langtoks_specs=args.langtoks_specs, + extra_data=args.extra_data, + ) + return d + + dicts = cls.load_all_dictionaries(args, language_list, load_dictionary_and_postproc, training) + return language_list, dicts, training + + @classmethod + def load_all_dictionaries(cls, args, language_list, load_dictionary, training): + dicts = OrderedDict() + if args.source_dict is not None: + dicts[SRC_DICT_NAME] = load_dictionary(args.source_dict) + if args.target_dict is not None: + dicts[TGT_DICT_NAME] = load_dictionary(args.target_dict) + if training: extra_lang_pairs = ( list( @@ -355,35 +392,52 @@ def check_langs(langs, pairs): if args.extra_lang_pairs else [] ) - langs_to_load_dicts = sorted( - {x for p in args.lang_pairs + extra_lang_pairs for x in p.split("-")} + src_langs_to_load_dicts = sorted( + {p.split("-")[0] for p in (args.lang_pairs + extra_lang_pairs)} + ) + tgt_langs_to_load_dicts = sorted( + {p.split("-")[1] for p in (args.lang_pairs + extra_lang_pairs)} ) else: - langs_to_load_dicts = sorted([args.source_lang, args.target_lang]) + src_langs_to_load_dicts = [args.source_lang] + tgt_langs_to_load_dicts = [args.target_lang] - dicts = OrderedDict() paths = utils.split_paths(args.data) assert len(paths) > 0 - for lang in langs_to_load_dicts: - if args.fixed_dictionary is not None: - dicts[lang] = load_dictionary(args.fixed_dictionary) - else: + + def load_dicts(langs_to_load_dicts): + for lang in langs_to_load_dicts: dicts[lang] = load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(lang)) ) - augment_dictionary( - dictionary=dicts[lang], - language_list=language_list, - lang_tok_style=args.lang_tok_style, - langtoks_specs=args.langtoks_specs, - extra_data=args.extra_data, - ) if len(dicts) > 0: - assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad() - assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos() - assert dicts[lang].unk() == dicts[langs_to_load_dicts[0]].unk() + dict0 = next(iter(dicts.values())) + assert dicts[lang].pad() == dict0.pad() + assert dicts[lang].eos() == dict0.eos() + assert dicts[lang].unk() == dict0.unk() logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) - return language_list, dicts, training + + if args.fixed_dictionary is not None: + fixed_dict = load_dictionary(args.fixed_dictionary) + dicts = {lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts} + else: + if args.source_dict is None: + load_dicts(src_langs_to_load_dicts) + if args.target_dict is None: + load_dicts(tgt_langs_to_load_dicts) + return dicts + + def get_source_dictionary(self, lang): + if self.args.source_dict is not None: + return self.dicts[SRC_DICT_NAME] + else: + return self.dicts[lang] + + def get_target_dictionary(self, lang): + if self.args.target_dict is not None: + return self.dicts[TGT_DICT_NAME] + else: + return self.dicts[lang] @classmethod def create_lang_dictionary(cls, langs): @@ -418,7 +472,7 @@ def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec ) return self.get_langtok_index( - langtok, self.dicts[src_lang if src_lang else tgt_lang] + langtok, self.get_source_dictionary(src_lang) if src_lang else self.get_target_dictionary(tgt_lang) ) def get_decoder_langtok(self, tgt_lang, spec=None): @@ -427,7 +481,7 @@ def get_decoder_langtok(self, tgt_lang, spec=None): langtok = get_lang_tok( lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec ) - return self.get_langtok_index(langtok, self.dicts[tgt_lang]) + return self.get_langtok_index(langtok, self.get_target_dictionary(tgt_lang)) @classmethod def load_data(cls, path, vdict, impl): @@ -760,9 +814,9 @@ def load_a_dataset( if self.args.lang_tok_replacing_bos_eos: ds = self.alter_dataset_langtok( langpair_ds, - src_eos=self.dicts[src if src else tgt].eos(), + src_eos=self.get_source_dictionary(src).eos() if src else self.get_target_dictionary(tgt).eos(), src_lang=src, - tgt_eos=self.dicts[tgt].eos(), + tgt_eos=self.get_target_dictionary(tgt).eos(), tgt_lang=tgt, src_langtok_spec=src_langtok_spec, tgt_langtok_spec=tgt_langtok_spec, @@ -906,11 +960,11 @@ def get_split_data_param_list(self, split, epoch, shard_epoch=None): "data_path": data_path, "split": split, "src": src, - "src_dict": self.dicts[src] + "src_dict": self.get_source_dictionary(src) if src and data_category != "mono_dae" else None, "tgt": tgt, - "tgt_dict": self.dicts[tgt], + "tgt_dict": self.get_target_dictionary(tgt), "data_category": data_category, "langtok_spec": lang_tok_spec, } diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index 34af9bf4a3..6f36e5b93e 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -105,8 +105,10 @@ def __init__(self, args, langs, dicts, training): args, self.lang_pairs, langs, dicts, self.sampling_method ) - @classmethod - def check_dicts(cls, dicts, source_langs, target_langs): + def check_dicts(self, dicts, source_langs, target_langs): + if self.args.source_dict is not None or self.args.target_dict is not None: + # no need to check whether the source side and target side are sharing dictionaries + return src_dict = dicts[source_langs[0]] tgt_dict = dicts[target_langs[0]] for src_lang in source_langs: @@ -123,7 +125,7 @@ def check_dicts(cls, dicts, source_langs, target_langs): @classmethod def setup_task(cls, args, **kwargs): langs, dicts, training = MultilingualDatasetManager.prepare( - cls.load_dictionary, args, **kwargs + cls.load_dictionary, args, **kwargs ) return cls(args, langs, dicts, training) @@ -263,11 +265,11 @@ def max_positions(self): @property def source_dictionary(self): - return self.dicts[self.source_langs[0]] + return self.data_manager.get_source_dictionary(self.source_langs[0]) @property def target_dictionary(self): - return self.dicts[self.target_langs[0]] + return self.data_manager.get_target_dictionary(self.target_langs[0]) def create_batch_sampler_func( self, diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 4e605bd0b1..ddfc1c4db5 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -538,6 +538,62 @@ def test_translation_multi_simple_epoch_dicts(self): + dec_ltok_flag, ) + def test_translation_multi_simple_epoch_src_tgt_dict_spec(self): + # test the specification of explicit --src-dict and --tgt-dict + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data( + data_dir, extra_flags=[] + ) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--source-dict", f"{data_dir}/dict.in.txt", + "--target-dict", f"{data_dir}/dict.out.txt", + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + "--virtual-epoch-size", + "1000", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, + ) + def test_transformer_cross_self_attention(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( From 4a6f89d373dafc50b416092b41c070c304b31698 Mon Sep 17 00:00:00 2001 From: Frank Seide <seide@fb.com> Date: Wed, 13 Jan 2021 00:02:05 -0800 Subject: [PATCH 398/707] Make Fairseq trainer multiply_grads resilient to sample_size 0 Summary: The Fairseq `Trainer` class does not always guard its `multiply_grads` step to the special case of `sample_size` of 0, which may happen in edge cases. This diff now guards in all conditions. Reviewed By: myleott Differential Revision: D25814612 fbshipit-source-id: 4974ee0148ab2a86f60980f3bf248878b2ebbb36 --- fairseq/trainer.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a6c1013635..fec60f7742 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -508,7 +508,7 @@ def train_step(self, samples, raise_oom=False): # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 - for i, sample in enumerate(samples): + for i, sample in enumerate(samples): # delayed update loop sample, is_dummy_batch = self._prepare_sample(sample) def maybe_no_sync(): @@ -605,21 +605,29 @@ def maybe_no_sync(): overflow = False try: with torch.autograd.profiler.record_function("reduce-grads"): + # reduce gradients across workers self.optimizer.all_reduce_grads(self.model) if utils.has_parameters(self.criterion): self.optimizer.all_reduce_grads(self.criterion) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (data_parallel_size / sample_size) since - # DDP already normalizes by the number of data parallel workers. + # DDP normalizes by the number of data parallel workers for + # improved fp16 precision. # Thus we get (sum_of_gradients / sample_size) at the end. - if not self.cfg.optimization.use_bmuf: - self.optimizer.multiply_grads( - self.data_parallel_world_size / sample_size - ) - elif sample_size > 0: # BMUF needs to check sample size - num = self.data_parallel_world_size if self._sync_stats() else 1 - self.optimizer.multiply_grads(num / sample_size) + # In case of fp16, this step also undoes loss scaling. + # (Debugging note: Some optimizers perform this scaling on the + # fly, so inspecting model.parameters() or optimizer.params may + # still show the original, unscaled gradients.) + numer = ( + self.data_parallel_world_size + if not self.cfg.optimization.use_bmuf or self._sync_stats() + else 1 + ) + self.optimizer.multiply_grads(numer / (sample_size or 1.0)) + # Note: (sample_size or 1.0) handles the case of a zero gradient, in a + # way that avoids CPU/device transfers in case sample_size is a GPU or + # TPU object. The assumption is that the gradient itself is also 0. with torch.autograd.profiler.record_function("clip-grads"): # clip grads @@ -661,7 +669,7 @@ def maybe_no_sync(): raise except OverflowError as e: overflow = True - logger.info("NOTE: overflow detected, " + str(e)) + logger.info(f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}") grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: From cb84694c195afced474d17318b5e746d1a9d20a3 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sat, 16 Jan 2021 19:08:50 -0800 Subject: [PATCH 399/707] fixes regression in lm decoding with flashlight (#1557) Summary: before: ``` PYTHONPATH=. python examples/speech_recognition/infer.py /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw --task audio_pretraining --nbest 1 --path /private/home/abaevski/models/wav2vec2/960h_scratch.pt --gen-subset test_clean --w2l-decoder kenlm --lm-model /checkpoint/abaevski/data/speech/libri/4-gram.bin --lexicon /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst --lm-weight 2.601664188829183 --word-score -1.4825337752451184 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 --remove-bpe letter INFO:__main__:Namespace(no_progress_bar=False, log_interval=100, log_format=None, tensorboard_logdir=None, wandb_project=None, azureml_logging=False, seed=1, cpu=False, tpu=False, bf16=False, memory_efficient_bf16=False, fp16=False, memory_efficient_fp16=False, fp16_no_flatten_grads=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, threshold_loss_scale=None, user_dir=None, empty_cache_freq=0, all_gather_list_size=16384, model_parallel_size=1, quantization_config_path=None, profile=False, reset_logging=True, suppress_crashes=False, criterion='ctc', tokenizer=None, bpe=None, optimizer=None, lr_scheduler='fixed', scoring='bleu', task='audio_pretraining', num_workers=1, skip_invalid_size_inputs_valid_test=False, max_tokens=4000000, batch_size=None, required_batch_size_multiple=8, required_seq_len_multiple=1, dataset_impl=None, data_buffer_size=10, train_subset='train', valid_subset='valid', validate_interval=1, validate_interval_updates=0, validate_after_updates=0, fixed_validation_seed=None, disable_validation=False, max_tokens_valid=4000000, batch_size_valid=None, curriculum=0, gen_subset='test_clean', num_shards=1, shard_id=0, distributed_world_size=1, distributed_rank=0, distributed_backend='nccl', distributed_init_method=None, distributed_port=-1, device_id=0, distributed_no_spawn=False, ddp_backend='c10d', bucket_cap_mb=25, fix_batches_to_gpus=False, find_unused_parameters=False, fast_stat_sync=False, heartbeat_timeout=-1, broadcast_buffers=False, distributed_wrapper='DDP', slowmo_momentum=None, slowmo_algorithm='LocalSGD', localsgd_frequency=3, nprocs_per_node=2, pipeline_model_parallel=False, pipeline_balance=None, pipeline_devices=None, pipeline_chunks=0, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_checkpoint='never', zero_sharding='none', path='/private/home/abaevski/models/wav2vec2/960h_scratch.pt', post_process='letter', quiet=False, model_overrides='{}', results_path=None, beam=5, nbest=1, max_len_a=0, max_len_b=200, min_len=1, match_source_len=False, unnormalized=False, no_early_stop=False, no_beamable_mm=False, lenpen=1, unkpen=0, replace_unk=None, sacrebleu=False, score_reference=False, prefix_size=0, no_repeat_ngram_size=0, sampling=False, sampling_topk=-1, sampling_topp=-1.0, constraints=None, temperature=1.0, diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, print_alignment=None, print_step=False, lm_path=None, lm_weight=2.601664188829183, iter_decode_eos_penalty=0.0, iter_decode_max_iter=10, iter_decode_force_max_iter=False, iter_decode_with_beam=1, iter_decode_with_external_reranker=False, retain_iter_history=False, retain_dropout=False, retain_dropout_modules=None, decoding_format=None, no_seed_provided=False, save_dir='checkpoints', restore_file='checkpoint_last.pt', finetune_from_model=None, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, optimizer_overrides='{}', save_interval=1, save_interval_updates=0, keep_interval_updates=-1, keep_last_epochs=-1, keep_best_checkpoints=-1, no_save=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_save_optimizer_state=False, best_checkpoint_metric='loss', maximize_best_checkpoint_metric=False, patience=-1, checkpoint_suffix='', checkpoint_shard_count=1, load_checkpoint_on_all_dp_ranks=False, kspmodel=None, wfstlm=None, rnnt_decoding_type='greedy', rnnt_len_penalty=-0.5, w2l_decoder='kenlm', lexicon='/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst', unit_lm=False, kenlm_model='/checkpoint/abaevski/data/speech/libri/4-gram.bin', beam_threshold=25.0, beam_size_token=100, word_score=-1.4825337752451184, unk_weight=-inf, sil_weight=0.0, dump_emissions=None, dump_features=None, load_emissions=None, data='/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw', labels='ltr', sample_rate=16000, normalize=False, enable_padding=False, max_sample_size=None, min_sample_size=None, eval_wer=False, eval_wer_tokenizer=None, eval_wer_post_process='letter', autoregressive=False, zero_infinity=False, wer_kenlm_model=None, wer_lexicon=None, wer_lm_weight=2.0, wer_word_score=-1.0, wer_args=None, force_anneal=None, lr_shrink=0.1, warmup_updates=0, pad=1, eos=2, unk=3) INFO:__main__:| decoding with criterion ctc INFO:__main__:| loading model(s) from /private/home/abaevski/models/wav2vec2/960h_scratch.pt INFO:fairseq.data.audio.raw_audio_dataset:loaded 2620, skipped 0 samples INFO:__main__:| /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw test_clean 2620 examples INFO:__main__:WER: 10.415398660986002 INFO:__main__:| Processed 2620 sentences (291252 tokens) in 130.4s (20.09sentences/s, 2233.70 tokens/s) INFO:__main__:| Generate test_clean with beam=5 ``` after: ``` PYTHONPATH=. python examples/speech_recognition/infer.py /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw --task audio_pretraining --nbest 1 --path /private/home/abaevski/models/wav2vec2/960h_scratch.pt --gen-subset test_clean --w2l-decoder kenlm --lm-model /checkpoint/abaevski/data/speech/libri/4-gram.bin --lexicon /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst --lm-weight 2.601664188829183 --word-score -1.4825337752451184 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 5000000 --remove-bpe letter INFO:__main__:Namespace(no_progress_bar=False, log_interval=100, log_format=None, tensorboard_logdir=None, wandb_project=None, azureml_logging=False, seed=1, cpu=False, tpu=False, bf16=False, memory_efficient_bf16=False, fp16=False, memory_efficient_fp16=False, fp16_no_flatten_grads=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, threshold_loss_scale=None, user_dir=None, empty_cache_freq=0, all_gather_list_size=16384, model_parallel_size=1, quantization_config_path=None, profile=False, reset_logging=True, suppress_crashes=False, criterion='ctc', tokenizer=None, bpe=None, optimizer=None, lr_scheduler='fixed', scoring='bleu', task='audio_pretraining', num_workers=1, skip_invalid_size_inputs_valid_test=False, max_tokens=5000000, batch_size=None, required_batch_size_multiple=8, required_seq_len_multiple=1, dataset_impl=None, data_buffer_size=10, train_subset='train', valid_subset='valid', validate_interval=1, validate_interval_updates=0, validate_after_updates=0, fixed_validation_seed=None, disable_validation=False, max_tokens_valid=5000000, batch_size_valid=None, curriculum=0, gen_subset='test_clean', num_shards=1, shard_id=0, distributed_world_size=1, distributed_rank=0, distributed_backend='nccl', distributed_init_method=None, distributed_port=-1, device_id=0, distributed_no_spawn=False, ddp_backend='c10d', bucket_cap_mb=25, fix_batches_to_gpus=False, find_unused_parameters=False, fast_stat_sync=False, heartbeat_timeout=-1, broadcast_buffers=False, distributed_wrapper='DDP', slowmo_momentum=None, slowmo_algorithm='LocalSGD', localsgd_frequency=3, nprocs_per_node=2, pipeline_model_parallel=False, pipeline_balance=None, pipeline_devices=None, pipeline_chunks=0, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_checkpoint='never', zero_sharding='none', path='/private/home/abaevski/models/wav2vec2/960h_scratch.pt', post_process='letter', quiet=False, model_overrides='{}', results_path=None, beam=5, nbest=1, max_len_a=0, max_len_b=200, min_len=1, match_source_len=False, unnormalized=False, no_early_stop=False, no_beamable_mm=False, lenpen=1, unkpen=0, replace_unk=None, sacrebleu=False, score_reference=False, prefix_size=0, no_repeat_ngram_size=0, sampling=False, sampling_topk=-1, sampling_topp=-1.0, constraints=None, temperature=1.0, diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, print_alignment=None, print_step=False, lm_path=None, lm_weight=2.601664188829183, iter_decode_eos_penalty=0.0, iter_decode_max_iter=10, iter_decode_force_max_iter=False, iter_decode_with_beam=1, iter_decode_with_external_reranker=False, retain_iter_history=False, retain_dropout=False, retain_dropout_modules=None, decoding_format=None, no_seed_provided=False, save_dir='checkpoints', restore_file='checkpoint_last.pt', finetune_from_model=None, reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, optimizer_overrides='{}', save_interval=1, save_interval_updates=0, keep_interval_updates=-1, keep_last_epochs=-1, keep_best_checkpoints=-1, no_save=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_save_optimizer_state=False, best_checkpoint_metric='loss', maximize_best_checkpoint_metric=False, patience=-1, checkpoint_suffix='', checkpoint_shard_count=1, load_checkpoint_on_all_dp_ranks=False, kspmodel=None, wfstlm=None, rnnt_decoding_type='greedy', rnnt_len_penalty=-0.5, w2l_decoder='kenlm', lexicon='/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst', unit_lm=False, kenlm_model='/checkpoint/abaevski/data/speech/libri/4-gram.bin', beam_threshold=25.0, beam_size_token=100, word_score=-1.4825337752451184, unk_weight=-inf, sil_weight=0.0, dump_emissions=None, dump_features=None, load_emissions=None, data='/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw', labels='ltr', sample_rate=16000, normalize=False, enable_padding=False, max_sample_size=None, min_sample_size=None, eval_wer=False, eval_wer_tokenizer=None, eval_wer_post_process='letter', autoregressive=False, zero_infinity=False, wer_kenlm_model=None, wer_lexicon=None, wer_lm_weight=2.0, wer_word_score=-1.0, wer_args=None, force_anneal=None, lr_shrink=0.1, warmup_updates=0, pad=1, eos=2, unk=3) INFO:__main__:| decoding with criterion ctc INFO:__main__:| loading model(s) from /private/home/abaevski/models/wav2vec2/960h_scratch.pt INFO:fairseq.data.audio.raw_audio_dataset:loaded 2620, skipped 0 samples INFO:__main__:| /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw test_clean 2620 examples INFO:__main__:WER: 2.991859403530128 INFO:__main__:| Processed 2620 sentences (288370 tokens) in 129.8s (20.18sentences/s, 2220.83 tokens/s) INFO:__main__:| Generate test_clean with beam=5 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1557 Reviewed By: wnhsu Differential Revision: D25935711 Pulled By: alexeib fbshipit-source-id: 36c1c9b9ba32a60b2c04275036514646d2fb33f5 --- examples/speech_recognition/w2l_decoder.py | 173 +++++++++++++-------- fairseq/models/wav2vec/wav2vec2_asr.py | 10 ++ 2 files changed, 116 insertions(+), 67 deletions(-) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 1fb20757d0..706d9f1433 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -59,10 +59,17 @@ def __init__(self, args, tgt_dict): if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos() ) + if "<sep>" in tgt_dict.indices: + self.silence = tgt_dict.index("<sep>") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") + else: + self.silence = tgt_dict.eos() self.asg_transitions = None elif args.criterion == "asg_loss": self.criterion_type = CriterionType.ASG self.blank = -1 + self.silence = -1 self.asg_transitions = args.asg_transitions self.max_replabel = args.max_replabel assert len(self.asg_transitions) == self.vocab_size ** 2 @@ -81,10 +88,13 @@ def generate(self, models, sample, **unused): def get_emissions(self, models, encoder_input): """Run encoder and normalize emissions""" - # encoder_out = models[0].encoder(**encoder_input) - encoder_out = models[0](**encoder_input) + model = models[0] + encoder_out = model(**encoder_input) if self.criterion_type == CriterionType.CTC: - emissions = models[0].get_normalized_probs(encoder_out, log_probs=True) + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) elif self.criterion_type == CriterionType.ASG: emissions = encoder_out["encoder_out"] return emissions.transpose(0, 1).float().cpu().contiguous() @@ -132,58 +142,75 @@ class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) - self.silence = ( - tgt_dict.index("<ctc_blank>") - if "<ctc_blank>" in tgt_dict.indices - else tgt_dict.bos() - ) - self.lexicon = load_words(args.lexicon) - self.word_dict = create_word_dict(self.lexicon) - self.unk_word = self.word_dict.get_index("<unk>") + self.unit_lm = getattr(args, "unit_lm", False) - self.lm = KenLM(args.kenlm_model, self.word_dict) - self.trie = Trie(self.vocab_size, self.silence) + if args.lexicon: + self.lexicon = load_words(args.lexicon) + self.word_dict = create_word_dict(self.lexicon) + self.unk_word = self.word_dict.get_index("<unk>") - start_state = self.lm.start(False) - for i, (word, spellings) in enumerate(self.lexicon.items()): - word_idx = self.word_dict.get_index(word) - _, score = self.lm.score(start_state, word_idx) - for spelling in spellings: - spelling_idxs = [tgt_dict.index(token) for token in spelling] - assert ( - tgt_dict.unk() not in spelling_idxs - ), f"{spelling} {spelling_idxs}" - self.trie.insert(spelling_idxs, word_idx, score) - self.trie.smear(SmearingMode.MAX) - - self.decoder_opts = LexiconDecoderOptions( - beam_size=args.beam, - beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), - beam_threshold=args.beam_threshold, - lm_weight=args.lm_weight, - word_score=args.word_score, - unk_score=args.unk_weight, - sil_score=args.sil_weight, - log_add=False, - criterion_type=self.criterion_type, - ) + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.trie = Trie(self.vocab_size, self.silence) + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + word_idx = self.word_dict.get_index(word) + _, score = self.lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + + if self.asg_transitions is None: + N = 768 + # self.asg_transitions = torch.FloatTensor(N, N).zero_() + self.asg_transitions = [] + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + self.asg_transitions, + self.unit_lm, + ) + else: + assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" + from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) - if self.asg_transitions is None: - N = 768 - # self.asg_transitions = torch.FloatTensor(N, N).zero_() - self.asg_transitions = [] - - self.decoder = LexiconDecoder( - self.decoder_opts, - self.trie, - self.lm, - self.silence, - self.blank, - self.unk_word, - self.asg_transitions, - False, - ) def decode(self, emissions): B, T, N = emissions.size() @@ -341,8 +368,6 @@ class W2lFairseqLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) - self.silence = tgt_dict.bos() - self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None @@ -368,18 +393,6 @@ def __init__(self, args, tgt_dict): self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) - self.decoder_opts = LexiconDecoderOptions( - beam_size=args.beam, - beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), - beam_threshold=args.beam_threshold, - lm_weight=args.lm_weight, - word_score=args.word_score, - unk_score=args.unk_weight, - sil_score=args.sil_weight, - log_add=False, - criterion_type=self.criterion_type, - ) - if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): @@ -399,6 +412,18 @@ def __init__(self, args, tgt_dict): self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) + self.decoder_opts = LexiconDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + word_score=args.word_score, + unk_score=args.unk_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconDecoder( self.decoder_opts, self.trie, @@ -406,11 +431,25 @@ def __init__(self, args, tgt_dict): self.silence, self.blank, self.unk_word, - [], + self.asg_transitions, self.unit_lm, ) else: - from flashlight.lib.text.decoder import LexiconFreeDecoder + assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" + from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=args.beam, + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), + beam_threshold=args.beam_threshold, + lm_weight=args.lm_weight, + sil_score=args.sil_weight, + log_add=False, + criterion_type=self.criterion_type, + ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 790b0a8ad1..bbd2ab9ec5 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -156,6 +156,16 @@ def get_normalized_probs(self, net_output, log_probs): else: return utils.softmax(logits.float(), dim=-1) + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][...,0] = 0 + logits[padding][...,1:] = float('-inf') + + return logits + def forward(self, **kwargs): x = self.w2v_encoder(**kwargs) return x From d927d69beefd695738eff0f5c9d9b0d6dc6abcb4 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sun, 17 Jan 2021 10:01:00 -0800 Subject: [PATCH 400/707] allows overwriting nested properties with model-overrrides (#1558) Summary: without this, it is not possible to use --model-overrides to override properties nested within a config Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1558 Reviewed By: myleott Differential Revision: D25935778 Pulled By: alexeib fbshipit-source-id: 1466f04b8e67842b91299ec88f1370ca0200c6a0 --- fairseq/dataclass/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 4dc978409e..401c212ecc 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -428,7 +428,14 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): for k in cfg.keys(): # "k in cfg" will return false if its a "mandatory value (e.g. ???)" if k in cfg and isinstance(cfg[k], DictConfig): - overwrite_args_by_name(cfg[k], overrides) + if k in overrides and isinstance(overrides[k], dict): + for ok, ov in overrides[k].items(): + if isinstance(ov, dict): + overwrite_args_by_name(cfg[k][ok], ov) + else: + cfg[k][ok] = ov + else: + overwrite_args_by_name(cfg[k], overrides) elif k in cfg and isinstance(cfg[k], Namespace): for override_key, val in overrides.items(): setattr(cfg[k], override_key, val) From 9f5eda48edfad4cb33610f272cb503dedf60ab67 Mon Sep 17 00:00:00 2001 From: Lior Deutsch <sliorde@gmail.com> Date: Sun, 17 Jan 2021 10:03:32 -0800 Subject: [PATCH 401/707] fixed dynamic convolution wrapper function unused arguments (#3136) Summary: The bug that this pull request addresses was discussed [in this GitHub issue](https://github.com/pytorch/fairseq/issues/3085#issue-777177450), and myleott has [asked for the pull request](https://github.com/pytorch/fairseq/issues/3085#issuecomment-754854074). As you can see from the diff, the pull request is very simple. However, I did not run any tests (I don't have a suitable environment, I think). Also: I did not add any tests, I did not change any documentation, I did not check linting. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3136 Reviewed By: myleott Differential Revision: D25935770 Pulled By: alexeib fbshipit-source-id: b338e7cfb409fd14dac653121256276550f53b21 --- fairseq/modules/dynamic_convolution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py index 5999a04539..9f2d28da65 100644 --- a/fairseq/modules/dynamic_convolution.py +++ b/fairseq/modules/dynamic_convolution.py @@ -37,7 +37,10 @@ def DynamicConv( num_heads=num_heads, weight_dropout=weight_dropout, weight_softmax=weight_softmax, + renorm_padding=renorm_padding, bias=bias, + conv_bias=conv_bias, + query_size=query_size, ) except ImportError as e: print(e) @@ -48,7 +51,10 @@ def DynamicConv( num_heads=num_heads, weight_dropout=weight_dropout, weight_softmax=weight_softmax, + renorm_padding=renorm_padding, bias=bias, + conv_bias=conv_bias, + query_size=query_size, ) From 1164a7fc432a188d401895018eaa85175fb06f9d Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Sun, 17 Jan 2021 23:25:04 -0800 Subject: [PATCH 402/707] Fix time warping for SpecAugment Summary: Fix time warping for SpecAugment Github issue: https://github.com/pytorch/fairseq/issues/3141 Reviewed By: jmp84 Differential Revision: D25941870 fbshipit-source-id: 97f9c67a49212556156b33aee0056a86ec990db4 --- fairseq/data/audio/feature_transforms/specaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py index 2ef4778b85..ce5802b41a 100644 --- a/fairseq/data/audio/feature_transforms/specaugment.py +++ b/fairseq/data/audio/feature_transforms/specaugment.py @@ -98,7 +98,7 @@ def __call__(self, spectrogram): import cv2 w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) - w = np.random.randint(0, self.time_warp_w) + w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w) upper, lower = distorted[:w0, :], distorted[w0:, :] upper = cv2.resize( upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR From ecf0b60e124e1e795e30004ced00883bf8ba5192 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Tue, 19 Jan 2021 12:34:09 -0800 Subject: [PATCH 403/707] add defaults to configs (#1564) Summary: previously, when using the hydra_train entry point, the config object that got created would only contain things explicitly specified in config files/command line. normally this is not a problem as we load defaults when creating any object for a particular config, but in fact this config was getting stored in checkpoints. the checkpoints would then have incomplete config that would be incorrect if defaults got changed in the code. this PR adds defaults into configs for all config objects that hydra doesn't know about before: ``` {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 200, 'log_format': 'json', 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': None, 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': True, 'suppress_crashes': False}, 'common_eval': {'_name': None, 'path': None, 'post_process': None, 'quiet': False, 'model_overrides': '{}', 'results_path': None}, 'distributed_training': {'_name': None, 'distributed_world_size': 2, 'distributed_rank': 0, 'distributed_backend': 'nccl', 'distributed_init_method': None, 'distributed_port': -1, 'device_id': 0, 'distributed_no_spawn': False, 'ddp_backend': 'no_c10d', 'bucket_cap_mb': 25, 'fix_batches_to_gpus': False, 'find_unused_parameters': False, 'fast_stat_sync': False, 'heartbeat_timeout': -1, 'broadcast_buffers': False, 'distributed_wrapper': 'DDP', 'slowmo_momentum': None, 'slowmo_algorithm': 'LocalSGD', 'localsgd_frequency': 3, 'nprocs_per_node': 2, 'pipeline_model_parallel': False, 'pipeline_balance': None, 'pipeline_devices': None, 'pipeline_chunks': 0, 'pipeline_encoder_balance': None, 'pipeline_encoder_devices': None, 'pipeline_decoder_balance': None, 'pipeline_decoder_devices': None, 'pipeline_checkpoint': 'never', 'zero_sharding': 'none', 'tpu': False}, 'dataset': {'_name': None, 'num_workers': 6, 'skip_invalid_size_inputs_valid_test': True, 'max_tokens': 3200000, 'batch_size': None, 'required_batch_size_multiple': 8, 'required_seq_len_multiple': 1, 'dataset_impl': None, 'data_buffer_size': 10, 'train_subset': 'train', 'valid_subset': 'dev_other', 'validate_interval': 50, 'validate_interval_updates': 0, 'validate_after_updates': 10000, 'fixed_validation_seed': None, 'disable_validation': False, 'max_tokens_valid': 3200000, 'batch_size_valid': None, 'curriculum': 0, 'gen_subset': 'test', 'num_shards': 1, 'shard_id': 0}, 'optimization': {'_name': None, 'max_epoch': 0, 'max_update': 20000, 'stop_time_hours': 0.0, 'clip_norm': 0.0, 'sentence_avg': True, 'update_freq': [4], 'lr': [5e-05], 'stop_min_lr': -1.0, 'use_bmuf': False}, 'checkpoint': {'_name': None, 'save_dir': 'checkpoints', 'restore_file': 'checkpoint_last.pt', 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset_meters': False, 'reset_optimizer': False, 'optimizer_overrides': '{}', 'save_interval': 50, 'save_interval_updates': 10000, 'keep_interval_updates': 1, 'keep_last_epochs': -1, 'keep_best_checkpoints': -1, 'no_save': False, 'no_epoch_checkpoints': True, 'no_last_checkpoints': False, 'no_save_optimizer_state': False, 'best_checkpoint_metric': 'wer', 'maximize_best_checkpoint_metric': False, 'patience': -1, 'checkpoint_suffix': '', 'checkpoint_shard_count': 1, 'load_checkpoint_on_all_dp_ranks': False, 'model_parallel_size': 1, 'distributed_rank': 0}, 'bmuf': {'_name': None, 'block_lr': 1.0, 'block_momentum': 0.875, 'global_sync_iter': 50, 'warmup_iterations': 500, 'use_nbm': False, 'average_sync': False, 'distributed_world_size': 2}, 'generation': {'_name': None, 'beam': 5, 'nbest': 1, 'max_len_a': 0.0, 'max_len_b': 200, 'min_len': 1, 'match_source_len': False, 'unnormalized': False, 'no_early_stop': False, 'no_beamable_mm': False, 'lenpen': 1.0, 'unkpen': 0.0, 'replace_unk': None, 'sacrebleu': False, 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, 'constraints': None, 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, 'diversity_rate': -1.0, 'print_alignment': None, 'print_step': False, 'lm_path': None, 'lm_weight': 0.0, 'iter_decode_eos_penalty': 0.0, 'iter_decode_max_iter': 10, 'iter_decode_force_max_iter': False, 'iter_decode_with_beam': 1, 'iter_decode_with_external_reranker': False, 'retain_iter_history': False, 'retain_dropout': False, 'retain_dropout_modules': None, 'decoding_format': None, 'no_seed_provided': False}, 'eval_lm': {'_name': None, 'output_word_probs': False, 'output_word_stats': False, 'context_window': 0, 'softmax_batch': 9223372036854775807}, 'interactive': {'_name': None, 'buffer_size': 0, 'input': '-'}, 'model': {'_name': 'wav2vec_ctc', 'w2v_path': '/private/home/abaevski/models/wav2vec2/wav2vec_small.pt', 'apply_mask': True, 'mask_prob': 0.65, 'mask_channel_prob': 0.5, 'mask_channel_length': 64, 'layerdrop': 0.05, 'activation_dropout': 0.1, 'feature_grad_mult': 0.0, 'freeze_finetune_updates': 10000}, 'task': {'_name': 'audio_pretraining', 'data': '/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw', 'normalize': False, 'labels': 'ltr'}, 'criterion': {'_name': 'ctc', 'zero_infinity': True}, 'optimizer': {'_name': 'adam', 'adam_betas': '(0.9,0.98)', 'adam_eps': 1e-08}, 'lr_scheduler': {'_name': 'tri_stage', 'phase_ratio': [0.1, 0.4, 0.5], 'final_lr_scale': 0.05}, 'scoring': None, 'bpe': None, 'tokenizer': None} ``` after: ``` {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 200, 'log_format': 'json', 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': None, 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': True, 'suppress_crashes': False}, 'common_eval': {'_name': None, 'path': None, 'post_process': None, 'quiet': False, 'model_overrides': '{}', 'results_path': None}, 'distributed_training': {'_name': None, 'distributed_world_size': 2, 'distributed_rank': 0, 'distributed_backend': 'nccl', 'distributed_init_method': 'tcp://localhost:16054', 'distributed_port': -1, 'device_id': 0, 'distributed_no_spawn': False, 'ddp_backend': 'no_c10d', 'bucket_cap_mb': 25, 'fix_batches_to_gpus': False, 'find_unused_parameters': False, 'fast_stat_sync': False, 'heartbeat_timeout': -1, 'broadcast_buffers': False, 'distributed_wrapper': 'DDP', 'slowmo_momentum': None, 'slowmo_algorithm': 'LocalSGD', 'localsgd_frequency': 3, 'nprocs_per_node': 2, 'pipeline_model_parallel': False, 'pipeline_balance': None, 'pipeline_devices': None, 'pipeline_chunks': 0, 'pipeline_encoder_balance': None, 'pipeline_encoder_devices': None, 'pipeline_decoder_balance': None, 'pipeline_decoder_devices': None, 'pipeline_checkpoint': 'never', 'zero_sharding': 'none', 'tpu': False, 'distributed_num_procs': 2}, 'dataset': {'_name': None, 'num_workers': 6, 'skip_invalid_size_inputs_valid_test': True, 'max_tokens': 3200000, 'batch_size': None, 'required_batch_size_multiple': 8, 'required_seq_len_multiple': 1, 'dataset_impl': None, 'data_buffer_size': 10, 'train_subset': 'train', 'valid_subset': 'dev_other', 'validate_interval': 50, 'validate_interval_updates': 0, 'validate_after_updates': 0, 'fixed_validation_seed': None, 'disable_validation': False, 'max_tokens_valid': 3200000, 'batch_size_valid': None, 'curriculum': 0, 'gen_subset': 'test', 'num_shards': 1, 'shard_id': 0}, 'optimization': {'_name': None, 'max_epoch': 0, 'max_update': 20000, 'stop_time_hours': 0.0, 'clip_norm': 0.0, 'sentence_avg': True, 'update_freq': [4], 'lr': [5e-05], 'stop_min_lr': -1.0, 'use_bmuf': False}, 'checkpoint': {'_name': None, 'save_dir': 'checkpoints', 'restore_file': 'checkpoint_last.pt', 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset_meters': False, 'reset_optimizer': False, 'optimizer_overrides': '{}', 'save_interval': 50, 'save_interval_updates': 20, 'keep_interval_updates': 1, 'keep_last_epochs': -1, 'keep_best_checkpoints': -1, 'no_save': False, 'no_epoch_checkpoints': True, 'no_last_checkpoints': False, 'no_save_optimizer_state': False, 'best_checkpoint_metric': 'wer', 'maximize_best_checkpoint_metric': False, 'patience': -1, 'checkpoint_suffix': '', 'checkpoint_shard_count': 1, 'load_checkpoint_on_all_dp_ranks': False, 'model_parallel_size': 1, 'distributed_rank': 0}, 'bmuf': {'_name': None, 'block_lr': 1.0, 'block_momentum': 0.875, 'global_sync_iter': 50, 'warmup_iterations': 500, 'use_nbm': False, 'average_sync': False, 'distributed_world_size': 2}, 'generation': {'_name': None, 'beam': 5, 'nbest': 1, 'max_len_a': 0.0, 'max_len_b': 200, 'min_len': 1, 'match_source_len': False, 'unnormalized': False, 'no_early_stop': False, 'no_beamable_mm': False, 'lenpen': 1.0, 'unkpen': 0.0, 'replace_unk': None, 'sacrebleu': False, 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, 'constraints': None, 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, 'diversity_rate': -1.0, 'print_alignment': None, 'print_step': False, 'lm_path': None, 'lm_weight': 0.0, 'iter_decode_eos_penalty': 0.0, 'iter_decode_max_iter': 10, 'iter_decode_force_max_iter': False, 'iter_decode_with_beam': 1, 'iter_decode_with_external_reranker': False, 'retain_iter_history': False, 'retain_dropout': False, 'retain_dropout_modules': None, 'decoding_format': None, 'no_seed_provided': False}, 'eval_lm': {'_name': None, 'output_word_probs': False, 'output_word_stats': False, 'context_window': 0, 'softmax_batch': 9223372036854775807}, 'interactive': {'_name': None, 'buffer_size': 0, 'input': '-'}, 'model': {'_name': 'wav2vec_ctc', 'w2v_path': '/private/home/abaevski/models/wav2vec2/wav2vec_small.pt', 'no_pretrained_weights': False, 'dropout_input': 0.0, 'final_dropout': 0.0, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.1, 'apply_mask': True, 'mask_length': 10, 'mask_prob': 0.65, 'mask_selection': 'static', 'mask_other': 0.0, 'no_mask_overlap': False, 'mask_channel_length': 64, 'mask_channel_prob': 0.5, 'mask_channel_selection': 'static', 'mask_channel_other': 0.0, 'no_mask_channel_overlap': False, 'freeze_finetune_updates': 0, 'feature_grad_mult': 0.0, 'layerdrop': 0.05, 'normalize': False, 'data': '/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw', 'w2v_args': None}, 'task': {'_name': 'audio_pretraining', 'data': '/checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw', 'labels': 'ltr', 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_sample_size': None, 'min_sample_size': None, 'eval_wer': False, 'eval_wer_config': {'_name': None, 'beam': 5, 'nbest': 1, 'max_len_a': 0.0, 'max_len_b': 200, 'min_len': 1, 'match_source_len': False, 'unnormalized': False, 'no_early_stop': False, 'no_beamable_mm': False, 'lenpen': 1.0, 'unkpen': 0.0, 'replace_unk': None, 'sacrebleu': False, 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, 'constraints': None, 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, 'diversity_rate': -1.0, 'print_alignment': None, 'print_step': False, 'lm_path': None, 'lm_weight': 0.0, 'iter_decode_eos_penalty': 0.0, 'iter_decode_max_iter': 10, 'iter_decode_force_max_iter': False, 'iter_decode_with_beam': 1, 'iter_decode_with_external_reranker': False, 'retain_iter_history': False, 'retain_dropout': False, 'retain_dropout_modules': None, 'decoding_format': None, 'no_seed_provided': False}, 'eval_wer_tokenizer': None, 'eval_wer_post_process': 'letter', 'autoregressive': False}, 'criterion': {'_name': 'ctc', 'zero_infinity': True, 'sentence_avg': True, 'post_process': 'letter', 'wer_kenlm_model': None, 'wer_lexicon': None, 'wer_lm_weight': 2.0, 'wer_word_score': -1.0, 'wer_args': None}, 'optimizer': {'_name': 'adam', 'adam_betas': '(0.9,0.98)', 'adam_eps': 1e-08, 'weight_decay': 0.0, 'use_old_adam': False, 'tpu': False, 'lr': [5e-05]}, 'lr_scheduler': {'_name': 'tri_stage', 'warmup_steps': 0, 'hold_steps': 0, 'decay_steps': 0, 'phase_ratio': [0.1, 0.4, 0.5], 'init_lr_scale': 0.01, 'final_lr_scale': 0.05, 'max_update': 20000, 'lr': [5e-05]}, 'scoring': None, 'bpe': None, 'tokenizer': None} ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1564 Reviewed By: myleott Differential Revision: D25938221 Pulled By: alexeib fbshipit-source-id: e088667accf974ad6d9898a63f7c33722837fcfb --- fairseq/dataclass/initialize.py | 34 ++++++++++++++++++++++++++++++++- fairseq_cli/hydra_train.py | 4 ++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 7a1ebeff1c..385624f19b 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -5,9 +5,9 @@ """isort:skip_file""" import logging -from typing import Dict, Any from hydra.core.config_store import ConfigStore from fairseq.dataclass.configs import FairseqConfig +from omegaconf import DictConfig, open_dict logger = logging.getLogger(__name__) @@ -25,3 +25,35 @@ def hydra_init(cfg_name="config") -> None: except BaseException: logger.error(f"{k} - {v}") raise + + +def add_defaults(cfg: DictConfig) -> None: + """This function adds default values that are stored in dataclasses that hydra doesn't know about """ + + from fairseq.registry import REGISTRIES + from fairseq.tasks import TASK_DATACLASS_REGISTRY + from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY + from fairseq.dataclass.utils import merge_with_parent + from typing import Any + + for k, v in FairseqConfig.__dataclass_fields__.items(): + field_cfg = cfg.get(k) + if field_cfg is not None and v.type == Any: + dc = None + + if isinstance(field_cfg, str): + field_cfg = DictConfig({"_name": field_cfg}) + field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] + + name = field_cfg.get("_name") + + if k == "task": + dc = TASK_DATACLASS_REGISTRY.get(name) + elif k == "model": + name = ARCH_MODEL_NAME_REGISTRY.get(name, name) + dc = MODEL_DATACLASS_REGISTRY.get(name) + elif k in REGISTRIES: + dc = REGISTRIES[k]["dataclass_registry"].get(name) + + if dc is not None: + cfg[k] = merge_with_parent(dc, field_cfg) diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index cf48337462..6754f9483d 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -8,7 +8,7 @@ import os import sys -from fairseq.dataclass.initialize import hydra_init +from fairseq.dataclass.initialize import add_defaults, hydra_init from fairseq_cli.train import main as pre_main from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig @@ -23,8 +23,8 @@ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: + add_defaults(cfg) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) - OmegaConf.set_struct(cfg, True) if cfg.common.reset_logging: From 1e21f7e462e53a3bf2f881632f15548ddaa43464 Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <wnhsu@csail.mit.edu> Date: Tue, 19 Jan 2021 19:10:05 -0800 Subject: [PATCH 404/707] track loaded lines in raw_audio_dataset and load corresponding labels in audio_pretraining (#1566) Summary: Bug: `AudioPretrainingTask` is not aware of what samples have been skipped by `FileAudioDataset`, and hence would load labels of utterances that were skipped, causing `AddTargetDataset` to misalign utterances with labels. This PR tracks line line indices loaded by the `FileAudioDataset` to filter labels correspondingly in `AudioPretrainingTask` Before: ``` INFO:__main__:| decoding with criterion ctc INFO:__main__:| loading model(s) from ... INFO:fairseq.data.audio.raw_audio_dataset:loaded 284, skipped 223 samples INFO:__main__:| /private/home/wnhsu/wav2vec2_robust/data/joint_swbd ted_dev 284 examples INFO:__main__:WER: 152.15605749486653 INFO:__main__:| Processed 284 sentences (72405 tokens) in 20.5s (13.88sentences/s, 3538.01 tokens/s) ``` After ``` INFO:__main__:| decoding with criterion ctc INFO:__main__:| loading model(s) from ... INFO:fairseq.data.audio.raw_audio_dataset:loaded 284, skipped 223 samples INFO:__main__:WER: 9.904153354632587 INFO:__main__:| Processed 284 sentences (72405 tokens) in 20.7s (13.70sentences/s, 3492.20 tokens/s) ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1566 Reviewed By: alexeib Differential Revision: D25963317 Pulled By: wnhsu fbshipit-source-id: c9748f5dad1ff787642ba0bc28698c4ecfbcd221 --- fairseq/data/audio/raw_audio_dataset.py | 4 +++- fairseq/tasks/audio_pretraining.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 8d6ce85ecc..ac5acd03bb 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -153,11 +153,12 @@ def __init__( ) self.fnames = [] + self.line_inds = set() skipped = 0 with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() - for line in f: + for i, line in enumerate(f): items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) @@ -165,6 +166,7 @@ def __init__( skipped += 1 continue self.fnames.append(items[0]) + self.line_inds.add(i) self.sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 6ea40a813f..7c82777331 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -148,8 +148,14 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") labels = [] with open(label_path, "r") as f: - for line in f: - labels.append(line) + labels = [ + line for i, line in enumerate(f) + if i in self.datasets[split].line_inds + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match") process_label = LabelEncoder(self.target_dictionary) From bf54551cafa13678c0254d2c20354cc026cc0bac Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 20 Jan 2021 05:49:22 -0800 Subject: [PATCH 405/707] Fix param sharing in Linformer (#1561) Summary: Parameter sharing (both `--untie-weights-roberta` and `--shared-layer-kv-compressed`) was broken by one of my earlier refactors (D22411012 (https://github.com/pytorch/fairseq/commit/d73e543e3853bb813d8f7955a06ce19359810707)). This fixes it. Note: it was correct in the original version of the code for the paper. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1561 Test Plan: - confirmed that training gives identical losses as before when not using any param sharing (including `--untie-weights-roberta`): ``` CUDA_VISIBLE_DEVICES=0 python train.py --task dummy_masked_lm --arch linformer_roberta_base --untie-weights-roberta --user-dir examples/linformer/linformer_src/ --criterion masked_lm --batch-size 8 --optimizer adam --lr 0.0001 --log-format json --log-interval 1 --max-update 5 --disable-validation --no-save before: 2021-01-19 06:37:21 | INFO | fairseq_cli.train | num. model params: 164,465,744 (num. trained: 164,465,744) (...) 2021-01-19 06:41:56 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "15.893", "ppl": "60870.7", "wps": "0", "ups": "0", "wpb": "4096", "bsz": "8", "num_updates": "1", "lr": "0.0001", "gnorm": "7.716", "train_wall": "1", "wall": "1"} 2021-01-19 06:41:56 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "13.176", "ppl": "9252.9", "wps": "11813.8", "ups": "2.88", "wpb": "4096", "bsz": "8", "num_updates": "2", "lr": "0.0001", "gnorm": "6.988", "train_wall": "0", "wall": "1"} 2021-01-19 06:41:57 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "11.049", "ppl": "2119.22", "wps": "12002.2", "ups": "2.93", "wpb": "4096", "bsz": "8", "num_updates": "3", "lr": "0.0001", "gnorm": "8.008", "train_wall": "0", "wall": "1"} 2021-01-19 06:41:57 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "9.044", "ppl": "527.7", "wps": "11894.2", "ups": "2.9", "wpb": "4096", "bsz": "8", "num_updates": "4", "lr": "0.0001", "gnorm": "7.893", "train_wall": "0", "wall": "2"} 2021-01-19 06:41:57 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "7.526", "ppl": "184.27", "wps": "11834.9", "ups": "2.89", "wpb": "4096", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "6.949", "train_wall": "0", "wall": "2"} after: 2021-01-19 06:39:20 | INFO | fairseq_cli.train | num. model params: 164,465,744 (num. trained: 164,465,744) (...) 2021-01-19 06:39:22 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "15.893", "ppl": "60870.7", "wps": "0", "ups": "0", "wpb": "4096", "bsz": "8", "num_updates": "1", "lr": "0.0001", "gnorm": "7.716", "train_wall": "1", "wall": "1"} 2021-01-19 06:39:23 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "13.176", "ppl": "9252.9", "wps": "12094.7", "ups": "2.95", "wpb": "4096", "bsz": "8", "num_updates": "2", "lr": "0.0001", "gnorm": "6.988", "train_wall": "0", "wall": "1"} 2021-01-19 06:39:23 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "11.049", "ppl": "2119.22", "wps": "12290", "ups": "3", "wpb": "4096", "bsz": "8", "num_updates": "3", "lr": "0.0001", "gnorm": "8.008", "train_wall": "0", "wall": "1"} 2021-01-19 06:39:23 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "9.044", "ppl": "527.7", "wps": "11990.4", "ups": "2.93", "wpb": "4096", "bsz": "8", "num_updates": "4", "lr": "0.0001", "gnorm": "7.893", "train_wall": "0", "wall": "2"} 2021-01-19 06:39:24 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "7.526", "ppl": "184.27", "wps": "12073.8", "ups": "2.95", "wpb": "4096", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "6.949", "train_wall": "0", "wall": "2"} ``` - with input embedding and output LM head param sharing, the `num. model params` now goes down (as expected), whereas before it stayed constant: ``` CUDA_VISIBLE_DEVICES=0 python train.py --task dummy_masked_lm --arch linformer_roberta_base --user-dir examples/linformer/linformer_src/ --criterion masked_lm --batch-size 8 --optimizer adam --lr 0.0001 --log-format json --log-interval 1 --max-update 5 --disable-validation --no-save before: 2021-01-19 06:44:58 | INFO | fairseq_cli.train | num. model params: 164,465,744 (num. trained: 164,465,744) (...) after: 2021-01-19 06:43:03 | INFO | fairseq_cli.train | num. model params: 126,065,744 (num. trained: 126,065,744) (...) ``` - confirmed that old checkpoints can be loaded and produce identical valid ppl: ``` python -m fairseq_cli.validate --path $MODEL --user-dir examples/linformer/linformer_src/ --task dummy_masked_lm --criterion masked_lm --max-sentences 8 --dataset-size 100 no sharing: before: 2021-01-19 07:07:54 | INFO | valid | | valid on 'valid' subset | loss 5.485 | ppl 44.8 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 07:30:10 | INFO | valid | | valid on 'valid' subset | loss 5.485 | ppl 44.8 | wps 0 | wpb 53248 | bsz 104 shared_kv_compressed: before: 2021-01-19 07:08:50 | INFO | valid | | valid on 'valid' subset | loss 5.355 | ppl 40.94 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 07:30:45 | INFO | valid | | valid on 'valid' subset | loss 5.355 | ppl 40.94 | wps 0 | wpb 53248 | bsz 104 shared_kv_compressed + shared_layer_kv_compressed: before: 2021-01-19 07:09:26 | INFO | valid | | valid on 'valid' subset | loss 5.482 | ppl 44.7 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 08:09:36 | INFO | valid | | valid on 'valid' subset | loss 5.482 | ppl 44.7 | wps 0 | wpb 53248 | bsz 104 using a really old checkpoint with sharing (trained on commit cf4219b048d31f55970356520860b2543ee97570): before: | valid on 'valid' subset | loss 5.548 | ppl 46.8 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 08:34:07 | INFO | valid | | valid on 'valid' subset | loss 5.548 | ppl 46.8 | wps 0 | wpb 53248 | bsz 104 ``` Reviewed By: madian9 Differential Revision: D25938236 Pulled By: myleott fbshipit-source-id: 4d515e5c8e0601476856ae27eb46c64c30033c88 --- .../linformer_src/models/linformer_roberta.py | 28 ++++++++++++- .../modules/linformer_sentence_encoder.py | 2 +- .../linformer_sentence_encoder_layer.py | 42 ++++++++++++++++--- fairseq/models/fairseq_decoder.py | 4 +- fairseq/models/fairseq_encoder.py | 4 +- fairseq/models/roberta/model.py | 28 ++++++++----- fairseq/modules/transformer_layer.py | 1 + .../transformer_sentence_encoder_layer.py | 5 +++ 8 files changed, 92 insertions(+), 22 deletions(-) diff --git a/examples/linformer/linformer_src/models/linformer_roberta.py b/examples/linformer/linformer_src/models/linformer_roberta.py index 913351f238..be5d8e85ec 100644 --- a/examples/linformer/linformer_src/models/linformer_roberta.py +++ b/examples/linformer/linformer_src/models/linformer_roberta.py @@ -8,6 +8,8 @@ import logging +import torch +from fairseq import utils from fairseq.models import register_model, register_model_architecture from fairseq.models.roberta import RobertaEncoder, RobertaModel @@ -62,8 +64,10 @@ class LinformerEncoder(RobertaEncoder): def __init__(self, args, dictionary): super().__init__(args, dictionary) + self.register_buffer("version", torch.tensor(2)) - self.sentence_encoder = LinformerSentenceEncoder( + def build_encoder(self, args, dictionary): + return LinformerSentenceEncoder( padding_idx=dictionary.pad(), vocab_size=len(dictionary), num_encoder_layers=args.encoder_layers, @@ -87,6 +91,27 @@ def __init__(self, args, dictionary): freeze_compress=args.freeze_compress, ) + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + prefix = name + "." if name != "" else "" + + # some old checkpoints had weight sharing implemented incorrectly + # (note: this was correct in the original paper code) + if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2: + state_dict[f"{prefix}version"] = torch.tensor(1) + # check if input embeddings and output embeddings were tied + if not torch.allclose( + state_dict[f"{prefix}sentence_encoder.embed_tokens.weight"], + state_dict[f"{prefix}lm_head.weight"], + ): + # they weren't tied, re-init the LM head without weight sharing + self.lm_head = self.build_lm_head( + embed_dim=self.args.encoder_embed_dim, + output_dim=len(self.dictionary), + activation_fn=self.args.activation_fn, + weight=None, # don't share weights + ) + @register_model_architecture("linformer_roberta", "linformer_roberta") def base_architecture(args): @@ -104,6 +129,7 @@ def base_architecture(args): args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.compressed = getattr(args, "compressed", 4) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py index d6de9eeaae..3cdca01235 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py @@ -116,7 +116,7 @@ def build_transformer_sentence_encoder_layer( q_noise, qn_block_size, ): - if self.shared_layer_kv_compressed == 1: + if self.shared_layer_kv_compressed == 1 and self.compress_layer is None: compress_layer = nn.Linear( self.max_seq_len, self.max_seq_len // self.compressed ) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py index d27c5afd09..0b80fabefe 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py @@ -5,6 +5,8 @@ from typing import Callable +import torch +from fairseq import utils from fairseq.modules import TransformerSentenceEncoderLayer from .multihead_linear_attention import MultiheadLinearAttention @@ -42,9 +44,8 @@ def __init__( self.shared_kv_compressed = shared_kv_compressed self.freeze_compress = freeze_compress - def init_fn(): - # This needs to be set after nn.Module.__init__ is called - self.shared_compress_layer = shared_compress_layer + # wrap in a list so it's not automatically registered by PyTorch + self.shared_compress_layer = [shared_compress_layer] super().__init__( embedding_dim=embedding_dim, @@ -57,8 +58,8 @@ def init_fn(): export=export, q_noise=q_noise, qn_block_size=qn_block_size, - init_fn=init_fn, ) + self.register_buffer("version", torch.tensor(2)) def build_self_attention( self, @@ -79,6 +80,37 @@ def build_self_attention( compressed=self.compressed, max_seq_len=self.max_seq_len, shared_kv_compressed=self.shared_kv_compressed, - shared_compress_layer=self.shared_compress_layer, + shared_compress_layer=self.shared_compress_layer[0], freeze_compress=self.freeze_compress, ) + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + + # some old checkpoints had weight sharing implemented incorrectly + # (note: this was correct in the original paper code) + if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2: + state_dict[f"{prefix}version"] = torch.tensor(1) + # check compression layer sharing + if f"{prefix}shared_compress_layer.weight" in state_dict: + # reinitialize block without sharing compression layer to match + # old behavior + self.shared_compress_layer = [ + torch.nn.Linear( + self.shared_compress_layer[0].weight.size(1), + self.shared_compress_layer[0].weight.size(0), + ) + ] + self.self_attn = self.build_self_attention( + self.embedding_dim, + self.num_attention_heads, + dropout=self.attention_dropout, + self_attention=True, + q_noise=self.q_noise, + qn_block_size=self.qn_block_size, + ) + # delete shared_compress_layer, since it's already copied to + # self_attn.compress_k.weight + del state_dict[f"{prefix}shared_compress_layer.weight"] + if f"{prefix}shared_compress_layer.bias" in state_dict: + del state_dict[f"{prefix}shared_compress_layer.bias"] diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index fb6c52dc7f..35a349fa5f 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -82,8 +82,8 @@ def max_positions(self): """Maximum input length supported by the decoder.""" return 1e6 # an arbitrary large number - def upgrade_state_dict(self, state_dict): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" return state_dict def prepare_for_onnx_export_(self): diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py index c8873daa28..08cbde15a4 100644 --- a/fairseq/models/fairseq_encoder.py +++ b/fairseq/models/fairseq_encoder.py @@ -78,8 +78,8 @@ def max_positions(self): """Maximum input length supported by the encoder.""" return 1e6 # an arbitrary large number - def upgrade_state_dict(self, state_dict): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" return state_dict def set_num_updates(self, num_updates): diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 96a7b9c8a2..00a5a5485f 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -401,7 +401,21 @@ def __init__(self, args, dictionary): if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - self.sentence_encoder = TransformerSentenceEncoder( + self.sentence_encoder = self.build_encoder(args, dictionary) + + self.lm_head = self.build_lm_head( + embed_dim=args.encoder_embed_dim, + output_dim=len(dictionary), + activation_fn=args.activation_fn, + weight=( + self.sentence_encoder.embed_tokens.weight + if not args.untie_weights_roberta + else None + ), + ) + + def build_encoder(self, args, dictionary): + return TransformerSentenceEncoder( padding_idx=dictionary.pad(), vocab_size=len(dictionary), num_encoder_layers=args.encoder_layers, @@ -421,16 +435,8 @@ def __init__(self, args, dictionary): qn_block_size=args.quant_noise_pq_block_size, ) - self.lm_head = RobertaLMHead( - embed_dim=args.encoder_embed_dim, - output_dim=len(dictionary), - activation_fn=args.activation_fn, - weight=( - self.sentence_encoder.embed_tokens.weight - if not args.untie_weights_roberta - else None - ), - ) + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) def forward( self, diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 6f3c79de7c..03e70f4279 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -31,6 +31,7 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, args): super().__init__() + self.args = args self.embed_dim = args.encoder_embed_dim self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py index 3589c60fe6..f869c4b2f8 100644 --- a/fairseq/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -40,6 +40,11 @@ def __init__( # Initialize parameters self.embedding_dim = embedding_dim + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + self.q_noise = q_noise + self.qn_block_size = qn_block_size + self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) From 338aa57966b11a31120e87840d6bb68e74257182 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 20 Jan 2021 07:37:50 -0800 Subject: [PATCH 406/707] Add test for activation checkpointing (#1563) Summary: Forgot to merge this with the original code Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1563 Reviewed By: sshleifer Differential Revision: D25948393 Pulled By: myleott fbshipit-source-id: b083001015e97f7e21cfa02d4126eba79cc34bfa --- tests/test_activation_checkpointing.py | 72 ++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_activation_checkpointing.py diff --git a/tests/test_activation_checkpointing.py b/tests/test_activation_checkpointing.py new file mode 100644 index 0000000000..4b86211bde --- /dev/null +++ b/tests/test_activation_checkpointing.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from torch.utils.checkpoint import checkpoint + + +class Model(nn.Module): + def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False): + super().__init__() + torch.manual_seed(0) + self.use_pytorch_checkpoint = use_pytorch_checkpoint + self.ffn = nn.Sequential( + nn.Linear(32, 128), + # add a Dropout layer to test RNG save/restore + nn.Dropout(p=0.5), + nn.Linear(128, 32), + ) + if use_fairseq_checkpoint: + self.ffn = checkpoint_wrapper(self.ffn) + self.out = nn.Linear(32, 1) + + def forward(self, x): + if self.use_pytorch_checkpoint: + x = checkpoint(self.ffn, x) + else: + x = self.ffn(x) + return self.out(x) + + +class TestActivationCheckpointing(unittest.TestCase): + def _test_checkpoint_wrapper(self, device, log_memory_usage=False): + def get_loss_and_gnorm(model): + torch.manual_seed(1) + input = torch.rand(2, 16, 32).requires_grad_(True).to(device) + model.zero_grad() + loss = model(input).sum() + loss.backward() + gnorm = torch.norm( + torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]) + ) + return {"loss": loss, "gnorm": gnorm} + + model = Model().to(device) + no_cpt = get_loss_and_gnorm(model) + + model = Model(use_pytorch_checkpoint=True).to(device) + pyt_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True).to(device) + fairseq_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) + + def test_checkpoint_wrapper_cpu(self): + self._test_checkpoint_wrapper(device=torch.device("cpu")) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_checkpoint_wrapper_cuda(self): + self._test_checkpoint_wrapper(device=torch.device("cuda")) + + +if __name__ == "__main__": + unittest.main() From 9fc53d62177b8465b1eca2dd00185540dcd2fb92 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 20 Jan 2021 10:46:20 -0800 Subject: [PATCH 407/707] Support --post-process="@@" (#1571) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1571 Differential Revision: D25974834 Pulled By: myleott fbshipit-source-id: 8cf4c4874087408f76d9d47a6b5bee46c52ac33b --- fairseq/data/data_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 9a0580977d..d98c58a2f4 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -363,8 +363,10 @@ def post_process(sentence: str, symbol: str): sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() - elif symbol in {"subword_nmt", "@@ "}: - sentence = (sentence + " ").replace("@@ ", "").rstrip() + elif symbol in {"subword_nmt", "@@ ", "@@"}: + if symbol == "subword_nmt": + symbol = "@@ " + sentence = (sentence + " ").replace(symbol, "").rstrip() elif symbol == "none": pass elif symbol is not None: From 15867e12841cf16b3e6a60d5efccc169f728b70a Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Wed, 20 Jan 2021 17:59:39 -0800 Subject: [PATCH 408/707] migrate translation task (#1569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1569 Test Plan: Imported from OSS tests + ran ``` python fairseq_cli/train.py \  18:08:56 ~/data/iwslt14.de-en \ --arch transformer_iwslt_de_en --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ --dropout 0.3 --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 4096 \ --eval-bleu \ --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ --eval-bleu-detok moses \ --eval-bleu-remove-bpe \ --eval-bleu-print-samples \ --best-checkpoint-metric bleu --maximize-best-checkpoint-metric ``` Reviewed By: myleott Differential Revision: D25967217 Pulled By: alexeib fbshipit-source-id: 808f3cb0939fa13e1e05f39bfa02a7fb0b152940 --- .../translation_moe_src/translation_moe.py | 107 +++++--- fairseq/tasks/translation.py | 253 ++++++++++-------- .../tasks/translation_from_pretrained_xlm.py | 12 +- fairseq/tasks/translation_lev.py | 54 ++-- tests/test_checkpoint_utils.py | 7 +- 5 files changed, 248 insertions(+), 185 deletions(-) diff --git a/examples/translation_moe/translation_moe_src/translation_moe.py b/examples/translation_moe/translation_moe_src/translation_moe.py index ae458aaad3..7f28c32dd6 100644 --- a/examples/translation_moe/translation_moe_src/translation_moe.py +++ b/examples/translation_moe/translation_moe_src/translation_moe.py @@ -3,16 +3,52 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field import torch +from omegaconf import II + from fairseq import metrics, utils +from fairseq.dataclass import ChoiceEnum from fairseq.tasks import register_task -from fairseq.tasks.translation import TranslationTask +from fairseq.tasks.translation import TranslationConfig, TranslationTask from .logsumexp_moe import LogSumExpMoE from .mean_pool_gating_network import MeanPoolGatingNetwork -@register_task("translation_moe") +METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"]) + + +@dataclass +class TranslationMoEConfig(TranslationConfig): + method: METHOD_CHOICES = field( + default="hMoEup", + metadata={"help": "MoE method"}, + ) + num_experts: int = field( + default=3, + metadata={"help": "number of experts"}, + ) + mean_pool_gating_network: bool = field( + default=False, + metadata={"help": "use a simple mean-pooling gating network"}, + ) + mean_pool_gating_network_dropout: float = field( + default=0, + metadata={"help": "dropout for mean-pooling gating network"}, + ) + mean_pool_gating_network_encoder_dim: int = field( + default=0, + metadata={"help": "encoder output dim for mean-pooling gating network"}, + ) + gen_expert: int = field( + default=0, + metadata={"help": "which expert to use for generation"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_task("translation_moe", dataclass=TranslationMoEConfig) class TranslationMoETask(TranslationTask): """ Translation task for Mixture of Experts (MoE) models. @@ -37,77 +73,60 @@ class TranslationMoETask(TranslationTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - TranslationTask.add_args(parser) - parser.add_argument('--method', default='hMoEup', - choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup']) - parser.add_argument('--num-experts', default=3, type=int, metavar='N', - help='number of experts') - parser.add_argument('--mean-pool-gating-network', action='store_true', - help='use a simple mean-pooling gating network') - parser.add_argument('--mean-pool-gating-network-dropout', type=float, - help='dropout for mean-pooling gating network') - parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float, - help='encoder output dim for mean-pooling gating network') - parser.add_argument('--gen-expert', type=int, default=0, - help='which expert to use for generation') - # fmt: on - - def __init__(self, args, src_dict, tgt_dict): - if args.method == "sMoElp": + cfg: TranslationMoEConfig + + def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict): + if cfg.method == "sMoElp": # soft MoE with learned prior self.uniform_prior = False self.hard_selection = False - elif args.method == "sMoEup": + elif cfg.method == "sMoEup": # soft MoE with uniform prior self.uniform_prior = True self.hard_selection = False - elif args.method == "hMoElp": + elif cfg.method == "hMoElp": # hard MoE with learned prior self.uniform_prior = False self.hard_selection = True - elif args.method == "hMoEup": + elif cfg.method == "hMoEup": # hard MoE with uniform prior self.uniform_prior = True self.hard_selection = True # add indicator tokens for each expert - for i in range(args.num_experts): + for i in range(cfg.num_experts): # add to both dictionaries in case we're sharing embeddings src_dict.add_symbol("<expert_{}>".format(i)) tgt_dict.add_symbol("<expert_{}>".format(i)) - super().__init__(args, src_dict, tgt_dict) + super().__init__(cfg, src_dict, tgt_dict) - def build_model(self, args): + def build_model(self, cfg): from fairseq import models - model = models.build_model(args, self) + model = models.build_model(cfg, self) if not self.uniform_prior and not hasattr(model, "gating_network"): - if self.args.mean_pool_gating_network: - if getattr(args, "mean_pool_gating_network_encoder_dim", None): - encoder_dim = args.mean_pool_gating_network_encoder_dim - elif getattr(args, "encoder_embed_dim", None): + if self.cfg.mean_pool_gating_network: + if self.cfg.mean_pool_gating_network_encoder_dim > 0: + encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim + elif getattr(cfg, "encoder_embed_dim", None): # assume that encoder_embed_dim is the encoder's output dimension - encoder_dim = args.encoder_embed_dim + encoder_dim = cfg.encoder_embed_dim else: raise ValueError( "Must specify --mean-pool-gating-network-encoder-dim" ) - if getattr(args, "mean_pool_gating_network_dropout", None): - dropout = args.mean_pool_gating_network_dropout - elif getattr(args, "dropout", None): - dropout = args.dropout + if self.cfg.mean_pool_gating_network_dropout > 0: + dropout = self.cfg.mean_pool_gating_network_dropout + elif getattr(cfg, "dropout", None): + dropout = cfg.dropout else: - raise ValueError("Must specify --mean-pool-gating-network-dropout") + raise ValueError("Must specify task.mean_pool_gating_network_dropout") model.gating_network = MeanPoolGatingNetwork( encoder_dim, - args.num_experts, + self.cfg.num_experts, dropout, ) else: @@ -125,7 +144,7 @@ def _get_loss(self, sample, model, criterion): criterion, "compute_loss" ), "translation_moe task requires the criterion to implement the compute_loss() method" - k = self.args.num_experts + k = self.cfg.num_experts bsz = sample["target"].size(0) def get_lprob_y(encoder_out, prev_output_tokens_k): @@ -185,7 +204,7 @@ def get_lprob_yz(winners=None): loss = loss.sum() sample_size = ( - sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"] ) logging_output = { "loss": utils.item(loss.data), @@ -221,7 +240,7 @@ def inference_step( expert=None, constraints=None, ): - expert = expert or self.args.gen_expert + expert = expert or self.cfg.gen_expert with torch.no_grad(): return generator.generate( models, diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 79007a6d9f..d975fd49d2 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -3,14 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field import itertools import json import logging import os +from typing import Optional from argparse import Namespace +from omegaconf import II import numpy as np -from fairseq import metrics, options, utils +from fairseq import metrics, utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, @@ -22,7 +25,9 @@ encoders, indexed_dataset, ) -from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task EVAL_BLEU_ORDER = 4 @@ -161,8 +166,102 @@ def split_exists(split, src, tgt, lang, data_path): ) -@register_task("translation") -class TranslationTask(LegacyFairseqTask): +@dataclass +class TranslationConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, + metadata={ + "help": "colon separated path to data directories list, will be iterated upon during epochs " + "in round-robin manner; however, valid and test data are always in the first directory " + "to avoid the need for repeating them in all directories" + }, + ) + source_lang: Optional[str] = field( + default=None, + metadata={ + "help": "source language", + "argparse_alias": "-s", + }, + ) + target_lang: Optional[str] = field( + default=None, + metadata={ + "help": "target language", + "argparse_alias": "-t", + }, + ) + load_alignments: bool = field( + default=False, metadata={"help": "load the binarized alignments"} + ) + left_pad_source: bool = field( + default=True, metadata={"help": "pad the source on the left"} + ) + left_pad_target: bool = field( + default=False, metadata={"help": "pad the target on the left"} + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=-1, metadata={"help": "the amount of upsample primary dataset"} + ) + truncate_source: bool = field( + default=False, metadata={"help": "truncate source to max-source-positions"} + ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into " + "N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations" + }, + ) + train_subset: str = II("dataset.train_subset") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + + # options for reporting BLEU during validation + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={ + "help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + }, + ) + eval_bleu_detok: str = field( + default="space", + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; " + "use 'space' to disable detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: str = field( + default="{}", + metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, + metadata={ + "help": "remove BPE before computing BLEU", + "argparse_const": "@@ ", + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + + +@register_task("translation", dataclass=TranslationConfig) +class TranslationTask(FairseqTask): """ Translate from one (source) language to another (target) language. @@ -174,108 +273,47 @@ class TranslationTask(LegacyFairseqTask): The translation task is compatible with :mod:`fairseq-train`, :mod:`fairseq-generate` and :mod:`fairseq-interactive`. - - The translation task provides the following additional command-line - arguments: - - .. argparse:: - :ref: fairseq.tasks.translation_parser - :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument('data', help='colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner; \ - however, valid and test data are always in the first directory to \ - avoid the need for repeating them in all directories') - parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', - help='source language') - parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', - help='target language') - parser.add_argument('--load-alignments', action='store_true', - help='load the binarized alignments') - parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', - help='pad the source on the left') - parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', - help='pad the target on the left') - parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') - parser.add_argument('--upsample-primary', default=1, type=int, - help='amount to upsample primary dataset') - parser.add_argument('--truncate-source', action='store_true', default=False, - help='truncate source to max-source-positions') - parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', - help='if >0, then bucket source and target lengths into N ' - 'buckets and pad accordingly; this is useful on TPUs ' - 'to minimize the number of compilations') - - # options for reporting BLEU during validation - parser.add_argument('--eval-bleu', action='store_true', - help='evaluation with BLEU scores') - parser.add_argument('--eval-bleu-detok', type=str, default="space", - help='detokenize before computing BLEU (e.g., "moses"); ' - 'required if using --eval-bleu; use "space" to ' - 'disable detokenization; see fairseq.data.encoders ' - 'for other options') - parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', - help='args for building the tokenizer, if needed') - parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, - help='compute tokenized BLEU instead of sacrebleu') - parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, - help='remove BPE before computing BLEU') - parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', - help='generation args for BLUE scoring, ' - 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') - parser.add_argument('--eval-bleu-print-samples', action='store_true', - help='print sample generations during validation') - # fmt: on - - def __init__(self, args, src_dict, tgt_dict): - super().__init__(args) + cfg: TranslationConfig + + def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict): + super().__init__(cfg) self.src_dict = src_dict self.tgt_dict = tgt_dict @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: TranslationConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ - args.left_pad_source = utils.eval_bool(args.left_pad_source) - args.left_pad_target = utils.eval_bool(args.left_pad_target) - paths = utils.split_paths(args.data) + paths = utils.split_paths(cfg.data) assert len(paths) > 0 # find language pair automatically - if args.source_lang is None or args.target_lang is None: - args.source_lang, args.target_lang = data_utils.infer_language_pair( - paths[0] - ) - if args.source_lang is None or args.target_lang is None: + if cfg.source_lang is None or cfg.target_lang is None: + cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0]) + if cfg.source_lang is None or cfg.target_lang is None: raise Exception( "Could not infer language pair, please provide it explicitly" ) # load dictionaries src_dict = cls.load_dictionary( - os.path.join(paths[0], "dict.{}.txt".format(args.source_lang)) + os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang)) ) tgt_dict = cls.load_dictionary( - os.path.join(paths[0], "dict.{}.txt".format(args.target_lang)) + os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang)) ) assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() - logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict))) - logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict))) + logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict))) + logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict))) - return cls(args, src_dict, tgt_dict) + return cls(cfg, src_dict, tgt_dict) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -283,15 +321,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 - if split != getattr(self.args, "train_subset", None): + if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode - src, tgt = self.args.source_lang, self.args.target_lang + src, tgt = self.cfg.source_lang, self.cfg.target_lang self.datasets[split] = load_langpair_dataset( data_path, @@ -301,17 +339,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): tgt, self.tgt_dict, combine=combine, - dataset_impl=self.args.dataset_impl, - upsample_primary=self.args.upsample_primary, - left_pad_source=self.args.left_pad_source, - left_pad_target=self.args.left_pad_target, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, - load_alignments=self.args.load_alignments, - truncate_source=self.args.truncate_source, - num_buckets=self.args.num_batch_buckets, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, + load_alignments=self.cfg.load_alignments, + truncate_source=self.cfg.truncate_source, + num_buckets=self.cfg.num_batch_buckets, shuffle=(split != "test"), - pad_to_multiple=self.args.required_seq_len_multiple, + pad_to_multiple=self.cfg.required_seq_len_multiple, ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): @@ -323,22 +361,15 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None) constraints=constraints, ) - def build_model(self, args): - model = super().build_model(args) - if getattr(args, "eval_bleu", False): - assert getattr(args, "eval_bleu_detok", None) is not None, ( - "--eval-bleu-detok is required if using --eval-bleu; " - "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " - "to disable detokenization, e.g., when using sentencepiece)" - ) - detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") + def build_model(self, cfg): + model = super().build_model(cfg) + if self.cfg.eval_bleu: + detok_args = json.loads(self.cfg.eval_bleu_detok_args) self.tokenizer = encoders.build_tokenizer( - Namespace( - tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args - ) + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) ) - gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") + gen_args = json.loads(self.cfg.eval_bleu_args) self.sequence_generator = self.build_generator( [model], Namespace(**gen_args) ) @@ -346,7 +377,7 @@ def build_model(self, args): def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - if self.args.eval_bleu: + if self.cfg.eval_bleu: bleu = self._inference_with_bleu(self.sequence_generator, sample, model) logging_output["_bleu_sys_len"] = bleu.sys_len logging_output["_bleu_ref_len"] = bleu.ref_len @@ -360,7 +391,7 @@ def valid_step(self, sample, model, criterion): def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) - if self.args.eval_bleu: + if self.cfg.eval_bleu: def sum_logs(key): return sum(log.get(key, 0) for log in logging_outputs) @@ -399,7 +430,7 @@ def compute_bleu(meters): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def source_dictionary(self): @@ -417,7 +448,7 @@ def _inference_with_bleu(self, generator, sample, model): def decode(toks, escape_unk=False): s = self.tgt_dict.string( toks.int().cpu(), - self.args.eval_bleu_remove_bpe, + self.cfg.eval_bleu_remove_bpe, # The default unknown string in fairseq is `<unk>`, but # this is tokenized by sacrebleu as `< unk >`, inflating # BLEU scores. Instead, we use a somewhat more verbose @@ -439,10 +470,10 @@ def decode(toks, escape_unk=False): escape_unk=True, # don't count <unk> as matches to the hypo ) ) - if self.args.eval_bleu_print_samples: + if self.cfg.eval_bleu_print_samples: logger.info("example hypothesis: " + hyps[0]) logger.info("example reference: " + refs[0]) - if self.args.eval_tokenized_bleu: + if self.cfg.eval_tokenized_bleu: return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") else: return sacrebleu.corpus_bleu(hyps, [refs]) diff --git a/fairseq/tasks/translation_from_pretrained_xlm.py b/fairseq/tasks/translation_from_pretrained_xlm.py index 347a6eccb7..a05f289152 100644 --- a/fairseq/tasks/translation_from_pretrained_xlm.py +++ b/fairseq/tasks/translation_from_pretrained_xlm.py @@ -3,13 +3,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary -from fairseq.tasks.translation import TranslationTask +from fairseq.tasks.translation import TranslationConfig, TranslationTask from . import register_task -@register_task("translation_from_pretrained_xlm") +@dataclass +class TranslationFromPretrainedXLMConfig(TranslationConfig): + pass + + +@register_task( + "translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig +) class TranslationFromPretrainedXLMTask(TranslationTask): """ Same as TranslationTask except use the MaskedLMDictionary class so that diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 4678774922..041279305d 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -3,33 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os - +from dataclasses import dataclass, field import torch from fairseq import utils from fairseq.data import LanguagePairDataset +from fairseq.dataclass import ChoiceEnum from fairseq.tasks import register_task -from fairseq.tasks.translation import TranslationTask, load_langpair_dataset +from fairseq.tasks.translation import TranslationConfig, TranslationTask, load_langpair_dataset from fairseq.utils import new_arange -@register_task("translation_lev") +NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"]) + +@dataclass +class TranslationLevenshteinConfig(TranslationConfig): + noise: NOISE_CHOICES = field( + default="random_delete", + metadata={ + "help": "type of noise" + }, + ) + +@register_task("translation_lev", dataclass=TranslationLevenshteinConfig) class TranslationLevenshteinTask(TranslationTask): """ Translation (Sequence Generation) task for Levenshtein Transformer See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_. """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - TranslationTask.add_args(parser) - parser.add_argument( - '--noise', - default='random_delete', - choices=['random_delete', 'random_mask', 'no_noise', 'full_mask']) - # fmt: on + cfg: TranslationLevenshteinConfig def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -37,12 +39,12 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode - src, tgt = self.args.source_lang, self.args.target_lang + src, tgt = self.cfg.source_lang, self.cfg.target_lang self.datasets[split] = load_langpair_dataset( data_path, @@ -52,12 +54,12 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): tgt, self.tgt_dict, combine=combine, - dataset_impl=self.args.dataset_impl, - upsample_primary=self.args.upsample_primary, - left_pad_source=self.args.left_pad_source, - left_pad_target=self.args.left_pad_target, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, prepend_bos=True, ) @@ -133,13 +135,13 @@ def _full_mask(target_tokens): ) return target_tokens.masked_fill(~target_mask, unk) - if self.args.noise == "random_delete": + if self.cfg.noise == "random_delete": return _random_delete(target_tokens) - elif self.args.noise == "random_mask": + elif self.cfg.noise == "random_mask": return _random_mask(target_tokens) - elif self.args.noise == "full_mask": + elif self.cfg.noise == "full_mask": return _full_mask(target_tokens) - elif self.args.noise == "no_noise": + elif self.cfg.noise == "no_noise": return target_tokens else: raise NotImplementedError diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index e3c685deec..617a5f7c84 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -53,7 +53,7 @@ def _train_transformer(self, seed, extra_args=None): yield os.path.join(data_dir, "checkpoint_last.pt") def test_load_model_ensemble_and_task(self): - with contextlib.redirect_stdout(StringIO()): + # with contextlib.redirect_stdout(StringIO()): with self._train_transformer(seed=123) as model1: with self._train_transformer(seed=456) as model2: ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( @@ -67,7 +67,10 @@ def test_load_model_ensemble_and_task(self): self.assertEqual(ensemble[1].args.seed, 456) # the task from the first model should be returned - self.assertEqual(task.args.seed, 123) + self.assertTrue("seed123" in task.cfg.data) + + # last cfg is saved + self.assertEqual(cfg.common.seed, 456) def test_prune_state_dict(self): with contextlib.redirect_stdout(StringIO()): From fc989279647666d7ae537345954556f76ef326c1 Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Wed, 20 Jan 2021 19:14:14 -0800 Subject: [PATCH 409/707] TALNet Training: Log `logits` and `target` in Wav2VecCriterion Summary: TALNet needs to log the `logits` and `target` arrays in `Wav2VecCriterion` to compute the MAP and MAUC metrics on the validation set, which are used to decide when to reduce the learning rate. Reviewed By: alexeib Differential Revision: D25957415 fbshipit-source-id: 9c8ccfc408dd0747611b36f430d9f0796bc5340d --- fairseq/criterions/wav2vec_criterion.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 8a1c348a58..cc454b9309 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -102,8 +102,16 @@ def forward(self, model, sample, reduce=True): } for lk in self.log_keys: - if lk in net_output: - logging_output[lk] = float((net_output[lk])) + # Only store "logits" and "target" for computing MAP and MAUC + # during validation + if lk == "logits": + if not self.training: + logging_output["logits"] = logits.cpu().numpy() + elif lk == "target": + if not self.training: + logging_output["target"] = target.cpu().numpy() + elif lk in net_output: + logging_output[lk] = float(net_output[lk]) if len(losses) > 1: for i, l in enumerate(losses): From 9a1c49706b35bdfc0e4de1a905298113eb236603 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Wed, 20 Jan 2021 20:40:52 -0800 Subject: [PATCH 410/707] Make Hydra logging work with DDP (#1568) Summary: without this pr, when launching hydra based sweeps via "hydra_train.py", logging would only go to standard out/error files, rather than into "hydra_train.log" file. The problem is that the standard out/err files are placed in a different folder by submitit without a clear way to modify this behavior, while hydra_train.log is always empty (except when training on one gpu) because either a) reset_logging will remove hydra logging hooks, or b) hydra logging does not work properly with ddp based training. to address a) we do not remove hydra logging by default (although it is optionally still possible) to address b) we reconfigure the loggers in the train method which will be called by each spawned process Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1568 Reviewed By: myleott Differential Revision: D25965658 Pulled By: alexeib fbshipit-source-id: 77cbd4d310fe2d291fb1003c6a3e27e619d571aa --- fairseq/dataclass/configs.py | 2 +- fairseq/dataclass/initialize.py | 4 +++- fairseq/distributed_utils.py | 4 ++++ fairseq_cli/hydra_train.py | 12 +++++++++--- fairseq_cli/train.py | 8 +++++++- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 2968d2ab0f..2ed27284dc 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -181,7 +181,7 @@ class CommonConfig(FairseqDataclass): default=False, metadata={"help": "enable autograd profiler emit_nvtx"} ) reset_logging: bool = field( - default=True, + default=False, metadata={ "help": "when using Hydra, reset the logging at the beginning of training" }, diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 385624f19b..479aeb8b16 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -7,7 +7,7 @@ import logging from hydra.core.config_store import ConfigStore from fairseq.dataclass.configs import FairseqConfig -from omegaconf import DictConfig, open_dict +from omegaconf import DictConfig, OmegaConf logger = logging.getLogger(__name__) @@ -36,6 +36,8 @@ def add_defaults(cfg: DictConfig) -> None: from fairseq.dataclass.utils import merge_with_parent from typing import Any + OmegaConf.set_struct(cfg, False) + for k, v in FairseqConfig.__dataclass_fields__.items(): field_cfg = cfg.get(k) if field_cfg is not None and v.type == Any: diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 8f98ac88f9..79deba5b98 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -301,6 +301,10 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs): main(cfg, **kwargs) + # make sure checkpoints finish saving + if torch.distributed.is_initialized(): + torch.distributed.barrier() + def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 6754f9483d..180bd40717 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -14,8 +14,9 @@ from fairseq.dataclass.configs import FairseqConfig import hydra +from hydra.core.hydra_config import HydraConfig import torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict logger = logging.getLogger("fairseq_cli.hydra_train") @@ -24,11 +25,16 @@ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: add_defaults(cfg) - cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) - OmegaConf.set_struct(cfg, True) if cfg.common.reset_logging: reset_logging() # Hydra hijacks logging, fix that + else: + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) + + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) + OmegaConf.set_struct(cfg, True) try: if cfg.common.profile: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 1156222642..5f3fff8e7f 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -25,7 +25,9 @@ utils, ) from fairseq.data import iterators +from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed_utils import is_master from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer @@ -41,12 +43,16 @@ logger = logging.getLogger("fairseq_cli.train") -def main(cfg: DictConfig) -> None: +def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) + if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) + assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" From cfbf0dddbc2f06b4d2975655a3959d13e5ba6667 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 21 Jan 2021 07:32:08 -0800 Subject: [PATCH 411/707] Small changes to make tests more reliable (#1572) Summary: After this, `python setup.py test` should be more reliable (including when multiple GPUs are present) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1572 Reviewed By: alexeib Differential Revision: D25984113 Pulled By: myleott fbshipit-source-id: 7fef27ae90c079c07f592ed9fb350ccf8b56d23d --- README.md | 4 ++-- fairseq/data/data_utils_fast.pyx | 3 +++ fairseq/distributed_utils.py | 5 +--- fairseq/models/roberta/hub_interface.py | 2 +- fairseq/ngram_repeat_block.py | 2 +- fairseq/sequence_generator.py | 12 ++++++---- setup.py | 31 +++++++++++++------------ tests/distributed/utils.py | 1 + tests/test_binaries.py | 5 ++-- tests/test_bmuf.py | 17 ++++++-------- 10 files changed, 43 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index c22abba8c0..5fedac7eec 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,8 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ -# to install the latest stable release (0.10.1) -# pip install fairseq==0.10.1 +# to install the latest stable release (0.10.x) +# pip install fairseq ``` * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index d197d3f00e..c61f31d6b2 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -24,6 +24,9 @@ cpdef list batch_by_size_vec( int64_t max_sentences, int32_t bsz_mult, ): + if indices.shape[0] == 0: + return [] + assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( f"Sentences lengths should not exceed max_tokens={max_tokens}" ) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 79deba5b98..37822362d4 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -301,10 +301,6 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs): main(cfg, **kwargs) - # make sure checkpoints finish saving - if torch.distributed.is_initialized(): - torch.distributed.barrier() - def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: @@ -323,6 +319,7 @@ def call_main(cfg: FairseqConfig, main, **kwargs): torch.cuda.device_count(), cfg.distributed_training.distributed_world_size, ), + join=True, ) else: distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 0c723f06dd..c9af434bde 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -173,7 +173,7 @@ def fill_mask(self, masked_input: str, topk: int = 5): add_if_not_exist=False, ) - masked_index = (tokens == self.task.mask_idx).nonzero() + masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False) if tokens.dim() == 1: tokens = tokens.unsqueeze(0) diff --git a/fairseq/ngram_repeat_block.py b/fairseq/ngram_repeat_block.py index 856c9e64f7..ed2d744635 100644 --- a/fairseq/ngram_repeat_block.py +++ b/fairseq/ngram_repeat_block.py @@ -19,7 +19,7 @@ def is_cuda_extension_usable() -> bool: """Check whether ngram_repeat_block_cuda is built properly""" - if not EXTENSION_BUILT: + if not EXTENSION_BUILT or not torch.cuda.is_available(): return False bsz = 2 tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index b0249888ce..117c6116fb 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -86,8 +86,10 @@ def __init__( self.temperature = temperature self.match_source_len = match_source_len - self.no_repeat_ngram_size = no_repeat_ngram_size - self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None assert temperature > 0, "--temperature must be greater than 0" @@ -373,8 +375,10 @@ def _generate( if self.should_set_src_lengths: self.search.set_src_lengths(src_lengths) - if self.no_repeat_ngram_size > 0: - lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker( + tokens, lprobs, bsz, beam_size, step + ) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( diff --git a/setup.py b/setup.py index 08fe0dcccc..d1a976104e 100644 --- a/setup.py +++ b/setup.py @@ -242,18 +242,19 @@ def get_files(path, relative_to="fairseq"): return all_files -try: - # symlink examples into fairseq package so package_data accepts them - fairseq_examples = os.path.join("fairseq", "examples") - if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): - os.symlink(os.path.join("..", "examples"), fairseq_examples) - - package_data = { - "fairseq": ( - get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) - ) - } - do_setup(package_data) -finally: - if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): - os.unlink(fairseq_examples) +if __name__ == "__main__": + try: + # symlink examples into fairseq package so package_data accepts them + fairseq_examples = os.path.join("fairseq", "examples") + if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): + os.symlink(os.path.join("..", "examples"), fairseq_examples) + + package_data = { + "fairseq": ( + get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) + ) + } + do_setup(package_data) + finally: + if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + os.unlink(fairseq_examples) diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py index d2b3ddb1ff..c8040392a8 100644 --- a/tests/distributed/utils.py +++ b/tests/distributed/utils.py @@ -17,6 +17,7 @@ def spawn_and_init(fn, world_size, args=None): fn=functools.partial(init_and_run, fn, args), args=(world_size, tmp_file.name,), nprocs=world_size, + join=True, ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index ddfc1c4db5..967b1bf7ba 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1623,8 +1623,9 @@ def _train(extra_flags): with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) ckpt_logs = _train(["--checkpoint-activations"]) baseline_logs = _train([]) assert len(baseline_logs) == len(ckpt_logs) diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py index e7aa6da1ca..785da37bc2 100644 --- a/tests/test_bmuf.py +++ b/tests/test_bmuf.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import random import unittest from multiprocessing import Manager @@ -141,16 +142,12 @@ class TestBMUF(unittest.TestCase): def bmuf_process(self, cfg, args, iterations): processes = [] results = Manager().dict() - ctx = torch.multiprocessing.get_context("spawn") - for rank in range(args.distributed_world_size): - p = ctx.Process( - target=single_gpu_training, args=(cfg, args, rank, iterations, results) - ) - p.start() - processes.append(p) - - for p in processes: - p.join() + torch.multiprocessing.spawn( + fn=functools.partial(single_gpu_training, cfg, args), + args=(iterations, results), + nprocs=args.distributed_world_size, + join=True, + ) return results def test_bmuf_sync(self): From bfcc13e20a6cfa18fb25daaae39644f9b7872699 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Sat, 23 Jan 2021 03:28:48 -0800 Subject: [PATCH 412/707] add return type hints for readability (#1575) Summary: As I was going through the dataset/preprocessing code, knowing return types would have made my life easier. Also I don't think you need to inherit from object in python3. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1575 Reviewed By: myleott Differential Revision: D25999257 Pulled By: sshleifer fbshipit-source-id: 818a623c68fb7812306c760f3ae6346a14937c51 --- fairseq/binarizer.py | 11 +++++++---- fairseq/data/dictionary.py | 4 ++-- fairseq/data/indexed_dataset.py | 12 ++++++------ fairseq/data/lm_context_window_dataset.py | 11 ++++++++--- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index c736c8754d..18ae67bf25 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -9,6 +9,7 @@ import torch from fairseq.file_io import PathManager from fairseq.tokenizer import tokenize_line +from typing import List, Dict def safe_readline(f): @@ -33,7 +34,7 @@ def binarize( offset=0, end=-1, already_numberized=False, - ): + ) -> Dict[str, int]: nseq, ntok = 0, 0 replaced = Counter() @@ -52,7 +53,7 @@ def replaced_consumer(word, idx): # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely # that the procedure breaks by the undeterministic behavior of # f.tell() - if end > 0 and f.tell() > end and f.tell() < end + 2**32: + if end > 0 and f.tell() > end and f.tell() < end + 2 ** 32: break if already_numberized: id_strings = line.strip().split() @@ -83,7 +84,9 @@ def replaced_consumer(word, idx): } @staticmethod - def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): + def binarize_alignments( + filename, alignment_parser, consumer, offset=0, end=-1 + ) -> Dict[str, int]: nseq = 0 with open(PathManager.get_local_path(filename), "r") as f: @@ -99,7 +102,7 @@ def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): return {"nseq": nseq} @staticmethod - def find_offsets(filename, num_chunks): + def find_offsets(filename, num_chunks) -> List[int]: with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: size = os.fstat(f.fileno()).st_size chunk_size = size // num_chunks diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 127d023f4c..8d219e20ef 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -15,7 +15,7 @@ from fairseq.tokenizer import tokenize_line -class Dictionary(object): +class Dictionary: """A mapping from symbols to consecutive integers""" def __init__( @@ -298,7 +298,7 @@ def encode_line( consumer=None, append_eos=True, reverse_order=False, - ): + ) -> torch.IntTensor: words = line_tokenizer(line) if reverse_order: words = list(reversed(words)) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 827754d848..a821417321 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -159,7 +159,7 @@ def __del__(self): self.data_file.close() @lru_cache(maxsize=8) - def __getitem__(self, i): + def __getitem__(self, i) -> torch.Tensor: if not self.data_file: self.read_data(self.path) self.check_index(i) @@ -296,7 +296,7 @@ def exists(path): return PathManager.exists(path) -class IndexedDatasetBuilder(object): +class IndexedDatasetBuilder: element_sizes = { np.uint8: 1, np.int8: 1, @@ -363,12 +363,12 @@ def _warmup_mmap_file(path): class MMapIndexedDataset(torch.utils.data.Dataset): - class Index(object): + class Index: _HDR_MAGIC = b"MMIDIDX\x00\x00" @classmethod def writer(cls, path, dtype): - class _Writer(object): + class _Writer: def __enter__(self): self._file = open(path, "wb") @@ -517,7 +517,7 @@ def exists(path): ) -def get_indexed_dataset_to_local(path): +def get_indexed_dataset_to_local(path) -> str: local_index_path = PathManager.get_local_path(index_file_path(path)) local_data_path = PathManager.get_local_path(data_file_path(path)) @@ -531,7 +531,7 @@ def get_indexed_dataset_to_local(path): return local_path -class MMapIndexedDatasetBuilder(object): +class MMapIndexedDatasetBuilder: def __init__(self, out_file, dtype=np.int64): self._data_file = open(out_file, "wb") self._dtype = dtype diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py index 39512797bc..1a945927cf 100644 --- a/fairseq/data/lm_context_window_dataset.py +++ b/fairseq/data/lm_context_window_dataset.py @@ -5,6 +5,8 @@ import numpy as np import torch +from typing import Dict + from fairseq.data.monolingual_dataset import MonolingualDataset from . import FairseqDataset @@ -26,7 +28,11 @@ class LMContextWindowDataset(FairseqDataset): """ def __init__( - self, dataset, tokens_per_sample: int, context_window: int, pad_idx: int + self, + dataset: MonolingualDataset, + tokens_per_sample: int, + context_window: int, + pad_idx: int, ): assert context_window > 0 self.dataset = dataset @@ -41,7 +47,7 @@ def __getitem__(self, index): def __len__(self): return len(self.dataset) - def collater(self, samples): + def collater(self, samples) -> Dict: sample = self.dataset.collater(samples) pad = self.pad_idx @@ -71,7 +77,6 @@ def collater(self, samples): sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks) sample["target"] = torch.from_numpy(new_tgt) sample["start_indices"] = start_idxs - return sample def num_tokens(self, index): From 1e6323e9346c172225f20415735c33f96dd9aad1 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@fb.com> Date: Mon, 25 Jan 2021 09:22:08 -0800 Subject: [PATCH 413/707] Offload inputs to CPU (V2) Reviewed By: myleott Differential Revision: D26035523 fbshipit-source-id: 7dc08a38c10d1f26a871106f143f92fd11f6073c --- .../truncated_bptt/transformer_xl_model.py | 10 +- fairseq/models/transformer.py | 26 +++- fairseq/models/transformer_lm.py | 8 ++ fairseq/modules/checkpoint_activations.py | 23 +++- tests/test_binaries.py | 116 ++++++++++++------ 5 files changed, 131 insertions(+), 52 deletions(-) diff --git a/examples/truncated_bptt/transformer_xl_model.py b/examples/truncated_bptt/transformer_xl_model.py index 83b248479e..a6c8b25a07 100644 --- a/examples/truncated_bptt/transformer_xl_model.py +++ b/examples/truncated_bptt/transformer_xl_model.py @@ -37,6 +37,7 @@ class TransformerXLConfig(FairseqDataclass): dropout: float = 0.0 dropatt: float = 0.0 checkpoint_activations: bool = False + offload_activations: bool = False max_target_positions: int = II("task.max_target_positions") @@ -51,7 +52,8 @@ class TransformerXLDecoder(FairseqIncrementalDecoder): def __init__(self, cfg, task): try: from transformers.models.transfo_xl import ( - TransfoXLConfig, TransfoXLLMHeadModel + TransfoXLConfig, + TransfoXLLMHeadModel, ) except ImportError: from transformers.configuration_transfo_xl import TransfoXLConfig @@ -96,11 +98,13 @@ def __init__(self, cfg, task): except Exception: pass - if cfg.checkpoint_activations: + if cfg.checkpoint_activations or cfg.offload_activations: for i in range(len(self.model.transformer.layers)): self.model.transformer.layers[i] = checkpoint_wrapper( - self.model.transformer.layers[i] + self.model.transformer.layers[i], + offload_to_cpu=cfg.offload_activations, ) + # TODO: may save mem to wrap(layer.pos_ff.CoreNet[3]) self._mems = None diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index fa4c29855b..362d9b28d6 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -167,6 +167,8 @@ def add_args(parser): parser.add_argument('--checkpoint-activations', action='store_true', help='checkpoint activations at each layer, which saves GPU ' 'memory usage at the cost of some additional compute') + parser.add_argument('--offload-activations', action='store_true', + help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) parser.add_argument('--no-cross-attention', default=False, action='store_true', help='do not perform cross-attention') @@ -234,7 +236,8 @@ def build_model(cls, args, task): decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) - + if getattr(args, "offload_activations", False): + args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) return cls(args, encoder, decoder) @@ -380,7 +383,8 @@ def __init__(self, args, dictionary, embed_tokens): def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) if getattr(args, "checkpoint_activations", False): - layer = checkpoint_wrapper(layer) + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) return layer def forward_embedding( @@ -670,7 +674,8 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) if getattr(args, "checkpoint_activations", False): - layer = checkpoint_wrapper(layer) + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) return layer def forward( @@ -947,6 +952,17 @@ def Linear(in_features, out_features, bias=True): return m +@register_model_architecture("transformer", "transformer_tiny") +def tiny_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64) + args.encoder_layers = getattr(args, "encoder_layers", 2) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) + return base_architecture(args) + + @register_model_architecture("transformer", "transformer") def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) @@ -991,7 +1007,9 @@ def base_architecture(args): args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) - + args.offload_activations = getattr(args, "offload_activations", False) + if args.offload_activations: + args.checkpoint_activations = True args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index d86b68b508..edf62b12b3 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -144,6 +144,10 @@ class TransformerLanguageModelConfig(FairseqDataclass): checkpoint_activations: bool = field( default=False, metadata={"help": "checkpoint activations at each layer"} ) + offload_activations: bool = field( + default=False, + metadata={"help": "move checkpointed activations to CPU after they are used."}, + ) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -171,6 +175,7 @@ class TransformerLanguageModel(FairseqLanguageModel): def hub_models(cls): def moses_fastbpe(path): return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} + def spm(path): return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} @@ -321,6 +326,9 @@ def base_lm_architecture(args): args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.offload_activations = getattr(args, "offload_activations", False) + if args.offload_activations: + args.checkpoint_activations = True @register_model_architecture("transformer_lm", "transformer_lm_big") diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index e0e5679c5a..c84e70bf7b 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -11,7 +11,7 @@ from fairseq import utils -def checkpoint_wrapper(m): +def checkpoint_wrapper(m, offload_to_cpu=False): """ A friendlier wrapper for performing activation checkpointing. @@ -22,7 +22,7 @@ def checkpoint_wrapper(m): Usage:: - checkpointed_module = checkpoint_wrapper(my_module) + checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) """ original_forward = m.forward @@ -32,7 +32,7 @@ def _checkpointed_forward(*args, **kwargs): # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) - parent_ctx_dict = {} + parent_ctx_dict = {"offload": offload_to_cpu} output = CheckpointFunction.apply( original_forward, parent_ctx_dict, kwarg_keys, *flat_args ) @@ -141,6 +141,14 @@ def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): ctx.fwd_rng_state = utils.get_rng_state() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + if parent_ctx_dict["offload"]: + ctx.fwd_device = tuple(x.device for x in tensor_inputs) + ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) + tensor_inputs = tuple(x.cpu() for x in tensor_inputs) + + else: + ctx.fwd_device, ctx.grad_requirements = None, None + ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs @@ -165,8 +173,14 @@ def backward(ctx, *args): "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) - tensor_inputs = ctx.saved_tensors + tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs = checkpoint.detach_variable(tensor_inputs) + if ctx.fwd_device is not None: + tensor_inputs = [ + t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs) + ] + for i, need_grad in enumerate(ctx.grad_requirements): + tensor_inputs[i].requires_grad = need_grad inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. @@ -179,7 +193,6 @@ def backward(ctx, *args): unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) - # Set the states back to what it was at the start of this function. utils.set_rng_state(bwd_rng_state) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 967b1bf7ba..981ffd49cd 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -547,16 +547,16 @@ def test_translation_multi_simple_epoch_src_tgt_dict_spec(self): "test_translation_multi_simple_epoch_dict" ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data( - data_dir, extra_flags=[] - ) + preprocess_translation_data(data_dir, extra_flags=[]) train_translation_model( data_dir, arch="transformer", task="translation_multi_simple_epoch", extra_flags=[ - "--source-dict", f"{data_dir}/dict.in.txt", - "--target-dict", f"{data_dir}/dict.out.txt", + "--source-dict", + f"{data_dir}/dict.in.txt", + "--target-dict", + f"{data_dir}/dict.out.txt", "--encoder-layers", "2", "--decoder-layers", @@ -1250,6 +1250,20 @@ def test_transformer_xl_bptt_lm(self): extra_valid_flags=task_flags, ) eval_lm_main(data_dir, extra_flags=task_flags) + # Train with activation offloading + train_language_model( + data_dir=data_dir, + arch="transformer_xl", + extra_flags=task_flags + + [ + "--n-layer", + "2", + "--offload-activations", + ], + task="truncated_bptt_lm", + run_validation=True, + extra_valid_flags=task_flags, + ) class TestMaskedLanguageModel(unittest.TestCase): @@ -1589,45 +1603,67 @@ def read_last_log_entry( class TestActivationCheckpointing(unittest.TestCase): + base_flags = [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--restore-file", + "x.pt", + "--log-format", + "json", + "--log-interval", + "1", + "--max-update", + "2", + ] + + def _train(self, data_dir, extra_flags): + with self.assertLogs() as logs: + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + self.base_flags + extra_flags, + run_validation=True, + extra_valid_flags=["--log-format", "json"], + ) + return logs.records + + def test_activation_offloading_does_not_change_metrics(self): + """Neither ----checkpoint-activations nor --offload-activations should change loss""" + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + offload_logs = self._train(data_dir, ["--offload-activations"]) + baseline_logs = self._train(data_dir, []) + + assert len(baseline_logs) == len(offload_logs) + + baseline_valid_stats = read_last_log_entry(baseline_logs, "valid") + offload_valid_stats = read_last_log_entry(offload_logs, "valid") + baseline_train_stats = read_last_log_entry(baseline_logs, "train") + offload_train_stats = read_last_log_entry(offload_logs, "train") + + assert ( + baseline_train_stats["train_loss"] == offload_train_stats["train_loss"] + ) + assert ( + baseline_valid_stats["valid_loss"] == offload_valid_stats["valid_loss"] + ) + def test_activation_checkpointing_does_not_change_metrics(self): """--checkpoint-activations should not change loss""" - base_flags = [ - "--encoder-layers", - "2", - "--decoder-layers", - "2", - "--encoder-embed-dim", - "8", - "--decoder-embed-dim", - "8", - "--restore-file", - "x.pt", - "--log-format", - "json", - "--log-interval", - "1", - "--max-update", - "2", - ] - - def _train(extra_flags): - with self.assertLogs() as logs: - train_translation_model( - data_dir, - "transformer_iwslt_de_en", - base_flags + extra_flags, - run_validation=True, - extra_valid_flags=["--log-format", "json"], - ) - return logs.records with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - - with self.assertLogs(): - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) - ckpt_logs = _train(["--checkpoint-activations"]) - baseline_logs = _train([]) + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) + baseline_logs = self._train(data_dir, []) assert len(baseline_logs) == len(ckpt_logs) baseline_train_stats = read_last_log_entry(baseline_logs, "train") From 1f23d83fcff3718805447b8d83edd666cf586850 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 27 Jan 2021 10:45:38 -0800 Subject: [PATCH 414/707] Fix masked_lm task to be more deterministic when reloading checkpoints (#1581) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1581 Test Plan: Imported from OSS Reviewed By: joshim5 Differential Revision: D26018994 Pulled By: myleott fbshipit-source-id: 9013b2795936bd5877fb9e8a8397c9b0ea35ef60 --- fairseq/data/mask_tokens_dataset.py | 5 ++++- fairseq/tasks/masked_lm.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index b239013c80..9123235594 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -108,8 +108,11 @@ def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch - @lru_cache(maxsize=8) def __getitem__(self, index: int): + return self.__getitem_cached__(self.seed, self.epoch, index) + + @lru_cache(maxsize=8) + def __getitem_cached__(self, seed: int, epoch: int, index: int): with data_utils.numpy_seed(self.seed, self.epoch, index): item = self.dataset[index] sz = len(item) diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 70208bc4d5..fd2ea6ade1 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -193,7 +193,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): mask_stdev=self.args.mask_stdev, ) - with data_utils.numpy_seed(self.args.seed + epoch): + with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( From c4a4562299eb58f295426b9b81b36701390e883e Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 27 Jan 2021 10:45:38 -0800 Subject: [PATCH 415/707] Reset meters when reloading an end-of-epoch checkpoint (#1582) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1582 Test Plan: Imported from OSS Reviewed By: joshim5 Differential Revision: D26018993 Pulled By: myleott fbshipit-source-id: a0e988d8a2d720443571814d8d9f51acc571f98f --- fairseq/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index fec60f7742..5035751741 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -376,7 +376,8 @@ def load_checkpoint( self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: - epoch = extra_state["train_iterator"]["epoch"] + itr_state = extra_state["train_iterator"] + epoch = itr_state["epoch"] if "previous_training_time" in extra_state: self._previous_training_time = extra_state["previous_training_time"] @@ -384,6 +385,10 @@ def load_checkpoint( self.lr_step(epoch) + if itr_state["version"] >= 2 and itr_state["iterations_in_epoch"] == 0: + # reset meters at start of epoch + reset_meters = True + if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) From 6225dccb989ebfb268274bad36a794b27e4dd43f Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 27 Jan 2021 10:45:38 -0800 Subject: [PATCH 416/707] Use a fixed epoch (=1) when creating validation batch iterators (#1583) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1583 Test Plan: Imported from OSS Reviewed By: joshim5 Differential Revision: D26018992 Pulled By: myleott fbshipit-source-id: f4908358f8db79b6034639419d47698f933ddf91 --- fairseq/data/iterators.py | 24 +++++++++++++++++------- fairseq/trainer.py | 3 +++ fairseq_cli/train.py | 4 +++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 00bf41375c..66eaf875cb 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -109,15 +109,19 @@ def __len__(self) -> int: def next_epoch_idx(self): raise NotImplementedError - def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): """Return a new iterator over the dataset. Args: shuffle (bool, optional): shuffle batches before returning the iterator (default: True). - fix_batches_to_gpus: ensure that batches are always + fix_batches_to_gpus (bool, optional): ensure that batches are always allocated to the same shards across epochs. Requires that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). """ raise NotImplementedError @@ -193,9 +197,11 @@ def next_epoch_idx(self): else: return self.epoch - def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): self.epoch = self.next_epoch_idx - if hasattr(self.dataset, "set_epoch"): + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(self.epoch) self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle) return self._current_epoch_iterator @@ -355,20 +361,24 @@ def next_epoch_idx(self): else: return self.epoch - def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): """Return a new iterator over the dataset. Args: shuffle (bool, optional): shuffle batches before returning the iterator (default: True). - fix_batches_to_gpus: ensure that batches are always + fix_batches_to_gpus (bool, optional): ensure that batches are always allocated to the same shards across epochs. Requires that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). """ if self.disable_shuffling: shuffle = False self.epoch = self.next_epoch_idx - if hasattr(self.dataset, "set_epoch"): + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(self.epoch) if self._next_epoch_itr is not None: self._cur_epoch_itr = self._next_epoch_itr diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5035751741..eea194b950 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -468,6 +468,9 @@ def get_valid_iterator( num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, num_workers=self.cfg.dataset.num_workers, + # always pass a fixed "epoch" to keep validation data consistent + # across training epochs + epoch=1, data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 5f3fff8e7f..9af7568a77 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -378,7 +378,9 @@ def validate( logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator - itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) + itr = trainer.get_valid_iterator(subset).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( From 96da4d38eb60c3971ca88df9383e09d53493cf1e Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 417/707] Small Hydra fix (#1543) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1543 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836852 Pulled By: myleott fbshipit-source-id: 7fda711d21f2d1b7bac26792233997e8dea2f835 --- fairseq/dataclass/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 401c212ecc..a4d4a412dd 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -162,7 +162,7 @@ def get_kwargs_from_dc( continue else: del kwargs["default"] - if delete_default: + if delete_default and "default" in kwargs: del kwargs["default"] try: parser.add_argument(*field_args, **kwargs) From 5e343f5f23b4a90cca2beec416b87d4dd7a4264f Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 418/707] Remove --distributed-wrapper (consolidate to --ddp-backend) (#1544) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1544 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836856 Pulled By: myleott fbshipit-source-id: eb0a6a02f4d9fe2b6b12a456ef95208dd92c97cb --- examples/cross_lingual_language_model/README.md | 2 +- examples/language_model/README.adaptive_inputs.md | 2 +- examples/language_model/README.conv.md | 2 +- examples/latent_depth/README.md | 2 +- examples/mbart/README.md | 2 +- examples/nonautoregressive_translation/README.md | 2 +- examples/nonautoregressive_translation/scripts.md | 12 ++++++------ examples/pay_less_attention_paper/README.md | 6 +++--- examples/quant_noise/README.md | 10 +++++----- examples/roberta/README.race.md | 2 +- examples/roberta/commonsense_qa/README.md | 2 +- examples/roberta/wsc/README.md | 4 ++-- examples/translation/README.md | 2 +- examples/translation_moe/README.md | 2 +- examples/wav2vec/config/finetuning/base_100h.yaml | 2 +- examples/wav2vec/config/finetuning/base_10h.yaml | 2 +- examples/wav2vec/config/finetuning/base_10m.yaml | 2 +- examples/wav2vec/config/finetuning/base_1h.yaml | 2 +- examples/wav2vec/config/finetuning/base_960h.yaml | 2 +- examples/wav2vec/config/finetuning/vox_100h.yaml | 2 +- examples/wav2vec/config/finetuning/vox_10h.yaml | 2 +- examples/wav2vec/config/finetuning/vox_10m.yaml | 2 +- examples/wav2vec/config/finetuning/vox_1h.yaml | 2 +- examples/wav2vec/config/finetuning/vox_960h.yaml | 2 +- .../pretraining/wav2vec2_base_librispeech.yaml | 2 +- .../config/pretraining/wav2vec2_large_librivox.yaml | 2 +- fairseq/criterions/adaptive_loss.py | 6 +++--- fairseq/dataclass/configs.py | 8 ++------ fairseq/dataclass/constants.py | 9 +++++++-- fairseq/distributed_utils.py | 2 +- fairseq/models/distributed_fairseq_model.py | 6 +++--- fairseq/trainer.py | 9 +++++---- 32 files changed, 59 insertions(+), 57 deletions(-) diff --git a/examples/cross_lingual_language_model/README.md b/examples/cross_lingual_language_model/README.md index f4c76cfed5..af9128e39e 100644 --- a/examples/cross_lingual_language_model/README.md +++ b/examples/cross_lingual_language_model/README.md @@ -68,7 +68,7 @@ fairseq-train \ --dataset-impl lazy --seed 0 \ --masked-lm-only \ --monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \ ---ddp-backend=no_c10d +--ddp-backend=legacy_ddp ``` Some Notes: diff --git a/examples/language_model/README.adaptive_inputs.md b/examples/language_model/README.adaptive_inputs.md index 98043c5377..6650d58f37 100644 --- a/examples/language_model/README.adaptive_inputs.md +++ b/examples/language_model/README.adaptive_inputs.md @@ -22,7 +22,7 @@ fairseq-train --task language_modeling \ --max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \ --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \ - --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d + --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=legacy_ddp ``` ## Citation diff --git a/examples/language_model/README.conv.md b/examples/language_model/README.conv.md index f0b6a3a921..1ff8635906 100644 --- a/examples/language_model/README.conv.md +++ b/examples/language_model/README.conv.md @@ -17,7 +17,7 @@ fairseq-train --task language_modeling \ --optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \ --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --max-tokens 1024 --tokens-per-sample 1024 \ - --ddp-backend no_c10d \ + --ddp-backend legacy_ddp \ --max-epoch 35 ``` diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md index e70e16405c..7774c33305 100644 --- a/examples/latent_depth/README.md +++ b/examples/latent_depth/README.md @@ -30,7 +30,7 @@ fairseq-train ${databin_dir} \ --lr 0.0015 \ --clip-norm 1.0 \ --seed 2 \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --encoder-layers 12 \ --decoder-layers 24 \ --decoder-latent-layer \ diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 8a3e22d425..a45e37243c 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -81,7 +81,7 @@ fairseq-train path_2_data \ --restore-file $PRETRAIN \ --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \ --langs $langs \ - --ddp-backend no_c10d + --ddp-backend legacy_ddp ``` ## Generate on EN-RO Get sacrebleu on finetuned en-ro model diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md index 7b2d42a91d..8793e225c9 100644 --- a/examples/nonautoregressive_translation/README.md +++ b/examples/nonautoregressive_translation/README.md @@ -36,7 +36,7 @@ The following command will train a *Levenshtein Transformer* on the binarized da fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch levenshtein_transformer \ diff --git a/examples/nonautoregressive_translation/scripts.md b/examples/nonautoregressive_translation/scripts.md index a3a33e6e02..9d3d7b67dc 100644 --- a/examples/nonautoregressive_translation/scripts.md +++ b/examples/nonautoregressive_translation/scripts.md @@ -6,7 +6,7 @@ Note that we need to have an additional module to perform "length prediction" (` fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch nonautoregressive_transformer \ @@ -35,7 +35,7 @@ Note that we implemented a low-rank appromixated CRF model by setting `--crf-low fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch nacrf_transformer \ @@ -68,7 +68,7 @@ Note that `--train-step` means how many iterations of refinement we used during fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch iterative_nonautoregressive_transformer \ @@ -101,7 +101,7 @@ Note that we need to specify the "slot-loss" (uniform or balanced tree) describe fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch insertion_transformer \ @@ -128,7 +128,7 @@ fairseq-train \ fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch cmlm_transformer \ @@ -157,7 +157,7 @@ fairseq-train \ fairseq-train \ data-bin/wmt14_en_de_distill \ --save-dir checkpoints \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task translation_lev \ --criterion nat_loss \ --arch levenshtein_transformer \ diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index d5b19af6cc..5adab11f4d 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -113,7 +113,7 @@ CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \ --log-interval 100 --stop-min-lr '1e-09' --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --lr-scheduler inverse_sqrt \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \ --adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \ -a lightconv_iwslt_de_en --save-dir $SAVE \ @@ -138,7 +138,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ - --ddp-backend=no_c10d --max-tokens 3584 \ + --ddp-backend=legacy_ddp --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 20000 \ @@ -163,7 +163,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ - --ddp-backend=no_c10d --max-tokens 3584 \ + --ddp-backend=legacy_ddp --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 70000 \ diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 7fe301f732..539c3d5af9 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -154,7 +154,7 @@ fairseq-train $DATA_DIR \ --batch-size $MAX_SENTENCES \ --update-freq $UPDATE_FREQ --max-update $TOTAL_UPDATES \ --save-dir checkpoint/roberta \ - --ddp-backend no_c10d --encoder-layerdrop 0.2 \ + --ddp-backend legacy_ddp --encoder-layerdrop 0.2 \ --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 --untie-weights-roberta ``` @@ -189,7 +189,7 @@ fairseq-train /path/to/rte/data/ \ --max-epoch 10 \ --find-unused-parameters \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ - --ddp-backend no_c10d \ + --ddp-backend legacy_ddp \ --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 ``` @@ -205,7 +205,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --arch transformer_lm_gbw \ --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ --clip-norm 0.1 --criterion adaptive_loss \ - --ddp-backend no_c10d \ + --ddp-backend legacy_ddp \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \ --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 1.0 --t-mult 2.0 \ @@ -252,7 +252,7 @@ fairseq-train --task sentence_prediction /path/to/data/ \ --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ --clip-norm 0.0 --lr-scheduler polynomial_decay \ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ - --no-progress-bar --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ + --no-progress-bar --skip-invalid-size-inputs-valid-test --ddp-backend legacy_ddp \ --quantization-config-path /path/to/config/yaml ``` @@ -266,7 +266,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ --bucket-cap-mb 25 --char-embedder-highway-layers 2 --character-embedding-dim 4 \ --clip-norm 0.1 --criterion adaptive_loss \ - --ddp-backend no_c10d \ + --ddp-backend legacy_ddp \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ --fp16 --keep-last-epochs -1 \ --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 0.05 --stop-min-lr 1e-09 \ diff --git a/examples/roberta/README.race.md b/examples/roberta/README.race.md index 527a0bce14..13c917e8ec 100644 --- a/examples/roberta/README.race.md +++ b/examples/roberta/README.race.md @@ -19,7 +19,7 @@ UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs. DATA_DIR=/path/to/race-output-dir ROBERTA_PATH=/path/to/roberta/model.pt -CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \ +CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=legacy_ddp \ --restore-file $ROBERTA_PATH \ --reset-optimizer --reset-dataloader --reset-meters \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ diff --git a/examples/roberta/commonsense_qa/README.md b/examples/roberta/commonsense_qa/README.md index 4f371f8b30..05c6f841a8 100644 --- a/examples/roberta/commonsense_qa/README.md +++ b/examples/roberta/commonsense_qa/README.md @@ -39,7 +39,7 @@ DATA_DIR=data/CommonsenseQA FAIRSEQ_PATH=/path/to/fairseq FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/commonsense_qa -CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 --ddp-backend=no_c10d \ +CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 --ddp-backend=legacy_ddp \ $DATA_DIR \ --user-dir $FAIRSEQ_USER_DIR \ --restore-file $ROBERTA_PATH \ diff --git a/examples/roberta/wsc/README.md b/examples/roberta/wsc/README.md index d40da6a5fd..21a045d999 100644 --- a/examples/roberta/wsc/README.md +++ b/examples/roberta/wsc/README.md @@ -51,7 +51,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \ --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ --valid-subset val \ - --fp16 --ddp-backend no_c10d \ + --fp16 --ddp-backend legacy_ddp \ --user-dir $FAIRSEQ_USER_DIR \ --task wsc --criterion wsc --wsc-cross-entropy \ --arch roberta_large --bpe gpt2 --max-positions 512 \ @@ -110,7 +110,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \ --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ --valid-subset val \ - --fp16 --ddp-backend no_c10d \ + --fp16 --ddp-backend legacy_ddp \ --user-dir $FAIRSEQ_USER_DIR \ --task winogrande --criterion winogrande \ --wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \ diff --git a/examples/translation/README.md b/examples/translation/README.md index 7b1fcc8de2..2941f5eb84 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -263,7 +263,7 @@ fairseq-preprocess --source-lang fr --target-lang en \ mkdir -p checkpoints/multilingual_transformer CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ --max-epoch 50 \ - --ddp-backend=no_c10d \ + --ddp-backend=legacy_ddp \ --task multilingual_translation --lang-pairs de-en,fr-en \ --arch multilingual_transformer_iwslt_de_en \ --share-decoders --share-decoder-input-output-embed \ diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md index 3cc3fb46dc..2e5c8af617 100644 --- a/examples/translation_moe/README.md +++ b/examples/translation_moe/README.md @@ -15,7 +15,7 @@ The model is trained with online responsibility assignment and shared parameteri The following command will train a `hMoElp` model with `3` experts: ```bash -fairseq-train --ddp-backend='no_c10d' \ +fairseq-train --ddp-backend='legacy_ddp' \ data-bin/wmt17_en_de \ --max-update 100000 \ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \ diff --git a/examples/wav2vec/config/finetuning/base_100h.yaml b/examples/wav2vec/config/finetuning/base_100h.yaml index 7d1664a184..539dabb047 100644 --- a/examples/wav2vec/config/finetuning/base_100h.yaml +++ b/examples/wav2vec/config/finetuning/base_100h.yaml @@ -22,7 +22,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 2 criterion: diff --git a/examples/wav2vec/config/finetuning/base_10h.yaml b/examples/wav2vec/config/finetuning/base_10h.yaml index 31125947c0..16a3c4d96c 100644 --- a/examples/wav2vec/config/finetuning/base_10h.yaml +++ b/examples/wav2vec/config/finetuning/base_10h.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 2 criterion: diff --git a/examples/wav2vec/config/finetuning/base_10m.yaml b/examples/wav2vec/config/finetuning/base_10m.yaml index 2235504489..3ceb77a252 100644 --- a/examples/wav2vec/config/finetuning/base_10m.yaml +++ b/examples/wav2vec/config/finetuning/base_10m.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 2 criterion: diff --git a/examples/wav2vec/config/finetuning/base_1h.yaml b/examples/wav2vec/config/finetuning/base_1h.yaml index 2235504489..3ceb77a252 100644 --- a/examples/wav2vec/config/finetuning/base_1h.yaml +++ b/examples/wav2vec/config/finetuning/base_1h.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 2 criterion: diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml index d742c94abf..e393805ad8 100644 --- a/examples/wav2vec/config/finetuning/base_960h.yaml +++ b/examples/wav2vec/config/finetuning/base_960h.yaml @@ -22,7 +22,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 8 criterion: diff --git a/examples/wav2vec/config/finetuning/vox_100h.yaml b/examples/wav2vec/config/finetuning/vox_100h.yaml index 8885c78470..2fdb0c568c 100644 --- a/examples/wav2vec/config/finetuning/vox_100h.yaml +++ b/examples/wav2vec/config/finetuning/vox_100h.yaml @@ -22,7 +22,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 4 criterion: diff --git a/examples/wav2vec/config/finetuning/vox_10h.yaml b/examples/wav2vec/config/finetuning/vox_10h.yaml index c0957c0058..f1a979e05d 100644 --- a/examples/wav2vec/config/finetuning/vox_10h.yaml +++ b/examples/wav2vec/config/finetuning/vox_10h.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 4 criterion: diff --git a/examples/wav2vec/config/finetuning/vox_10m.yaml b/examples/wav2vec/config/finetuning/vox_10m.yaml index 0d567552d7..d12439bb28 100644 --- a/examples/wav2vec/config/finetuning/vox_10m.yaml +++ b/examples/wav2vec/config/finetuning/vox_10m.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 4 criterion: diff --git a/examples/wav2vec/config/finetuning/vox_1h.yaml b/examples/wav2vec/config/finetuning/vox_1h.yaml index 10c45a52d8..7f3b04c034 100644 --- a/examples/wav2vec/config/finetuning/vox_1h.yaml +++ b/examples/wav2vec/config/finetuning/vox_1h.yaml @@ -27,7 +27,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 4 criterion: diff --git a/examples/wav2vec/config/finetuning/vox_960h.yaml b/examples/wav2vec/config/finetuning/vox_960h.yaml index 6212a2e738..0633915bb2 100644 --- a/examples/wav2vec/config/finetuning/vox_960h.yaml +++ b/examples/wav2vec/config/finetuning/vox_960h.yaml @@ -22,7 +22,7 @@ dataset: valid_subset: dev_other distributed_training: - ddp_backend: no_c10d + ddp_backend: legacy_ddp distributed_world_size: 24 criterion: diff --git a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml index e2c2b7b0b3..767aee2852 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml @@ -23,7 +23,7 @@ dataset: distributed_training: distributed_world_size: 64 - ddp_backend: no_c10d + ddp_backend: legacy_ddp criterion: _name: wav2vec diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml index 0c911b7491..bee41157a9 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml @@ -24,7 +24,7 @@ dataset: distributed_training: distributed_world_size: 128 - ddp_backend: no_c10d + ddp_backend: legacy_ddp criterion: _name: wav2vec diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 15ad9a15bf..6209ceaedb 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -32,11 +32,11 @@ def __init__(self, task, sentence_avg): @classmethod def build_criterion(cls, cfg: AdaptiveLossConfig, task): - if cfg.ddp_backend == "c10d": + if cfg.ddp_backend in {"c10d", "pytorch_ddp"}: raise Exception( - "AdaptiveLoss is not compatible with the c10d " + "AdaptiveLoss is not compatible with the PyTorch " "version of DistributedDataParallel. Please use " - "`--ddp-backend=no_c10d` instead." + "`--ddp-backend=legacy_ddp` instead." ) return cls(task, cfg.sentence_avg) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 2ed27284dc..f66e98fe83 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -12,7 +12,6 @@ from fairseq.dataclass.constants import ( DATASET_IMPL_CHOICES, DDP_BACKEND_CHOICES, - DISTRIBUTED_WRAPPER_CHOICES, GENERATION_CONSTRAINTS_CHOICES, GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, @@ -236,7 +235,7 @@ class DistributedTrainingConfig(FairseqDataclass): }, ) ddp_backend: DDP_BACKEND_CHOICES = field( - default="c10d", metadata={"help": "DistributedDataParallel backend"} + default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} @@ -252,7 +251,7 @@ class DistributedTrainingConfig(FairseqDataclass): default=False, metadata={ "help": "disable unused parameter detection (not applicable to " - "no_c10d ddp-backend" + "--ddp-backend=legacy_ddp)" }, ) fast_stat_sync: bool = field( @@ -273,9 +272,6 @@ class DistributedTrainingConfig(FairseqDataclass): "batchnorm population statistics" }, ) - distributed_wrapper: DISTRIBUTED_WRAPPER_CHOICES = field( - default="DDP", metadata={"help": "DistributedDataParallel backend"} - ) slowmo_momentum: Optional[float] = field( default=None, metadata={ diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 46881786a8..93bc6d03cb 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -35,9 +35,14 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) -DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) +DDP_BACKEND_CHOICES = ChoiceEnum([ + "c10d", # alias for pytorch_ddp + "legacy_ddp", + "no_c10d", # alias for legacy_ddp + "pytorch_ddp", + "slow_mo", +]) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) -DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( ["unigram", "ensemble", "vote", "dp", "bs"] diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 37822362d4..3b5fe6e7a8 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -591,7 +591,7 @@ def all_gather_list(data, group=None, max_size=16384): "sync if one of them runs out of memory, or if there are other conditions " "in your training script that can cause one worker to finish an epoch " "while other workers are still iterating over their portions of the data. " - "Try rerunning with --ddp-backend=no_c10d and see if that helps." + "Try rerunning with --ddp-backend=legacy_ddp and see if that helps." ) diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ffa3c37b19..b8fbc37793 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -49,7 +49,7 @@ def DistributedFairseqModel(args, model, process_group): module=model, process_group=process_group, ) - elif args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d": + elif args.ddp_backend in {"c10d", "pytorch_ddp"}: ddp_class = nn.parallel.DistributedDataParallel init_kwargs = dict( module=model, @@ -62,14 +62,14 @@ def DistributedFairseqModel(args, model, process_group): # Maintain backward compatibility if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: init_kwargs["find_unused_parameters"] = args.find_unused_parameters - elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d": + elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: ddp_class = LegacyDistributedDataParallel init_kwargs = dict( module=model, buffer_size=2 ** 28, process_group=process_group, ) - elif args.distributed_wrapper == "SlowMo": + elif args.ddp_backend == "slow_mo": if _GOSSIP_DISABLED: raise ImportError( "Cannot find gossip library. Please install from: " diff --git a/fairseq/trainer.py b/fairseq/trainer.py index eea194b950..d893518fea 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -646,7 +646,7 @@ def maybe_no_sync(): if not self.tpu: if ( not self.cfg.optimization.use_bmuf - and self.cfg.distributed_training.distributed_wrapper != "SlowMo" + and self.cfg.distributed_training.ddp_backend != "slow_mo" ): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): @@ -686,7 +686,8 @@ def maybe_no_sync(): logger.error("OOM during optimization, irrecoverable") raise e - # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step + # Some distributed wrappers (e.g., SlowMo) need access to the optimizer + # after the step if hasattr(self.model, "perform_additional_optimizer_actions"): if hasattr(self.optimizer, "fp32_params"): self.model.perform_additional_optimizer_actions( @@ -700,7 +701,7 @@ def maybe_no_sync(): logging_output = None if ( not overflow - or self.cfg.distributed_training.distributed_wrapper == "SlowMo" + or self.cfg.distributed_training.ddp_backend == "slow_mo" ): self.set_num_updates(self.get_num_updates() + 1) @@ -1120,7 +1121,7 @@ def is_consistent(tensor): # use FloatingPointError to trigger NanDetector raise FloatingPointError( "Fatal error: gradients are inconsistent between workers. " - "Try --ddp-backend=no_c10d. " + "Try --ddp-backend=legacy_ddp. " "Or are you mixing up different generation of GPUs in training?" + "\n" + "-" * 80 From 922528d58feea5ada68094df117c1cdbe67aec45 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 419/707] Log amount of free GPU memory (#1545) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1545 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836854 Pulled By: myleott fbshipit-source-id: 6bb5cb69a90022aa206618ee7a903a653fb1ed09 --- fairseq/trainer.py | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d893518fea..274d556ea2 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -600,11 +600,7 @@ def maybe_no_sync(): ooms, total_train_time, ) = self._aggregate_logging_outputs( - logging_outputs, - sample_size, - ooms, - train_time, - ignore=is_dummy_batch, + logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch ) self._cumulative_training_time = ( total_train_time / self.data_parallel_world_size @@ -699,10 +695,7 @@ def maybe_no_sync(): ) logging_output = None - if ( - not overflow - or self.cfg.distributed_training.ddp_backend == "slow_mo" - ): + if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo": self.set_num_updates(self.get_num_updates() + 1) if self.tpu: @@ -720,24 +713,14 @@ def maybe_no_sync(): gb_free = mem_info["kb_free"] / 1024 / 1024 gb_total = mem_info["kb_total"] / 1024 / 1024 metrics.log_scalar( - "gb_free", - gb_free, - priority=1500, - round=1, - weight=0, + "gb_free", gb_free, priority=1500, round=1, weight=0 ) metrics.log_scalar( - "gb_total", - gb_total, - priority=1600, - round=1, - weight=0, + "gb_total", gb_total, priority=1600, round=1, weight=0 ) logging_output = self._reduce_and_log_stats( - logging_outputs, - sample_size, - grad_norm, + logging_outputs, sample_size, grad_norm ) # log whenever there's an XLA compilation, since these @@ -745,11 +728,18 @@ def maybe_no_sync(): # optimization self._check_xla_compilation() else: + if self.cuda and self.cuda_env is not None: + # log minimum free memory over the iteration + gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + torch.cuda.reset_peak_memory_stats() + gb_free = self.cuda_env.total_memory_in_GB - gb_used + metrics.log_scalar( + "gb_free", gb_free, priority=1500, round=1, weight=0 + ) + # log stats logging_output = self._reduce_and_log_stats( - logging_outputs, - sample_size, - grad_norm, + logging_outputs, sample_size, grad_norm ) # clear CUDA cache to reduce memory fragmentation From d68a3530dda7f8275e490864b28974ef30fe854b Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 420/707] Refactor distributed code under fairseq.distributed (#1546) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1546 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836853 Pulled By: myleott fbshipit-source-id: c5076615d49774633ecfaf0aa68b68e8b2331bd9 --- fairseq/__init__.py | 1 + fairseq/distributed/__init__.py | 17 +++ .../distributed_timeout_wrapper.py | 94 +++++++++++++ .../legacy_distributed_data_parallel.py | 11 +- fairseq/distributed/module_proxy_wrapper.py | 55 ++++++++ .../tpu_distributed_data_parallel.py | 43 ++++++ fairseq/models/distributed_fairseq_model.py | 131 +++++------------- fairseq/trainer.py | 27 +++- .../test_distributed_timeout_wrapper.py | 54 ++++++++ .../distributed/test_module_proxy_wrapper.py | 75 ++++++++++ 10 files changed, 391 insertions(+), 117 deletions(-) create mode 100644 fairseq/distributed/__init__.py create mode 100644 fairseq/distributed/distributed_timeout_wrapper.py rename fairseq/{ => distributed}/legacy_distributed_data_parallel.py (96%) create mode 100644 fairseq/distributed/module_proxy_wrapper.py create mode 100644 fairseq/distributed/tpu_distributed_data_parallel.py create mode 100644 tests/distributed/test_distributed_timeout_wrapper.py create mode 100644 tests/distributed/test_module_proxy_wrapper.py diff --git a/fairseq/__init__.py b/fairseq/__init__.py index ccd45add79..8e51b61be0 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -28,6 +28,7 @@ hydra_init() import fairseq.criterions # noqa +import fairseq.distributed # noqa import fairseq.models # noqa import fairseq.modules # noqa import fairseq.optim # noqa diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py new file mode 100644 index 0000000000..7f4016e38c --- /dev/null +++ b/fairseq/distributed/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .legacy_distributed_data_parallel import LegacyDistributedDataParallel +from .module_proxy_wrapper import ModuleProxyWrapper +from .tpu_distributed_data_parallel import TPUDistributedDataParallel + + +__all__ = [ + "DistributedTimeoutWrapper", + "LegacyDistributedDataParallel", + "ModuleProxyWrapper", + "TPUDistributedDataParallel", +] diff --git a/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/distributed/distributed_timeout_wrapper.py new file mode 100644 index 0000000000..c8ab477073 --- /dev/null +++ b/fairseq/distributed/distributed_timeout_wrapper.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +from torch import nn + + +logger = logging.getLogger(__name__) + + +class DistributedTimeoutWrapper(nn.Module): + """ + A wrapper that kills the process if no progress is made within a given + *timeout*. The timer is reset every time :func:`forward` is called. + + Usage:: + + module = DistributedTimeoutWrapper(module, timeout=30) + x = module(input) + time.sleep(20) # safe + x = module(input) + time.sleep(45) # job will be killed before this returns + + Args: + module (nn.Module): module to wrap + timeout (int): number of seconds before killing the process + (set to a value <= 0 to disable the timeout) + signal (Optional): signal to send once timeout is triggered + """ + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGKILL): + super().__init__() + self.module = module + self.timeout = timeout + self.signal = signal + + if timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + self._terminated = False + else: + self._heartbeat = None + self._heartbeat_thread = None + + def __del__(self): + self.stop_timeout() + + def __getattr__(self, name): + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def stop_timeout(self): + if self._heartbeat_thread is not None: + self._terminated = True + self._heartbeat_thread.join() + + def state_dict(self, *args, **kwargs): + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return self.module(*args, **kwargs) + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self.timeout) + if self._terminated: + break + elif not success: + logger.error(( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self.timeout))) + os.kill(parent_pid, self.signal) + return diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py similarity index 96% rename from fairseq/legacy_distributed_data_parallel.py rename to fairseq/distributed/legacy_distributed_data_parallel.py index 7e176eaf3d..35de179b2f 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -14,15 +14,13 @@ training with `--update-freq`. """ -import copy from collections import OrderedDict from contextlib import contextmanager import torch from torch import nn -from torch.autograd import Variable -from . import distributed_utils +from fairseq import distributed_utils class LegacyDistributedDataParallel(nn.Module): @@ -64,13 +62,6 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): paramlists[device] += [param] self.per_device_params = list(paramlists.values()) - def __getstate__(self): - attrs = copy.copy(self.__dict__) - return attrs - - def __setstate__(self, state): - super().__setstate__(state) - @contextmanager def no_sync(self): """A context manager to disable gradient synchronization.""" diff --git a/fairseq/distributed/module_proxy_wrapper.py b/fairseq/distributed/module_proxy_wrapper.py new file mode 100644 index 0000000000..fc2c6f8c71 --- /dev/null +++ b/fairseq/distributed/module_proxy_wrapper.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class ModuleProxyWrapper(nn.Module): + """ + Wrap a DistributedDataParallel module and forward requests for missing + attributes to the module wrapped by DDP (the twice-wrapped module). + Also forward calls to :func:`state_dict` and :func:`load_state_dict`. + + Usage:: + + module.xyz = "hello world" + wrapped_module = DistributedDataParallel(module, **ddp_args) + wrapped_module = ModuleProxyWrapper(wrapped_module) + assert wrapped_module.xyz == "hello world" + assert wrapped_module.state_dict().keys() == module.state_dict().keys() + + Args: + module (nn.Module): module to wrap + """ + + def __init__(self, module: nn.Module): + super().__init__() + assert hasattr(module, "module"), \ + "ModuleProxyWrapper expects input to wrap another module" + self.module = module + + def __getattr__(self, name): + """Forward missing attributes to twice-wrapped module.""" + try: + # defer to nn.Module's logic + return super().__getattr__(name) + except AttributeError: + try: + # forward to the once-wrapped module + return getattr(self.module, name) + except AttributeError: + # forward to the twice-wrapped module + return getattr(self.module.module, name) + + def state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/distributed/tpu_distributed_data_parallel.py new file mode 100644 index 0000000000..2adcf1cb58 --- /dev/null +++ b/fairseq/distributed/tpu_distributed_data_parallel.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from fairseq import distributed_utils + + +class TPUDistributedDataParallel(nn.Module): + + def __init__(self, module, process_group): + super().__init__() + self.module = module + self.process_group = process_group + self.world_size = distributed_utils.get_world_size(self.process_group) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + gradients = [] + for p in self.parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + if p.grad.requires_grad: + raise RuntimeError( + "TPUDistributedDataParallel only works with gradients that don't " + "require grad" + ) + gradients.append(p.grad) + + import torch_xla.core.xla_model as xm + xm.all_reduce( + 'sum', + gradients, + scale=1. / self.world_size, + groups=self.process_group[1], + ) diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index b8fbc37793..bee1033110 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import inspect import logging import os import signal @@ -11,9 +10,15 @@ import torch import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel from fairseq import distributed_utils -from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel +from fairseq.distributed import ( + DistributedTimeoutWrapper, + LegacyDistributedDataParallel, + ModuleProxyWrapper, + TPUDistributedDataParallel, +) logger = logging.getLogger(__name__) @@ -26,7 +31,7 @@ _GOSSIP_DISABLED = True -def DistributedFairseqModel(args, model, process_group): +def DistributedFairseqModel(args, model, process_group, device): """ Wrap a *model* to support distributed data parallel training. @@ -40,42 +45,42 @@ def DistributedFairseqModel(args, model, process_group): model (BaseFairseqModel): model to wrap process_group: the c10d process group to be used for distributed data parallel all-reduction. + device: device to move model to """ - # determine which DDP class to extend assert isinstance(model, nn.Module) if args.tpu: - ddp_class = TPUDistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = TPUDistributedDataParallel( + module=model.to(device), process_group=process_group, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend in {"c10d", "pytorch_ddp"}: - ddp_class = nn.parallel.DistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = DistributedDataParallel( + module=model.to(device), device_ids=[args.device_id], output_device=args.device_id, broadcast_buffers=args.broadcast_buffers, bucket_cap_mb=args.bucket_cap_mb, process_group=process_group, + find_unused_parameters=args.find_unused_parameters, ) - # Maintain backward compatibility - if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: - init_kwargs["find_unused_parameters"] = args.find_unused_parameters + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: - ddp_class = LegacyDistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = LegacyDistributedDataParallel( + module=model.to(device), buffer_size=2 ** 28, process_group=process_group, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend == "slow_mo": if _GOSSIP_DISABLED: raise ImportError( "Cannot find gossip library. Please install from: " "github.com/facebookresearch/stochastic_gradient_push" ) - ddp_class = gossip.GossipDataParallel # The values of slowmo_momentum below were obtained by tuning on the # En-De 16 dataset by training the transformer_wmt_en_de_large model @@ -89,8 +94,8 @@ def DistributedFairseqModel(args, model, process_group): else: args.slowmo_momentum = 0.6 - init_kwargs = dict( - module=model, + wrapped_model = gossip.GossipDataParallel( + module=model.to(device), device_ids=[args.device_id], output_device=args.device_id, broadcast_buffers=args.broadcast_buffers, @@ -99,88 +104,14 @@ def DistributedFairseqModel(args, model, process_group): localsgd=(args.slowmo_algorithm == "LocalSGD"), localsgd_frequency=args.localsgd_frequency, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) - heartbeat_timeout = getattr(args, "heartbeat_timeout", -1) + # kill hung distributed jobs after a timeout + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) - class _DistributedFairseqModel(ddp_class): - """ - Extend DistributedDataParallel to check for missing attributes in the - wrapped module and to add a timeout to kill the job if no progress is - made (--heartbeat-timeout). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._heartbeat_timeout = heartbeat_timeout - if self._heartbeat_timeout > 0: - self._heartbeat = threading.Event() - self._heartbeat_thread = threading.Thread( - target=self._check_heartbeat, - args=(os.getpid(),), - daemon=True, - ) - self._heartbeat_thread.start() - else: - self._heartbeat = None - - def _check_heartbeat(self, parent_pid): - self._heartbeat.wait() # wait for the first forward pass - while True: - self._heartbeat.clear() - success = self._heartbeat.wait(timeout=self._heartbeat_timeout) - if not success: - logger.error(( - "Killing job for not making progress in {} seconds. " - "Set --heartbeat-timeout=-1 to disable this timeout." - ).format(int(self._heartbeat_timeout))) - os.kill(parent_pid, signal.SIGKILL) - return - - def __getattr__(self, name): - wrapped_module = super().__getattr__("module") - if hasattr(wrapped_module, name): - return getattr(wrapped_module, name) - return super().__getattr__(name) - - def forward(self, *args, **kwargs): - if self._heartbeat is not None: - self._heartbeat.set() - return super().forward(*args, **kwargs) - - return _DistributedFairseqModel(**init_kwargs) - - -class TPUDistributedDataParallel(nn.Module): - - def __init__(self, module, process_group): - super().__init__() - self.module = module - self.process_group = process_group - self.world_size = distributed_utils.get_world_size(self.process_group) - - def forward(self, *inputs, **kwargs): - return self.module(*inputs, **kwargs) - - def all_reduce_grads(self): - gradients = [] - for p in self.parameters(): - if not p.requires_grad: - continue - if p.grad is None: - p.grad = torch.zeros_like(p) - if p.grad.requires_grad: - raise RuntimeError( - "TPUDistributedDataParallel only works with gradients that don't " - "require grad" - ) - gradients.append(p.grad) - - import torch_xla.core.xla_model as xm - xm.all_reduce( - 'sum', - gradients, - scale=1. / self.world_size, - groups=self.process_group[1], - ) + return wrapped_model diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 274d556ea2..b441eaa4d2 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -69,7 +69,12 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - if not cfg.distributed_training.pipeline_model_parallel: + if ( + not cfg.distributed_training.pipeline_model_parallel + # the DistributedFairseqModel wrapper will handle moving to device, + # so only handle cases which don't use the wrapper + and not self.use_distributed_wrapper + ): self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel @@ -158,18 +163,25 @@ def is_data_parallel_master(self): # parallel rank 0 return self.data_parallel_rank == 0 + @property + def use_distributed_wrapper(self) -> bool: + return ( + self.data_parallel_world_size > 1 + and not self.cfg.optimization.use_bmuf + ) + @property def criterion(self): if self._wrapped_criterion is None: if ( utils.has_parameters(self._criterion) - and self.data_parallel_world_size > 1 - and not self.cfg.optimization.use_bmuf + and self.use_distributed_wrapper ): self._wrapped_criterion = models.DistributedFairseqModel( self.cfg.distributed_training, self._criterion, process_group=self.data_parallel_process_group, + device=self.device, ) else: self._wrapped_criterion = self._criterion @@ -178,11 +190,12 @@ def criterion(self): @property def model(self): if self._wrapped_model is None: - if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf: + if self.use_distributed_wrapper: self._wrapped_model = models.DistributedFairseqModel( self.cfg.distributed_training, self._model, process_group=self.data_parallel_process_group, + device=self.device, ) else: self._wrapped_model = self._model @@ -268,8 +281,8 @@ def save_checkpoint(self, filename, extra_state): checkpoint_utils.save_state( filename, self.cfg, - self.get_model().state_dict(), - self.get_criterion(), + self.model.state_dict(), + self.criterion, self.optimizer, self.lr_scheduler, self.get_num_updates(), @@ -336,7 +349,7 @@ def load_checkpoint( # load model parameters try: - self.get_model().load_state_dict( + self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) if utils.has_parameters(self.get_criterion()): diff --git a/tests/distributed/test_distributed_timeout_wrapper.py b/tests/distributed/test_distributed_timeout_wrapper.py new file mode 100644 index 0000000000..27908b9d3f --- /dev/null +++ b/tests/distributed/test_distributed_timeout_wrapper.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import signal +import time +import unittest + +import torch +from torch import nn + +from fairseq.distributed import DistributedTimeoutWrapper + + +class ModuleWithDelay(nn.Module): + + def __init__(self, delay): + super().__init__() + self.delay = delay + + def forward(self, x): + time.sleep(self.delay) + return x + + +class TestDistributedTimeoutWrapper(unittest.TestCase): + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_no_timeout(self): + module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + def test_timeout_safe(self): + module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + def test_timeout_killed(self): + with self.assertRaises(KeyboardInterrupt): + module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/test_module_proxy_wrapper.py b/tests/distributed/test_module_proxy_wrapper.py new file mode 100644 index 0000000000..2803a044cd --- /dev/null +++ b/tests/distributed/test_module_proxy_wrapper.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch import nn + +from fairseq.distributed import ModuleProxyWrapper + +from .utils import objects_are_equal + + +class MockDDPWrapper(nn.Module): + """A simple wrapper with an interface similar to DistributedDataParallel.""" + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + return self.module(x) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + self.xyz = "hello" + + def forward(self, x): + return self.linear(x) + + def get_xyz(self): + return self.xyz + + +class TestModuleProxyWrapper(unittest.TestCase): + + def _get_module(self): + module = Model() + wrapped_module = MockDDPWrapper(module) + wrapped_module = ModuleProxyWrapper(wrapped_module) + return wrapped_module, module + + def test_getattr_forwarding(self): + wrapped_module, module = self._get_module() + assert module.xyz == "hello" + assert module.get_xyz() == "hello" + assert wrapped_module.xyz == "hello" + + wrapped_module.xyz = "world" + assert wrapped_module.xyz == "world" + assert module.get_xyz() == "hello" + + def test_state_dict(self): + wrapped_module, module = self._get_module() + assert objects_are_equal(wrapped_module.state_dict(), module.state_dict()) + + def test_load_state_dict(self): + wrapped_module, module = self._get_module() + wrapped_module.load_state_dict(module.state_dict()) + input = torch.rand(4, 5) + torch.testing.assert_allclose(wrapped_module(input), module(input)) + + def test_forward(self): + wrapped_module, module = self._get_module() + input = torch.rand(4, 5) + torch.testing.assert_allclose(wrapped_module(input), module(input)) + + +if __name__ == "__main__": + unittest.main() From 27b96eb698610d6ca8835b7bdf47528230ebfd00 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 421/707] Move fairseq.distributed_utils -> fairseq.distributed.utils (#1547) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1547 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836855 Pulled By: myleott fbshipit-source-id: addd8a7fe8dac43252b100d7331e04e95f555781 --- .../truncated_bptt/truncated_bptt_lm_task.py | 3 +- fairseq/__init__.py | 4 +- .../multilingual/sampled_multi_dataset.py | 2 +- .../legacy_distributed_data_parallel.py | 6 +- .../tpu_distributed_data_parallel.py | 4 +- .../utils.py} | 321 ++++++++++-------- fairseq/model_parallel/megatron_trainer.py | 4 +- fairseq/models/distributed_fairseq_model.py | 1 - fairseq/trainer.py | 3 +- fairseq/utils.py | 11 +- tests/{ => distributed}/test_bmuf.py | 4 +- ...est_distributed_utils.py => test_utils.py} | 2 +- 12 files changed, 199 insertions(+), 166 deletions(-) rename fairseq/{distributed_utils.py => distributed/utils.py} (73%) rename tests/{ => distributed}/test_bmuf.py (98%) rename tests/distributed/{test_distributed_utils.py => test_utils.py} (97%) diff --git a/examples/truncated_bptt/truncated_bptt_lm_task.py b/examples/truncated_bptt/truncated_bptt_lm_task.py index 34c4f03955..02be0e7fb4 100644 --- a/examples/truncated_bptt/truncated_bptt_lm_task.py +++ b/examples/truncated_bptt/truncated_bptt_lm_task.py @@ -9,7 +9,7 @@ from typing import List, Optional, Tuple import torch -from fairseq import distributed_utils as dist_utils, utils +from fairseq import utils from fairseq.data import ( Dictionary, TokenBlockDataset, @@ -17,6 +17,7 @@ iterators, ) from fairseq.dataclass import FairseqDataclass +from fairseq.distributed import utils as dist_utils from fairseq.tasks import FairseqTask, register_task from omegaconf import II diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 8e51b61be0..dc9fd1886d 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -16,9 +16,11 @@ __all__ = ["pdb"] -# backwards compatibility to support `from fairseq.meters import AverageMeter` +# backwards compatibility to support `from fairseq.X import Y` +from fairseq.distributed import utils as distributed_utils from fairseq.logging import meters, metrics, progress_bar # noqa +sys.modules["fairseq.distributed_utils"] = distributed_utils sys.modules["fairseq.meters"] = meters sys.modules["fairseq.metrics"] = metrics sys.modules["fairseq.progress_bar"] = progress_bar diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index f74ec18141..b0a617424e 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -14,8 +14,8 @@ import numpy as np import torch -from fairseq import distributed_utils from fairseq.data import FairseqDataset, data_utils +from fairseq.distributed import utils as distributed_utils def get_time_gap(s, e): diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 35de179b2f..b586e76b7f 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -20,7 +20,7 @@ import torch from torch import nn -from fairseq import distributed_utils +from fairseq.distributed import utils class LegacyDistributedDataParallel(nn.Module): @@ -43,7 +43,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.module = module self.process_group = process_group - self.world_size = distributed_utils.get_world_size(self.process_group) + self.world_size = utils.get_world_size(self.process_group) # Never use a bigger buffer than the number of model params self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) @@ -107,7 +107,7 @@ def all_reduce_params(params): if nonzero_buffer: buffer.div_(self.world_size) - distributed_utils.all_reduce(buffer, self.process_group) + utils.all_reduce(buffer, self.process_group) # copy all-reduced grads back into their original place offset = 0 diff --git a/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/distributed/tpu_distributed_data_parallel.py index 2adcf1cb58..e971cf07c5 100644 --- a/fairseq/distributed/tpu_distributed_data_parallel.py +++ b/fairseq/distributed/tpu_distributed_data_parallel.py @@ -6,7 +6,7 @@ import torch from torch import nn -from fairseq import distributed_utils +from fairseq.distributed import utils class TPUDistributedDataParallel(nn.Module): @@ -15,7 +15,7 @@ def __init__(self, module, process_group): super().__init__() self.module = module self.process_group = process_group - self.world_size = distributed_utils.get_world_size(self.process_group) + self.world_size = utils.get_world_size(self.process_group) def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed/utils.py similarity index 73% rename from fairseq/distributed_utils.py rename to fairseq/distributed/utils.py index 3b5fe6e7a8..e3c17859f4 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed/utils.py @@ -19,7 +19,6 @@ import torch import torch.distributed as dist -from fairseq import utils from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig from omegaconf import open_dict @@ -49,169 +48,193 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): return if cfg.pipeline_model_parallel: - balance_exists = ( - cfg.pipeline_balance is not None - or cfg.pipeline_encoder_balance is not None - or cfg.pipeline_decoder_balance is not None - ) - devices_exist = ( - cfg.pipeline_devices is not None - or cfg.pipeline_encoder_devices is not None - or cfg.pipeline_decoder_devices is not None - ) - if not balance_exists: - raise ValueError( - "--pipeline-balance is currently required for pipeline model parallelism" - ) - if not devices_exist: - raise ValueError( - "--pipeline-devices is currently required for pipeline model parallelism" - ) + num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) - cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) - if cfg.pipeline_devices is not None: - cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) - num_pipeline_devices = len(set(cfg.pipeline_devices)) - else: - cfg.pipeline_encoder_devices = utils.eval_str_list( - cfg.pipeline_encoder_devices, type=int - ) - cfg.pipeline_decoder_devices = utils.eval_str_list( - cfg.pipeline_decoder_devices, type=int - ) - num_pipeline_devices = len( - set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) - ) - gpus_per_node = torch.cuda.device_count() - assert ( - gpus_per_node >= num_pipeline_devices - and gpus_per_node % num_pipeline_devices == 0 - ), ( - "the number of unique device IDs in --pipeline-devices must evenly divide " - "the number of GPUs per node (multi-node pipelining is not yet supported)" - ) - num_pipelines_per_node = gpus_per_node // num_pipeline_devices - - # support torch.distributed.launch if all( key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): - cfg.distributed_init_method = "env://" - cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) - cfg.distributed_rank = int(os.environ["RANK"]) - # processes are created by torch.distributed.launch - cfg.distributed_no_spawn = True - - # we can determine the init method automatically for Slurm + # support torch.distributed.launch + _infer_torch_distributed_launch_init(cfg) elif cfg.distributed_port > 0: - node_list = os.environ.get("SLURM_STEP_NODELIST") - if node_list is None: - node_list = os.environ.get("SLURM_JOB_NODELIST") - if node_list is not None: - try: - hostnames = subprocess.check_output( - ["scontrol", "show", "hostnames", node_list] - ) - cfg.distributed_init_method = "tcp://{host}:{port}".format( - host=hostnames.split()[0].decode("utf-8"), - port=cfg.distributed_port, - ) - nnodes = int(os.environ.get("SLURM_NNODES")) - ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") - if ntasks_per_node is not None: - ntasks_per_node = int(ntasks_per_node) - else: - ntasks = int(os.environ.get("SLURM_NTASKS")) - nnodes = int(os.environ.get("SLURM_NNODES")) - assert ntasks % nnodes == 0 - ntasks_per_node = int(ntasks / nnodes) - if ntasks_per_node == 1: - gpus_per_node = torch.cuda.device_count() - node_id = int(os.environ.get("SLURM_NODEID")) - cfg.distributed_rank = node_id * gpus_per_node - cfg.distributed_world_size = nnodes * gpus_per_node - elif cfg.pipeline_model_parallel: - assert ntasks_per_node == num_pipelines_per_node, ( - "SLURM --ntasks-per-node must match number of pipelines per " - "node (={})".format(num_pipelines_per_node) - ) - cfg.distributed_no_spawn = True - # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on - # the first node, [1, 2] on the second node, etc. This - # matches torch.distributed.launch. - node_id = int(os.environ.get("SLURM_NODEID")) - local_id = int(os.environ.get("SLURM_LOCALID")) - cfg.distributed_rank = node_id * num_pipelines_per_node + local_id - # In the above example, device_id will always be in [0, 1], - # which also matches torch.distributed.launch. - cfg.device_id = local_id - # We also want to set distributed_world_size to be the total - # number of pipelines across all nodes. - cfg.distributed_world_size = nnodes * num_pipelines_per_node - else: - assert ntasks_per_node == cfg.distributed_world_size // nnodes - cfg.distributed_no_spawn = True - cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) - cfg.device_id = int(os.environ.get("SLURM_LOCALID")) - except subprocess.CalledProcessError as e: # scontrol failed - raise e - except FileNotFoundError: # Slurm is not installed - pass - + # we can determine the init method automatically for Slurm + _infer_slurm_init(cfg, num_pipelines_per_node) elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert ( - cfg.distributed_world_size <= torch.cuda.device_count() - ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" - port = random.randint(10000, 20000) - cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) + _infer_single_node_init(cfg) if cfg.pipeline_model_parallel: - if not cfg.distributed_no_spawn: - # When distributed_no_spawn is False, we expect distributed_rank and - # distributed_world_size to be based on the total number of GPUs, so - # we need to correct them to be based on the number of pipelines. - assert cfg.distributed_world_size % num_pipeline_devices == 0 - cfg.distributed_world_size = ( - cfg.distributed_world_size // num_pipeline_devices + _pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size ) - # In the case of 4-way MP on nodes with 8 GPUs, we want - # distributed_rank to be the starting GPU index for each pipeline - # i.e., 0, 2, ... - assert cfg.distributed_rank % gpus_per_node == 0 - assert cfg.distributed_rank % num_pipeline_devices == 0 - - with open_dict(cfg): - cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices - # launch one process per pipeline - cfg.distributed_num_procs = num_pipelines_per_node - - # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 - # and 4, indicating the starting device IDs for each pipeline - cfg.device_id *= num_pipeline_devices - - if cfg.device_id > 0: - # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 - # GPU node), we need to adjust pipeline_devices accordingly - logger.debug( - "setting CUDA device={} on rank {}".format( - cfg.device_id, cfg.distributed_rank - ) + + +def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) + # processes are created by torch.distributed.launch + cfg.distributed_no_spawn = True + + +def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): + node_list = os.environ.get("SLURM_STEP_NODELIST") + if node_list is None: + node_list = os.environ.get("SLURM_JOB_NODELIST") + if node_list is not None: + try: + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", node_list] ) - torch.cuda.set_device(cfg.device_id) - with open_dict(cfg): - cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] - logger.info( - "setting pipeline_devices={} on rank {}".format( - cfg.pipeline_devices, cfg.distributed_rank + cfg.distributed_init_method = "tcp://{host}:{port}".format( + host=hostnames.split()[0].decode("utf-8"), + port=cfg.distributed_port, + ) + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None: + ntasks_per_node = int(ntasks_per_node) + else: + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) + assert ntasks % nnodes == 0 + ntasks_per_node = int(ntasks / nnodes) + if ntasks_per_node == 1: + gpus_per_node = torch.cuda.device_count() + node_id = int(os.environ.get("SLURM_NODEID")) + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) ) + cfg.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, device_id will always be in [0, 1], + # which also matches torch.distributed.launch. + cfg.device_id = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + cfg.distributed_world_size = nnodes * num_pipelines_per_node + else: + assert ntasks_per_node == cfg.distributed_world_size // nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) + except subprocess.CalledProcessError as e: # scontrol failed + raise e + except FileNotFoundError: # Slurm is not installed + pass + + +def _infer_single_node_init(cfg: DistributedTrainingConfig): + assert ( + cfg.distributed_world_size <= torch.cuda.device_count() + ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" + port = random.randint(10000, 20000) + cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) + + +def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig): + from fairseq import utils + + balance_exists = ( + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None + ) + devices_exist = ( + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None + ) + if not balance_exists: + raise ValueError( + "--pipeline-balance is currently required for pipeline model parallelism" + ) + if not devices_exist: + raise ValueError( + "--pipeline-devices is currently required for pipeline model parallelism" + ) + + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) + else: + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int + ) + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int + ) + num_pipeline_devices = len( + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) + ) + gpus_per_node = torch.cuda.device_count() + assert ( + gpus_per_node >= num_pipeline_devices + and gpus_per_node % num_pipeline_devices == 0 + ), ( + "the number of unique device IDs in --pipeline-devices must evenly divide " + "the number of GPUs per node (multi-node pipelining is not yet supported)" + ) + num_pipelines_per_node = gpus_per_node // num_pipeline_devices + return num_pipeline_devices, num_pipelines_per_node + + +def _pipeline_parallel_post_init( + cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node +): + if not cfg.distributed_no_spawn: + # When distributed_no_spawn is False, we expect distributed_rank and + # distributed_world_size to be based on the total number of GPUs, so + # we need to correct them to be based on the number of pipelines. + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = ( + cfg.distributed_world_size // num_pipeline_devices + ) + # In the case of 4-way MP on nodes with 8 GPUs, we want + # distributed_rank to be the starting GPU index for each pipeline + # i.e., 0, 2, ... + gpus_per_node = torch.cuda.device_count() + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node + + # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 + # and 4, indicating the starting device IDs for each pipeline + cfg.device_id *= num_pipeline_devices + + if cfg.device_id > 0: + # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 + # GPU node), we need to adjust pipeline_devices accordingly + logger.debug( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank ) - elif not cfg.distributed_no_spawn: + ) + torch.cuda.set_device(cfg.device_id) with open_dict(cfg): - cfg.distributed_num_procs = min( - torch.cuda.device_count(), cfg.distributed_world_size + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] + logger.info( + "setting pipeline_devices={} on rank {}".format( + cfg.pipeline_devices, cfg.distributed_rank ) + ) def distributed_init(cfg: FairseqConfig): @@ -537,6 +560,8 @@ def all_gather_list(data, group=None, max_size=16384): max_size (int, optional): maximum size of the data to be gathered across workers """ + from fairseq import utils + if group is None: group = get_global_group() rank = get_rank(group=group) diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 1a6e844aee..8ab4657f73 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -7,9 +7,9 @@ Train a network across multiple GPUs. """ -from fairseq import distributed_utils -from fairseq.trainer import Trainer from fairseq.dataclass.configs import FairseqConfig +from fairseq.distributed import utils as distributed_utils +from fairseq.trainer import Trainer try: from fairseq.model_parallel.megatron.mpu import ( diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index bee1033110..ca157f06e9 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -12,7 +12,6 @@ import torch.nn as nn from torch.nn.parallel import DistributedDataParallel -from fairseq import distributed_utils from fairseq.distributed import ( DistributedTimeoutWrapper, LegacyDistributedDataParallel, diff --git a/fairseq/trainer.py b/fairseq/trainer.py index b441eaa4d2..49129a7fb0 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -16,9 +16,10 @@ from typing import Any, Dict, List import torch -from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq import checkpoint_utils, models, optim, utils from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector diff --git a/fairseq/utils.py b/fairseq/utils.py index a20c83384c..d4bf73648b 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -17,10 +17,6 @@ import torch import torch.nn.functional as F -from fairseq.data import iterators -from fairseq.file_io import PathManager -from fairseq.logging.meters import safe_round -from fairseq.modules import gelu, gelu_accurate from fairseq.modules.multihead_attention import MultiheadAttention from torch import Tensor @@ -51,6 +47,8 @@ def __init__(self, option_strings, dest, nargs=None, **kwargs): super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) def __call__(self, parser, namespace, values, option_string=None): + from fairseq.file_io import PathManager + if PathManager.isfile(values): with PathManager.open(values) as f: argument = f.read().strip() @@ -482,6 +480,8 @@ def log_softmax(x, dim: int, onnx_trace: bool = False): def get_perplexity(loss, round=2, base=2): + from fairseq.logging.meters import safe_round + if loss is None: return 0.0 try: @@ -497,6 +497,8 @@ def deprecation_warning(message, stacklevel=3): def get_activation_fn(activation: str) -> Callable: """ Returns the activation function corresponding to `activation` """ + from fairseq.modules import gelu, gelu_accurate + if activation == "relu": return F.relu elif activation == "gelu": @@ -665,6 +667,7 @@ def get_tpu_device(): def tpu_data_loader(itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl + from fairseq.data import iterators xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() diff --git a/tests/test_bmuf.py b/tests/distributed/test_bmuf.py similarity index 98% rename from tests/test_bmuf.py rename to tests/distributed/test_bmuf.py index 785da37bc2..8b7cadb094 100644 --- a/tests/test_bmuf.py +++ b/tests/distributed/test_bmuf.py @@ -11,9 +11,11 @@ import torch import torch.nn as nn -from fairseq import distributed_utils, optim +from fairseq import optim +from fairseq.distributed import utils as distributed_utils from omegaconf import OmegaConf + class Model(nn.Module): def __init__(self, input_size, output_size): super(Model, self).__init__() diff --git a/tests/distributed/test_distributed_utils.py b/tests/distributed/test_utils.py similarity index 97% rename from tests/distributed/test_distributed_utils.py rename to tests/distributed/test_utils.py index 161ee85eaa..0a5d665068 100644 --- a/tests/distributed/test_distributed_utils.py +++ b/tests/distributed/test_utils.py @@ -9,7 +9,7 @@ import torch -from fairseq import distributed_utils as dist_utils +from fairseq.distributed import utils as dist_utils from .utils import objects_are_equal, spawn_and_init From 148327d8c1e3a5f9d17a11bbb1973a7cf3f955d3 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 28 Jan 2021 14:18:48 -0800 Subject: [PATCH 422/707] Add tests for fairseq.distributed.utils.all_gather_list (#1548) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1548 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836857 Pulled By: myleott fbshipit-source-id: 3fb844fa21640cbda989dafa6592ef3e5c59bfa7 --- fairseq/distributed/utils.py | 3 +- tests/distributed/test_utils.py | 71 +++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index e3c17859f4..c39dc6d912 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -552,7 +552,8 @@ def all_gather_list(data, group=None, max_size=16384): """Gathers arbitrary data from all nodes into a list. Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python - data. Note that *data* must be picklable. + data. Note that *data* must be picklable and any CUDA tensors will be moved + to CPU and returned on CPU as well. Args: data (Any): data from the local worker to be gathered on other workers diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 0a5d665068..30f995b67a 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -14,7 +14,7 @@ from .utils import objects_are_equal, spawn_and_init -class TestDistributedUtils(unittest.TestCase): +class DistributedTest(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") @@ -23,28 +23,29 @@ def setUp(self): if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") - def test_broadcast_object_python(self): + +class TestBroadcastObject(DistributedTest): + def test_str(self): spawn_and_init( functools.partial( - TestDistributedUtils._test_broadcast_object, - "hello world", + TestBroadcastObject._test_broadcast_object, "hello world" ), world_size=2, ) - def test_broadcast_object_tensor(self): + def test_tensor(self): spawn_and_init( functools.partial( - TestDistributedUtils._test_broadcast_object, + TestBroadcastObject._test_broadcast_object, torch.rand(5), ), world_size=2, ) - def test_broadcast_object_complex(self): + def test_complex(self): spawn_and_init( functools.partial( - TestDistributedUtils._test_broadcast_object, + TestBroadcastObject._test_broadcast_object, { "a": "1", "b": [2, torch.rand(2, 3), 3], @@ -65,5 +66,59 @@ def _test_broadcast_object(ref_obj, rank, group): assert objects_are_equal(ref_obj, obj) +class TestAllGatherList(DistributedTest): + def test_str_equality(self): + spawn_and_init( + functools.partial( + TestAllGatherList._test_all_gather_list_equality, + "hello world", + ), + world_size=2, + ) + + def test_tensor_equality(self): + spawn_and_init( + functools.partial( + TestAllGatherList._test_all_gather_list_equality, + torch.rand(5), + ), + world_size=2, + ) + + def test_complex_equality(self): + spawn_and_init( + functools.partial( + TestAllGatherList._test_all_gather_list_equality, + { + "a": "1", + "b": [2, torch.rand(2, 3), 3], + "c": (torch.rand(2, 3), 4), + "d": {5, torch.rand(5)}, + "e": torch.rand(5), + "f": torch.rand(5).int(), + }, + ), + world_size=2, + ) + + @staticmethod + def _test_all_gather_list_equality(ref_obj, rank, group): + objs = dist_utils.all_gather_list(ref_obj, group) + for obj in objs: + assert objects_are_equal(ref_obj, obj) + + def test_rank_tensor(self): + spawn_and_init( + TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2 + ) + + @staticmethod + def _test_all_gather_list_rank_tensor(rank, group): + obj = torch.tensor([rank]) + objs = dist_utils.all_gather_list(obj, group) + for i, obj in enumerate(objs): + assert obj.item() == i + + if __name__ == "__main__": unittest.main() From da83e2f3568fa6c93edb528859eef7135be75c2a Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek <gwenzek@users.noreply.github.com> Date: Tue, 2 Feb 2021 09:23:27 -0800 Subject: [PATCH 423/707] add fast filter_indices_by_size for RoundRobinZipDatasets (#1555) Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? this has been extracted from https://github.com/fairinternal/fairseq-py/issues/1538 - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Implements a fast RoundRobinZipDataset.filter_indices_by_size. Instead of filtering the dataset sample by sample, the different datasets that are part of the RoundRobinZipDataset, are now filtered before being zipped together. This might generate slightly different datasets. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1555 Reviewed By: myleott Differential Revision: D25924464 Pulled By: gwenzek fbshipit-source-id: bc64d9dc35eee62da7e3e17fd75a7f9facb60452 --- fairseq/data/data_utils.py | 6 --- fairseq/data/round_robin_zip_datasets.py | 67 +++++++++++++++++++----- tests/test_dataset.py | 59 +++++++++++++++++++++ 3 files changed, 114 insertions(+), 18 deletions(-) create mode 100644 tests/test_dataset.py diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index d98c58a2f4..1a83063542 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -164,12 +164,6 @@ def check_size(idx): for key in intersect_keys ) else: - # Hacky as heck, for the specific case of multilingual training with RoundRobin. - if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple): - return all( - a is None or b is None or compare_leq(a, b) - for a, b in zip(size_fn(idx).values(), max_positions) - ) # For MultiCorpusSampledDataset, will generalize it later if not isinstance(size_fn(idx), Iterable): return all(size_fn(idx) <= b for b in max_positions) diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py index 690823fc86..d710335b81 100644 --- a/fairseq/data/round_robin_zip_datasets.py +++ b/fairseq/data/round_robin_zip_datasets.py @@ -3,11 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging from collections import OrderedDict +from typing import Dict, Sequence import numpy as np -from . import FairseqDataset +from . import FairseqDataset, LanguagePairDataset + +logger = logging.getLogger(__name__) class RoundRobinZipDatasets(FairseqDataset): @@ -25,25 +29,26 @@ class RoundRobinZipDatasets(FairseqDataset): def __init__(self, datasets, eval_key=None): super().__init__() + if isinstance(datasets, dict): + datasets = OrderedDict(datasets) assert isinstance(datasets, OrderedDict) + assert datasets, "Can't make a RoundRobinZipDatasets out of nothing" + for dataset in datasets.values(): + assert isinstance(dataset, FairseqDataset) + self.datasets = datasets self.eval_key = eval_key - self.longest_dataset = None - self.longest_dataset_key = None - for key, dataset in datasets.items(): - assert isinstance(dataset, FairseqDataset) - if self.longest_dataset is None or len(dataset) > len(self.longest_dataset): - self.longest_dataset = dataset - self.longest_dataset_key = key - - self._ordered_indices = None + self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) + self.longest_dataset = datasets[self.longest_dataset_key] + self._ordered_indices: Dict[str, Sequence[int]] = None def _map_index(self, key, index): assert ( self._ordered_indices is not None ), "Must call RoundRobinZipDatasets.ordered_indices() first" - return self._ordered_indices[key][index % len(self.datasets[key])] + o = self._ordered_indices[key] + return o[index % len(o)] def __getitem__(self, index): if self.eval_key is None: @@ -58,6 +63,8 @@ def __getitem__(self, index): return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] def __len__(self): + if self._ordered_indices is not None: + return len(self._ordered_indices[self.longest_dataset_key]) return len(self.longest_dataset) def collater(self, samples): @@ -96,7 +103,7 @@ def ordered_indices(self): if self._ordered_indices is None: # Call the underlying dataset's ordered_indices() here, so that we # get the same random ordering as we would have from using the - # underlying dataset directly. + # underlying sub-datasets directly. self._ordered_indices = OrderedDict( [ (key, dataset.ordered_indices()) @@ -105,6 +112,42 @@ def ordered_indices(self): ) return np.arange(len(self)) + def filter_indices_by_size(self, indices, max_positions=None): + """ + Filter each sub-dataset independently, then update the round robin to work + on the filtered sub-datasets. + """ + + def _deep_until_language_pair(dataset): + if isinstance(dataset, LanguagePairDataset): + return dataset + if hasattr(dataset, "tgt_dataset"): + return _deep_until_language_pair(dataset.tgt_dataset) + if hasattr(dataset, "dataset"): + return _deep_until_language_pair(dataset.dataset) + raise Exception(f"Don't know how to unwrap this dataset: {dataset}") + + if not isinstance(max_positions, dict): + max_positions = {k: max_positions for k in self.datasets.keys()} + ignored_some = False + for key, dataset in self.datasets.items(): + dataset = _deep_until_language_pair(dataset) + self._ordered_indices[key], ignored = dataset.filter_indices_by_size( + self._ordered_indices[key], max_positions[key] + ) + if len(ignored) > 0: + ignored_some = True + logger.warning( + f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " + f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" + ) + # Since we are modifiying in place the _ordered_indices, + # it's not possible anymore to return valid ignored indices. + # Hopefully the extra debug information print above should be enough to debug. + # Ideally we would receive ignore_invalid_inputs so that we could have + # a proper error message. + return (np.arange(len(self)), [0] if ignored_some else []) + @property def supports_prefetch(self): return all( diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000000..9fb69a5f77 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Sequence + +from fairseq.data import LanguagePairDataset, ListDataset, RoundRobinZipDatasets +from tests.test_train import mock_dict + + +def lang_pair_dataset(lengths: Sequence[int]) -> LanguagePairDataset: + tokens = [[i] * l for i, l in enumerate(lengths)] + return LanguagePairDataset(ListDataset(tokens), lengths, mock_dict()) + + +def sample(id: int, length: int): + return {"id": id, "source": [id] * length, "target": None} + + +class TestDataset(unittest.TestCase): + def test_round_robin_zip_datasets(self): + long_dataset = lang_pair_dataset([10, 9, 8, 11]) + short_dataset = lang_pair_dataset([11, 9]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + dataset.ordered_indices() + assert dataset.longest_dataset is long_dataset + self.assertEqual(dict(dataset[0]), {"a": sample(2, 8), "b": sample(1, 9)}) + # The item 2 of dataset 'a' is with item (2 % 2 = 0) of dataset 'b' + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 9)}) + + def test_round_robin_zip_datasets_filtered(self): + long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12]) + short_dataset = lang_pair_dataset([11, 20, 9, 1000]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + idx = dataset.ordered_indices() + idx, _ = dataset.filter_indices_by_size(idx, {"a": 19, "b": 900}) + self.assertEqual(list(idx), [0, 1, 2, 3, 4]) + self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 20)}) + self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(0, 11)}) + + def test_round_robin_zip_datasets_filtered_with_tuple(self): + long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12]) + short_dataset = lang_pair_dataset([11, 20, 9, 1000]) + + dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset}) + # Dataset is now sorted by sentence length + idx = dataset.ordered_indices() + idx, _ = dataset.filter_indices_by_size(idx, 19) + self.assertEqual(list(idx), [0, 1, 2, 3, 4]) + self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(2, 9)}) + self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(2, 9)}) From 4b152cbdc029b8a4b1aa8c4189afc16bd289a0ba Mon Sep 17 00:00:00 2001 From: Benjamin Bolte <ben@bolte.cc> Date: Tue, 2 Feb 2021 14:22:48 -0800 Subject: [PATCH 424/707] Speech recognition sharded infer script (#1587) Summary: Wrote a sharded version of `examples/speech_recognition/infer.py` (in a new `examples/speech_recognition/hydra/` folder) which uses the Hydra entry point for launching Slurm jobs. Tested by decoding a fine-tuned HUBERT model and got a reasonable WER. Also tested using Ax sweeper to sweep WER and it seems to work fine. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1587 Reviewed By: wnhsu Differential Revision: D26208091 Pulled By: codekansas fbshipit-source-id: 82671514a85394036b670fe9f7b299236d6c767a --- examples/speech_recognition/hydra/README.md | 45 ++ .../hydra/conf/hydra/sweeper/ax.yaml | 26 + .../speech_recognition/hydra/conf/infer.yaml | 22 + examples/speech_recognition/hydra/decoder.py | 599 ++++++++++++++++++ examples/speech_recognition/hydra/infer.py | 445 +++++++++++++ 5 files changed, 1137 insertions(+) create mode 100644 examples/speech_recognition/hydra/README.md create mode 100644 examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml create mode 100644 examples/speech_recognition/hydra/conf/infer.yaml create mode 100644 examples/speech_recognition/hydra/decoder.py create mode 100644 examples/speech_recognition/hydra/infer.py diff --git a/examples/speech_recognition/hydra/README.md b/examples/speech_recognition/hydra/README.md new file mode 100644 index 0000000000..17d5946675 --- /dev/null +++ b/examples/speech_recognition/hydra/README.md @@ -0,0 +1,45 @@ +# Flashlight Decoder + +This script runs decoding for pre-trained speech recognition models. + +## Usage + +Assuming a few variables: + +```bash +exp_dir=<path-to-experiment-directory> +data=<path-to-data-directory> +lm_model=<path-to-language-model> +lexicon=<path-to-lexicon> +``` + +Example usage for decoding a fine-tuned Wav2Vec model: + +```bash +python $FAIRSEQ_ROOT/examples/speech_recognition/hydra/infer.py --multirun \ + task=audio_pretraining \ + task.data=$data \ + task.labels=ltr \ + decoding.exp_dir=$exp_dir \ + decoding.decoder.name=kenlm \ + decoding.decoder.lexicon=$lexicon \ + decoding.decoder.lmpath=$lm_model \ + dataset.gen_subset=dev_clean,dev_other,test_clean,test_other +``` + +Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`): + +```bash +python $FAIRSEQ_ROOT/examples/speech_recognition/hydra/infer.py --multirun \ + hydra/sweeper=ax \ + task=audio_pretraining \ + task.data=$data \ + task.labels=ltr \ + decoding.exp_dir=$exp_dir \ + decoding.decoder.name=kenlm \ + decoding.decoder.lexicon=$lexicon \ + decoding.decoder.lmpath=$lm_model \ + decoding.write_sentences=false \ + decoding.unique_wer_file=true \ + dataset.gen_subset=dev_other +``` diff --git a/examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml b/examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml new file mode 100644 index 0000000000..7700712ea0 --- /dev/null +++ b/examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml @@ -0,0 +1,26 @@ +# @package hydra.sweeper +_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper +max_batch_size: null +ax_config: + max_trials: 100 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 1.0e-05 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.decoder.lmweight: + type: range + bounds: [0.0, 5.0] + decoding.decoder.wordscore: + type: range + bounds: [-5.0, 5.0] diff --git a/examples/speech_recognition/hydra/conf/infer.yaml b/examples/speech_recognition/hydra/conf/infer.yaml new file mode 100644 index 0000000000..1d78ba14cb --- /dev/null +++ b/examples/speech_recognition/hydra/conf/infer.yaml @@ -0,0 +1,22 @@ +# @package _group_ + +defaults: + - task: null + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/${dataset.gen_subset} + sweep: + dir: ${common_eval.results_path} + subdir: ${dataset.gen_subset} +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name} + path: ${decoding.exp_dir}/checkpoint_best.pt + post_process: letter +generation: + nbest: 1 + beam: 500 +dataset: + max_tokens: 1000000 + gen_subset: test diff --git a/examples/speech_recognition/hydra/decoder.py b/examples/speech_recognition/hydra/decoder.py new file mode 100644 index 0000000000..41fcbd7087 --- /dev/null +++ b/examples/speech_recognition/hydra/decoder.py @@ -0,0 +1,599 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import itertools as it +import math +import os.path as osp +import warnings +from collections import deque, namedtuple +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from examples.speech_recognition.data.replabels import unpack_replabels +from fairseq import tasks +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.dataclass.constants import ChoiceEnum +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models.fairseq_model import FairseqModel +from fairseq.utils import apply_to_sample +from omegaconf import MISSING, open_dict + +try: + from flashlight.lib.sequence.criterion import (CpuViterbiPath, + get_data_ptr_as_bytes) + from flashlight.lib.text.decoder import (LM, CriterionType, DecodeResult, + KenLM, LexiconDecoder, + LexiconDecoderOptions, + LexiconFreeDecoder, + LexiconFreeDecoderOptions, + LMState, SmearingMode, Trie) + from flashlight.lib.text.dictionary import create_word_dict, load_words +except ImportError: + warnings.warn( + "flashlight python bindings are required to use this functionality. " + "Please install from " + "https://github.com/facebookresearch/flashlight/tree/master/bindings/python" + ) + LM = object + LMState = object + + +CRITERION_CHOICES = ChoiceEnum(["ctc", "asg"]) +DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"]) + + +@dataclass +class DecoderConfig(FairseqDataclass): + name: DECODER_CHOICES = field( + default="viterbi", + metadata={"help": "The type of decoder to use"}, + ) + nbest: int = field( + default=1, + metadata={"help": "Number of decodings to return"}, + ) + criterion: CRITERION_CHOICES = field( + default="ctc", + metadata={"help": "Criterion to use"}, + ) + asgtransitions: List[int] = field( + default=MISSING, + metadata={"help": "ASG transition indices"}, + ) + maxreplabel: int = field( + default=2, + metadata={"help": "Maximum repeated labels for ASG criterion"}, + ) + unitlm: bool = field( + default=False, + metadata={"help": "If set, use unit language model"}, + ) + lmpath: str = field( + default=MISSING, + metadata={"help": "Language model for KenLM decoder"}, + ) + lexicon: Optional[str] = field( + default=None, + metadata={"help": "Lexicon for Flashlight decoder"}, + ) + beam: int = field( + default=50, + metadata={"help": "Number of beams to use for decoding"}, + ) + beamthreshold: float = field( + default=15.0, + metadata={"help": "Threshold for beam search decoding"}, + ) + beamsizetoken: Optional[int] = field( + default=None, + metadata={"help": "Beam size to use"} + ) + wordscore: float = field( + default=1.5, + metadata={"help": "Word score for KenLM decoder"}, + ) + unkweight: float = field( + default=-math.inf, + metadata={"help": "Unknown weight for KenLM decoder"}, + ) + silweight: float = field( + default=-0.3, + metadata={"help": "Silence weight for KenLM decoder"}, + ) + lmweight: float = field( + default=1.5, + metadata={"help": "Weight for LM while interpolating score"}, + ) + unitlm: bool = field( + default=False, + metadata={"help": "If using a unit language model"}, + ) + + +class BaseDecoder: + def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + self.nbest = cfg.nbest + self.unitlm = cfg.unitlm + + if cfg.criterion == "ctc": + self.criterion_type = CriterionType.CTC + self.blank = ( + tgt_dict.index("<ctc_blank>") + if "<ctc_blank>" in tgt_dict.indices + else tgt_dict.bos() + ) + if "<sep>" in tgt_dict.indices: + self.silence = tgt_dict.index("<sep>") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") + else: + self.silence = tgt_dict.eos() + self.asgtransitions = None + elif cfg.criterion == "asg_loss": + self.criterion_type = CriterionType.ASG + self.blank = -1 + self.silence = -1 + self.asgtransitions = cfg.asgtransitions + self.maxreplabel = cfg.maxreplabel + assert len(self.asgtransitions) == self.vocab_size ** 2 + else: + raise RuntimeError(f"unknown criterion: {cfg.criterion}") + + def generate( + self, + models: List[FairseqModel], + sample: Dict[str, Any], + **unused + ) -> List[List[Dict[str, torch.LongTensor]]]: + encoder_input = { + k: v + for k, v in sample["net_input"].items() + if k != "prev_output_tokens" + } + emissions = self.get_emissions(models, encoder_input) + return self.decode(emissions) + + def get_emissions( + self, + models: List[FairseqModel], + encoder_input: Dict[str, Any], + ) -> torch.FloatTensor: + model = models[0] + encoder_out = model(**encoder_input) + if self.criterion_type == CriterionType.CTC: + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) + else: + emissions = model.get_normalized_probs( + encoder_out, log_probs=True) + elif self.criterion_type == CriterionType.ASG: + emissions = encoder_out["encoder_out"] + else: + raise ValueError("Criterion not implemented: " + f"{self.criterion_type}") + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: + idxs = (g[0] for g in it.groupby(idxs)) + if self.criterion_type == CriterionType.CTC: + idxs = filter(lambda x: x != self.blank, idxs) + elif self.criterion_type == CriterionType.ASG: + idxs = filter(lambda x: x >= 0, idxs) + idxs = unpack_replabels( + list(idxs), self.tgt_dict, self.maxreplabel) + return torch.LongTensor(list(idxs)) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + raise NotImplementedError + + +class ViterbiDecoder(BaseDecoder): + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + B, T, N = emissions.size() + if self.asgtransitions is None: + transitions = torch.FloatTensor(N, N).zero_() + else: + transitions = torch.FloatTensor(self.asgtransitions).view(N, N) + viterbi_path = torch.IntTensor(B, T) + workspace = torch.ByteTensor( + CpuViterbiPath.get_workspace_size(B, T, N)) + CpuViterbiPath.compute( + B, + T, + N, + get_data_ptr_as_bytes(emissions), + get_data_ptr_as_bytes(transitions), + get_data_ptr_as_bytes(viterbi_path), + get_data_ptr_as_bytes(workspace), + ) + return [ + [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] + for b in range(B) + ] + + +class KenLMDecoder(BaseDecoder): + def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(cfg, tgt_dict) + + if cfg.lexicon: + self.lexicon = load_words(cfg.lexicon) + self.word_dict = create_word_dict(self.lexicon) + self.unk_word = self.word_dict.get_index("<unk>") + + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.trie = Trie(self.vocab_size, self.silence) + + start_state = self.lm.start(False) + for word, spellings in self.lexicon.items(): + word_idx = self.word_dict.get_index(word) + _, score = self.lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idxs = [ + tgt_dict.index(token) + for token in spelling + ] + assert tgt_dict.unk() not in spelling_idxs, \ + f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + word_score=cfg.wordscore, + unk_score=cfg.unkweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=self.criterion_type, + ) + + if self.asgtransitions is None: + self.asgtransitions = [] + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + self.asgtransitions, + self.unitlm, + ) + else: + assert self.unitlm, "Lexicon-free decoding requires unit LM" + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + B, T, N = emissions.size() + hypos = [] + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append([ + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "words": [ + self.word_dict.get_entry(x) + for x in result.words if x >= 0 + ], + } for result in nbest_results + ]) + return hypos + + +FairseqLMState = namedtuple( + "FairseqLMState", + [ + "prefix", + "incremental_state", + "probs", + ] +) + + +class FairseqLM(LM): + def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None: + super().__init__() + + self.dictionary = dictionary + self.model = model + self.unk = self.dictionary.unk() + + self.save_incremental = False # this currently does not work properly + self.max_cache = 20_000 + + model.cuda() + model.eval() + model.make_generation_fast_() + + self.states = {} + self.stateq = deque() + + def start(self, start_with_nothing: bool) -> LMState: + state = LMState() + prefix = torch.LongTensor([[self.dictionary.eos()]]) + incremental_state = {} if self.save_incremental else None + with torch.no_grad(): + res = self.model( + prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None) + + if incremental_state is not None: + incremental_state = apply_to_sample( + lambda x: x.cpu(), incremental_state) + self.states[state] = FairseqLMState( + prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() + ) + self.stateq.append(state) + + return state + + def score( + self, + state: LMState, + token_index: int, + no_cache: bool = False, + ) -> Tuple[LMState, int]: + """ + Evaluate language model based on the current lm state and new word + Parameters: + ----------- + state: current lm state + token_index: index of the word + (can be lexicon index then you should store inside LM the + mapping between indices of lexicon and lm, or lm index of a word) + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + curr_state = self.states[state] + + def trim_cache(targ_size: int) -> None: + while len(self.stateq) > targ_size: + rem_k = self.stateq.popleft() + rem_st = self.states[rem_k] + rem_st = FairseqLMState(rem_st.prefix, None, None) + self.states[rem_k] = rem_st + + if curr_state.probs is None: + new_incremental_state = ( + curr_state.incremental_state.copy() + if curr_state.incremental_state is not None + else None + ) + with torch.no_grad(): + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cuda(), new_incremental_state + ) + elif self.save_incremental: + new_incremental_state = {} + + res = self.model( + torch.from_numpy(curr_state.prefix).cuda(), + incremental_state=new_incremental_state, + ) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None + ) + + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cpu(), new_incremental_state + ) + + curr_state = FairseqLMState( + curr_state.prefix, new_incremental_state, probs[0, -1].cpu( + ).numpy() + ) + + if not no_cache: + self.states[state] = curr_state + self.stateq.append(state) + + score = curr_state.probs[token_index].item() + + trim_cache(self.max_cache) + + outstate = state.child(token_index) + if outstate not in self.states and not no_cache: + prefix = np.concatenate( + [curr_state.prefix, torch.LongTensor([[token_index]])], -1 + ) + incr_state = curr_state.incremental_state + + self.states[outstate] = FairseqLMState(prefix, incr_state, None) + + if token_index == self.unk: + score = float("-inf") + + return outstate, score + + def finish(self, state: LMState) -> Tuple[LMState, int]: + """ + Evaluate eos for language model based on the current lm state + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + return self.score(state, self.dictionary.eos()) + + def empty_cache(self) -> None: + self.states = {} + self.stateq = deque() + gc.collect() + + +class FairseqLMDecoder(BaseDecoder): + def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(cfg, tgt_dict) + + self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None + self.idx_to_wrd = {} + + checkpoint = torch.load(cfg.lmpath, map_location="cpu") + + if "cfg" in checkpoint and checkpoint["cfg"] is not None: + lm_args = checkpoint["cfg"] + else: + lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) + + with open_dict(lm_args.task): + lm_args.task.data = osp.dirname(cfg.lmpath) + + task = tasks.setup_task(lm_args.task) + model = task.build_model(lm_args.model) + model.load_state_dict(checkpoint["model"], strict=False) + + self.trie = Trie(self.vocab_size, self.silence) + + self.word_dict = task.dictionary + self.unk_word = self.word_dict.unk() + self.lm = FairseqLM(self.word_dict, model) + + if self.lexicon: + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + if self.unitlm: + word_idx = i + self.idx_to_wrd[i] = word + score = 0 + else: + word_idx = self.word_dict.index(word) + _, score = self.lm.score( + start_state, word_idx, no_cache=True) + + for spelling in spellings: + spelling_idxs = [ + tgt_dict.index(token) + for token in spelling + ] + assert tgt_dict.unk() not in spelling_idxs, \ + f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = LexiconDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + word_score=cfg.wordscore, + unk_score=cfg.unkweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=self.criterion_type, + ) + + if self.asgtransitions is None: + self.asgtransitions = [] + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + self.asgtransitions, + self.unitlm, + ) + else: + assert self.unitlm, "Lexicon-free decoding requires unit LM" + + d = {w: [[w]] for w in tgt_dict.symbols} + self.word_dict = create_word_dict(d) + self.lm = KenLM(cfg.lmpath, self.word_dict) + self.decoder_opts = LexiconFreeDecoderOptions( + beam_size=cfg.beam, + beam_size_token=cfg.beamsizetoken or len(tgt_dict), + beam_threshold=cfg.beamthreshold, + lm_weight=cfg.lmweight, + sil_score=cfg.silweight, + log_add=False, + criterion_type=self.criterion_type, + ) + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + B, T, N = emissions.size() + hypos = [] + + def make_hypo(result: DecodeResult) -> Dict[str, Any]: + hypo = { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + } + if self.lexicon: + hypo["words"] = [ + self.idx_to_wrd[x] if self.unitlm else self.word_dict[x] + for x in result.words if x >= 0 + ] + return hypo + + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[:self.nbest] + hypos.append([make_hypo(result) for result in nbest_results]) + self.lm.empty_cache() + + return hypos + + +def Decoder(cfg: DecoderConfig, tgt_dict: Dictionary) -> BaseDecoder: + if cfg.name == "viterbi": + return ViterbiDecoder(cfg, tgt_dict) + if cfg.name == "kenlm": + return KenLMDecoder(cfg, tgt_dict) + if cfg.name == "fairseqlm": + return FairseqLMDecoder(cfg, tgt_dict) + raise NotImplementedError(f"Invalid decoder name: {cfg.name}") diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/hydra/infer.py new file mode 100644 index 0000000000..6afa066f25 --- /dev/null +++ b/examples/speech_recognition/hydra/infer.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import hashlib +import logging +import os +import shutil +import sys +from argparse import Namespace +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import editdistance +import torch +import torch.distributed as dist +from examples.speech_recognition.hydra.decoder import Decoder, DecoderConfig +from fairseq import (checkpoint_utils, distributed_utils, progress_bar, tasks, + utils) +from fairseq.data.data_utils import post_process +from fairseq.dataclass.configs import (CheckpointConfig, CommonConfig, + CommonEvalConfig, DatasetConfig, + DistributedTrainingConfig, + FairseqDataclass, GenerationConfig) +from fairseq.dataclass.initialize import hydra_init +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.logging.progress_bar import BaseProgressBar +from fairseq.models.fairseq_model import FairseqModel +from omegaconf import MISSING, OmegaConf + +import hydra +from hydra.core.config_store import ConfigStore + +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +config_path = Path(__file__).resolve().parent / "conf" + + +@dataclass +class DecodingConfig(FairseqDataclass): + exp_dir: str = field( + default=MISSING, + metadata={"help": "Path to the experiment directory"}, + ) + unique_wer_file: bool = field( + default=False, + metadata={"help": "If set, use a unique file for storing WER"}, + ) + write_sentences: bool = field( + default=True, + metadata={"help": "If set, write hypothesis and reference sentences"}, + ) + decoder: DecoderConfig = DecoderConfig() + + +@dataclass +class InferConfig(FairseqDataclass): + task: Any = None + decoding: DecodingConfig = DecodingConfig() + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + generation: GenerationConfig = GenerationConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + + +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +class InferenceProcessor: + def __init__(self, cfg: InferConfig) -> None: + self.cfg = cfg + self.task = tasks.setup_task(cfg.task) + self.tgt_dict = self.task.target_dictionary + + models, saved_cfg = self.load_model_ensemble() + self.models = models + self.saved_cfg = saved_cfg + + self.task.load_dataset( + self.cfg.dataset.gen_subset, + task_cfg=saved_cfg.task, + ) + self.generator = Decoder(cfg.decoding.decoder, self.tgt_dict) + self.gen_timer = StopwatchMeter() + self.wps_meter = TimeMeter() + self.num_sentences = 0 + self.total_errors = 0 + self.total_length = 0 + + self.hypo_words_file = None + self.hypo_units_file = None + self.ref_words_file = None + self.ref_units_file = None + + self.progress_bar = self.build_progress_bar() + + def __enter__(self) -> "InferenceProcessor": + if self.cfg.decoding.write_sentences: + self.hypo_words_file = self.get_res_file("hypo.word") + self.hypo_units_file = self.get_res_file("hypo.units") + self.ref_words_file = self.get_res_file("ref.word") + self.ref_units_file = self.get_res_file("ref.units") + return self + + def __exit__(self, *exc) -> bool: + if self.cfg.decoding.write_sentences: + self.hypo_words_file.close() + self.hypo_units_file.close() + self.ref_words_file.close() + self.ref_units_file.close() + return False + + def __iter__(self) -> Any: + for sample in self.progress_bar: + if not self.cfg.common.cpu: + sample = utils.move_to_cuda(sample) + + # Happens on the last batch. + if "net_input" not in sample: + continue + yield sample + + def log(self, *args, **kwargs): + self.progress_bar.log(*args, **kwargs) + + def print(self, *args, **kwargs): + self.progress_bar.print(*args, **kwargs) + + def get_res_file(self, fname: str) -> None: + if self.data_parallel_world_size > 1: + fname = f"{fname}.{self.data_parallel_rank}" + return open(fname, "w", buffering=1) + + def merge_shards(self) -> None: + """Merges all shard files into shard 0, then removes shard suffix.""" + + shard_id = self.data_parallel_rank + num_shards = self.data_parallel_world_size + + def merge_shards_with_root(fname: str) -> None: + logger.info("Merging %s on shard %d", fname, shard_id) + base_fpath = Path(f"{fname}.0") + with open(base_fpath, "a") as out_file: + for s in range(1, num_shards): + shard_fpath = Path(f"{fname}.{s}") + with open(shard_fpath, "r") as in_file: + for line in in_file: + out_file.write(line) + shard_fpath.unlink() + shutil.move(f"{fname}.0", fname) + + if shard_id == (0 % num_shards): + merge_shards_with_root("hypo.word") + if shard_id == (1 % num_shards): + merge_shards_with_root("hypo.units") + if shard_id == (2 % num_shards): + merge_shards_with_root("ref.word") + if shard_id == (3 % num_shards): + merge_shards_with_root("ref.units") + dist.barrier() + + def optimize_model(self, model: FairseqModel) -> None: + gcfg = self.cfg.generation + model.make_generation_fast_( + beamable_mm_beam_size=None if gcfg.no_beamable_mm else gcfg.beam, + need_attn=gcfg.print_alignment, + ) + if self.cfg.common.fp16: + model.half() + if not self.cfg.common.cpu: + model.cuda() + + def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]: + arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(self.cfg.common_eval.path), + arg_overrides=arg_overrides, + task=self.task, + suffix=self.cfg.checkpoint.checkpoint_suffix, + strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=self.cfg.checkpoint.checkpoint_shard_count, + ) + for model in models: + self.optimize_model(model) + return models, saved_cfg + + def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None: + return self.task.get_batch_iterator( + dataset=self.task.dataset(self.cfg.dataset.gen_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ).next_epoch_itr(shuffle=False) + + def build_progress_bar( + self, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default_log_format: str = "tqdm", + ) -> BaseProgressBar: + return progress_bar.progress_bar( + iterator=self.get_dataset_itr(), + log_format=self.cfg.common.log_format, + log_interval=self.cfg.common.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=self.cfg.common.tensorboard_logdir, + default_log_format=default_log_format, + ) + + @property + def data_parallel_world_size(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 1 + return distributed_utils.get_data_parallel_world_size() + + @property + def data_parallel_rank(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 0 + return distributed_utils.get_data_parallel_rank() + + def process_sentence( + self, + sample: Dict[str, Any], + hypo: Dict[str, Any], + sid: int, + batch_id: int, + ) -> Tuple[int, int]: + speaker = None # Speaker can't be parsed from dataset. + + if "target_label" in sample: + toks = sample["target_label"] + else: + toks = sample["target"] + toks = toks[batch_id, :] + + # Processes hypothesis. + hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu()) + if "words" in hypo: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, + self.cfg.common_eval.post_process) + + # Processes target. + target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) + tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) + tgt_words = post_process(tgt_pieces, + self.cfg.common_eval.post_process) + + if self.cfg.decoding.write_sentences: + print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) + print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) + print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) + print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) + + hyp_words, tgt_words = hyp_words.split(), tgt_words.split() + + return editdistance.eval(hyp_words, tgt_words), len(tgt_words) + + def process_sample(self, sample: Dict[str, Any]) -> None: + self.gen_timer.start() + hypos = self.task.inference_step( + generator=self.generator, + models=self.models, + sample=sample, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + self.gen_timer.stop(num_generated_tokens) + self.wps_meter.update(num_generated_tokens) + + for batch_id, sample_id in enumerate(sample["id"].tolist()): + errs, length = self.process_sentence( + sample=sample, + sid=sample_id, + batch_id=batch_id, + hypo=hypos[batch_id][0], + ) + self.total_errors += errs + self.total_length += length + + self.log({"wps": round(self.wps_meter.avg)}) + if "nsentences" in sample: + self.num_sentences += sample["nsentences"] + else: + self.num_sentences += sample["id"].numel() + + def log_generation_time(self) -> None: + logger.info("Processed %d sentences (%d tokens) in %.1fs %.2f " + "sentences per second, %.2f tokens per second)", + self.num_sentences, self.gen_timer.n, self.gen_timer.sum, + self.num_sentences / self.gen_timer.sum, + 1.0 / self.gen_timer.avg) + + +def parse_wer(wer_file: Path) -> float: + with open(wer_file, "r") as f: + return float(f.readline().strip().split(" ")[1]) + + +def get_wer_file(cfg: InferConfig) -> Path: + """Hashes the decoding parameters to a unique file ID.""" + if cfg.decoding.unique_wer_file: + yaml_str = OmegaConf.to_yaml(cfg.decoding) + fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) + return Path(f"wer.{fid % 1000000}") + else: + return Path("wer") + + +def main(cfg: InferConfig) -> float: + """Entry point for main processing logic. + + Args: + cfg: The inferance configuration to use. + wer: Optional shared memory pointer for returning the WER. If not None, + the final WER value will be written here instead of being returned. + + Returns: + The final WER if `wer` is None, otherwise None. + """ + + yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg) + + # Validates the provided configuration. + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 4000000 + if not cfg.common.cpu and not torch.cuda.is_available(): + raise ValueError("CUDA not found; set `cpu=True` to run without CUDA") + if cfg.generation.nbest > 1: + raise ValueError("`nbest > 1` not implemented yet") + + with InferenceProcessor(cfg) as processor: + for sample in processor: + processor.process_sample(sample) + + processor.log_generation_time() + + if cfg.decoding.write_sentences: + processor.merge_shards() + + errs_t, leng_t = processor.total_errors, processor.total_length + + if cfg.common.cpu: + logger.warning("Merging WER requires CUDA.") + else: + stats = torch.LongTensor([errs_t, leng_t]).cuda() + dist.all_reduce(stats, op=dist.ReduceOp.SUM) + errs_t, leng_t = stats[0].item(), stats[1].item() + + wer = errs_t * 100.0 / leng_t + + if distributed_utils.is_master(cfg.distributed_training): + with open(wer_file, "w") as f: + f.write(f"WER: {wer}\n\n{yaml_str}") + + return wer + + +@hydra.main(config_path=config_path, config_name="infer") +def hydra_main(cfg: InferConfig) -> None: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + + if cfg.common.reset_logging: + reset_logging() + + logger.info("Config:\n%s", OmegaConf.to_yaml(cfg)) + logger.info("Working directory: %s", Path.cwd()) + wer = float("inf") + + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + wer = parse_wer(get_wer_file(cfg)) + except BaseException as e: # pylint: disable=broad-except + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! %s", str(e)) + + logger.info("Word error rate: %.4f", wer) + return wer + + +def cli_main() -> None: + try: + from hydra._internal.utils import \ + get_args # pylint: disable=import-outside-toplevel + cfg_name = get_args().config_name or "infer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "infer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=InferConfig) + + for k in InferConfig.__dataclass_fields__: + v = InferConfig.__dataclass_fields__[k].default + try: + cs.store(name=k, node=v) + except BaseException: + logger.error(f"{k} - {v}") + raise + + hydra_main() # pylint: disable=no-value-for-parameter + + +if __name__ == "__main__": + cli_main() From 82ec2e722f6fe75686ab2abc872b487ca748f1ce Mon Sep 17 00:00:00 2001 From: Pranav Deshpande <pranavcd@fb.com> Date: Tue, 2 Feb 2021 15:49:39 -0800 Subject: [PATCH 425/707] Fix the task data arg conversion to string. Summary: We were getting some test failures on our end due to incompatibility of task data argument type. The actual exception is defined in this task: T83395097 and T83395052. Fixing the task data arg to be a string instead of list of strings. Reviewed By: myleott Differential Revision: D26205482 fbshipit-source-id: d29d1ee7c469177e8bdad7ca603938f8450bd81c --- fairseq/checkpoint_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 79c811424a..2f209b6b39 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -546,6 +546,9 @@ def _upgrade_state_dict(state): # convert legacy float learning rate to List[float] if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] + # convert task data arg to a string instead of List[string] + if hasattr(state["args"], "data") and isinstance(state["args"].data, list) and len(state["args"].data) > 0: + state["args"].data = state["args"].data[0] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) From 62c1bb307582a481629662cca4ce7005d8c0c236 Mon Sep 17 00:00:00 2001 From: Giri Anantharaman <giriman@learnfair6000.h2.fair> Date: Wed, 3 Feb 2021 04:48:27 -0800 Subject: [PATCH 426/707] Adding initialization for `num_pipelines_per_node` (#1599) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …hod` to avoid unbounded local error. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Adding initialization for `num_pipelines_per_node` in `infer_init_method` in `distributed/utils.py` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1599 Reviewed By: myleott Differential Revision: D26208044 Pulled By: girifb fbshipit-source-id: 98d3c0b70b59a5e0abb027850baa3bc44d9c3c78 --- fairseq/distributed/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index c39dc6d912..e3d8e1e0d3 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -47,6 +47,7 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): if cfg.distributed_init_method is not None or cfg.tpu: return + num_pipelines_per_node = None if cfg.pipeline_model_parallel: num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) From 83391858c9ca951bb243d60ba3f5f23b87e99cf8 Mon Sep 17 00:00:00 2001 From: Yuriy Nazarov <nazarov.yuriy.pavlovich@gmail.com> Date: Wed, 3 Feb 2021 08:59:55 -0800 Subject: [PATCH 427/707] Update lang-pairs path and add fixed-dictionary for small models (#3084) Summary: With missing file extension in --lang-pairs option generation from 418M and 1.2B Models fails with the following error ``` ValueError: language pair en-fr contains languages that are not in the language dictionary; langs: ['language_pairs_small_models'] ``` However generation still fails after restoring file extension with following error: ``` RuntimeError: Error(s) in loading state_dict for TransformerModel: size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([128112, 1024]) from checkpoint, the shape in current model is torch.Size([128104, 1024]). size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([128112, 1024]) from checkpoint, the shape in current model is torch.Size([128104, 1024]). size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([128112, 1024]) from checkpoint, the shape in current model is torch.Size([128104, 1024]). ``` This could be resolved by adding --fixed-dictionary model_dict.128k.txt like in Generation for the 12B model section. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3084 Reviewed By: huihuifan Differential Revision: D26225960 Pulled By: myleott fbshipit-source-id: 0cabe1fd074e45484264d551117704180c7ade9f --- examples/m2m_100/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index f1b465c7b9..05801584d6 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -130,7 +130,7 @@ wget https://dl.fbaipublicfiles.com/m2m_100/418M_last_checkpoint.pt wget https://dl.fbaipublicfiles.com/m2m_100/1.2B_last_checkpoint.pt # Generation: -fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models --decoder-langtok --encoder-langtok src --gen-subset test > gen_out +fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model --fixed-dictionary model_dict.128k.txt -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models.txt --decoder-langtok --encoder-langtok src --gen-subset test > gen_out ``` ### 12B Model From e802a30bffdad0b22ad6efc413230ca348f8f50b Mon Sep 17 00:00:00 2001 From: Xu Song <xusong.vip@gmail.com> Date: Wed, 3 Feb 2021 09:43:42 -0800 Subject: [PATCH 428/707] Fix hyperlink (#3193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fix hyperlink ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3193 Reviewed By: ngoyal2707 Differential Revision: D26225560 Pulled By: myleott fbshipit-source-id: a67a11cf76d1f003d8408b15edbe30f3f7b4fd5b --- examples/bart/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/bart/README.md b/examples/bart/README.md index e891894a84..013a809be6 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -1,6 +1,6 @@ # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension -[https://arxiv.org/pdf/1910.13461.pdf] +[https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461) ## Introduction From 6ec7ed9920c64ae99a787c0885c543896b525df0 Mon Sep 17 00:00:00 2001 From: Xu Song <xusong.vip@gmail.com> Date: Wed, 3 Feb 2021 11:25:52 -0800 Subject: [PATCH 429/707] Fix logger error (#3184) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fix logger error ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3184 Reviewed By: ngoyal2707 Differential Revision: D26225518 Pulled By: myleott fbshipit-source-id: eeffc5ede7de1b335148d9a2a2a9cf69fc7630ad --- fairseq/models/bart/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 44f03b0162..71d0b27cd2 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -277,7 +277,7 @@ def truncate_emb(key): cur_state = self.classification_heads.state_dict() for k, v in cur_state.items(): if prefix + "classification_heads." + k not in state_dict: - logger.info("Overwriting", prefix + "classification_heads." + k) + logger.info("Overwriting " + prefix + "classification_heads." + k) state_dict[prefix + "classification_heads." + k] = v From fd624018bf3e834c09cc03695a8fa0bcaa4a10f3 Mon Sep 17 00:00:00 2001 From: Xu Song <xusong.vip@gmail.com> Date: Wed, 3 Feb 2021 11:31:36 -0800 Subject: [PATCH 430/707] Fix AttributeError: 'Namespace' object has no attribute 'max_positions' (#3183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fix AttributeError: 'Namespace' object has no attribute 'max_positions' https://github.com/pytorch/fairseq/blob/master/examples/bart/README.glue.md#3-fine-tuning-on-glue-task ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3183 Reviewed By: ngoyal2707 Differential Revision: D26225511 Pulled By: myleott fbshipit-source-id: 29e219b3d9be552aa3f17963b1095c9aa610f4a1 --- fairseq/tasks/sentence_prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 0ec3824d04..f5bead972f 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -174,7 +174,7 @@ def make_dataset(key, dictionary): split, self.args.shorten_data_split_list, self.args.shorten_method, - self.args.max_positions, + self.max_positions(), self.args.seed, ) From 51c312a30f33e3366b1bb61084d037a90aa1d4a0 Mon Sep 17 00:00:00 2001 From: Sugiyama <h.sugi@ieee.org> Date: Wed, 3 Feb 2021 11:55:02 -0800 Subject: [PATCH 431/707] Modify eval_bleu_args to Optional (#3175) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3158 . ## PR review Fixes bug of loading previously trained translation-task model. It is not necessary to define eval_bleu_args and eval_bleu_detok_args when we use eval_bleu=False. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3175 Reviewed By: alexeib Differential Revision: D26225882 Pulled By: myleott fbshipit-source-id: ec5908179560cc44c31bac29831beb62dd81305d --- fairseq/tasks/translation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index d975fd49d2..90635d882f 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -228,7 +228,7 @@ class TranslationConfig(FairseqDataclass): eval_bleu: bool = field( default=False, metadata={"help": "evaluation with BLEU scores"} ) - eval_bleu_args: str = field( + eval_bleu_args: Optional[str] = field( default="{}", metadata={ "help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' @@ -241,7 +241,7 @@ class TranslationConfig(FairseqDataclass): "use 'space' to disable detokenization; see fairseq.data.encoders for other options" }, ) - eval_bleu_detok_args: str = field( + eval_bleu_detok_args: Optional[str] = field( default="{}", metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, ) From 791ab7c20831a76a9196aaf0db3a2cb1cf906dde Mon Sep 17 00:00:00 2001 From: Hongfei XU <anoidgit@users.noreply.github.com> Date: Wed, 3 Feb 2021 12:03:15 -0800 Subject: [PATCH 432/707] More accurate label smoothing loss computation (#3182) Summary: It seems that the current implementation uses a slightly larger label smoothing value, for a large vocabulary, it is fine, but it can be more different with a small vocabulary size. By changing these 2 lines, the computation of label smoothing loss shall be consistent with the standard. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3182 Reviewed By: glample Differential Revision: D26225506 Pulled By: myleott fbshipit-source-id: 75447275e32336ae3b52e732e6124e15d0043b74 --- fairseq/criterions/label_smoothed_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index cb47a1582f..56d63e3e1b 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -45,8 +45,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() - eps_i = epsilon / lprobs.size(-1) - loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss return loss, nll_loss From 81e38ed39da02b5104939cef941bd848b87f9e26 Mon Sep 17 00:00:00 2001 From: Tim Gates <tim.gates@iress.com> Date: Wed, 3 Feb 2021 12:04:25 -0800 Subject: [PATCH 433/707] docs: fix simple typo, efficieny -> efficiency (#3070) Summary: There is a small typo in fairseq/modules/dynamic_convolution.py, fairseq/modules/dynamicconv_layer/dynamicconv_layer.py. Should read `efficiency` rather than `efficieny`. Semi-automated pull request generated by https://github.com/timgates42/meticulous/blob/master/docs/NOTE.md Pull Request resolved: https://github.com/pytorch/fairseq/pull/3070 Reviewed By: huihuifan Differential Revision: D26225968 Pulled By: myleott fbshipit-source-id: fb7479f9678bc420e80fbb72c5389a54ad4d4c1d --- fairseq/modules/dynamic_convolution.py | 2 +- fairseq/modules/dynamicconv_layer/dynamicconv_layer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py index 9f2d28da65..0121d453b9 100644 --- a/fairseq/modules/dynamic_convolution.py +++ b/fairseq/modules/dynamic_convolution.py @@ -263,7 +263,7 @@ def _forward_expanded(self, x, incremental_stat, query): weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) else: P = self.padding_l - # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + # For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length if K > T and P == K - 1: weight = weight.narrow(2, K - T, T) K, P = T, T - 1 diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py index 4a683d2690..711ed03483 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -212,7 +212,7 @@ def _forward_expanded(self, x, incremental_stat, query): weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) else: P = self.padding_l - # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + # For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length if K > T and P == K - 1: weight = weight.narrow(2, K - T, T) K, P = T, T - 1 From e996bddcd7b244b1e22d476bc6f402e4ff86167c Mon Sep 17 00:00:00 2001 From: zzxn <zzxnhackman@foxmail.com> Date: Wed, 3 Feb 2021 12:07:16 -0800 Subject: [PATCH 434/707] A small fix on a problem on Windows: AttributeError: module 'signal' has no attribute 'SIGKILL' (#3188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3187 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3188 Reviewed By: lematt1991 Differential Revision: D26225553 Pulled By: myleott fbshipit-source-id: 7aff636f9ba3392bee6cdf305e849aa5c8994a5b --- fairseq/distributed/distributed_timeout_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/distributed/distributed_timeout_wrapper.py index c8ab477073..18107ef27e 100644 --- a/fairseq/distributed/distributed_timeout_wrapper.py +++ b/fairseq/distributed/distributed_timeout_wrapper.py @@ -33,7 +33,7 @@ class DistributedTimeoutWrapper(nn.Module): (set to a value <= 0 to disable the timeout) signal (Optional): signal to send once timeout is triggered """ - def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGKILL): + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): super().__init__() self.module = module self.timeout = timeout From de3e0fc65158fecf7e6ffa464003839c70a7494f Mon Sep 17 00:00:00 2001 From: Muhammad Khalifa <moyle2010@gmail.com> Date: Wed, 3 Feb 2021 12:11:31 -0800 Subject: [PATCH 435/707] Added shorten dataset to denoising task (#3148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/2318 The denoising task had no truncate option, which caused errors with longer sentences. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3148 Reviewed By: joshim5 Differential Revision: D26225934 Pulled By: myleott fbshipit-source-id: 338194e570501293bc5d3b61b8522416d1e6cf07 --- fairseq/tasks/denoising.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index 41bddc1a05..cbf01e14df 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -21,6 +21,7 @@ data_utils, ) from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.tasks import LegacyFairseqTask, register_task import numpy as np @@ -121,6 +122,20 @@ def add_args(parser): help="max number of tokens in the target sequence", ) + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + + def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary @@ -162,6 +177,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = StripTokenDataset(dataset, self.dictionary.eos()) + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, From 4c197de87f92b0bd7e427aa3e094d05112b325a0 Mon Sep 17 00:00:00 2001 From: Harveen Singh Chadha <30959215+harveenchadha@users.noreply.github.com> Date: Wed, 3 Feb 2021 12:15:10 -0800 Subject: [PATCH 436/707] Fixes #3005 (#3122) Summary: The normalize and encoder_embed_dim are not present in base pretraining config which leads to errors during finetuning. ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3005 Pull Request resolved: https://github.com/pytorch/fairseq/pull/3122 Reviewed By: alexeib Differential Revision: D26225929 Pulled By: myleott fbshipit-source-id: 38067492b0241a30f84dc439f704324a032e054b --- .../wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml index 767aee2852..b686e21ab1 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml @@ -15,6 +15,7 @@ task: data: ??? max_sample_size: 250000 min_sample_size: 32000 + normalize: false dataset: num_workers: 6 @@ -53,3 +54,4 @@ model: dropout_input: 0.1 dropout_features: 0.1 feature_grad_mult: 0.1 + encoder_embed_dim: 768 From 8629245b0329a7e704ebc7ec05b94ac094468c1b Mon Sep 17 00:00:00 2001 From: markaa <oma654@yandex.ru> Date: Wed, 3 Feb 2021 12:25:02 -0800 Subject: [PATCH 437/707] Added weight initialization for ConvTBC (#3179) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3131 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3179 Reviewed By: huihuifan Differential Revision: D26225878 Pulled By: myleott fbshipit-source-id: 52267d10db6ec86be1c89207a768ae8b54ae1f82 --- fairseq/modules/conv_tbc.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py index 79b2b2ad57..65e17ec94f 100644 --- a/fairseq/modules/conv_tbc.py +++ b/fairseq/modules/conv_tbc.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch +from torch import nn from torch.nn.modules.utils import _single from torch import Tensor @@ -27,6 +28,12 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0): ) self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + nn.init.zeros_(self.bias) + def conv_tbc(self, input: Tensor): return torch.conv_tbc( input.contiguous(), self.weight, self.bias, self.padding[0] From b4843681b4d5af442febf8caba58ca9600b01656 Mon Sep 17 00:00:00 2001 From: Xu Song <xusong.vip@gmail.com> Date: Wed, 3 Feb 2021 12:27:18 -0800 Subject: [PATCH 438/707] Update sentence_prediction.py (#3165) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes warning local variable `label_dict` is not used https://github.com/pytorch/fairseq/blob/bfcc13e20a6cfa18fb25daaae39644f9b7872699/fairseq/tasks/sentence_prediction.py#L122-L132 ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3165 Reviewed By: ngoyal2707 Differential Revision: D26225892 Pulled By: myleott fbshipit-source-id: f4fe0ceb6f0959112a61c119cf437405d01179ed --- fairseq/tasks/sentence_prediction.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index f5bead972f..67acf7d377 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -119,9 +119,8 @@ def setup_task(cls, args, **kwargs): ) logger.info("[input] dictionary: {} types".format(len(data_dict))) - label_dict = None + # load label dictionary if not args.regression_target: - # load label dictionary label_dict = cls.load_dictionary( args, os.path.join(args.data, "label", "dict.txt"), From 4f9831bf847b8595f5590faf30b2f0af6a03bac4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen <patrick.v.platen@gmail.com> Date: Thu, 4 Feb 2021 18:34:37 -0800 Subject: [PATCH 439/707] Add small section for wav2vec 2.0 HF Transformers implementation (#3216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3216 Reviewed By: aconneau Differential Revision: D26269645 Pulled By: alexeib fbshipit-source-id: 239af3a16ef39b90fc7fe71b2e02e068b7727040 --- examples/wav2vec/README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index b3300a8ed8..663adf97dc 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -147,6 +147,35 @@ python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/lib To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. +## Use wav2vec 2.0 with 🤗Transformers: + +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since vesion 4.3. + +Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) +and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). + +Usage example: + +```python +# !pip install transformers +import soundfile as sf +import torch +from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer + +# load pretrained model +tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") +model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") + +# load audio +audio_input, _ = sf.read("path/to/audio/file") + +# transcribe +input_values = tokenizer(audio_input, return_tensors="pt").input_values +logits = model(input_values).logits +predicted_ids = torch.argmax(logits, dim=-1) +transcription = tokenizer.batch_decode(predicted_ids)[0] +``` + # wav2vec Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). From 9316f13ad53fa532c7306c7261e4d76c58e38b48 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Thu, 4 Feb 2021 21:36:18 -0800 Subject: [PATCH 440/707] Add interactive.py support for S2T Summary: Add interactive.py support for S2T Github issue: https://github.com/pytorch/fairseq/issues/3146 Reviewed By: jmp84 Differential Revision: D26260681 fbshipit-source-id: 7f2f6e49f8e4b7767550665a3cfe12c962469a7d --- examples/speech_to_text/README.md | 14 +++++++++++--- examples/speech_to_text/docs/covost_example.md | 9 +++++++++ .../speech_to_text/docs/librispeech_example.md | 8 ++++++++ fairseq/tasks/fairseq_task.py | 10 ++++++++++ fairseq/tasks/speech_to_text.py | 12 +++++++++--- fairseq_cli/interactive.py | 16 +++++----------- 6 files changed, 52 insertions(+), 17 deletions(-) diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 0bd8bfdac9..4b6f89d105 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -19,9 +19,14 @@ Fairseq S2T also employs a YAML file for data related configurations: tokenizer for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment, temperature-based resampling, etc. -## Model Training & Evaluation -Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation. -It requires arguments `--task speech_to_text` and `--arch <model architecture in fairseq.models.speech_to_text.*>`. +## Model Training +Fairseq S2T uses the unified `fairseq-train` interface for model training. It requires arguments `--task speech_to_text`, + `--arch <model architecture in fairseq.models.speech_to_text.*>` and `--config-yaml <config YAML filename>`. + +## Inference & Evaluation +Fairseq S2T uses the unified `fairseq-generate`/`fairseq-interactive` interface for inference and evaluation. It +requires arguments `--task speech_to_text` and `--config-yaml <config YAML filename>`. The interactive console takes +audio paths (one per line) as inputs. ## Examples @@ -32,6 +37,9 @@ It requires arguments `--task speech_to_text` and `--arch <model architecture in - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) ## Updates +- 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: + [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding) + and [ST (CoVoST 2)](docs/covost_example.md#interactive-decoding). - 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data preparation scripts. We also add pre-trained models to the examples and revise the instructions. Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied diff --git a/examples/speech_to_text/docs/covost_example.md b/examples/speech_to_text/docs/covost_example.md index a4ce8a10e4..55cd134c16 100644 --- a/examples/speech_to_text/docs/covost_example.md +++ b/examples/speech_to_text/docs/covost_example.md @@ -85,6 +85,15 @@ fairseq-generate ${COVOST_ROOT}/fr \ --max-tokens 50000 --beam 5 --scoring sacrebleu ``` +## Interactive Decoding +Launch the interactive console via +```bash +fairseq-interactive ${COVOST_ROOT}/fr --config-yaml config_st_fr_en.yaml \ + --task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 +``` +Type in WAV/FLAC/OGG audio paths (one per line) after the prompt. + #### Results | --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model | |---|---|---|---|---|---|---|---|---|---|---| diff --git a/examples/speech_to_text/docs/librispeech_example.md b/examples/speech_to_text/docs/librispeech_example.md index 21b754ee11..4749e6cecc 100644 --- a/examples/speech_to_text/docs/librispeech_example.md +++ b/examples/speech_to_text/docs/librispeech_example.md @@ -50,6 +50,14 @@ for SUBSET in dev-clean dev-other test-clean test-other; do done ``` +## Interactive Decoding +Launch the interactive console via +```bash +fairseq-interactive ${LS_ROOT} --config-yaml config.yaml --task speech_to_text \ + --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 +``` +Type in WAV/FLAC/OGG audio paths (one per line) after the prompt. + ## Results | --arch | Params | dev-clean | dev-other | test-clean | test-other | Model | diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index eb5e6a7694..3fe3ac995c 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -532,6 +532,16 @@ def build_bpe(self, args): """Build the tokenizer for this task.""" return encoders.build_bpe(args) + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + tokens = [ + self.source_dictionary.encode_line( + encode_fn(src_str), add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + return tokens, lengths + class LegacyFairseqTask(FairseqTask): def __init__(self, args: Namespace): diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 8fb341b0c5..8bdf215643 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -12,6 +12,7 @@ S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, + get_features_or_waveform ) from fairseq.tasks import LegacyFairseqTask, register_task @@ -138,6 +139,11 @@ def build_bpe(self, args): logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) - @classmethod - def build_dataset_for_inference(cls, audio_paths, n_frames): - return SpeechToTextDataset("interactive", False, {}, audio_paths, n_frames) + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + n_frames = [get_features_or_waveform(p).shape[0] for p in lines] + return lines, n_frames + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + return SpeechToTextDataset( + "interactive", False, self.data_cfg, src_tokens, src_lengths + ) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 4785855985..cadef2821a 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -20,7 +20,6 @@ import numpy as np import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils -from fairseq.data import encoders from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints @@ -76,19 +75,13 @@ def encode_fn_target(x): for constraint in constraint_list ] - tokens = [ - task.source_dictionary.encode_line( - encode_fn(src_str), add_if_not_exist=False - ).long() - for src_str in lines - ] - if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None - lengths = [t.numel() for t in tokens] + tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) + itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor @@ -176,8 +169,8 @@ def main(cfg: FairseqConfig): generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(cfg.tokenizer) - bpe = encoders.build_bpe(cfg.bpe) + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: @@ -256,6 +249,7 @@ def decode_fn(x): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): + src_str = '' if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) print("S-{}\t{}".format(id_, src_str)) From 0f93bd1a7d451944b77804aaf25e40696510411b Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Fri, 5 Feb 2021 15:42:02 -0800 Subject: [PATCH 441/707] Implement Mixup as a batch transform in PySpeech Summary: This diff implements the Mixup data augmentation in PySpeech. It is implemented as `MixupBatchTransform`, which acts on batches of data instead of single instances. Such a batch transform should be called by the collater, after it collates samples into a batch and before it converts the batch from numpy arrays to PyTorch tensors. See `TALNetTask` for an example of how to use it. Sometimes we may want to apply SpecAugment after Mixup, when data instances have already been collated into batches. The class `Batchify` is a wrapper that turns an instance transform into a batch transform, by applying the instance transform to every instance in a batch. Reviewed By: nayansinghal Differential Revision: D26228942 fbshipit-source-id: b5784d8acab840d6ae6fa636d01cd7c68955d606 --- fairseq/criterions/wav2vec_criterion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index cc454b9309..859177f2b6 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -148,7 +148,7 @@ def reduce_metrics(logging_outputs) -> None: ) metrics.log_scalar( - "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + "loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3 ) metrics.log_scalar("ntokens", ntokens) metrics.log_scalar("nsentences", nsentences) @@ -183,7 +183,7 @@ def reduce_metrics(logging_outputs) -> None: val = sum(log.get(k, 0) for log in logging_outputs) if k.startswith("loss"): metrics.log_scalar( - k, val / sample_size / math.log(2), sample_size, round=3 + k, val / (sample_size or 1) / math.log(2), sample_size, round=3 ) else: metrics.log_scalar(k, val / len(logging_outputs), round=3) From 5a170841f2faba7413a2d59c792bee6a3ff38838 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Sat, 6 Feb 2021 08:05:41 -0800 Subject: [PATCH 442/707] Make checkpoint wrapper pickleable (#1603) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1603 Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26237760 Pulled By: myleott fbshipit-source-id: 73c67bdea4b5b16e3159a5d4f0151e514e853357 --- fairseq/modules/checkpoint_activations.py | 44 ++++++++++++----------- tests/test_activation_checkpointing.py | 11 ++++-- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index c84e70bf7b..ae07dcfaa0 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools from typing import Any, Dict, List, Tuple, Union import torch @@ -25,29 +26,32 @@ def checkpoint_wrapper(m, offload_to_cpu=False): checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) """ - original_forward = m.forward - - def _checkpointed_forward(*args, **kwargs): - # Autograd Functions in PyTorch work best with positional args, since - # the backward must return gradients (or None) for every input argument. - # We can flatten keyword arguments to make this easier. - kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) - parent_ctx_dict = {"offload": offload_to_cpu} - output = CheckpointFunction.apply( - original_forward, parent_ctx_dict, kwarg_keys, *flat_args - ) - if isinstance(output, torch.Tensor): - return output - else: - packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] - if packed_non_tensor_outputs: - output = unpack_non_tensors(output, packed_non_tensor_outputs) - return output - - m.forward = _checkpointed_forward + m.forward = functools.partial( + _checkpointed_forward, + m.forward, # original_forward + offload_to_cpu, + ) return m +def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict = {"offload": offload_to_cpu} + output = CheckpointFunction.apply( + original_forward, parent_ctx_dict, kwarg_keys, *flat_args + ) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: """ Usage:: diff --git a/tests/test_activation_checkpointing.py b/tests/test_activation_checkpointing.py index 4b86211bde..647a957288 100644 --- a/tests/test_activation_checkpointing.py +++ b/tests/test_activation_checkpointing.py @@ -12,7 +12,9 @@ class Model(nn.Module): - def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False): + def __init__( + self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs + ): super().__init__() torch.manual_seed(0) self.use_pytorch_checkpoint = use_pytorch_checkpoint @@ -23,7 +25,7 @@ def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False): nn.Linear(128, 32), ) if use_fairseq_checkpoint: - self.ffn = checkpoint_wrapper(self.ffn) + self.ffn = checkpoint_wrapper(self.ffn, **kwargs) self.out = nn.Linear(32, 1) def forward(self, x): @@ -60,6 +62,11 @@ def get_loss_and_gnorm(model): torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) + model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) + fairseq_cpt_offload = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) + def test_checkpoint_wrapper_cpu(self): self._test_checkpoint_wrapper(device=torch.device("cpu")) From 7aa999f2a8084428a9675ee9d9b782bb797fd6ce Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Sat, 6 Feb 2021 08:05:41 -0800 Subject: [PATCH 443/707] Add --optimizer=cpu_adam (#1604) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1604 Test Plan: Imported from OSS Reviewed By: shruti-bh Differential Revision: D26237761 Pulled By: myleott fbshipit-source-id: 2deb78a93ca23c261c38370ac810c317e4ec20ee --- fairseq/optim/cpu_adam.py | 187 +++++++++++++++++++++++++++++ fairseq/optim/fairseq_optimizer.py | 2 + 2 files changed, 189 insertions(+) create mode 100644 fairseq/optim/cpu_adam.py diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py new file mode 100644 index 0000000000..fad5a64ecb --- /dev/null +++ b/fairseq/optim/cpu_adam.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer +from omegaconf import II, DictConfig + + +try: + from deepspeed.ops.op_builder import CPUAdamBuilder + has_deepspeed_cpu_adam = True +except ImportError: + has_deepspeed_cpu_adam = False + + +@dataclass +class FairseqCPUAdamConfig(FairseqDataclass): + adam_betas: str = field( + default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} + ) + adam_eps: float = field( + default=1e-8, metadata={"help": "epsilon for Adam optimizer"} + ) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + fp16_adam_stats: bool = field( + default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} + ) + # TODO common vars below in parent + lr: List[float] = II("optimization.lr") + + +@register_optimizer("cpu_adam", dataclass=FairseqCPUAdamConfig) +class FairseqCPUAdam(FairseqOptimizer): + """Adam optimizer for fairseq, optimized for CPU tensors. + + Important note: this optimizer corresponds to the "AdamW" variant of + Adam in its weight decay behavior. As such, it is most closely + analogous to torch.optim.AdamW from PyTorch. + """ + + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + self._optimizer = CPUAdam(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, + "use_fp16_stats": self.cfg.fp16_adam_stats, + } + + +class CPUAdam(torch.optim.Optimizer): + + optimizer_id = 0 + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + use_fp16_stats=False, + ): + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + self.use_fp16_stats = use_fp16_stats + self.FLOAT16_MAX = 65504.0 + + if not has_deepspeed_cpu_adam: + raise ImportError("Please install DeepSpeed: pip install deepspeed") + + self.opt_id = CPUAdam.optimizer_id + CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 + + self.ds_opt_adam = CPUAdamBuilder().load() + adamw_mode = True + self.ds_opt_adam.create_adam( + self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group["params"]): + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + dtype = torch.float16 if self.use_fp16_stats else p.data.dtype + # gradient momentums + state["exp_avg"] = torch.zeros_like( + p.data, dtype=dtype, device="cpu" + ) + # gradient variances + state["exp_avg_sq"] = torch.zeros_like( + p.data, dtype=dtype, device="cpu" + ) + if self.use_fp16_stats: + assert torch.is_floating_point(p.data) + state["exp_avg_scale"] = 1.0 + state["exp_avg_sq_scale"] = 1.0 + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + p_data_bak = p.data # backup of the original data pointer + + p.data = p.data.to(dtype=torch.float32, device="cpu") + p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") + + if self.use_fp16_stats: + exp_avg = exp_avg.float() * state["exp_avg_scale"] + exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] + + state["step"] += 1 + beta1, beta2 = group["betas"] + + self.ds_opt_adam.adam_update( + self.opt_id, + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + exp_avg, + exp_avg_sq, + ) + + if p_data_bak.data_ptr() != p.data.data_ptr(): + p_data_bak.copy_(p.data) + p.data = p_data_bak + + if self.use_fp16_stats: + + def inf_norm(t): + return torch.norm(t, float("inf")) + + # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( + 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, + 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, + ) + state["exp_avg"], state["exp_avg_sq"] = ( + (exp_avg / state["exp_avg_scale"]).half(), + (exp_avg_sq / state["exp_avg_sq_scale"]).half(), + ) + + return loss diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index a1c1d219a0..7e5411753a 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -103,6 +103,8 @@ def multiply_grads(self, c): """Multiplies grads by a constant *c*.""" for p in self.params: if p.grad is not None: + if torch.is_tensor(c): + c = c.to(p.grad.device) p.grad.data.mul_(c) def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): From 5605d6fbb44d7ee8240048b6a172385ce3c07c6f Mon Sep 17 00:00:00 2001 From: Arthur Guo <arthurguo@fb.com> Date: Mon, 8 Feb 2021 12:13:39 -0800 Subject: [PATCH 444/707] Refactor Python LAS Rescoring Inference Summary: This diff cleans up the code in D26195511 Differential Revision: D25857284 fbshipit-source-id: 6b445c3d263078bf711a429130d9f983421e6a10 --- fairseq/models/fairseq_decoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 35a349fa5f..7eeb5c652f 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -17,6 +17,8 @@ def __init__(self, dictionary): super().__init__() self.dictionary = dictionary self.onnx_trace = False + self.adaptive_softmax = None + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): """ From 6381aa2bb24f125d271e241c726a2fea581bc3c4 Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov <mavlyutov@fb.com> Date: Mon, 8 Feb 2021 14:13:34 -0800 Subject: [PATCH 445/707] Adding FBSequenceGenerator Reviewed By: mikekgfb Differential Revision: D26228721 fbshipit-source-id: b7a83bbc719d50d677d9b4c8a74f3c20def85357 --- fairseq/tasks/fairseq_task.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 3fe3ac995c..34264bdc01 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -313,6 +313,10 @@ def build_generator( SequenceGenerator, SequenceGeneratorWithAlignment, ) + try: + from fairseq.fb_sequence_generator import FBSequenceGenerator + except ModuleNotFoundError: + pass # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) @@ -379,6 +383,8 @@ def build_generator( if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + elif getattr(args, "fb_seq_gen", False): + seq_gen_cls = FBSequenceGenerator else: seq_gen_cls = SequenceGenerator From 3aeb8fe1007f098f629bb20cc31c339dcbf5ad57 Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Mon, 8 Feb 2021 16:18:00 -0800 Subject: [PATCH 446/707] Explicitly annotate attn_scores as Optional[Tensor] Summary: The behavior of ternary if's type inference changed a bit with D26278969. Need to annotate explicitly for it to work. Otherwise tests fail as in T84436394. Reviewed By: nikithamalgifb Differential Revision: D26320452 fbshipit-source-id: ad61f8ba5ea730150350cb839f5d9c73d476aed4 --- fairseq/models/lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 1a9dca3c75..12e3aff85d 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -535,7 +535,7 @@ def extract_features( assert ( srclen > 0 or self.attention is None ), "attention is not supported if there are no encoder outputs" - attn_scores = ( + attn_scores: Optional[Tensor] = ( x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None ) outs = [] From 4fed0beca64a52aa718371dc3b2cf1fd979197a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen <patrick.v.platen@gmail.com> Date: Wed, 10 Feb 2021 14:03:24 -0800 Subject: [PATCH 447/707] Fix padding mask for new architectures (#3228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3227 All models that do **not** make use of group norm, such as - Wav2Vec 2.0 Large (LV-60)* - Wav2Vec 2.0 Large (LV-60) + Self Training * do need this fix IMO to able to correctly run batches through the model. Before this PR, the following code snippet failed: ```python import fairseq import torch # get model wav2vec_path = "data/wav2vec2_vox_960h_new.pt" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [wav2vec_path], arg_overrides={"data": "./data"} ) model = model[0] model.eval() # create single input input_wav_0 = torch.randn((1, 2000)) input_wav_1 = torch.randn((1, 3000)) # create batched input batch_input_wav = torch.zeros((2, 3000)) batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0 batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1 # create padding mask padding_mask = torch.zeros((2, 3000), dtype=torch.bool) padding_mask[0, input_wav_0.shape[-1]:] = True # run batch & single output = model(source=input_wav_0, padding_mask=None)["encoder_out"] batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"] # is equal? print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3)) ``` Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data. Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility. This PR should fix the behavior and make the above code snippet / notebook run succesfully. ## PR review Gently pinging alexeib for Wav2Vec2 Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3228 Reviewed By: aconneau Differential Revision: D26373721 Pulled By: alexeib fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52 --- fairseq/models/wav2vec/wav2vec2.py | 32 +++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 783ebcfe6b..644add7b17 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -443,6 +443,21 @@ def compute_preds(self, x, y, negatives): return logits + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) + + return input_lengths.to(torch.long) + def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: @@ -460,11 +475,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): unmasked_features = features.clone() if padding_mask is not None: - extra = padding_mask.size(1) % features.size(1) - if extra > 0: - padding_mask = padding_mask[:, :-extra] - padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) - padding_mask = padding_mask.all(-1) + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() if self.post_extract_proj is not None: features = self.post_extract_proj(features) From ac90cb3085439d15af1a33cf0a0b1a6703f07413 Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov <mavlyutov@fb.com> Date: Wed, 10 Feb 2021 14:57:17 -0800 Subject: [PATCH 448/707] Extra logging to confirm OOM source Reviewed By: myleott, chtran Differential Revision: D26348808 fbshipit-source-id: 010ef00024e02c09ec35b624f0713ce5f1f387b4 --- fairseq/trainer.py | 1 + fairseq_cli/train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 49129a7fb0..24f72e2f9a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -277,6 +277,7 @@ def consolidate_optimizer(self): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint + logger.info(f"Saving checkpoint to {filename}") extra_state["metrics"] = metrics.state_dict() extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 9af7568a77..ec4890b9e6 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -236,6 +236,7 @@ def train( valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() + logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i From 7061a0ff83872ac491ba5963eb7fc04cb10d57c4 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 10 Feb 2021 16:25:25 -0800 Subject: [PATCH 449/707] better error handling for expired handles Summary: At the start of the half there were some expired handles and it was annoying to track down which datasets were responsible when sampling data among multiple datasets and which flows were running them. Lets improve the error message to address several pain points 1. Explicitly tell the user which dataset has expired handles 2. Link to a scuba query to enable the user to find all flows that have expired handles 3. Fail job if 10k handles have expired, rather than if 10k handles in a row have expired. This can detect failures from datasets that have for example 50% expired handles 4. add logging when handles fail Reviewed By: cruvadom Differential Revision: D26187820 fbshipit-source-id: 771a359ea01de80b38932921346e98cff812f2f7 --- fairseq/data/multi_corpus_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 9c7f1cb976..7207174bf3 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -126,7 +126,11 @@ def __len__(self): def __getitem__(self, index): index, key = self._map_index(index) - return self.datasets[key][index] + try: + return self.datasets[key][index] + except Exception as e: + e.args = (f"Error from {key} dataset", *e.args) + raise def collater(self, samples): """ From ee48d1b95835a0e5fa2129219d205f8d9e748b76 Mon Sep 17 00:00:00 2001 From: pritam <pritam.damania@fb.com> Date: Thu, 11 Feb 2021 09:41:54 -0800 Subject: [PATCH 450/707] Use torch pipe if available in fairseq. (#3149) Summary: fairscale.nn.Pipe has been ported to PyTorch: https://github.com/pytorch/pytorch/blob/master/torch/distributed/pipeline/sync/pipe.py#L138. As a result, modifying the pipeline transformer to use PyTorch pipe if available. This change depends on https://github.com/pytorch/pytorch/pull/50860. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3149 Test Plan: ``` python train.py ru_en_bin/ --arch transformer_iwslt_de_en_pipeline_parallel --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --max-tokens 4096 --eval-bleu --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' --eval-bleu-detok moses --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --maximize-best-checkpoint-metric --pipeline-model-parallel --pipeline-balance '[1,3,5,3,3,1]' --pipeline-devices '[0,1,0,2,3,0]' --pipeline-chunks 16 --distributed-world-size 1 --distributed-no-spawn --disable-validation --max-epoch 1 ``` Output with torch pipe: ``` 2021-01-20 16:13:35 | INFO | train | epoch 001 | loss 12.676 | nll_loss 12.331 | ppl 5151.97 | wps 5108 | ups 1.66 | wpb 3081.6 | bsz 131.6 | num_updates 380 | lr 4.75e-05 | gnorm 2.08 | train_wall 229 | wall 233 2021-01-20 16:13:36 | INFO | fairseq_cli.train | done training in 233.1 seconds ``` Output with fairscale pipe: ``` 2021-01-20 14:13:59 | INFO | train | epoch 001 | loss 12.677 | nll_loss 12.331 | ppl 5152.07 | wps 5198.9 | ups 1.69 | wpb 3081.6 | bsz 131.6 | num_updates 380 | lr 4.75e-05 | gnorm 2.08 | train_wall 224 | wall 228 2021-01-20 14:13:59 | INFO | fairseq_cli.train | done training in 228.0 seconds ``` Reviewed By: myleott Differential Revision: D26204633 Pulled By: shruti-bh fbshipit-source-id: 535f816e8d149b47fc6ba8385981accf67257257 --- .../pipeline_parallel_transformer/model.py | 128 +++++++++++++----- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 7873611214..7f30dd98bb 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -39,15 +39,47 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +TORCH_PIPE = False +RPC_INIT = False + +def import_pipe(): + global TORCH_PIPE + global RPC_INIT + try: + from torch.distributed.pipeline.sync import Pipe # noqa + global Pipe + from torch.distributed.pipeline.sync.utils import partition_model + global partition_model + from torch.distributed import rpc + import tempfile + TORCH_PIPE = True + # Initialize single process RPC agent since TORCH_PIPE requires + # RRef. RRef depends on RPC being initialized and as a result we initialize + # RPC with a single node. + tmpfile = tempfile.NamedTemporaryFile() + if not RPC_INIT: + rpc.init_rpc( + name="worker", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(tmpfile.name), + ) + ) + RPC_INIT = True + logger.info('Using torch pipe') + except ImportError: + try: + from fairscale.nn import Pipe # noqa + logger.info('Using fairscale pipe') + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") @register_model("pipeline_parallel_transformer") class PipelineParallelTransformerModel(BaseFairseqModel): def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) @@ -65,13 +97,20 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices - self.model = Pipe( - nn.Sequential(*module_list), - balance=balance, - devices=devices, - chunks=chunks, - checkpoint=checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + partition_model(nn.Sequential(*module_list), balance, devices), + chunks=chunks, + checkpoint=checkpoint, + ) + else: + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) self.encoder_max_positions = self.max_positions_helper( encoder.embedding_layer, "max_source_positions" ) @@ -87,7 +126,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): if self.training: input_lst = [src_tokens, src_lengths, prev_output_tokens] input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) - return self.model(input) + if TORCH_PIPE: + return self.model(input).local_value() + else: + return self.model(input) else: assert self.encoder is not None and self.decoder is not None, ( "encoder and decoder need to be initialized by " @@ -425,10 +467,7 @@ class TransformerEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) @@ -449,13 +488,20 @@ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): f"Sum of encoder_balance={encoder_balance} is not equal " + f"to num_encoder_modules={len(encoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*encoder_module_list), - balance=encoder_balance, - devices=encoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward(self, src_tokens, src_lengths): """ @@ -485,7 +531,10 @@ def forward(self, src_tokens, src_lengths): input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - encoder_out = self.model(input_tuple) + if TORCH_PIPE: + encoder_out = self.model(input_tuple).local_value() + else: + encoder_out = self.model(input_tuple) else: encoder_embed_output_tuple = self.embedding_layer(input_tuple) encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) @@ -561,10 +610,7 @@ def __init__( ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) @@ -586,13 +632,20 @@ def __init__( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*decoder_module_list), - balance=decoder_balance, - devices=decoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward( self, @@ -622,7 +675,10 @@ def forward( ) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - return (self.model(input_tuple),) + if TORCH_PIPE: + return (self.model(input_tuple).local_value(),) + else: + return (self.model(input_tuple),) else: embed_layer_output = self.embedding_layer(input_tuple) state = self.decoder_layers(embed_layer_output) From fd7c2a8b371c2abf645f558282221eba6833f35f Mon Sep 17 00:00:00 2001 From: Mary Williamson <marywilliamson@fb.com> Date: Thu, 11 Feb 2021 13:53:33 -0800 Subject: [PATCH 451/707] More informative exception when numpy version changes (#3231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: More informative exception when numpy version changes to ask the user to recompile Cython files # Before submitting - [With myleott ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [N/A ] Did you make sure to update the docs? - [N/A ] Did you write any new necessary tests? ## What does this PR do? Raises a more informative error to tell the user to recompile Cython files after an update to the numpy version. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3231 Reviewed By: myleott Differential Revision: D26375174 Pulled By: mwillwork fbshipit-source-id: f0a93e162bc4cf84619581110d21bea907baf7fc --- fairseq/data/data_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 1a83063542..47d8492ec9 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -307,6 +307,11 @@ def batch_by_size( "Please build Cython components with: `pip install --editable .` " "or `python setup.py build_ext --inplace`" ) + except ValueError: + raise ValueError( + "Please build (or rebuild) Cython components with: `pip install " + " --editable .` or `python setup.py build_ext --inplace`." + ) max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 From 66e1803c60272602c719a5ba75acef1c530066ef Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 11 Feb 2021 13:59:08 -0800 Subject: [PATCH 452/707] save task state in the checkpoint (#1562) Summary: this allows tasks to declare some properties they'd like to save in the checkpoint (such as a dictionary), which are loaded when checkpoint is restored. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1562 Test Plan: tested by training a new wav2vec model, then finetuning it, then decoding it and making sure the dict only loaded once, during fine tuning process (and was obtained from checkpoint for decoding) Reviewed By: myleott, gwenzek Differential Revision: D25937974 Pulled By: alexeib fbshipit-source-id: b9908042f76ec8cda943f33885eb9b1f121662ae --- examples/speech_recognition/hydra/infer.py | 4 +- examples/speech_recognition/infer.py | 13 +++--- fairseq/checkpoint_utils.py | 5 +++ fairseq/distributed/utils.py | 3 ++ fairseq/tasks/audio_pretraining.py | 28 ++++++------ fairseq/tasks/fairseq_task.py | 50 ++++++++++++++++++++-- fairseq/trainer.py | 5 ++- 7 files changed, 77 insertions(+), 31 deletions(-) diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/hydra/infer.py index 6afa066f25..b1c985bc0d 100644 --- a/examples/speech_recognition/hydra/infer.py +++ b/examples/speech_recognition/hydra/infer.py @@ -10,10 +10,9 @@ import os import shutil import sys -from argparse import Namespace from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import editdistance import torch @@ -26,7 +25,6 @@ CommonEvalConfig, DatasetConfig, DistributedTrainingConfig, FairseqDataclass, GenerationConfig) -from fairseq.dataclass.initialize import hydra_init from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.progress_bar import BaseProgressBar from fairseq.models.fairseq_model import FairseqModel diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 5a582c54af..f4efbf39c8 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -144,11 +144,11 @@ def process_predictions( print( "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] ) - # only score top hypothesis - if not args.quiet: - logger.debug("HYPO:" + hyp_words) - logger.debug("TARGET:" + tgt_words) - logger.debug("___________________") + + if not args.quiet: + logger.info("HYPO:" + hyp_words) + logger.info("TARGET:" + tgt_words) + logger.info("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() @@ -216,7 +216,6 @@ def main(args, task=None, model_state=None): use_cuda = torch.cuda.is_available() and not args.cpu - logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) @@ -227,7 +226,7 @@ def main(args, task=None, model_state=None): task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) - models, saved_cfg = checkpoint_utils.load_model_ensemble( + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(args.path), arg_overrides=ast.literal_eval(args.model_overrides), task=task, diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 2f209b6b39..55a546356e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -349,6 +349,9 @@ def load_model_ensemble_and_task( if task is None: task = tasks.setup_task(cfg.task) + if "task_state" in state: + task.load_state_dict(state["task_state"]) + # build model for ensemble model = task.build_model(cfg.model) @@ -403,6 +406,7 @@ def save_state( num_updates, optim_history=None, extra_state=None, + task=None, **kwargs, ): from fairseq import utils @@ -425,6 +429,7 @@ def save_state( } ], "extra_state": extra_state, + "task_state": task.state_dict() if task is not None else {} } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index e3d8e1e0d3..710ca18628 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -325,6 +325,9 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs): main(cfg, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.barrier(get_global_group()) + def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 7c82777331..92685160d4 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -15,7 +15,6 @@ from omegaconf import MISSING from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders -from fairseq.data.data_utils import post_process from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import GenerationConfig @@ -98,16 +97,14 @@ class AudioPretrainingTask(FairseqTask): def __init__( self, cfg: AudioPretrainingConfig, - source_dictionary=None, - target_dictionary=None, ): super().__init__(cfg) - self._target_dictionary = target_dictionary - self._source_dictionary = source_dictionary if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" self.blank_symbol = "<s>" + self.state.add_factory("target_dictionary", self.load_target_dictionary) + @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): """Setup the task (e.g., load dictionaries). @@ -116,13 +113,13 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): cfg (AudioPretrainingConfig): configuration of this task """ - if cfg.labels: - dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") - target_dictionary = Dictionary.load(dict_path) - else: - target_dictionary = None + return cls(cfg) - return cls(cfg, target_dictionary=target_dictionary) + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data @@ -136,7 +133,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=task_cfg.sample_rate, + sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, @@ -146,7 +143,6 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") - labels = [] with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) @@ -166,18 +162,18 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=task_cfg.autoregressive, + add_to_input=task_cfg.get('autoregressive', False), ) @property def source_dictionary(self): - return self._source_dictionary + return None @property def target_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language model.""" - return self._target_dictionary + return self.state.target_dictionary def max_positions(self): """Maximum input length supported by the encoder.""" diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 34264bdc01..04025023fa 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,7 +7,7 @@ import os import warnings from argparse import Namespace -from typing import List +from typing import Any, Callable, Dict, List import torch from fairseq import metrics, search, tokenizer, utils @@ -20,10 +20,45 @@ logger = logging.getLogger(__name__) +class StatefulContainer(object): + + _state: Dict[str, Any] = dict() + _factories: Dict[str, Callable[[], Any]] = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + class FairseqTask(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. """ @classmethod @@ -42,10 +77,13 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() + cfg: FairseqDataclass + datasets: Dict[str, FairseqDataset] = dict() + dataset_to_epoch_iter: Dict[FairseqDataset, Any] = dict() + state: StatefulContainer = StatefulContainer() + def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg - self.datasets = {} - self.dataset_to_epoch_iter = {} @classmethod def load_dictionary(cls, filename): @@ -514,6 +552,12 @@ def reduce_metrics(self, logging_outputs, criterion): criterion.__class__.reduce_metrics(logging_outputs) + def state_dict(self): + return self.state.state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.state.merge_state_dict(state_dict) + def max_positions(self): """Return the max input length allowed by the task.""" return None diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 24f72e2f9a..e860fb1832 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -288,8 +288,9 @@ def save_checkpoint(self, filename, extra_state): self.optimizer, self.lr_scheduler, self.get_num_updates(), - self._optim_history, - extra_state, + optim_history=self._optim_history, + extra_state=extra_state, + task=self.task, ) logger.info(f"Finished saving checkpoint to {filename}") From 138265ce15d198e6baceae334effed8fb384a286 Mon Sep 17 00:00:00 2001 From: Kritika Singh <skritika@fb.com> Date: Thu, 11 Feb 2021 15:39:32 -0800 Subject: [PATCH 453/707] Make wav2vec_asr encoder compatible with pyspeech fst decoder Summary: - I don't think there is a convention for the shapes of `encoder_out` and `encoder_padding_mask` in fairseq but `fst_external_decoder.py` expects `encoder_padding_mask` to be of shape T x B. `encoder_padding_mask` also seems unused in the fairseq [CTC criterion and w2l decoder integration](https://fburl.com/diffusion/ms1zi2px) so taking the easy way out and changing its shape. - Also checking in some changes to the pyspeech audio_pretraining task required to make decoding work Reviewed By: alexeib Differential Revision: D26382442 fbshipit-source-id: 87c8f9433026c0e011847f4e2e094beb2cd2182c --- fairseq/models/wav2vec/wav2vec2_asr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index bbd2ab9ec5..9cd17b635c 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -158,7 +158,7 @@ def get_normalized_probs(self, net_output, log_probs): def get_logits(self, net_output): logits = net_output["encoder_out"] - padding = net_output["encoder_padding_mask"] + padding = net_output["padding_mask"] if padding is not None and padding.any(): padding = padding.T logits[padding][...,0] = 0 @@ -359,7 +359,7 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): return { "encoder_out": x, # T x B x C - "encoder_padding_mask": padding_mask, # B x T + "encoder_padding_mask": padding_mask.transpose(0, 1), # T x B "padding_mask": padding_mask, } @@ -539,7 +539,7 @@ def extract_features( x, attn, _ = layer( x, encoder_out["encoder_out"] if encoder_out is not None else None, - encoder_out["encoder_padding_mask"] + encoder_out["padding_mask"] if encoder_out is not None else None, incremental_state, From 1d5b075e3f30fd3f28af4c8851e8659285ded230 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 11 Feb 2021 18:12:11 -0800 Subject: [PATCH 454/707] fix fairseqlm decoder with flashlight chnages (#1617) Summary: fixes fairseqlm integration with flashlight (formerly wav2letter) decoder Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1617 Reviewed By: xuqiantong Differential Revision: D26415650 Pulled By: alexeib fbshipit-source-id: 813684ba55047e92378f508101ff1eec55754420 --- examples/speech_recognition/w2l_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 706d9f1433..8b158293a0 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -431,7 +431,7 @@ def __init__(self, args, tgt_dict): self.silence, self.blank, self.unk_word, - self.asg_transitions, + [], self.unit_lm, ) else: From 506a8e0f45c1206b1306276fed9cec92c7061dd0 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 11 Feb 2021 21:22:42 -0800 Subject: [PATCH 455/707] seq2seq autoregressive flag check (#1618) Summary: raise an exception if trying to use wav2vec seq2seq finetuning without autoregressive flag Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1618 Reviewed By: xuqiantong Differential Revision: D26417249 Pulled By: alexeib fbshipit-source-id: 777b6d170b0f8196746e03b399e4d7c21ac0b837 --- fairseq/models/wav2vec/wav2vec2_asr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 9cd17b635c..afa51299b6 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -220,6 +220,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) + autoregressive: bool = II("task.autoregressive") @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) @@ -231,6 +232,8 @@ def __init__(self, encoder, decoder): def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" + assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): From 7ffb40d9c8e33b272e85604892a0935d8e57bb0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pedro=20Megid=20Carrilho?= <joaopedro@nindoo.ai> Date: Fri, 12 Feb 2021 00:26:03 -0800 Subject: [PATCH 456/707] Fix typo Wav2Vec2 README.md (#3240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3240 Reviewed By: aconneau Differential Revision: D26420073 Pulled By: alexeib fbshipit-source-id: 5939535b945a64e61d655cd36dc955ae46410bfb --- examples/wav2vec/README.md | 294 ------------------------------------- 1 file changed, 294 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 663adf97dc..e69de29bb2 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -1,294 +0,0 @@ -# wav2vec 2.0 - -wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). - -We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). - -We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). - -## Pre-trained models - -Model | Finetuning split | Dataset | Model -|---|---|---|--- -Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) -Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) -Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) -Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) -Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) -Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) -Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) -Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) -Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) - -\* updated (Oct. 24, 2020) - -We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: - -Model | Architecture | Hours | Languages | Datasets | Model -|---|---|---|---|---|--- -XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) - -The XLSR model uses the following datasets for multilingual pretraining: - -* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* - -* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). - -* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* - - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) - -### Prepare training data manifest: - -First, install the `soundfile` library: -```shell script -pip install soundfile -``` - -Next, run: - -```shell script -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid -``` - -$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. - -$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. -To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a -separately pre-processed manifest file. - -### Train a wav2vec 2.0 base model: - -This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper - -Note that the input is expected to be single channel, sampled at 16 kHz - -```shell script -$ fairseq-hydra-train \ - task.data=/path/to/data \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ - --config-name wav2vec2_base_librispeech -``` - -Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k - -### Train a wav2vec 2.0 large model: - -This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper - -```shell script -$ fairseq-hydra-train \ - task.data=/path/to/data \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ - --config-name wav2vec2_large_librivox -``` - -Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k - -### Fine-tune a pre-trained model with CTC: - -Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. -A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). -An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: - -```shell script -split=train -$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split -``` - -Fine-tuning on 100h of Librispeech with letter targets: -```shell script -$ fairseq-hydra-train \ - distributed_training.distributed_port=$PORT \ - task.data=/path/to/data \ - model.w2v_path=/path/to/model.pt \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ - --config-name base_100h -``` - -There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. -You can specify the right config via the `--config-name` parameter. - -Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k - -Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). -If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. - -### Evaluating a CTC model: - -Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. - -Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). -Be sure to upper-case the language model vocab after downloading it. - -Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). - -Next, run the evaluation command: - -```shell script -$subset=dev_other -python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ ---nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ ---lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ ---post-process letter -``` - -To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. - -## Use wav2vec 2.0 with 🤗Transformers: - -Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since vesion 4.3. - -Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) -and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). - -Usage example: - -```python -# !pip install transformers -import soundfile as sf -import torch -from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer - -# load pretrained model -tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") -model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") - -# load audio -audio_input, _ = sf.read("path/to/audio/file") - -# transcribe -input_values = tokenizer(audio_input, return_tensors="pt").input_values -logits = model(input_values).logits -predicted_ids = torch.argmax(logits, dim=-1) -transcription = tokenizer.batch_decode(predicted_ids)[0] -``` - -# wav2vec - -Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). - -## Pre-trained models - -Description | Dataset | Model ----|---|--- -Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) - -#### Example usage: -```python -import torch -import fairseq - -cp_path = '/path/to/wav2vec.pt' -model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) -model = model[0] -model.eval() - -wav_input_16khz = torch.randn(1,10000) -z = model.feature_extractor(wav_input_16khz) -c = model.feature_aggregator(z) -``` - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) - -### Prepare training data manifest: - -``` -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav -``` - -### Train a wav2vec model: - -``` -$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ ---arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ ---conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ ---conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ ---skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ ---max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test -``` - -### Extract embeddings from the downstream task data: - -``` -$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ ---model /model/path/checkpoint_best.pt --split train valid test -``` - -# vq-wav2vec - -Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). - -These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). - -## Pre-trained models - -Description | Dataset | Model ----|---|--- -vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) -vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) -Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) - -#### Example usage: -```python -import torch -import fairseq - -cp = torch.load('/path/to/vq-wav2vec.pt') -model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) -model = model[0] -model.eval() - -wav_input_16khz = torch.randn(1,10000) -z = model.feature_extractor(wav_input_16khz) -_, idxs = model.vector_quantizer.forward_idx(z) -print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model -``` - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) - -### Prepare training data manifest: - -``` -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav -``` - -### Train a gumbel vq-wav2vec model: - -``` -$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ ---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ ---optimizer adam --lr 1e-05 --lr-scheduler cosine \ ---conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ ---conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ ---activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ ---log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ ---combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ ---warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ ---max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test -``` - -for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs. - -### Tokenize audio data (e.g. for BERT training): - -``` -$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \ ---checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv -``` From f3b6f5817fbee59057ae2506f01502ea3c301b4b Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Fri, 12 Feb 2021 11:32:52 -0800 Subject: [PATCH 457/707] Fix w2v readme (#1621) Summary: somehow merging previous pull request deleted the readme Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1621 Reviewed By: michaelauli Differential Revision: D26429893 Pulled By: alexeib fbshipit-source-id: 3e6ed1e4698e67e56e0b88d304f42907a4f6cf41 --- examples/wav2vec/README.md | 294 +++++++++++++++++++++++++++++++++++++ 1 file changed, 294 insertions(+) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e69de29bb2..e95f292b51 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -0,0 +1,294 @@ +# wav2vec 2.0 + +wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). + +We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). + +We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). + +## Pre-trained models + +Model | Finetuning split | Dataset | Model +|---|---|---|--- +Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) +Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) +Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) +Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) +Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) +Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) +Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) + +\* updated (Oct. 24, 2020) + +We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: + +Model | Architecture | Hours | Languages | Datasets | Model +|---|---|---|---|---|--- +XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) + +The XLSR model uses the following datasets for multilingual pretraining: + +* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* + +* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). + +* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* + + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +First, install the `soundfile` library: +```shell script +pip install soundfile +``` + +Next, run: + +```shell script +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + +### Train a wav2vec 2.0 base model: + +This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper + +Note that the input is expected to be single channel, sampled at 16 kHz + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_base_librispeech +``` + +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k + +### Train a wav2vec 2.0 large model: + +This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox +``` + +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k + +### Fine-tune a pre-trained model with CTC: + +Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. +A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: + +```shell script +split=train +$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split +``` + +Fine-tuning on 100h of Librispeech with letter targets: +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h +``` + +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the `--config-name` parameter. + +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k + +Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. + +### Evaluating a CTC model: + +Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. + +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). +Be sure to upper-case the language model vocab after downloading it. + +Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). + +Next, run the evaluation command: + +```shell script +$subset=dev_other +python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ +--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ +--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ +--post-process letter +``` + +To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. + +## Use wav2vec 2.0 with 🤗Transformers: + +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.3. + +Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) +and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). + +Usage example: + +```python +# !pip install transformers +import soundfile as sf +import torch +from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer + +# load pretrained model +tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") +model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") + +# load audio +audio_input, _ = sf.read("path/to/audio/file") + +# transcribe +input_values = tokenizer(audio_input, return_tensors="pt").input_values +logits = model(input_values).logits +predicted_ids = torch.argmax(logits, dim=-1) +transcription = tokenizer.batch_decode(predicted_ids)[0] +``` + +# wav2vec + +Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) + +#### Example usage: +```python +import torch +import fairseq + +cp_path = '/path/to/wav2vec.pt' +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +c = model.feature_aggregator(z) +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) + +### Prepare training data manifest: + +``` +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a wav2vec model: + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test +``` + +### Extract embeddings from the downstream task data: + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ +--model /model/path/checkpoint_best.pt --split train valid test +``` + +# vq-wav2vec + +Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). + +These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) +vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) +Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) + +#### Example usage: +```python +import torch +import fairseq + +cp = torch.load('/path/to/vq-wav2vec.pt') +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +_, idxs = model.vector_quantizer.forward_idx(z) +print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +``` +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a gumbel vq-wav2vec model: + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ +--optimizer adam --lr 1e-05 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ +--log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ +--combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ +--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ +--max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test +``` + +for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs. + +### Tokenize audio data (e.g. for BERT training): + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \ +--checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv +``` From 02803a1be45642b4c2f9c2970a4f4ae645a2dccf Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Fri, 12 Feb 2021 14:04:21 -0800 Subject: [PATCH 458/707] broadcast the whole optimizer state to each rank Summary: OSS removed the 'partition' key in their state dict to accommodate for changing partition size. This requires an update on the fairseq side to not look into the parameter partition, just broadcast everything, and let the optimizer on each rank decides which parameters are relevant. This diff also needs D26419095 to function completely, and blefaudeux has made fixes upstream in https://github.com/facebookresearch/fairscale/pull/383 Reviewed By: myleott Differential Revision: D26382917 fbshipit-source-id: 95af1022be59e88814748acaee36a1a350f7dc5b --- fairseq/optim/shard.py | 58 ++++++++---------------------------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 3d025a23ca..3c1b34ae60 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -5,11 +5,11 @@ from typing import Any, Dict -import torch +from fairseq.distributed import utils try: - from fairscale.optim import OSS, utils + from fairscale.optim import OSS _has_fairscale = True except ImportError: @@ -38,53 +38,15 @@ def broadcast_global_state_dict( self, state_dict: Dict[str, Any] ) -> Dict[str, Any]: """ - Broadcasts the relevant parts of a global state dict from rank 0 to - all other ranks. + Broadcasts the entire state_dict to all other ranks + each rank is responsible to load their own partition of data """ - if self.rank == 0: - - # Create template state dict for all other keys not related to sharding - template_state_dict = { - key: state_dict[key] - for key in state_dict - if key not in ("param_groups", "state") - } - template_state_dict["local_state_dict"] = True - - for dst_rank in range(self.world_size): - # Get the dst_rank's param_groups shard - send_state = { - "param_groups": state_dict["param_groups"][ - state_dict["partition"][dst_rank][0] : state_dict[ - "partition" - ][dst_rank][1] - ], - "state": state_dict["state"][dst_rank], - } - send_state.update(template_state_dict) - - if dst_rank == 0: - recv_state = send_state - else: - utils.broadcast_object( - send_state, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - else: - empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) - for dst_rank in range(1, self.world_size): - state = utils.broadcast_object( - empty_buffer, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - if dst_rank == self.rank: - recv_state = state - - return recv_state + return utils.broadcast_object( + state_dict, + src_rank=0, + group=self.group, + dist_device=self._device, + ) torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) From 09945b45d4e2608563b1b18c3bbe289bf9351529 Mon Sep 17 00:00:00 2001 From: cordercorder <2205722269@qq.com> Date: Fri, 12 Feb 2021 14:35:55 -0800 Subject: [PATCH 459/707] Fixes bugs of evaluation with BLEU score when training with multi-gpus. (#3237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …ith BLEU scores # Before submitting - [no] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [yes] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [no need] Did you make sure to update the docs? - [no need] Did you write any new necessary tests? ## What does this PR do? Fixes bugs of evaluation with BLEU score when training with multi-gpus. But no error will happend if there is no distributed training. when --eval-bleu is set to be `True` (default it is `False` and the best checkpoint is selected according to loss) and training with multi-gpus (when the number of gpu which participate in distributed training is greater than 1), following error will happend. ```bash Traceback (most recent call last): Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in <module> File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in <module> Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in <module> sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main distributed_utils.call_main(cfg, main)distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, **kwargs) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main main(cfg, **kwargs) main(cfg, **kwargs)rder/fairseq/fairseq_cli/train.py", line 143, in main File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in <module> return func(*args, **kwds) return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save cfg, trainer, task, epoch_itr, valid_subsets, end_of_epochsys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step return func(*args, **kwds)distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, **kwargs) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics self.task.reduce_metrics(logging_outputs, self.get_criterion())valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return func(*args, **kwds)metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return self.numpy() TypeError: can't convert cuda:2 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. return self.numpy() TypeError: can't convert cuda:3 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. return self.numpy() TypeError: can't convert cuda:1 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return self.numpy() TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/distributed/launch.py", line 261, in <module> main() File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/distributed/launch.py", line 257, in main cmd=cmd) subprocess.CalledProcessError: Command '['/data/cordercorder/anaconda3/envs/nmt/bin/python', '-u', '/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train', '--local_rank=3', 'tiny_data_bin', '--distributed-world-size', '4', '--arch', 'transformer', '--share-decoder-input-output-embed', '--optimizer', 'adam', '--adam-betas', '(0.9, 0.98)', '--clip-norm', '0.0', '--lr-scheduler', 'inverse_sqrt', '--warmup-init-lr', '1e-07', '--warmup-updates', '3000', '--lr', '0.0005', '--stop-min-lr', '1e-09', '--dropout', '0.25', '--weight-decay', '0.0001', '--criterion', 'label_smoothed_cross_entropy', '--label-smoothing', '0.1', '--max-tokens', '5000', '--batch-size', '64', '--update-freq', '4', '--max-epoch', '30', '--save-dir', 'checkpoint', '--skip-invalid-size-inputs-valid-test', '--eval-bleu', '--eval-bleu-args', '{"beam": 5}', '--eval-bleu-remove-bpe', 'sentencepiece', '--eval-bleu-print-samples', '--eval-tokenized-bleu', '--best-checkpoint-metric', 'bleu', '--maximize-best-checkpoint-metric', '--validate-interval-updates', '1']' returned non-zero exit status 1. ``` The error is cased by the fact that the numpy of version 1.20.1 does't support codes like following: ```python import torch import numpy as np a = torch.tensor(0, device="cuda:0") b = np.array([a]) ``` The above codes will lead to error: "TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.", but the codes run well if the numpy version is 1.18.1 or 1.17.0 (when the numpy version is below 1.20.0, it is ok, I guess). However, it seems like that the latest version of fairseq need a numpy package of version 1.20.0 or higher (issue https://github.com/pytorch/fairseq/issues/3203 ). ### Reproduce the error Download the source code of fairseq (commit ID: 7061a0ff83872ac491ba5963eb7fc04cb10d57c4) and run following code: ```bash export CUDA_VISIBLE_DEVICES=0,1,2,3 data_bin_dir=tiny_data_bin python -m torch.distributed.launch --nproc_per_node=4 \ --master_addr="127.0.0.1" \ --master_port=12345 \ $(which fairseq-train) ${data_bin_dir} \ --distributed-world-size 4 \ --arch transformer \ --share-decoder-input-output-embed \ --optimizer adam \ --adam-betas '(0.9, 0.98)' \ --clip-norm 0.0 \ --lr-scheduler inverse_sqrt \ --warmup-init-lr 1e-07 \ --warmup-updates 3000 \ --lr 0.0005 \ --stop-min-lr 1e-09 \ --dropout 0.25 \ --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy \ --label-smoothing 0.1 \ --max-tokens 5000 \ --batch-size 64 \ --update-freq 4 \ --max-epoch 30 \ --save-dir checkpoint \ --skip-invalid-size-inputs-valid-test \ --eval-bleu \ --eval-bleu-args '{"beam": 5}' \ --eval-bleu-remove-bpe sentencepiece \ --eval-bleu-print-samples \ --eval-tokenized-bleu \ --best-checkpoint-metric bleu \ --maximize-best-checkpoint-metric \ --validate-interval-updates 1 ``` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3237 Reviewed By: myleott Differential Revision: D26429732 Pulled By: alexeib fbshipit-source-id: bc887ce952d28541cb07dbbdc7e80e99428a6b34 --- fairseq/tasks/translation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 90635d882f..331f685495 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -394,7 +394,11 @@ def reduce_metrics(self, logging_outputs, criterion): if self.cfg.eval_bleu: def sum_logs(key): - return sum(log.get(key, 0) for log in logging_outputs) + import torch + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result counts, totals = [], [] for i in range(EVAL_BLEU_ORDER): From 5ac5e8a20a7a914698f9970c2a384f14015ece3d Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Fri, 12 Feb 2021 21:18:23 -0800 Subject: [PATCH 460/707] fix sharing objects between tasks (#1623) Summary: fixes previous change that changes state/dataset/etc to class variables instead of instance variables Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1623 Reviewed By: michaelauli Differential Revision: D26439560 Pulled By: alexeib fbshipit-source-id: ab9e75a425a47ac7ace006419259e254770e560e --- fairseq/tasks/fairseq_task.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 04025023fa..375b5277b9 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -78,12 +78,15 @@ def logging_outputs_can_be_summed(criterion) -> bool: return criterion.logging_outputs_can_be_summed() cfg: FairseqDataclass - datasets: Dict[str, FairseqDataset] = dict() - dataset_to_epoch_iter: Dict[FairseqDataset, Any] = dict() - state: StatefulContainer = StatefulContainer() + datasets: Dict[str, FairseqDataset] + dataset_to_epoch_iter: Dict[FairseqDataset, Any] + state: StatefulContainer = None def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() @classmethod def load_dictionary(cls, filename): @@ -553,10 +556,13 @@ def reduce_metrics(self, logging_outputs, criterion): criterion.__class__.reduce_metrics(logging_outputs) def state_dict(self): - return self.state.state_dict + if self.state is not None: + return self.state.state_dict + return {} def load_state_dict(self, state_dict: Dict[str, Any]): - self.state.merge_state_dict(state_dict) + if self.state is not None: + self.state.merge_state_dict(state_dict) def max_positions(self): """Return the max input length allowed by the task.""" From 43415b44781af6ac9c10adce0ae2a7d26d611bd1 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 461/707] Prepend embedding layer when return_all_hiddens=True in TransformerEncoder (#1559) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1559 This matches the behavior of RobertaEncoder. Test Plan: Imported from OSS Reviewed By: gwenzek Differential Revision: D25936937 Pulled By: myleott fbshipit-source-id: 795ec8d50298a41d9e9638101436faa01cdf1586 --- fairseq/models/transformer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 362d9b28d6..78762ef924 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -435,14 +435,21 @@ def forward( """ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) - # B x T x C -> T x B x C - x = x.transpose(0, 1) - # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) + # account for padding while computing the representation + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + encoder_states = [] + if return_all_hiddens: + encoder_states.append(x) + # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) @@ -454,7 +461,7 @@ def forward( x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in - # `foward` so we use a dictionary instead. + # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { From 54423d3b22a3e7f536e02e9e5445cef9becbd60d Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 462/707] refactor RobertaEncoder (#1560) Summary: This is long overdue, but finally deprecating the RobertaEncoder components and just using TransformerEncoder directly. This will make it easier for some upcoming online backtranslation changes, and will eventually make migrating it to dataclasses/Hydra easier too. It also fixes some longstanding inconsistencies in layernorm placement in the model parallel roberta code. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1560 Test Plan: - confirmed that training gives identical losses as before: https://gist.github.com/myleott/9a4d213fb88a02b00094ea074f5a2e2d - confirmed that old roberta models can be loaded and produce identical results - confirmed that old linformer models can be loaded and produce identical results (reran commands from D25938236 (https://github.com/pytorch/fairseq/commit/bf54551cafa13678c0254d2c20354cc026cc0bac)) - confirmed that old model parallel models can be loaded and produce identical results: ``` python -m fairseq_cli.validate --path checkpoint.mp1/checkpoint_last.pt --task dummy_masked_lm --criterion masked_lm --max-sentences 8 --dataset-size 100 --model-parallel-size 2 --distributed-world-size 2 before: 2021-01-19 19:04:14 | INFO | valid | | valid on 'valid' subset | loss 14.62 | ppl 25174.3 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 19:06:59 | INFO | valid | | valid on 'valid' subset | loss 14.62 | ppl 25174.3 | wps 0 | wpb 53248 | bsz 104 ``` Reviewed By: gwenzek, ngoyal2707 Differential Revision: D25937145 Pulled By: myleott fbshipit-source-id: 1ce0bc93e28e03fb926534ea4134684a49232599 --- .../linformer_src/models/linformer_roberta.py | 71 ++------- .../modules/linformer_sentence_encoder.py | 137 ++-------------- .../linformer_sentence_encoder_layer.py | 83 ++-------- .../model_parallel/models/roberta/model.py | 148 +++++------------- fairseq/model_parallel/models/transformer.py | 7 +- fairseq/model_parallel/modules/__init__.py | 6 - .../modules/transformer_sentence_encoder.py | 59 ------- .../transformer_sentence_encoder_layer.py | 77 --------- fairseq/models/roberta/model.py | 100 ++++++++---- fairseq/models/transformer.py | 1 + 10 files changed, 161 insertions(+), 528 deletions(-) delete mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder.py delete mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py diff --git a/examples/linformer/linformer_src/models/linformer_roberta.py b/examples/linformer/linformer_src/models/linformer_roberta.py index be5d8e85ec..18ad44f079 100644 --- a/examples/linformer/linformer_src/models/linformer_roberta.py +++ b/examples/linformer/linformer_src/models/linformer_roberta.py @@ -11,9 +11,15 @@ import torch from fairseq import utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.roberta import RobertaEncoder, RobertaModel +from fairseq.models.roberta import ( + init_bert_params, + roberta_base_architecture, + roberta_large_architecture, + RobertaEncoder, + RobertaModel, +) -from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder +from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder logger = logging.getLogger(__name__) @@ -66,30 +72,10 @@ def __init__(self, args, dictionary): super().__init__(args, dictionary) self.register_buffer("version", torch.tensor(2)) - def build_encoder(self, args, dictionary): - return LinformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - compressed=args.compressed, - shared_kv_compressed=args.shared_kv_compressed, - shared_layer_kv_compressed=args.shared_layer_kv_compressed, - freeze_compress=args.freeze_compress, - ) + def build_encoder(self, args, dictionary, embed_tokens): + encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -115,25 +101,11 @@ def upgrade_state_dict_named(self, state_dict, name): @register_model_architecture("linformer_roberta", "linformer_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.compressed = getattr(args, "compressed", 4) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) args.freeze_compress = getattr(args, "freeze_compress", 0) + roberta_base_architecture(args) @register_model_architecture("linformer_roberta", "linformer_roberta_base") @@ -143,18 +115,5 @@ def linformer_roberta_base_architecture(args): @register_model_architecture("linformer_roberta", "linformer_roberta_large") def linformer_roberta_large_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 24) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.compressed = getattr(args, "compressed", 4) - args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) - args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) + roberta_large_architecture(args) + base_architecture(args) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py index 3cdca01235..44f7989bd8 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py @@ -6,12 +6,12 @@ import math import torch.nn as nn -from fairseq.modules import TransformerSentenceEncoder +from fairseq.models.transformer import TransformerEncoder -from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer +from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer -class LinformerSentenceEncoder(TransformerSentenceEncoder): +class LinformerTransformerEncoder(TransformerEncoder): """ Implementation for a Bi-directional Linformer based Sentence Encoder used in BERT/XLM style pre-trained models. @@ -35,135 +35,20 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder): in format B x C. """ - def __init__( - self, - padding_idx: int, - vocab_size: int, - num_encoder_layers: int = 6, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - layerdrop: float = 0.0, - max_seq_len: int = 256, - num_segments: int = 2, - use_position_embeddings: bool = True, - offset_positions_by_padding: bool = True, - encoder_normalize_before: bool = False, - apply_bert_init: bool = False, - activation_fn: str = "relu", - learned_pos_embedding: bool = True, - embed_scale: float = None, - freeze_embeddings: bool = False, - n_trans_layers_to_freeze: int = 0, - export: bool = False, - traceable: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - compressed: int = 4, - shared_kv_compressed: int = 0, - shared_layer_kv_compressed: int = 0, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.shared_kv_compressed = shared_kv_compressed - self.shared_layer_kv_compressed = shared_layer_kv_compressed + def __init__(self, args, dictionary, embed_tokens): self.compress_layer = None - self.freeze_compress = freeze_compress - - super().__init__( - padding_idx=padding_idx, - vocab_size=vocab_size, - num_encoder_layers=num_encoder_layers, - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - layerdrop=layerdrop, - max_seq_len=max_seq_len, - num_segments=num_segments, - use_position_embeddings=use_position_embeddings, - offset_positions_by_padding=offset_positions_by_padding, - encoder_normalize_before=encoder_normalize_before, - apply_bert_init=apply_bert_init, - activation_fn=activation_fn, - learned_pos_embedding=learned_pos_embedding, - embed_scale=embed_scale, - freeze_embeddings=freeze_embeddings, - n_trans_layers_to_freeze=n_trans_layers_to_freeze, - export=export, - traceable=traceable, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args, dictionary, embed_tokens) - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - q_noise, - qn_block_size, - ): - if self.shared_layer_kv_compressed == 1 and self.compress_layer is None: + def build_encoder_layer(self, args): + if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None: compress_layer = nn.Linear( - self.max_seq_len, self.max_seq_len // self.compressed + self.args.max_positions, + self.args.max_positions // self.args.compressed, ) # intialize parameters for compressed layer nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) - if self.freeze_compress == 1: + if self.args.freeze_compress == 1: compress_layer.weight.requires_grad = False self.compress_layer = compress_layer - return LinformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, - shared_compress_layer=( - None if self.shared_layer_kv_compressed == 0 else self.compress_layer - ), - freeze_compress=self.freeze_compress, - ) - - def upgrade_state_dict_named(self, state_dict, name): - prefix = name + "." if name != "" else "" - items_to_add = {} - keys_to_remove = [] - - # update key name for shared layer in new version of code - for k in state_dict.keys(): - if k.startswith(prefix + "compress_layer"): - if self.shared_layer_kv_compressed: - for layer_idx in range(len(self.layers)): - new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( - layer_idx, - k[len(prefix + "compress_layer.") :], - ) - items_to_add[new_k] = state_dict[k] - - for k in keys_to_remove: - del state_dict[k] - - for key, value in items_to_add.items(): - state_dict[key] = value + return LinformerTransformerEncoderLayer(args, self.compress_layer) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py index 0b80fabefe..7e2caa0340 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py @@ -3,88 +3,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable - import torch from fairseq import utils -from fairseq.modules import TransformerSentenceEncoderLayer +from fairseq.modules import TransformerEncoderLayer from .multihead_linear_attention import MultiheadLinearAttention -class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): +class LinformerTransformerEncoderLayer(TransformerEncoderLayer): """ Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained models. """ - def __init__( - self, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - activation_fn: str = "relu", - export: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - init_fn: Callable = None, - compressed: int = 1, - max_seq_len: int = 256, - shared_kv_compressed: int = 0, - shared_compress_layer: any = None, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.max_seq_len = max_seq_len - self.shared_kv_compressed = shared_kv_compressed - self.freeze_compress = freeze_compress - + def __init__(self, args, shared_compress_layer): # wrap in a list so it's not automatically registered by PyTorch self.shared_compress_layer = [shared_compress_layer] - super().__init__( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args) + self.register_buffer("version", torch.tensor(2)) - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - self_attention, - q_noise, - qn_block_size, - ): + def build_self_attention(self, embed_dim, args): return MultiheadLinearAttention( embed_dim, - num_attention_heads, - dropout=dropout, + args.encoder_attention_heads, + dropout=args.dropout, self_attention=True, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, + q_noise=args.quant_noise_pq, + qn_block_size=args.quant_noise_pq_block_size, + compressed=args.compressed, + max_seq_len=args.max_positions, + shared_kv_compressed=args.shared_kv_compressed, shared_compress_layer=self.shared_compress_layer[0], - freeze_compress=self.freeze_compress, + freeze_compress=args.freeze_compress, ) def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) prefix = name + "." if name != "" else "" # some old checkpoints had weight sharing implemented incorrectly @@ -101,14 +57,7 @@ def upgrade_state_dict_named(self, state_dict, name): self.shared_compress_layer[0].weight.size(0), ) ] - self.self_attn = self.build_self_attention( - self.embedding_dim, - self.num_attention_heads, - dropout=self.attention_dropout, - self_attention=True, - q_noise=self.q_noise, - qn_block_size=self.qn_block_size, - ) + self.self_attn = self.build_self_attention(self.embed_dim, self.args) # delete shared_compress_layer, since it's already copied to # self_attn.compress_k.weight del state_dict[f"{prefix}shared_compress_layer.weight"] diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index 68ad88d2a5..77a80ef720 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -12,16 +12,15 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder -from fairseq.models import FairseqEncoder, register_model, register_model_architecture +from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder +from fairseq.models import register_model, register_model_architecture from fairseq.models.roberta import ( - RobertaClassificationHead, + roberta_base_architecture, + roberta_prenorm_architecture, RobertaEncoder, - RobertaLMHead, RobertaModel, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder -from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.modules import LayerNorm try: @@ -29,7 +28,7 @@ copy_to_model_parallel_region, gather_from_model_parallel_region, ColumnParallelLinear, - RowParallelLinear, + VocabParallelEmbedding, ) has_megatron_submodule = True @@ -48,7 +47,15 @@ def __init__(self, args, encoder): @staticmethod def add_args(parser): - super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser) + RobertaModel.add_args(parser) + parser.add_argument( + "--no-final-layer-norm", + action="store_true", + help=( + "don't add final layernorm (only applicable when " + "--encoder-normalize-before=True" + ), + ) @classmethod def build_model(cls, args, task): @@ -165,121 +172,52 @@ def forward(self, features, **kwargs): return x -class ModelParallelRobertaEncoder(FairseqEncoder): - """RoBERTa encoder. - - Implements the :class:`~fairseq.models.FairseqDecoder` interface required - by :class:`~fairseq.models.FairseqLanguageModel`. - """ +class ModelParallelRobertaEncoder(RobertaEncoder): + """RoBERTa encoder.""" def __init__(self, args, dictionary): - super().__init__(dictionary) - self.args = args - - # RoBERTa is a sentence encoder model, so users will intuitively trim - # encoder layers. However, the implementation uses the fairseq decoder, - # so we fix here. - if args.encoder_layers_to_keep: - args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - args.decoder_layers_to_keep = args.encoder_layers_to_keep - args.encoder_layers_to_keep = None - - self.sentence_encoder = ModelParallelTransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=False, - apply_bert_init=False, - activation_fn=args.activation_fn, - ) - self.lm_head = ModelParallelRobertaLMHead( - embed_dim=args.encoder_embed_dim, - output_dim=len(dictionary), - activation_fn=args.activation_fn, - weight=self.sentence_encoder.embed_tokens.weight, - ) - - def forward( - self, - src_tokens, - features_only=False, - return_all_hiddens=False, - masked_tokens=None, - **unused - ): - """ - Args: - src_tokens (LongTensor): input tokens of shape `(batch, src_len)` - features_only (bool, optional): skip LM head and just return - features. If True, the output will be of shape - `(batch, src_len, embed_dim)`. - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - - Returns: - tuple: - - the LM output of shape `(batch, src_len, vocab)` - - a dictionary of additional data, where 'inner_states' - is a list of hidden states. Note that the hidden - states have shape `(src_len, batch, vocab)`. - """ - x, extra = self.extract_features( - src_tokens, return_all_hiddens=return_all_hiddens - ) - if not features_only: - x = self.output_layer(x, masked_tokens=masked_tokens) - return x, extra + super().__init__(args, dictionary) + assert not self.args.untie_weights_roberta - def extract_features(self, src_tokens, return_all_hiddens=False, **unused): - inner_states, _ = self.sentence_encoder( - src_tokens, - last_state_only=not return_all_hiddens, - ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - def output_layer(self, features, masked_tokens=None, **unused): - return self.lm_head(features, masked_tokens) + def build_encoder(self, args, dictionary, embed_tokens): + return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) - def max_positions(self): - """Maximum output length supported by the encoder.""" - return self.args.max_positions + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) + # model parallel RoBERTa defaults to "Pre-LN" formulation + roberta_prenorm_architecture(args) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) +# earlier versions of model parallel RoBERTa removed the final layer norm +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") +def model_parallel_roberta_v1_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) + base_architecture(args) + + +@register_model_architecture( + "model_parallel_roberta", "model_parallel_roberta_postnorm" +) +def model_parallel_roberta_postnorm_architecture(args): + # the original BERT/RoBERTa uses the "Post-LN" formulation + roberta_base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") -def roberta_base_architecture(args): +def model_parallel_roberta_base_architecture(args): base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") -def roberta_large_architecture(args): +def model_parallel_roberta_large_architecture(args): args.encoder_layers = getattr(args, "encoder_layers", 24) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index 4f34645226..6b330ef1b7 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -6,7 +6,6 @@ import logging import torch.nn as nn -import torch.nn.functional as F from fairseq.model_parallel.modules import ( ModelParallelTransformerDecoderLayer, ModelParallelTransformerEncoderLayer, @@ -86,6 +85,12 @@ class ModelParallelTransformerEncoder(TransformerEncoder): is a :class:`ModelParallelTransformerEncoderLayer`. """ + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + + if args.no_final_layer_norm: + self.layer_norm = None + def build_encoder_layer(self, args): return ModelParallelTransformerEncoderLayer(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index fb45b3c9e0..11603217a1 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -9,15 +9,9 @@ ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer, ) -from .transformer_sentence_encoder_layer import ( - ModelParallelTransformerSentenceEncoderLayer, -) -from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder __all__ = [ "ModelParallelMultiheadAttention", "ModelParallelTransformerEncoderLayer", "ModelParallelTransformerDecoderLayer", - "ModelParallelTransformerSentenceEncoder", - "ModelParallelTransformerSentenceEncoderLayer", ] diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py deleted file mode 100644 index a5d50a33c6..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer -from fairseq.modules import ( - LayerNorm, - MultiheadAttention, - PositionalEmbedding, - TransformerSentenceEncoder, -) - - -try: - from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): - """ - Implementation for a Model Parallel Bi-directional Transformer based - Sentence Encoder used in BERT/XLM style pre-trained models. - """ - - def build_embedding(self, vocab_size, embedding_dim, padding_idx): - return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - **unused, - ): - return ModelParallelTransformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py deleted file mode 100644 index e10bf52332..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelMultiheadAttention -from fairseq.modules import TransformerSentenceEncoderLayer - - -try: - from fairseq.model_parallel.megatron.mpu import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): - """ - Implements a Model Parallel Transformer Encoder Layer used in - BERT/XLM style pre-trained models. - """ - - def build_fc1(self, input_dim, output_dim, **unused): - return ColumnParallelLinear(input_dim, output_dim, gather_output=False) - - def build_fc2(self, input_dim, output_dim, **unused): - return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) - - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - **kwargs, - ): - return ModelParallelMultiheadAttention( - embed_dim, num_attention_heads, dropout=dropout, self_attention=True - ) - - def forward( - self, - x: torch.Tensor, - self_attn_mask: torch.Tensor = None, - self_attn_padding_mask: torch.Tensor = None, - ): - """ - LayerNorm is applied either before or after the self-attention/ffn - modules similar to the original Transformer imlementation. - """ - residual = x - x = self.self_attn_layer_norm(x) - x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - need_weights=False, - attn_mask=self_attn_mask, - ) - x = self.dropout_module(x) - x = residual + x - - residual = x - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) - x = residual + x - return x, None diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 00a5a5485f..a2a40ba6e2 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -18,7 +18,8 @@ register_model, register_model_architecture, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder +from fairseq.models.transformer import TransformerEncoder +from fairseq.modules import LayerNorm from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -87,6 +88,11 @@ def add_args(parser): action="store_true", help="apply layernorm before each encoder block", ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability" ) @@ -264,6 +270,13 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict[new_k] = state_dict[k] del state_dict[k] + # rename emb_layer_norm -> layernorm_embedding + for k in list(state_dict.keys()): + if ".emb_layer_norm." in k: + new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.") + state_dict[new_k] = state_dict[k] + del state_dict[k] + # upgrade children modules super().upgrade_state_dict_named(state_dict, name) @@ -401,7 +414,11 @@ def __init__(self, args, dictionary): if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - self.sentence_encoder = self.build_encoder(args, dictionary) + embed_tokens = self.build_embedding( + len(dictionary), args.encoder_embed_dim, dictionary.pad() + ) + + self.sentence_encoder = self.build_encoder(args, dictionary, embed_tokens) self.lm_head = self.build_lm_head( embed_dim=args.encoder_embed_dim, @@ -414,26 +431,16 @@ def __init__(self, args, dictionary): ), ) - def build_encoder(self, args, dictionary): - return TransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - ) + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + encoder = TransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) @@ -470,13 +477,15 @@ def forward( return x, extra def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): - inner_states, _ = self.sentence_encoder( + encoder_out = self.sentence_encoder( src_tokens, - last_state_only=not return_all_hiddens, + return_all_hiddens=return_all_hiddens, token_embeddings=kwargs.get("token_embeddings", None), ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + # T x B x C -> B x T x C + features = encoder_out["encoder_out"][0].transpose(0, 1) + inner_states = encoder_out["encoder_states"] if return_all_hiddens else None + return features, {"inner_states": inner_states} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) @@ -493,21 +502,50 @@ def base_architecture(args): args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + args.max_source_positions = getattr(args, "max_positions", 512) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + # BERT has a few structural differences compared to the original Transformer + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) + + # Adaptive input config + args.adaptive_input = getattr(args, "adaptive_input", False) + + # LayerDrop config + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + + # Quantization noise config + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + # R4F config args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False ) +@register_model_architecture("roberta", "roberta_prenorm") +def roberta_prenorm_architecture(args): + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + base_architecture(args) + + @register_model_architecture("roberta", "roberta_base") def roberta_base_architecture(args): base_architecture(args) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 78762ef924..4960fd143d 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -325,6 +325,7 @@ class TransformerEncoder(FairseqEncoder): """ def __init__(self, args, dictionary, embed_tokens): + self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) From 7096ac35870aa24735bd0cc850beefa07784a668 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 463/707] Make validate.py work with model parallel (#1570) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1570 Test Plan: Imported from OSS Reviewed By: gwenzek, ngoyal2707 Differential Revision: D25967675 Pulled By: myleott fbshipit-source-id: 7c7f8d25b87ef9b4f0a85331548bb3a2886a1e92 --- fairseq/logging/progress_bar.py | 2 +- fairseq_cli/validate.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index dc061a1821..0ae2bc006d 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -123,7 +123,7 @@ def __init__(self, iterable, epoch=None, prefix=None): if epoch is not None: self.prefix += "epoch {:03d}".format(epoch) if prefix is not None: - self.prefix += " | {}".format(prefix) + self.prefix += (" | " if self.prefix != "" else "") + prefix def __len__(self): return len(self.iterable) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index c69bb94142..90d7e4c6a9 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -u -# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -43,6 +42,13 @@ def main(cfg: DictConfig, override_args=None): if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) @@ -91,8 +97,8 @@ def main(cfg: DictConfig, override_args=None): ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, - num_shards=cfg.distributed_training.distributed_world_size, - shard_id=cfg.distributed_training.distributed_rank, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) @@ -111,7 +117,7 @@ def main(cfg: DictConfig, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if cfg.distributed_training.distributed_world_size > 1: + if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, @@ -132,9 +138,13 @@ def cli_main(): # only override args that are explicitly given on the command line override_parser = options.get_validation_parser() - override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + override_args = options.parse_args_and_arch( + override_parser, suppress_defaults=True + ) - distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args) + distributed_utils.call_main( + convert_namespace_to_omegaconf(args), main, override_args=override_args + ) if __name__ == "__main__": From e0788f7007a8473a76db573985031f3c94201e79 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 17 Feb 2021 10:54:25 -0800 Subject: [PATCH 464/707] fix bart generation bug (#1629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1629 Reviewed By: myleott Differential Revision: D26484942 Pulled By: sshleifer fbshipit-source-id: 9dcbab5c404c14d8f35628d823102ad9ce59dffd --- fairseq/models/bart/hub_interface.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 1ff170a782..2ddeb763a3 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -92,22 +92,27 @@ def generate( tokenized_sentences: List[torch.LongTensor], *args, inference_step_args=None, + skip_invalid_size_inputs=False, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: inference_step_args = inference_step_args or {} if "prefix_tokens" in inference_step_args: raise NotImplementedError("prefix generation not implemented for BART") - else: - bsz = len(tokenized_sentences) - inference_step_args["prefix_tokens"] = tokenized_sentences[0].new_full( - (bsz, 1), fill_value=self.task.source_dictionary.bos() + res = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + src_tokens = batch['net_input']['src_tokens'] + inference_step_args["prefix_tokens"] =src_tokens.new_full( + (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() ).to(device=self.device) - return super().generate( - tokenized_sentences, - *args, - inference_step_args=inference_step_args, - **kwargs - ) + results = super().generate( + src_tokens, + *args, + inference_step_args=inference_step_args, + skip_invalid_size_inputs=skip_invalid_size_inputs, + **kwargs + ) + res.extend(results) + return res def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False From 7040ce71f3e0e84730adc267df764f48dc483dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=87elebi?= <celebio@fb.com> Date: Thu, 18 Feb 2021 03:09:14 -0800 Subject: [PATCH 465/707] LASER training code (#1207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Integrating LASER (Language-Agnostic SEntence Representations) training code - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ N/A] Did you make sure to update the docs? - [ Y] Did you write any new necessary tests? => an additional test in `test_iterators.py` ## What does this PR do? This diff introduces the training code for LASER. It includes a specific `laser` task in `laser_task.py` which reads a json configuration file describing the binarized datasets of language pairs. `multitask_data_utils.py` defines dataset wrappers and iterators used by `laser` task. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Yes. � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1207 Reviewed By: myleott Differential Revision: D26454296 Pulled By: Celebio fbshipit-source-id: c987672aa66abf31b039ee11867b06912d3486e5 --- examples/laser/README.md | 144 +++++ examples/laser/laser_src/__init__.py | 8 + examples/laser/laser_src/laser_lstm.py | 585 ++++++++++++++++++ examples/laser/laser_src/laser_task.py | 326 ++++++++++ examples/laser/laser_src/laser_transformer.py | 354 +++++++++++ .../laser/laser_src/multitask_data_utils.py | 143 +++++ tests/test_binaries.py | 60 ++ tests/utils.py | 38 ++ 8 files changed, 1658 insertions(+) create mode 100644 examples/laser/README.md create mode 100644 examples/laser/laser_src/__init__.py create mode 100644 examples/laser/laser_src/laser_lstm.py create mode 100644 examples/laser/laser_src/laser_task.py create mode 100644 examples/laser/laser_src/laser_transformer.py create mode 100644 examples/laser/laser_src/multitask_data_utils.py diff --git a/examples/laser/README.md b/examples/laser/README.md new file mode 100644 index 0000000000..66acada04f --- /dev/null +++ b/examples/laser/README.md @@ -0,0 +1,144 @@ +# LASER Language-Agnostic SEntence Representations + +LASER is a library to calculate and use multilingual sentence embeddings. + +You can find more information about LASER and how to use it on the official [LASER repository](https://github.com/facebookresearch/LASER). + +This folder contains source code for training LASER embeddings. + + +## Prepare data and configuration file + +Binarize your data with fairseq, as described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing). + +Create a json config file with this format: +``` +{ + "src_vocab": "/path/to/spm.src.cvocab", + "tgt_vocab": "/path/to/spm.tgt.cvocab", + "train": [ + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang1-tgtlang0/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang1-tgtlang1/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang1/train.tgtlang1" + }, + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang2-tgtlang0/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang2-tgtlang1/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang1/train.tgtlang1" + }, + ... + ], + "valid": [ + { + "type": "translation", + "id": 0, + "src": "/unused", + "tgt": "/unused" + } + ] +} +``` +where paths are paths to binarized indexed fairseq dataset files. +`id` represents the target language id. + + +## Training Command Line Example + +``` +fairseq-train \ + /path/to/configfile_described_above.json \ + --user-dir examples/laser/laser_src \ + --log-interval 100 --log-format simple \ + --task laser --arch laser_lstm \ + --save-dir . \ + --optimizer adam \ + --lr 0.001 \ + --lr-scheduler inverse_sqrt \ + --clip-norm 5 \ + --warmup-updates 90000 \ + --update-freq 2 \ + --dropout 0.0 \ + --encoder-dropout-out 0.1 \ + --max-tokens 2000 \ + --max-epoch 50 \ + --encoder-bidirectional \ + --encoder-layers 5 \ + --encoder-hidden-size 512 \ + --decoder-layers 1 \ + --decoder-hidden-size 2048 \ + --encoder-embed-dim 320 \ + --decoder-embed-dim 320 \ + --decoder-lang-embed-dim 32 \ + --warmup-init-lr 0.001 \ + --disable-validation +``` + + +## Applications + +We showcase several applications of multilingual sentence embeddings +with code to reproduce our results (in the directory "tasks"). + +* [**Cross-lingual document classification**](https://github.com/facebookresearch/LASER/tree/master/tasks/mldoc) using the + [*MLDoc*](https://github.com/facebookresearch/MLDoc) corpus [2,6] +* [**WikiMatrix**](https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix) + Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia [7] +* [**Bitext mining**](https://github.com/facebookresearch/LASER/tree/master/tasks/bucc) using the + [*BUCC*](https://comparable.limsi.fr/bucc2018/bucc2018-task.html) corpus [3,5] +* [**Cross-lingual NLI**](https://github.com/facebookresearch/LASER/tree/master/tasks/xnli) + using the [*XNLI*](https://www.nyu.edu/projects/bowman/xnli/) corpus [4,5,6] +* [**Multilingual similarity search**](https://github.com/facebookresearch/LASER/tree/master/tasks/similarity) [1,6] +* [**Sentence embedding of text files**](https://github.com/facebookresearch/LASER/tree/master/tasks/embed) + example how to calculate sentence embeddings for arbitrary text files in any of the supported language. + +**For all tasks, we use exactly the same multilingual encoder, without any task specific optimization or fine-tuning.** + + + +## References + +[1] Holger Schwenk and Matthijs Douze, + [*Learning Joint Multilingual Sentence Representations with Neural Machine Translation*](https://aclanthology.info/papers/W17-2619/w17-2619), + ACL workshop on Representation Learning for NLP, 2017 + +[2] Holger Schwenk and Xian Li, + [*A Corpus for Multilingual Document Classification in Eight Languages*](http://www.lrec-conf.org/proceedings/lrec2018/pdf/658.pdf), + LREC, pages 3548-3551, 2018. + +[3] Holger Schwenk, + [*Filtering and Mining Parallel Data in a Joint Multilingual Space*](http://aclweb.org/anthology/P18-2037) + ACL, July 2018 + +[4] Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk and Veselin Stoyanov, + [*XNLI: Cross-lingual Sentence Understanding through Inference*](https://aclweb.org/anthology/D18-1269), + EMNLP, 2018. + +[5] Mikel Artetxe and Holger Schwenk, + [*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136) + arXiv, Nov 3 2018. + +[6] Mikel Artetxe and Holger Schwenk, + [*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464) + arXiv, Dec 26 2018. + +[7] Holger Schwenk, Vishrav Chaudhary, Shuo Sun, Hongyu Gong and Paco Guzman, + [*WikiMatrix: Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia*](https://arxiv.org/abs/1907.05791) + arXiv, July 11 2019. + +[8] Holger Schwenk, Guillaume Wenzek, Sergey Edunov, Edouard Grave and Armand Joulin + [*CCMatrix: Mining Billions of High-Quality Parallel Sentences on the WEB*](https://arxiv.org/abs/1911.04944) diff --git a/examples/laser/laser_src/__init__.py b/examples/laser/laser_src/__init__.py new file mode 100644 index 0000000000..9ffbd656d8 --- /dev/null +++ b/examples/laser/laser_src/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .laser_task import * # noqa +from .laser_lstm import * # noqa +from .laser_transformer import * # noqa diff --git a/examples/laser/laser_src/laser_lstm.py b/examples/laser/laser_src/laser_lstm.py new file mode 100644 index 0000000000..10df90e002 --- /dev/null +++ b/examples/laser/laser_src/laser_lstm.py @@ -0,0 +1,585 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils + +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) + + +@register_model("laser_lstm") +class LSTMModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=None, + dataset_name="", + ): + assert target_language_id is not None + + src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name) + return self.decoder( + prev_output_tokens, src_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained encoder embedding", + ) + parser.add_argument( + "--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size" + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="number of encoder layers" + ) + parser.add_argument( + "--encoder-bidirectional", + action="store_true", + help="make all layers of encoder bidirectional", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained decoder embedding", + ) + parser.add_argument( + "--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size" + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="number of decoder layers" + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--decoder-zero-init", + type=str, + metavar="BOOL", + help="initialize the decoder hidden/cell state to zero", + ) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + parser.add_argument( + "--fixed-embeddings", + action="store_true", + help="keep embeddings fixed (ENCODER ONLY)", + ) # TODO Also apply to decoder embeddings? + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument( + "--encoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for encoder input embedding", + ) + parser.add_argument( + "--encoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for encoder output", + ) + parser.add_argument( + "--decoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for decoder input embedding", + ) + parser.add_argument( + "--decoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for decoder output", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + pretrained_encoder_embed = None + if args.encoder_embed_path: + pretrained_encoder_embed = load_pretrained_embedding_from_file( + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim + ) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LSTMEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_size=args.encoder_hidden_size, + num_layers=args.encoder_layers, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + bidirectional=args.encoder_bidirectional, + pretrained_embed=pretrained_encoder_embed, + fixed_embeddings=args.fixed_embeddings, + ) + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + zero_init=options.eval_bool(args.decoder_zero_init), + encoder_embed_dim=args.encoder_embed_dim, + encoder_output_units=encoder.output_units, + pretrained_embed=pretrained_decoder_embed, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + return cls(encoder, decoder) + + +class LSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_value=0.0, + fixed_embeddings=False, + ): + super().__init__(dictionary) + self.num_layers = num_layers + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + else: + self.embed_tokens = pretrained_embed + if fixed_embeddings: + self.embed_tokens.weight.requires_grad = False + + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_tokens, src_lengths, dataset_name): + if self.left_pad: + # convert left-padding to right-padding + src_tokens = utils.convert_padding_direction( + src_tokens, + self.padding_idx, + left_to_right=True, + ) + + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + try: + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + except BaseException: + raise Exception(f"Packing failed in dataset {dataset_name}") + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value + ) + x = F.dropout(x, p=self.dropout_out, training=self.training) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return torch.cat( + [ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( + 1, bsz, self.output_units + ) + for i in range(self.num_layers) + ], + dim=0, + ) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + # Set padded outputs to -inf so they are not selected by max-pooling + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + return { + "sentemb": sentemb, + "encoder_out": (x, final_hiddens, final_cells), + "encoder_padding_mask": encoder_padding_mask + if encoder_padding_mask.any() + else None, + } + + def reorder_encoder_out(self, encoder_out_dict, new_order): + encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select( + 0, new_order + ) + encoder_out_dict["encoder_out"] = tuple( + eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"] + ) + if encoder_out_dict["encoder_padding_mask"] is not None: + encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out_dict + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) # an arbitrary large number + + +class LSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + zero_init=False, + encoder_embed_dim=512, + encoder_output_units=512, + pretrained_embed=None, + num_langs=1, + lang_embed_dim=0, + ): + super().__init__(dictionary) + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=encoder_output_units + embed_dim + lang_embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + if zero_init: + self.sentemb2init = None + else: + self.sentemb2init = Linear( + encoder_output_units, 2 * num_layers * hidden_size + ) + + if lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(num_langs, lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + def forward( + self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0 + ): + sentemb = encoder_out_dict["sentemb"] + encoder_out = encoder_out_dict["encoder_out"] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + # get outputs from encoder + encoder_outs, _, _ = encoder_out[:3] + srclen = encoder_outs.size(0) + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # embed language identifier + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + # TODO Should we dropout here??? + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is not None: + prev_hiddens, prev_cells, input_feed = cached_state + else: + num_layers = len(self.layers) + if self.sentemb2init is None: + prev_hiddens = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + prev_cells = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + else: + init = self.sentemb2init(sentemb) + prev_hiddens = [ + init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size] + for i in range(num_layers) + ] + prev_cells = [ + init[ + :, + (2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size, + ] + for i in range(num_layers) + ] + input_feed = x.data.new(bsz, self.hidden_size).zero_() + + attn_scores = x.data.new(srclen, seqlen, bsz).zero_() + outs = [] + for j in range(seqlen): + if self.embed_lang is None: + input = torch.cat((x[j, :, :], sentemb), dim=1) + else: + input = torch.cat((x[j, :, :], sentemb, langemb), dim=1) + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # hidden state becomes the input to the next layer + input = F.dropout(hidden, p=self.dropout_out, training=self.training) + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + out = hidden + out = F.dropout(out, p=self.dropout_out, training=self.training) + + # input feeding + input_feed = out + + # save final output + outs.append(out) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, + incremental_state, + "cached_state", + (prev_hiddens, prev_cells, input_feed), + ) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + attn_scores = attn_scores.transpose(0, 2) + + # project back to size of vocabulary + if hasattr(self, "additional_fc"): + x = self.additional_fc(x) + x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.fc_out(x) + + return x, attn_scores + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state(self, incremental_state, "cached_state", new_state) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return int(1e5) # an arbitrary large number + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(m.weight, -0.1, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + + +@register_model_architecture("laser_lstm", "laser_lstm") +def base_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.decoder_zero_init = getattr(args, "decoder_zero_init", "0") + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) + args.fixed_embeddings = getattr(args, "fixed_embeddings", False) diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py new file mode 100644 index 0000000000..c8ac805f54 --- /dev/null +++ b/examples/laser/laser_src/laser_task.py @@ -0,0 +1,326 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import OrderedDict, defaultdict +import json +import os +import logging + +from fairseq import options, models +from fairseq.data import ( + data_utils, + Dictionary, + LanguagePairDataset, + IndexedDataset, + FairseqDataset, +) +from .multitask_data_utils import ( + MultitaskDatasetWrapper, + MultidatasetEpochBatchIterator, +) + + +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("laser") +class LaserTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "configfile", metavar="PATH", help="dataset configuration file in json" + ) + parser.add_argument( + "--weighting-alpha", + type=float, + default=None, + help="alpha for automatic weighting", + ) + parser.add_argument( + "--raw-text", action="store_true", help="load raw text dataset" + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left (default: True)", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left (default: False)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): + super().__init__(args) + self.config = config + self.src_dictionary = src_dictionary + self.tgt_dictionary = tgt_dictionary + self.num_tasks = num_tasks + + @classmethod + def setup_task(cls, args, **kwargs): + with open(args.configfile, "r") as f: + config = json.load(f) + num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 + + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + src_dictionary = Dictionary.load(config["src_vocab"]) + tgt_dictionary = Dictionary.load(config["tgt_vocab"]) + + logger.info( + "| src Dictionary {} : {} types".format( + config["src_vocab"], len(src_dictionary) + ) + ) + logger.info( + "| tgt Dictionary {} : {} types".format( + config["tgt_vocab"], len(tgt_dictionary) + ) + ) + + return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) + + # Experimental overriding for backtranslation + def build_model(self, args): + model = models.build_model(args, self) + return model + + def dataset(self, split): + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + return self.datasets[split] + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + + def indexed_dataset(path, dictionary): + if self.args.raw_text: + raise Exception("Unable to handle raw text.") + dataset = IndexedDataset(path, fix_lua_indexing=True) + + return dataset + + pair_datasets = OrderedDict() + + if split == "valid": + self.datasets[split] = pair_datasets + return + + if split not in self.config: + raise FileNotFoundError( + "Dataset not found in config file: {}".format(split) + ) + + size_by_corpus = defaultdict(int) + size_sum = 0 + size_sum_with_subsampling = 0 + init_pair_datasets = {} + + for dataset_config in self.config[split]: + src_path = os.path.dirname(dataset_config["src"]) + corpus_name = src_path.split("/")[-2] + language_pair_name = src_path.split("/")[-1] + pair_datasets_key = corpus_name + "-" + language_pair_name + + logger.info(f"loading... {pair_datasets_key}") + if "src" in dataset_config: + src_dataset = indexed_dataset( + dataset_config["src"], self.src_dictionary + ) + else: + src_dataset = None + + if "tgt" in dataset_config: + tgt_dataset = indexed_dataset( + dataset_config["tgt"], self.tgt_dictionary + ) + else: + tgt_dataset = None + + dataset = LanguagePairDataset( + src_dataset, + src_dataset.sizes, + self.src_dictionary, + tgt_dataset, + tgt_dataset.sizes, + self.tgt_dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + + if pair_datasets_key in init_pair_datasets: + logger.warning( + f"Ignoring already added {pair_datasets_key}. " + f"Consider using `sample` key in order to upsample." + ) + else: + init_pair_datasets[pair_datasets_key] = { + "dataset": dataset, + "sample": dataset_config.get("sample", None), + "id": dataset_config.get("id", None), + "len": len(dataset), + } + + length_sum = 0 + weighted_freqs_sum = 0 + freq_per_dataset = {} + vmax = 0 + vmin = 1 + weighted_freq_per_dataset = {} + + if self.args.weighting_alpha: + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + length_sum += len(init_pair_datasets[key]["dataset"]) + + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + val = float(init_pair_datasets[key]["len"]) / length_sum + freq_per_dataset[key] = val + weighted_freqs_sum += val ** self.args.weighting_alpha + + for key in freq_per_dataset: + val = ( + freq_per_dataset[key] ** self.args.weighting_alpha + / weighted_freqs_sum + ) + vmin = min(vmin, val) + vmax = max(vmax, val) + weighted_freq_per_dataset[key] = val + + for pair_datasets_key in init_pair_datasets: + dataset_config = init_pair_datasets[pair_datasets_key] + dataset = dataset_config["dataset"] + sample = dataset_config["sample"] + if sample is None: + sample = 1.0 + + if pair_datasets_key in weighted_freq_per_dataset: + w = vmax / weighted_freq_per_dataset[pair_datasets_key] + sample = w + + sample = round(sample) + + initial_sample = sample + initial_pair_datasets_key = pair_datasets_key + + while sample >= 1.0: + assert ( + pair_datasets_key not in pair_datasets + ), f"{pair_datasets_key} already in" + size_sum_with_subsampling += len(dataset) + pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( + dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key + ) + size_sum += len(dataset) + sample -= 1.0 + pair_datasets_key += "-up" + + assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" + + logger.info( + f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" + ) + size_by_corpus[corpus_name] += len(dataset) + + self.datasets[split] = pair_datasets + logger.info( + f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" + ) + + @property + def source_dictionary(self): + return self.src_dictionary + + @property + def target_dictionary(self): + return self.tgt_dictionary + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + # initialize the dataset with the correct starting epoch + for _, dt in dataset.items(): + dt.set_epoch(epoch) + + indices = OrderedDict() + batch_sampler = OrderedDict() + + with data_utils.numpy_seed(seed + epoch): + for key, dt in dataset.items(): + logger.info(f"\t ordered_indices {key}") + indices[key] = dt.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + for key, dt in dataset.items(): + logger.info(f"\t filter_by_size {key}") + indices[key], ignored = dt.filter_indices_by_size( + indices[key], max_positions + ) + + for key, dt in dataset.items(): + logger.info(f"\t batch_by_size {key}") + batch_sampler[key] = data_utils.batch_by_size( + indices[key], + dt.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + epoch_iter = MultidatasetEpochBatchIterator( + dataset=dataset, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + + return epoch_iter diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py new file mode 100644 index 0000000000..0be030994f --- /dev/null +++ b/examples/laser/laser_src/laser_transformer.py @@ -0,0 +1,354 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import Any, Dict, List, Optional +from torch import Tensor + +import torch +import torch.nn as nn + +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + base_architecture, + Embedding, + TransformerModel, + TransformerEncoder, + TransformerDecoder, +) +from fairseq.modules import ( + TransformerDecoderLayer, +) + +logger = logging.getLogger(__name__) + + +@register_model("laser_transformer") +class LaserTransformerModel(FairseqEncoderDecoderModel): + """Train Transformer for LASER task + + Requires --task laser + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=-1, + dataset_name="", + ): + laser_encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder( + prev_output_tokens, laser_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + + @classmethod + def build_model(cls, args, task): + base_laser_transformer_architecture(args) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + def load_embed_tokens(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + return Embedding(num_embeddings, embed_dim, padding_idx) + + encoder_embed_tokens = load_embed_tokens( + task.source_dictionary, args.encoder_embed_dim + ) + decoder_embed_tokens = load_embed_tokens( + task.target_dictionary, args.decoder_embed_dim + ) + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LaserTransformerEncoder( + args, task.source_dictionary, encoder_embed_tokens + ) + + decoder = LaserTransformerDecoder( + args, + task.target_dictionary, + decoder_embed_tokens, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + + return cls(encoder, decoder) + + +class LaserTransformerEncoder(TransformerEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, src_tokens, *args, **kwargs): + encoder_out = super().forward(src_tokens, *args, **kwargs) + + x = encoder_out["encoder_out"][0] # T x B x C + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return {"sentemb": [sentemb]} # B x C + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Same as the one in transformer.py, with new_sentemb + """ + if len(encoder_out["sentemb"]) == 0: + new_sentemb = [] + else: + new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)] + + return { + "sentemb": new_sentemb, # B x C + } + + +class LaserTransformerDecoder(TransformerDecoder): + def __init__(self, args, dictionary, *kargs, **kwargs): + self.num_langs = kwargs.get("num_langs", 1) + self.lang_embed_dim = kwargs.get("lang_embed_dim", 0) + kwargs.pop("num_langs", None) + kwargs.pop("lang_embed_dim", None) + + super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True) + + if self.lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + if self.output_projection is not None: + laser_output_embed_dim = ( + self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + self.output_projection = nn.Linear( + laser_output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=laser_output_embed_dim ** -0.5, + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + decoder_embed_dim = args.decoder_embed_dim + args.decoder_embed_dim = ( + decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + res = TransformerDecoderLayer(args, no_encoder_attn=True) + args.decoder_embed_dim = decoder_embed_dim + + return res + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + lang_id: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + bsz, seqlen = prev_output_tokens.size() + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + langemb = langemb.unsqueeze(0) + repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * ( + len(langemb.shape) - 1 + ) + x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1) + + sentemb = encoder_out["sentemb"][0] + sentemb = sentemb.unsqueeze(0) + + repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1) + x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + None, + None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + lang_id: Optional[int] = None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + assert lang_id is not None + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + lang_id=lang_id, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + +@register_model_architecture("laser_transformer", "laser_transformer") +def base_laser_transformer_architecture(args): + base_architecture(args) + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) diff --git a/examples/laser/laser_src/multitask_data_utils.py b/examples/laser/laser_src/multitask_data_utils.py new file mode 100644 index 0000000000..b05caea267 --- /dev/null +++ b/examples/laser/laser_src/multitask_data_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import numpy as np + +from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators + + +class MultiItr(object): + def __init__(self, itr): + self.itr = itr + self._counts = [0 for x in itr] + + def __len__(self): + return sum(len(itr) for itr in self.itr) + + def __iter__(self): + return self + + def __next__(self): + ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)] + idx = ratios.index(min(ratios)) + self._counts[idx] += 1 + return next(self.itr[idx]) + + +class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating): + """A wrapper around multiple epoch batch iterators.""" + + def __init__( + self, + dataset, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + self.iterators = [] + + self.epoch = epoch + for key, dt in dataset.items(): + epoch_iter = iterators.EpochBatchIterator( + dataset=dt, + collate_fn=dt.collater, + batch_sampler=batch_sampler[key], + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=0, + epoch=epoch, + ) + self.iterators.append(epoch_iter) + + def __len__(self): + return sum(len(itr) for itr in self.iterators) + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s. + return MultiItr( + [ + itr.next_epoch_itr( + shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus + ) + for itr in self.iterators + ] + ) + + def end_of_epoch(self): + return all(itr.end_of_epoch() for itr in self.iterators) + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + + epochs = [itr.next_epoch_idx for itr in self.iterators] + self.epoch = epochs[0] + assert all(epoch == self.epoch for epoch in epochs) + + return self.epoch + + @property + def iterations_in_epoch(self): + return sum(itr.iterations_in_epoch for itr in self.iterators) + + def state_dict(self): + return { + "iterators": [it.state_dict() for it in self.iterators], + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + for it, d in zip(self.iterators, state_dict["iterators"]): + it.load_state_dict(d) + + +class MultitaskDatasetWrapper(BaseWrapperDataset): + """A wrapper for a multitask dataset.""" + + def __init__(self, dataset, target_language_id, sample=1.0, name=""): + super().__init__(dataset) + self.target_language_id = target_language_id + self.sample = sample + self.name = name + + def collater(self, *args, **kwargs): + ans = self.dataset.collater(*args, **kwargs) + if "net_input" in ans: + ans["net_input"]["target_language_id"] = self.target_language_id + ans["net_input"]["dataset_name"] = self.name + return ans + + def num_tokens(self, *args, **kwargs): + return self.dataset.num_tokens(*args, **kwargs) + + def ordered_indices(self, *args, **kwargs): + indices = self.dataset.ordered_indices(*args, **kwargs) + # Hacky solution for sampling + size = int(self.sample * indices.shape[0]) + + return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size])) + + def size(self, index: int): + return self.dataset.size(index) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 981ffd49cd..3cb98897bf 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -22,6 +22,7 @@ preprocess_lm_data, preprocess_summarization_data, preprocess_translation_data, + create_laser_data_and_config_json, train_translation_model, ) @@ -935,6 +936,65 @@ def test_alignment(self): ) generate_main(data_dir) + def test_laser_lstm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_lstm") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_lstm", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-bidirectional", + "--encoder-hidden-size", + "512", + "--encoder-layers", + "5", + "--decoder-layers", + "1", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + + def test_laser_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_transformer") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_transformer", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + def test_alignment_full_context(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_alignment") as data_dir: diff --git a/tests/utils.py b/tests/utils.py index 178df5763e..1bf6f8d7f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import json import os import random import sys @@ -274,6 +275,43 @@ def preprocess_summarization_data(data_dir, extra_flags=None): preprocess.main(preprocess_args) +def create_laser_data_and_config_json(data_dir): + src_langs = ["de", "fr", "ru", "tr", "zh"] + tgt_langs = ["en", "es"] + config_json = {} + config_train_json = [] + src_vocab = None + tgt_vocab = None + + for src_lang in src_langs: + for tgt_lang in tgt_langs: + langpair_folder = f"{src_lang}-{tgt_lang}" + + langpair_path = os.path.join(data_dir, langpair_folder) + os.mkdir(langpair_path) + create_dummy_data(langpair_path) + preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"]) + + src_vocab = os.path.join(langpair_path, "dict.in.txt") + tgt_vocab = os.path.join(langpair_path, "dict.out.txt") + config_train_json.append( + { + "id": 0 if tgt_lang == "en" else 1, + "src": os.path.join(langpair_path, "train.in-out.in"), + "tgt": os.path.join(langpair_path, "train.in-out.out"), + } + ) + + config_json["src_vocab"] = src_vocab + config_json["tgt_vocab"] = tgt_vocab + config_json["train"] = config_train_json + + with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file: + json.dump(config_json, config_file) + + return config_file + + def train_translation_model( data_dir, arch, From 3bc43c17d14c4b9f6b052a915f9589cd538bc8b6 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 18 Feb 2021 13:10:02 -0800 Subject: [PATCH 466/707] Fix speed regression after RobertaEncoder refactor (#1626) Summary: Add back a couple speed optimizations in the original roberta code that got lost in the refactor Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1626 Reviewed By: gwenzek Differential Revision: D26478534 Pulled By: myleott fbshipit-source-id: b945de5e9bffd51cd63630cc3aa1f0078a41cca8 --- fairseq/models/transformer.py | 9 ++++++--- fairseq/modules/transformer_layer.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 4960fd143d..605cfa65e8 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -434,10 +434,11 @@ def forward( hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ - x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) - # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation if encoder_padding_mask is not None: @@ -453,7 +454,9 @@ def forward( # encoder layers for layer in self.layers: - x = layer(x, encoder_padding_mask) + x = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None + ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 03e70f4279..f9ada37bde 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -103,7 +103,7 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] - def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): + def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` @@ -135,6 +135,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): key=x, value=x, key_padding_mask=encoder_padding_mask, + need_weights=False, attn_mask=attn_mask, ) x = self.dropout_module(x) From da9eaba12d82b9bfc1442f0e2c6fc1b895f4d35d Mon Sep 17 00:00:00 2001 From: Elizabeth Salesky <elizabeth.salesky@gmail.com> Date: Thu, 18 Feb 2021 13:58:56 -0800 Subject: [PATCH 467/707] Add support for multi-channel audio and example for mTEDx data (#3253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? - updates audio_utils to handle multi-channel audio as well as mono, with no change needed for existing recipes - adds speech-to-text example for Multilingual TEDx (http://openslr.org/100) data ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3253 Reviewed By: yuntang Differential Revision: D26514419 Pulled By: kahne fbshipit-source-id: 699e428affda5b1347f96a8310691ab152dd6769 --- examples/speech_to_text/README.md | 2 + examples/speech_to_text/docs/mtedx_example.md | 200 +++++++++++++++ examples/speech_to_text/prep_mtedx_data.py | 235 ++++++++++++++++++ fairseq/data/audio/audio_utils.py | 10 +- .../models/speech_to_text/s2t_transformer.py | 9 + 5 files changed, 455 insertions(+), 1 deletion(-) create mode 100644 examples/speech_to_text/docs/mtedx_example.md create mode 100644 examples/speech_to_text/prep_mtedx_data.py diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 4b6f89d105..988ed83d77 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -36,6 +36,8 @@ audio paths (one per line) as inputs. - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) +- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md) + ## Updates - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding) diff --git a/examples/speech_to_text/docs/mtedx_example.md b/examples/speech_to_text/docs/mtedx_example.md new file mode 100644 index 0000000000..c0e17db9a2 --- /dev/null +++ b/examples/speech_to_text/docs/mtedx_example.md @@ -0,0 +1,200 @@ +[[Back]](..) + +# S2T Example: Speech Translation (ST) on Multilingual TEDx + +[Multilingual TEDx](https://arxiv.org/abs/2102.01757) is multilingual corpus for speech recognition and +speech translation. The data is derived from TEDx talks in 8 source languages +with translations to a subset of 5 target languages. + +## Data Preparation +[Download](http://openslr.org/100/) and unpack Multilingual TEDx data to a path +`${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary +# and configuration for each language +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr \ + --vocab-type unigram --vocab-size 1000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st \ + --vocab-type unigram --vocab-size 1000 + +# Add vocabulary and configuration for joint data +# (based on the manifests and features generated above) +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr --joint \ + --vocab-type unigram --vocab-size 8000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st --joint \ + --vocab-type unigram --vocab-size 8000 +``` +The generated files (manifest, features, vocabulary and data configuration) will be added to +`${MTEDX_ROOT}/${LANG_PAIR}` (per-language data) and `MTEDX_ROOT` (joint data). + + +## ASR +#### Training +Spanish as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset valid_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For joint model (using ASR data from all 8 languages): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml \ + --train-subset train_es-es_asr,train_fr-fr_asr,train_pt-pt_asr,train_it-it_asr,train_ru-ru_asr,train_el-el_asr,train_ar-ar_asr,train_de-de_asr \ + --valid-subset valid_es-es_asr,valid_fr-fr_asr,valid_pt-pt_asr,valid_it-it_asr,valid_ru-ru_asr,valid_el-el_asr,valid_ar-ar_asr,valid_de-de_asr \ + --save-dir ${MULTILINGUAL_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 +``` +where `MULTILINGUAL_ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs +with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --gen-subset test --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe + +# For models trained on joint data +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANG in es fr pt it ru el ar de; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml --gen-subset test_${LANG}-${LANG}_asr --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe +done +``` +#### Results +| Data | --arch | Params | Es | Fr | Pt | It | Ru | El | Ar | De | +|--------------|--------------------|--------|------|------|------|------|------|-------|-------|-------| +| Monolingual | s2t_transformer_xs | 10M | 46.4 | 45.6 | 54.8 | 48.0 | 74.7 | 109.5 | 104.4 | 111.1 | + + +## ST +#### Training +Es-En as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset valid_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For multilingual model (all 12 directions): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_st.yaml \ + --train-subset train_el-en_st,train_es-en_st,train_es-fr_st,train_es-it_st,train_es-pt_st,train_fr-en_st,train_fr-es_st,train_fr-pt_st,train_it-en_st,train_it-es_st,train_pt-en_st,train_pt-es_st,train_ru-en_st \ + --valid-subset valid_el-en_st,valid_es-en_st,valid_es-fr_st,valid_es-it_st,valid_es-pt_st,valid_fr-en_st,valid_fr-es_st,valid_fr-pt_st,valid_it-en_st,valid_it-es_st,valid_pt-en_st,valid_pt-es_st,valid_ru-en_st \ + --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} +``` +where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR +for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the `test` split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --gen-subset test --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu --remove-bpe + +# For multilingual models +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANGPAIR in es-en es-fr es-pt fr-en fr-es fr-pt pt-en pt-es it-en it-es ru-en el-en; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_st.yaml --gen-subset test_${LANGPAIR}_st --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring sacrebleu --remove-bpe +done +``` +For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. + +#### Results +| Data | --arch | Params | Es-En | Es-Pt | Es-Fr | Fr-En | Fr-Es | Fr-Pt | Pt-En | Pt-Es | It-En | It-Es | Ru-En | El-En | +|--------------|--------------------|-----|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| Bilingual | s2t_transformer_xs | 10M | 7.0 | 12.2 | 1.7 | 8.9 | 10.6 | 7.9 | 8.1 | 8.7 | 6.4 | 1.0 | 0.7 | 0.6 | +| Multilingual | s2t_transformer_s | 31M | 12.3 | 17.4 | 6.1 | 12.0 | 13.6 | 13.2 | 12.0 | 13.7 | 10.7 | 13.1 | 0.6 | 0.8 | + + +## Citation +Please cite as: +``` +@misc{salesky2021mtedx, + title={Multilingual TEDx Corpus for Speech Recognition and Translation}, + author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post}, + year={2021}, +} + +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` + +[[Back]](..) diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py new file mode 100644 index 0000000000..6c37398fcc --- /dev/null +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +import shutil +from itertools import groupby +from tempfile import NamedTemporaryFile +from typing import Tuple + +import pandas as pd +import torchaudio +from examples.speech_to_text.data_utils import ( + create_zip, + extract_fbank_features, + filter_manifest_df, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + load_df_from_tsv, + save_df_to_tsv, +) +from torch import Tensor +from torch.utils.data import Dataset +from tqdm import tqdm + + +log = logging.getLogger(__name__) + + +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"] + + +class mTEDx(Dataset): + """ + Create a Dataset for Multilingual TEDx. + Each item is a tuple of the form: waveform, sample_rate, source utterance, + target utterance, speaker_id, utterance_id + """ + + SPLITS = ["train", "valid", "test"] + LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", "de-de", + "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", "fr-pt", + "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] + + def __init__(self, root: str, lang: str, split: str) -> None: + assert split in self.SPLITS and lang in self.LANGPAIRS + _root = Path(root) / f"{lang}" / "data" / split + wav_root, txt_root = _root / "wav", _root / "txt" + assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() + # Load audio segments + try: + import yaml + except ImportError: + print("Please install PyYAML to load the Multilingual TEDx YAML files") + with open(txt_root / f"{split}.yaml") as f: + segments = yaml.load(f, Loader=yaml.BaseLoader) + # Load source and target utterances + src, tgt = lang.split("-") + for _lang in [src, tgt]: + with open(txt_root / f"{split}.{_lang}") as f: + utterances = [r.strip() for r in f] + assert len(segments) == len(utterances) + for i, u in enumerate(utterances): + segments[i][_lang] = u + # Gather info + self.data = [] + for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): + wav_filename = wav_filename.replace(".wav", ".flac") + wav_path = wav_root / wav_filename + sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) + for i, segment in enumerate(seg_group): + offset = int(float(segment["offset"]) * sample_rate) + n_frames = int(float(segment["duration"]) * sample_rate) + _id = f"{wav_path.stem}_{i}" + self.data.append( + ( + wav_path.as_posix(), + offset, + n_frames, + sample_rate, + segment[src], + segment[tgt], + segment["speaker_id"], + tgt, + _id, + ) + ) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] + waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id + + def __len__(self) -> int: + return len(self.data) + + +def process(args): + root = Path(args.data_root).absolute() + for lang in mTEDx.LANGPAIRS: + cur_root = root / f"{lang}" + if not cur_root.is_dir(): + print(f"{cur_root.as_posix()} does not exist. Skipped.") + continue + # Extract features + feature_root = cur_root / "fbank80" + feature_root.mkdir(exist_ok=True) + for split in mTEDx.SPLITS: + print(f"Fetching split {split}...") + dataset = mTEDx(root.as_posix(), lang, split) + print("Extracting log mel filter bank features...") + for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features( + waveform, sample_rate, feature_root / f"{utt_id}.npy" + ) + # Pack features into ZIP + zip_path = cur_root / "fbank80.zip" + print("ZIPing features...") + create_zip(feature_root, zip_path) + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(zip_path) + # Generate TSV manifest + print("Generating manifest...") + train_text = [] + for split in mTEDx.SPLITS: + is_train_split = split.startswith("train") + manifest = {c: [] for c in MANIFEST_COLUMNS} + dataset = mTEDx(args.data_root, lang, split) + for wav, sr, src_utt, tgt_utt, speaker_id, tgt_lang, utt_id in tqdm(dataset): + manifest["id"].append(utt_id) + manifest["audio"].append(zip_manifest[utt_id]) + duration_ms = int(wav.size(1) / sr * 1000) + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) + manifest["speaker"].append(speaker_id) + manifest["tgt_lang"].append(tgt_lang) + if is_train_split: + train_text.extend(manifest["tgt_text"]) + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") + # Generate vocab + v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for t in train_text: + f.write(t + "\n") + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + ) + # Clean up + shutil.rmtree(feature_root) + + +def process_joint(args): + cur_root = Path(args.data_root) + assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \ + "do not have downloaded data available for all languages" + # Generate vocab + vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for lang in mTEDx.LANGPAIRS: + tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv" + df = load_df_from_tsv(tsv_path) + for t in df["tgt_text"]: + f.write(t + "\n") + special_symbols = None + if args.joint: + # Add tgt_lang tags to dict + special_symbols = list({f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS}) + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + special_symbols=special_symbols + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="ld", + prepend_tgt_lang_tag=(args.joint), + ) + # Make symbolic links to manifests + for lang in mTEDx.LANGPAIRS: + for split in mTEDx.SPLITS: + src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv" + desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" + if not desc_path.is_symlink(): + os.symlink(src_path, desc_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--vocab-size", default=8000, type=int) + parser.add_argument("--task", type=str, choices=["asr", "st"]) + parser.add_argument("--joint", action="store_true", help="") + args = parser.parse_args() + + if args.joint: + process_joint(args) + else: + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index de08669851..f0e75b1d65 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -56,8 +56,16 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr try: import torch import torchaudio.compliance.kaldi as ta_kaldi + import torchaudio.sox_effects as ta_sox + + waveform = torch.from_numpy(waveform) + if len(waveform.shape) == 1: + # Mono channel: D -> 1 x D + waveform = waveform.unsqueeze(0) + else: + # Merge multiple channels to one: C x D -> 1 x D + waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) - waveform = torch.from_numpy(waveform).unsqueeze(0) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 1f556107a2..814924ec97 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -422,6 +422,15 @@ def s2t_transformer_s(args): base_architecture(args) +@register_model_architecture("s2t_transformer", "s2t_transformer_xs") +def s2t_transformer_xs(args): + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) + args.dropout = getattr(args, "dropout", 0.3) + s2t_transformer_s(args) + + @register_model_architecture("s2t_transformer", "s2t_transformer_sp") def s2t_transformer_sp(args): args.encoder_layers = getattr(args, "encoder_layers", 16) From 284a86a49a054dcace1e66ee4c65dfb4adb5a39f Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Thu, 18 Feb 2021 16:35:02 -0800 Subject: [PATCH 468/707] remove the missing _device property Summary: after D26382917 (https://github.com/pytorch/fairseq/commit/02803a1be45642b4c2f9c2970a4f4ae645a2dccf) shipped somehow the self._device was removed in optimizer, (or maybe I didn't test it the right way in the previous diff?) fortunately OSS doesn't need it any way. Reviewed By: myleott Differential Revision: D26523538 fbshipit-source-id: 637c1e344670340ae40b32635ef51f5501966b0c --- fairseq/optim/shard.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 3c1b34ae60..9d7f2eb9e5 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -45,7 +45,6 @@ def broadcast_global_state_dict( state_dict, src_rank=0, group=self.group, - dist_device=self._device, ) torch_optimizer = optimizer.optimizer From d2ee5883e774700c41b1eaddd0326e9afa6d3cd2 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutai_ma@jhu.edu> Date: Thu, 18 Feb 2021 22:41:32 -0800 Subject: [PATCH 469/707] Simultaneous Speech Translation Model (#1607) Summary: This is the pull request for the code for the paper [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58/) The model will also be used for [IWSLT 2021 shared task on simultaneous translation ](https://iwslt.org/2021/simultaneous) This pull request includes - Convtransformer offline model - Convtransformer simultaneous translation model with fixed pre-decision module - The agent files for inference for the convtransformer simultaneous translation model jmp84 The README is still missing. Just curious where should I place it? Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1607 Test Plan: Imported from GitHub, without a `Test Plan:` line. ********** One of the failing landing integration tests ``` buck test mode/dev //multimo/fb/models/test:multimo_fb_model_test https://fburl.com/testinfra/oxq2cn5n ``` Reviewed By: jmp84 Differential Revision: D26439663 Pulled By: sravyapopuri388 fbshipit-source-id: b127cb4962756af221b65e3ccb6598a42fc75f7f --- .../models/transformer_monotonic_attention.py | 28 +- .../modules/fixed_pre_decision.py | 170 +++++++ .../modules/monotonic_multihead_attention.py | 343 ++++++------- .../modules/monotonic_transformer_layer.py | 8 + .../utils/data_utils.py | 100 ++++ examples/speech_to_text/README.md | 1 + .../docs/simulst_mustc_example.md | 52 ++ .../agents/fairseq_simul_st_agent.py | 331 +++++++++++++ .../agents/simul_trans_agent.py | 200 ++++++++ fairseq/models/speech_to_text/__init__.py | 2 + .../models/speech_to_text/convtransformer.py | 452 ++++++++++++++++++ .../convtransformer_simul_trans.py | 49 ++ 12 files changed, 1560 insertions(+), 176 deletions(-) create mode 100644 examples/simultaneous_translation/modules/fixed_pre_decision.py create mode 100644 examples/simultaneous_translation/utils/data_utils.py create mode 100644 examples/speech_to_text/docs/simulst_mustc_example.md create mode 100644 examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py create mode 100644 examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py create mode 100644 fairseq/models/speech_to_text/convtransformer.py create mode 100644 fairseq/models/speech_to_text/convtransformer_simul_trans.py diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index ab8adf3aab..dd3895f0eb 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -10,17 +10,20 @@ TransformerMonotonicDecoderLayer, TransformerMonotonicEncoderLayer, ) -from fairseq.models import register_model, register_model_architecture +from fairseq.models import ( + register_model, + register_model_architecture, +) from fairseq.models.transformer import ( - TransformerDecoder, - TransformerEncoder, TransformerModel, + TransformerEncoder, + TransformerDecoder, base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, + transformer_vaswani_wmt_en_fr_big, ) - DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -33,7 +36,7 @@ def build_encoder(cls, args, src_dict, embed_tokens): @register_model("transformer_monotonic") -class TransformerMonotonicModel(TransformerModel): +class TransformerModelSimulTrans(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @@ -178,13 +181,18 @@ def pre_attention( if positions is not None: x += positions + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) - encoder_out = encoder_out_dict.encoder_out - encoder_padding_mask = encoder_out_dict.encoder_padding_mask + encoder_out = encoder_out_dict["encoder_out"][0] + encoder_padding_mask = ( + encoder_out_dict["encoder_padding_mask"][0] + if len(encoder_out_dict["encoder_padding_mask"]) > 0 + else None + ) return x, encoder_out, encoder_padding_mask @@ -236,7 +244,7 @@ def extract_features( attn_list.append(attn) if incremental_state is not None: - curr_steps = layer.get_steps(incremental_state) + curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) if incremental_state.get("online", False): @@ -287,7 +295,7 @@ def reorder_incremental_state(self, incremental_state, new_order): @register_model_architecture("transformer_monotonic", "transformer_monotonic") -def base_monotonic_rchitecture(args): +def base_monotonic_architecture(args): base_architecture(args) args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False) @@ -297,7 +305,7 @@ def base_monotonic_rchitecture(args): ) def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) - base_monotonic_rchitecture(args) + base_monotonic_architecture(args) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py new file mode 100644 index 0000000000..2cde55b35e --- /dev/null +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -0,0 +1,170 @@ +from functools import partial + +import torch +import torch.nn.functional as F + +from . import register_monotonic_attention +from .monotonic_multihead_attention import ( + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionHardAligned, + MonotonicMultiheadAttentionInfiniteLookback, +) + + +def fixed_pooling_monotonic_attention(monotonic_attention): + def create_model(monotonic_attention, klass): + class FixedStrideMonotonicAttention(monotonic_attention): + def __init__(self, args): + super().__init__(args) + self.pre_decision_type = args.fixed_pre_decision_type + self.pre_decision_ratio = args.fixed_pre_decision_ratio + self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold + if self.pre_decision_ratio == 1: + return + + if args.fixed_pre_decision_type == "average": + self.pooling_layer = torch.nn.AvgPool1d( + kernel_size=self.pre_decision_ratio, + stride=self.pre_decision_ratio, + ceil_mode=True, + ) + elif args.fixed_pre_decision_type == "last": + + def last(key): + if key.size(2) < self.pre_decision_ratio: + return key + else: + k = key[ + :, + :, + self.pre_decision_ratio - 1 :: self.pre_decision_ratio, + ].contiguous() + if key.size(-1) % self.pre_decision_ratio != 0: + k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() + return k + + self.pooling_layer = last + else: + raise NotImplementedError + + @staticmethod + def add_args(parser): + super( + FixedStrideMonotonicAttention, FixedStrideMonotonicAttention + ).add_args(parser) + parser.add_argument( + "--fixed-pre-decision-ratio", + type=int, + required=True, + help=( + "Ratio for the fixed pre-decision," + "indicating how many encoder steps will start" + "simultaneous decision making process." + ), + ) + parser.add_argument( + "--fixed-pre-decision-type", + default="average", + choices=["average", "last"], + help="Pooling type", + ) + parser.add_argument( + "--fixed-pre-decision-pad-threshold", + type=float, + default=0.3, + help="If a part of the sequence has pad" + ",the threshold the pooled part is a pad.", + ) + + def insert_zeros(self, x): + bsz_num_heads, tgt_len, src_len = x.size() + stride = self.pre_decision_ratio + weight = F.pad(x.new_ones(1, 1, 1), (stride - 1, 0)) + x_upsample = F.conv_transpose1d( + x.view(-1, src_len).unsqueeze(1), + weight, + stride=stride, + padding=0, + ) + return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) + + def p_choose( + self, + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ): + + if self.pre_decision_ratio == 1: + return super().p_choose( + self, + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ) + + key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) + + if key_padding_mask is not None: + key_padding_mask_pool = ( + self.pooling_layer(key_padding_mask.unsqueeze(0).float()) + .squeeze(0) + .gt(self.pre_decision_pad_threshold) + ) + # Make sure at least one element is not pad + key_padding_mask_pool[:, 0] = 0 + else: + key_padding_mask_pool = None + + p_choose_pooled = super().p_choose( + query, + key_pool, + key_padding_mask_pool, + incremental_state=incremental_state, + ) + + # Upsample, interpolate zeros + p_choose = self.insert_zeros(p_choose_pooled) + + # can be larger than src_len because we used ceil before + src_len = key.size(0) + p_choose = p_choose[:, :, :src_len] + p_choose[:, :, -1] = p_choose_pooled[:, :, -1] + + tgt_len = query.size(0) + batch_size = query.size(1) + + assert list(p_choose.size()) == [ + batch_size * self.num_heads, + tgt_len, + src_len, + ] + + return p_choose + + FixedStrideMonotonicAttention.__name__ = klass.__name__ + return FixedStrideMonotonicAttention + + return partial(create_model, monotonic_attention) + + +@register_monotonic_attention("waitk_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionWaitK) +class MonotonicMultiheadAttentionWaitkFixedStride: + pass + + +@register_monotonic_attention("hard_aligned_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionHardAligned) +class MonotonicMultiheadAttentionHardFixedStride: + pass + + +@register_monotonic_attention("infinite_lookback_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionInfiniteLookback) +class MonotonicMultiheadAttentionInfiniteLookbackFixedStride: + pass diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index c09725ac9a..5423f26c34 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn + import torch.nn.functional as F from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, @@ -30,6 +31,7 @@ def __init__(self, args): self.eps = args.attention_eps self.mass_preservation = args.mass_preservation + self.noise_type = args.noise_type self.noise_mean = args.noise_mean self.noise_var = args.noise_var @@ -43,23 +45,26 @@ def __init__(self, args): @staticmethod def add_args(parser): # fmt: off - parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation", + parser.add_argument('--no-mass-preservation', action="store_false", + dest="mass_preservation", help='Do not stay on the last token when decoding') - parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation", + parser.add_argument('--mass-preservation', action="store_true", + dest="mass_preservation", help='Stay on the last token when decoding') parser.set_defaults(mass_preservation=True) - parser.add_argument('--noise-var', type=float, default=1.0, help='Variance of discretness noise') parser.add_argument('--noise-mean', type=float, default=0.0, help='Mean of discretness noise') - parser.add_argument('--energy-bias', action="store_true", default=False, + parser.add_argument('--noise-type', type=str, default="flat", + help='Type of discretness noise') + parser.add_argument('--energy-bias', action="store_true", + default=False, help='Bias for energy') parser.add_argument('--energy-bias-init', type=float, default=-2.0, help='Initial value of the bias for energy') parser.add_argument('--attention-eps', type=float, default=1e-6, help='Epsilon when calculating expected attention') - # fmt: on def p_choose(self, *args): raise NotImplementedError @@ -67,7 +72,9 @@ def p_choose(self, *args): def input_projections(self, *args): raise NotImplementedError - def attn_energy(self, q_proj, k_proj, key_padding_mask=None): + def attn_energy( + self, q_proj, k_proj, key_padding_mask=None, attn_mask=None + ): """ Calculating monotonic energies @@ -82,7 +89,13 @@ def attn_energy(self, q_proj, k_proj, key_padding_mask=None): bsz = bsz // self.num_heads src_len = k_proj.size(1) - attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + attn_energy = ( + torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + ) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_energy += attn_mask attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) @@ -102,7 +115,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask): q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij - parellel solution: + Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ @@ -139,21 +152,40 @@ def expected_alignment_train(self, p_choose, key_padding_mask): if self.mass_preservation: # Last token has the residual probabilities - alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) - - assert not torch.isnan(alpha).any(), "NaN detected in alpha." + if key_padding_mask is not None and key_padding_mask[:, -1].any(): + # right padding + batch_size = key_padding_mask.size(0) + residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) + src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) + src_lens = src_lens.expand( + batch_size, self.num_heads + ).contiguous().view(-1, 1) + src_lens = src_lens.expand(-1, tgt_len).contiguous() + # add back the last value + residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) + alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) + else: + residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) + alpha[:, :, -1] = residuals + + if torch.isnan(alpha).any(): + # Something is wrong + raise RuntimeError("NaN in alpha.") return alpha - def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state): + def expected_alignment_infer( + self, p_choose, encoder_padding_mask, incremental_state + ): + # TODO modify this function """ Calculating mo alignment for MMA during inference time ============================================================ Expected input size p_choose: bsz * num_heads, tgt_len, src_len - key_padding_mask: bsz * src_len incremental_state: dict + encodencoder_padding_mask: bsz * src_len """ # p_choose: bsz * self.num_heads, src_len bsz_num_heads, tgt_len, src_len = p_choose.size() @@ -166,7 +198,8 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # prev_monotonic_step: bsz, num_heads bsz = bsz_num_heads // self.num_heads prev_monotonic_step = monotonic_cache.get( - "step", p_choose.new_zeros([bsz, self.num_heads]).long() + "head_step", + p_choose.new_zeros([bsz, self.num_heads]).long() ) bsz, num_heads = prev_monotonic_step.size() assert num_heads == self.num_heads @@ -175,8 +208,9 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # p_choose: bsz, num_heads, src_len p_choose = p_choose.view(bsz, num_heads, src_len) - if key_padding_mask is not None: - src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long() + if encoder_padding_mask is not None: + src_lengths = src_len - \ + encoder_padding_mask.sum(dim=1, keepdim=True).long() else: src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len @@ -186,16 +220,16 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step = prev_monotonic_step step_offset = 0 - if key_padding_mask is not None: - if key_padding_mask[:, 0].any(): + if encoder_padding_mask is not None: + if encoder_padding_mask[:, 0].any(): # left_pad_source = True: - step_offset = key_padding_mask.sum(dim=-1, keepdim=True) + step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) max_steps = src_lengths - 1 if self.mass_preservation else src_lengths # finish_read: bsz, num_heads finish_read = new_monotonic_step.eq(max_steps) - + p_choose_i = 1 while finish_read.sum().item() < bsz * self.num_heads: # p_choose: bsz * self.num_heads, src_len # only choose the p at monotonic steps @@ -224,23 +258,34 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step += action finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - # finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read - monotonic_cache["step"] = new_monotonic_step + if p_choose_i is None: + import pdb;pdb.set_trace() + + monotonic_cache["head_step"] = new_monotonic_step + # Whether a head is looking for new input + monotonic_cache["head_read"] = ( + new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + ) # alpha: bsz * num_heads, 1, src_len # new_monotonic_step: bsz, num_heads - alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter( - 1, - (step_offset + new_monotonic_step) - .view(bsz * self.num_heads, 1) - .clamp(0, src_len - 1), - 1, + alpha = ( + p_choose + .new_zeros([bsz * self.num_heads, src_len]) + .scatter( + 1, + (step_offset + new_monotonic_step) + .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), + 1 + ) ) if not self.mass_preservation: alpha = alpha.masked_fill( - (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0 + (new_monotonic_step == max_steps) + .view(bsz * self.num_heads, 1), + 0 ) alpha = alpha.unsqueeze(1) @@ -249,18 +294,28 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state return alpha + def _get_monotonic_buffer(self, incremental_state): + return utils.get_incremental_state( + self, + incremental_state, + 'monotonic', + ) or {} + + def _set_monotonic_buffer(self, incremental_state, buffer): + utils.set_incremental_state( + self, + incremental_state, + 'monotonic', + buffer, + ) + def v_proj_output(self, value): raise NotImplementedError def forward( - self, - query, - key, - value, - key_padding_mask=None, - incremental_state=None, - *args, - **kwargs, + self, query, key, value, + key_padding_mask=None, attn_mask=None, incremental_state=None, + need_weights=True, static_kv=False, *args, **kwargs ): tgt_len, bsz, embed_dim = query.size() @@ -268,26 +323,31 @@ def forward( # stepwise prob # p_choose: bsz * self.num_heads, tgt_len, src_len - p_choose = self.p_choose(query, key, key_padding_mask) + p_choose = self.p_choose( + query, key, key_padding_mask, incremental_state, + ) # expected alignment alpha # bsz * self.num_heads, tgt_len, src_len if incremental_state is not None: alpha = self.expected_alignment_infer( - p_choose, key_padding_mask, incremental_state - ) + p_choose, key_padding_mask, incremental_state) else: - alpha = self.expected_alignment_train(p_choose, key_padding_mask) + alpha = self.expected_alignment_train( + p_choose, key_padding_mask) # expected attention beta # bsz * self.num_heads, tgt_len, src_len beta = self.expected_attention( - alpha, query, key, value, key_padding_mask, incremental_state + alpha, query, key, value, + key_padding_mask, attn_mask, + incremental_state ) attn_weights = beta v_proj = self.v_proj_output(value) + attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) @@ -298,67 +358,17 @@ def forward( alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) - return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose} - - def reorder_incremental_state(self, incremental_state, new_order): - """Reorder buffered internal state (for incremental generation).""" - super().reorder_incremental_state(incremental_state, new_order) - input_buffer = self._get_monotonic_buffer(incremental_state) - if input_buffer is not None: - for k in input_buffer.keys(): - input_buffer[k] = input_buffer[k].index_select(0, new_order) - self._set_monotonic_buffer(incremental_state, input_buffer) - - def _get_monotonic_buffer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def _set_monotonic_buffer(self, incremental_state, buffer): - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - buffer, - ) - - def get_pointer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def get_fastest_pointer(self, incremental_state): - return self.get_pointer(incremental_state)["step"].max(0)[0] - - def set_pointer(self, incremental_state, p_choose): - curr_pointer = self.get_pointer(incremental_state) - if len(curr_pointer) == 0: - buffer = torch.zeros_like(p_choose) - else: - buffer = self.get_pointer(incremental_state)["step"] - - buffer += (p_choose < 0.5).type_as(buffer) - - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - {"step": buffer}, - ) + return attn, { + "alpha": alpha, + "beta": beta, + "p_choose": p_choose, + } @register_monotonic_attention("hard_aligned") -class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): +class MonotonicMultiheadAttentionHardAligned( + MonotonicAttention, MultiheadAttention +): def __init__(self, args): MultiheadAttention.__init__( self, @@ -392,39 +402,36 @@ def input_projections(self, query, key, value, name): bsz = query.size(1) q = self.q_in_proj[name](query) q *= self.scaling - q = ( - q.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + q = q.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: q = None if key is not None: bsz = key.size(1) k = self.k_in_proj[name](key) - k = ( - k.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + k = k.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: k = None if value is not None: bsz = value.size(1) v = self.v_in_proj[name](value) - v = ( - v.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + v = v.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: v = None return q, k, v - def p_choose(self, query, key, key_padding_mask=None): + def p_choose( + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args + ): """ Calculating step wise prob for reading and writing 1 to read, 0 to write @@ -440,7 +447,9 @@ def p_choose(self, query, key, key_padding_mask=None): """ # prepare inputs - q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") + q_proj, k_proj, _ = self.input_projections( + query, key, None, "monotonic" + ) # attention energy attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) @@ -473,7 +482,9 @@ def v_proj_output(self, value): @register_monotonic_attention("infinite_lookback") -class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard): +class MonotonicMultiheadAttentionInfiniteLookback( + MonotonicMultiheadAttentionHardAligned +): def __init__(self, args): super().__init__(args) self.init_soft_attention() @@ -498,30 +509,33 @@ def init_soft_attention(self): nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) def expected_attention( - self, alpha, query, key, value, key_padding_mask, incremental_state + self, alpha, query, key, value, + key_padding_mask, attn_mask, incremental_state ): # monotonic attention, we will calculate milk here bsz_x_num_heads, tgt_len, src_len = alpha.size() bsz = int(bsz_x_num_heads / self.num_heads) q, k, _ = self.input_projections(query, key, None, "soft") - soft_energy = self.attn_energy(q, k, key_padding_mask) + soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask) - assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len] + assert list(soft_energy.size()) == \ + [bsz, self.num_heads, tgt_len, src_len] soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) if incremental_state is not None: monotonic_cache = self._get_monotonic_buffer(incremental_state) - monotonic_step = monotonic_cache["step"] + 1 + monotonic_length = monotonic_cache["head_step"] + 1 step_offset = 0 if key_padding_mask is not None: if key_padding_mask[:, 0].any(): # left_pad_source = True: step_offset = key_padding_mask.sum(dim=-1, keepdim=True) - monotonic_step += step_offset + monotonic_length += step_offset mask = lengths_to_mask( - monotonic_step.view(-1), soft_energy.size(2), 1 + monotonic_length.view(-1), + soft_energy.size(2), 1 ).unsqueeze(1) soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) @@ -531,84 +545,81 @@ def expected_attention( beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) else: - # bsz * num_heads, tgt_len, src_len soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] - exp_soft_energy = torch.exp(soft_energy) - exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2) + exp_soft_energy = torch.exp(soft_energy) + self.eps + inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2)) + + beta = ( + exp_soft_energy + * torch.cumsum(inner_items.flip(dims=[2]), dim=2) + .flip(dims=[2]) + ) + + beta = beta.view(bsz, self.num_heads, tgt_len, src_len) if key_padding_mask is not None: - if key_padding_mask.any(): - exp_soft_energy_cumsum = ( - exp_soft_energy_cumsum.view( - -1, self.num_heads, tgt_len, src_len - ) - .masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps - ) - .view(-1, tgt_len, src_len) - ) - - inner_items = alpha / exp_soft_energy_cumsum - - beta = exp_soft_energy * torch.cumsum( - inner_items.flip(dims=[2]), dim=2 - ).flip(dims=[2]) + beta = beta.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), 0) + beta = beta / beta.sum(dim=3, keepdim=True) + beta = beta.view(bsz * self.num_heads, tgt_len, src_len) beta = self.dropout_module(beta) - assert not torch.isnan(beta).any(), "NaN detected in beta." + if torch.isnan(beta).any(): + # Something is wrong + raise RuntimeError("NaN in beta.") return beta @register_monotonic_attention("waitk") -class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback): +class MonotonicMultiheadAttentionWaitK( + MonotonicMultiheadAttentionInfiniteLookback +): def __init__(self, args): super().__init__(args) self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"] self.waitk_lagging = args.waitk_lagging - assert ( - self.waitk_lagging > 0 - ), f"Lagging has to been larger than 0, get {self.waitk_lagging}." + assert self.waitk_lagging > 0, ( + f"Lagging has to been larger than 0, get {self.waitk_lagging}." + ) @staticmethod def add_args(parser): super( - MonotonicMultiheadAttentionWaitk, - MonotonicMultiheadAttentionWaitk, + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionWaitK, ).add_args(parser) parser.add_argument( - "--waitk-lagging", type=int, required=True, help="Wait k lagging" + "--waitk-lagging", type=int, required=True, help="Wait K lagging" ) def p_choose( - self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args ): """ query: bsz, tgt_len key: bsz, src_len key_padding_mask: bsz, src_len """ - src_len, bsz, _ = key.size() - tgt_len, bsz, _ = query.size() + if incremental_state is not None: + tgt_len = int(incremental_state["steps"]["tgt"]) + src_len = int(incremental_state["steps"]["src"]) + bsz = 1 + else: + src_len, bsz, _ = key.size() + tgt_len, bsz, _ = query.size() + p_choose = query.new_ones(bsz, tgt_len, src_len) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) - if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): - # Left pad source - # add -1 to the end - p_choose = p_choose.masked_fill( - key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1 - ) - p_choose = convert_padding_direction( - p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True - ) - p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) - # remove -1 - p_choose[p_choose.eq(-1)] = 0 + if incremental_state is not None: + p_choose = p_choose[:, -1:] + tgt_len = 1 # Extend to each head p_choose = ( diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index 442b7d487d..e6e1850a18 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -26,11 +26,19 @@ def __init__( add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) + + assert args.simul_type is not None, "A --simul-type is needed." + self.encoder_attn = build_monotonic_attention(args) self.encoder_attn_layer_norm = LayerNorm( self.embed_dim, export=getattr(args, "char_inputs", False) ) + def get_head_steps(self, incremental_state): + return self.encoder_attn._get_monotonic_buffer(incremental_state).get( + "head_step" + ) + def prune_incremental_state(self, incremental_state): def prune(module): input_buffer = module._get_input_buffer(incremental_state) diff --git a/examples/simultaneous_translation/utils/data_utils.py b/examples/simultaneous_translation/utils/data_utils.py new file mode 100644 index 0000000000..cc4729e63c --- /dev/null +++ b/examples/simultaneous_translation/utils/data_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calc_mean_invstddev(feature): + if len(feature.size()) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = feature.mean(0) + var = feature.var(0) + # avoid division by ~zero + eps = 1e-8 + if (var < eps).any(): + return mean, 1.0 / (torch.sqrt(var) + eps) + return mean, 1.0 / torch.sqrt(var) + + +def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def lengths_to_encoder_padding_mask(lengths, batch_first=False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = 0 for t < lengths[b] and 1 otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) >= lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +def encoder_padding_mask_to_lengths( + encoder_padding_mask, max_lengths, batch_size, device +): + """ + convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor + + Conventionally, encoder output contains a encoder_padding_mask, which is + a 2-D mask in a shape (T, B), whose (t, b) element indicate whether + encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we + need to convert this mask tensor to a 1-D tensor in shape (B, ), where + [b] denotes the valid length of b-th sequence + + Args: + encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, + indicating all are valid + Return: + seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the + number of valid elements of b-th sequence + + max_lengths: maximum length of all sequence, if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(0) + + batch_size: batch size; if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(1) + + device: which device to put the result on + """ + if encoder_padding_mask is None: + return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) + + assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" + assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" + + return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 988ed83d77..f639d300d3 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -37,6 +37,7 @@ audio paths (one per line) as inputs. - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) - [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md) +- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md) ## Updates - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md new file mode 100644 index 0000000000..5dea0d8475 --- /dev/null +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -0,0 +1,52 @@ +# Simultaneous Speech Translation (SimulST) on MuST-C + +[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. + +## Data Preparation & ASR +Please follow the steps in offline [speech-to-text](../mustc_example.md) translation for data preparation and ASR pretraining. + +## Training + +#### Wait-K(K=3) with fixed pre-decision module +``` + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --criterion label_smoothed_cross_entropy \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --task speech_to_text \ + --arch convtransformer_simul_trans_espnet \ + --simul-type waitk_fixed_pre_decision \ + --waitk-lagging 3 \ + --fixed-pre-decision-ratio 7 +``` +#### Monotonic multihead attention with fixed pre-decision module +``` + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --task speech_to_text \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-avg 0.1 \ + --arch convtransformer_simul_trans_espnet \ + --simul-type infinite_lookback_fixed_pre_decision \ + --fixed-pre-decision-ratio 7 +``` +## Inference & Evaluation +[SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. +``` +simuleval \ + --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --src-file ${SRC_LIST_OF_AUDIO} + --tgt-file ${TGT_FILE} + --data-bin ${MUSTC_ROOT}/en-de \ + --model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --tgt-splitter-type SentencePieceModel \ + --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ + --scores +``` diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py new file mode 100644 index 0000000000..cbe8bc4322 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -0,0 +1,331 @@ +import math +import os + +import numpy as np +import torch +import torchaudio.compliance.kaldi as kaldi +import yaml +from fairseq import checkpoint_utils, tasks + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent + from simuleval.states import ListEntry +except ImportError: + print("Please install simuleval 'pip install simuleval'") + + +SHIFT_SIZE = 10 +WINDOW_SIZE = 25 +SAMPLE_RATE = 16000 +FEATURE_DIM = 80 +BOW_PREFIX = "\u2581" + + +class OnlineFeatureExtractor: + """ + Extract speech feature on the fly. + """ + + def __init__( + self, + shift_size=SHIFT_SIZE, + window_size=WINDOW_SIZE, + sample_rate=SAMPLE_RATE, + feature_dim=FEATURE_DIM, + global_cmvn=None, + ): + self.shift_size = shift_size + self.window_size = window_size + assert self.window_size >= self.shift_size + + self.sample_rate = sample_rate + self.feature_dim = feature_dim + self.num_samples_per_shift = int(SHIFT_SIZE * SAMPLE_RATE / 1000) + self.num_samples_per_window = int(WINDOW_SIZE * SAMPLE_RATE / 1000) + self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 + self.previous_residual_samples = [] + self.global_cmvn = global_cmvn + + def clear_cache(self): + self.previous_residual_samples = [] + + def __call__(self, new_samples): + samples = self.previous_residual_samples + new_samples + if len(samples) < self.num_samples_per_window: + self.previous_residual_samples = samples + return + + # num_frames is the number of frames from the new segment + num_frames = math.floor( + (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) + / self.num_samples_per_shift + ) + + # the number of frames used for feature extraction + # including some part of thte previous segment + effective_num_samples = int( + num_frames * self.len_ms_to_samples(self.shift_size) + + self.len_ms_to_samples(self.window_size - self.shift_size) + ) + + input_samples = samples[:effective_num_samples] + self.previous_residual_samples = samples[ + num_frames * self.num_samples_per_shift : + ] + + torch.manual_seed(1) + output = kaldi.fbank( + torch.FloatTensor(input_samples).unsqueeze(0), + num_mel_bins=self.feature_dim, + frame_length=self.window_size, + frame_shift=self.shift_size, + ).numpy() + + output = self.transform(output) + + return torch.from_numpy(output) + + def transform(self, input): + if self.global_cmvn is None: + return input + + mean = self.global_cmvn["mean"] + std = self.global_cmvn["std"] + + x = np.subtract(input, mean) + x = np.divide(x, std) + return x + + +class TensorListEntry(ListEntry): + """ + Data structure to store a list of tensor. + """ + + def append(self, value): + + if len(self.value) == 0: + self.value = value + return + + self.value = torch.cat([self.value] + [value], dim=0) + + def info(self): + return { + "type": str(self.new_value_type), + "length": self.__len__(), + "value": "" if type(self.value) is list else self.value.size(), + } + + +class FairseqSimulSTAgent(SpeechAgent): + + speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size + + def __init__(self, args): + super().__init__(args) + + self.eos = DEFAULT_EOS + + self.gpu = getattr(args, "gpu", False) + + self.args = args + + self.load_model_vocab(args) + + config_yaml = os.path.join(args.data_bin, "config.yaml") + with open(config_yaml, "r") as f: + config = yaml.load(f) + + if "global_cmvn" in config: + global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + else: + global_cmvn = None + + self.feature_extractor = OnlineFeatureExtractor(global_cmvn=global_cmvn) + + self.max_len = args.max_len + + self.force_finish = args.force_finish + + torch.set_grad_enabled(False) + + def to_device(self, tensor): + if self.gpu: + return tensor.cuda() + else: + return tensor.cpu() + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--model-path', type=str, required=True, + help='path to your pretrained model.') + parser.add_argument("--data-bin", type=str, required=True, + help="Path of data binary") + parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for target text") + parser.add_argument("--tgt-splitter-path", type=str, default=None, + help="Subword splitter model path for target text") + parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", + help="User directory for simultaneous translation") + parser.add_argument("--max-len", type=int, default=200, + help="Max length of translation") + parser.add_argument("--force-finish", default=False, action="store_true", + help="") + # fmt: on + return parser + + def load_model_vocab(self, args): + + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + task_args = state["cfg"]["task"] + task_args.data = args.data_bin + + task = tasks.setup_task(task_args) + + # build model for ensemble + self.model = task.build_model(state["cfg"]["model"]) + self.model.load_state_dict(state["model"], strict=True) + self.model.eval() + self.model.share_memory() + + if self.gpu: + self.model.cuda() + + # Set dictionary + self.dict = {} + self.dict["tgt"] = task.target_dictionary + + def initialize_states(self, states): + self.feature_extractor.clear_cache() + states.units.source = TensorListEntry() + states.units.target = ListEntry() + states.incremental_states = dict() + + def segment_to_units(self, segment, states): + # Convert speech samples to features + features = self.feature_extractor(segment) + if features is not None: + return [features] + else: + return [] + + def units_to_segment(self, units, states): + # Merge sub word to full word. + if self.model.decoder.dictionary.eos() == units[0]: + return DEFAULT_EOS + + segment = [] + if None in units.value: + units.value.remove(None) + + for index in units: + if index is None: + units.pop() + token = self.model.decoder.dictionary.string([index]) + if token.startswith(BOW_PREFIX): + if len(segment) == 0: + segment += [token.replace(BOW_PREFIX, "")] + else: + for j in range(len(segment)): + units.pop() + + string_to_return = ["".join(segment)] + + if self.model.decoder.dictionary.eos() == units[0]: + string_to_return += [DEFAULT_EOS] + + return string_to_return + else: + segment += [token.replace(BOW_PREFIX, "")] + + if ( + len(units) > 0 + and self.model.decoder.dictionary.eos() == units[-1] + or len(states.units.target) > self.max_len + ): + tokens = [self.model.decoder.dictionary.string([unit]) for unit in units] + return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS] + + return None + + def update_model_encoder(self, states): + if len(states.units.source) == 0: + return + src_indices = self.to_device(states.units.source.value.unsqueeze(0)) + src_lengths = self.to_device( + torch.LongTensor([states.units.source.value.size(0)]) + ) + print(src_lengths) + + states.encoder_states = self.model.encoder(src_indices, src_lengths) + torch.cuda.empty_cache() + + def update_states_read(self, states): + # Happens after a read action. + self.update_model_encoder(states) + + def policy(self, states): + if not getattr(states, "encoder_states", None): + return READ_ACTION + + tgt_indices = self.to_device( + torch.LongTensor( + [self.model.decoder.dictionary.eos()] + + [x for x in states.units.target.value if x is not None] + ).unsqueeze(0) + ) + + states.incremental_states["steps"] = { + "src": states.encoder_states["encoder_out"][0].size(0), + "tgt": 1 + len(states.units.target), + } + + states.incremental_states["online"] = True + + x, outputs = self.model.decoder.forward( + prev_output_tokens=tgt_indices, + encoder_out=states.encoder_states, + incremental_state=states.incremental_states, + # features_only=True, + ) + + states.decoder_out = x + + states.decoder_out_extra = outputs + + torch.cuda.empty_cache() + + if outputs["action"] == 0: + return READ_ACTION + else: + return WRITE_ACTION + + def predict(self, states): + decoder_states = states.decoder_out + + lprobs = self.model.get_normalized_probs( + [decoder_states[:, -1:]], log_probs=True + ) + + index = lprobs.argmax(dim=-1) + + torch.cuda.empty_cache() + + index = index[0, 0].item() + + if ( + self.force_finish + and index == self.model.decoder.dictionary.eos() + and not states.finish_read() + ): + index = None + + return index diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py new file mode 100644 index 0000000000..45df5fa227 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +from fairseq import checkpoint_utils, utils, tasks + +from . import DEFAULT_EOS, GET, SEND +from .agent import Agent + + +class SimulTransAgent(Agent): + def __init__(self, args): + # Load Model + self.load_model(args) + + # build word spliter + self.build_word_splitter(args) + + self.max_len = args.max_len + + self.eos = DEFAULT_EOS + + @staticmethod + def add_args(parser): + parser.add_argument( + "--model-path", + type=str, + required=True, + help="path to your pretrained model.", + ) + parser.add_argument( + "--data-bin", type=str, required=True, help="Path of data binary" + ) + parser.add_argument( + "--user-dir", + type=str, + default="example/simultaneous_translation", + help="User directory for simultaneous translation", + ) + parser.add_argument( + "--src-splitter-type", + type=str, + default=None, + help="Subword splitter type for source text", + ) + parser.add_argument( + "--tgt-splitter-type", + type=str, + default=None, + help="Subword splitter type for target text", + ) + parser.add_argument( + "--src-splitter-path", + type=str, + default=None, + help="Subword splitter model path for source text", + ) + parser.add_argument( + "--tgt-splitter-path", + type=str, + default=None, + help="Subword splitter model path for target text", + ) + parser.add_argument( + "--max-len", + type=int, + default=150, + help="Maximum length difference between source and target prediction", + ) + parser.add_argument( + "--model-overrides", + default="{}", + type=str, + metavar="DICT", + help="A dictionary used to override model args at generation " + "that were used during model training", + ) + # fmt: on + return parser + + def load_dictionary(self, task): + raise NotImplementedError + + def load_model(self, args): + args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") + utils.import_user_module(args) + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, json.loads(args.model_overrides) + ) + + saved_args = state["args"] + saved_args.data = args.data_bin + + task = tasks.setup_task(saved_args) + + # build model for ensemble + self.model = task.build_model(saved_args) + self.model.load_state_dict(state["model"], strict=True) + + # Set dictionary + self.load_dictionary(task) + + def init_states(self): + return { + "indices": {"src": [], "tgt": []}, + "tokens": {"src": [], "tgt": []}, + "segments": {"src": [], "tgt": []}, + "steps": {"src": 0, "tgt": 0}, + "finished": False, + "finish_read": False, + "model_states": {}, + } + + def update_states(self, states, new_state): + raise NotImplementedError + + def policy(self, states): + # Read and Write policy + action = None + + while action is None: + if states["finished"]: + # Finish the hypo by sending eos to server + return self.finish_action() + + # Model make decision given current states + decision = self.model.decision_from_states(states) + + if decision == 0 and not self.finish_read(states): + # READ + action = self.read_action(states) + else: + # WRITE + action = self.write_action(states) + + # None means we make decision again but not sending server anything + # This happened when read a buffered token + # Or predict a subword + return action + + def finish_read(self, states): + raise NotImplementedError + + def write_action(self, states): + token, index = self.model.predict_from_states(states) + + if ( + index == self.dict["tgt"].eos() + or len(states["tokens"]["tgt"]) > self.max_len + ): + # Finish this sentence is predict EOS + states["finished"] = True + end_idx_last_full_word = self._target_length(states) + + else: + states["tokens"]["tgt"] += [token] + end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( + states["tokens"]["tgt"] + ) + self._append_indices(states, [index], "tgt") + + if end_idx_last_full_word > states["steps"]["tgt"]: + # Only sent detokenized full words to the server + word = self.word_splitter["tgt"].merge( + states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] + ) + states["steps"]["tgt"] = end_idx_last_full_word + states["segments"]["tgt"] += [word] + + return {"key": SEND, "value": word} + else: + return None + + def read_action(self, states): + return {"key": GET, "value": None} + + def finish_action(self): + return {"key": SEND, "value": DEFAULT_EOS} + + def reset(self): + pass + + def finish_eval(self, states, new_state): + if len(new_state) == 0 and len(states["indices"]["src"]) == 0: + return True + return False + + def _append_indices(self, states, new_indices, key): + states["indices"][key] += new_indices + + def _target_length(self, states): + return len(states["tokens"]["tgt"]) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 5d7f59b3a6..28e3bb720f 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -4,4 +4,6 @@ # LICENSE file in the root directory of this source tree. from .berard import * # noqa +from .convtransformer import * # noqa +from .convtransformer_simul_trans import * # noqa from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py new file mode 100644 index 0000000000..512ee78be0 --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +import logging +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from examples.simultaneous_translation.utils.data_utils import ( + lengths_to_encoder_padding_mask, +) +from fairseq import checkpoint_utils, utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer +from torch import Tensor + +logger = logging.getLogger(__name__) + + +@register_model("convtransformer") +class ConvTransformerModel(FairseqEncoderDecoderModel): + """ + Transformer-based Speech translation model from ESPNet-ST + https://arxiv.org/abs/2004.10234 + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension (extra linear layer if different from decoder embed dim)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="INT", + help="the number of output channels of conv layer", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + @staticmethod + @torch.jit.unused + def set_batch_first(lprobs): + lprobs.batch_first = True + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + if self.training: + self.set_batch_first(lprobs) + return lprobs + + def output_layout(self): + return "BTD" + + """ + The forward method inherited from the base class has a **kwargs argument in + its input, which is not supported in torchscript. This method overrites the forward + method definition without **kwargs. + """ + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + return decoder_out + + +class ConvTransformerEncoder(FairseqEncoder): + """Conv + Transformer encoder""" + + def __init__(self, args): + """Construct an Encoder object.""" + super().__init__(None) + + self.dropout = args.dropout + self.embed_scale = ( + 1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim) + ) + self.padding_idx = 1 + self.in_channels = 1 + self.input_dim = args.input_feat_per_channel + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2), + torch.nn.ReLU(), + torch.nn.Conv2d( + args.conv_out_channels, + args.conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = self.infer_conv_output_dim( + self.in_channels, self.input_dim, args.conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, + args.encoder_embed_dim, + self.padding_idx, + learned=False, + ) + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def pooling_ratio(self): + return 4 + + def infer_conv_output_dim(self, in_channels, input_dim, out_channels): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x) + x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + def forward(self, src_tokens, src_lengths): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + + input_lengths = min( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([1]), + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + + if not encoder_padding_mask.any(): + maybe_encoder_padding_mask = None + else: + maybe_encoder_padding_mask = encoder_padding_mask + + return { + "encoder_out": [x], + "encoder_padding_mask": [maybe_encoder_padding_mask] + if maybe_encoder_padding_mask is not None + else [], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + (encoder_out["encoder_padding_mask"][0]).index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + (encoder_out["encoder_embedding"][0]).index_select(0, new_order) + ] + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, + "encoder_padding_mask": new_encoder_padding_mask, + "encoder_embedding": new_encoder_embedding, + "encoder_states": encoder_states, + "src_tokens": [], + "src_lengths": [], + } + + +class TransformerDecoderNoExtra(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="convtransformer", arch_name="convtransformer") +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.max_source_positions = getattr(args, "max_source_positions", 3000) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim) + + +@register_model_architecture("convtransformer", "convtransformer_espnet") +def convtransformer_espnet(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/fairseq/models/speech_to_text/convtransformer_simul_trans.py new file mode 100644 index 0000000000..e5dd771e03 --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer_simul_trans.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from examples.simultaneous_translation.models.transformer_monotonic_attention import ( + TransformerMonotonicDecoder, +) +from fairseq import checkpoint_utils +from fairseq.models import ( + register_model, + register_model_architecture, +) + +from .convtransformer import ConvTransformerModel, convtransformer_espnet + + +@register_model("convtransformer_simul_trans") +class SimulConvTransformerModel(ConvTransformerModel): + @staticmethod + def add_args(parser): + super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) + parser.add_argument( + "--train-monotonic-only", + action="store_true", + default=False, + help="Only train monotonic attention", + ) + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + tgt_dict = task.tgt_dict + + decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) + + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + +@register_model_architecture( + "convtransformer_simul_trans", "convtransformer_simul_trans_espnet" +) +def convtransformer_simul_trans_espnet(args): + convtransformer_espnet(args) From 523fe83828e6374439a6203330ed0e8c13e86b62 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Thu, 18 Feb 2021 22:41:32 -0800 Subject: [PATCH 470/707] Integrate Simul ST model into pyspeech Summary: This diff integrates simul ST training into pyspeech with very minor modifications to the open sourced code. Specific changes made are - In fixed_pre_decision.py remove self as argument to p_choose function as it is already called with super in line 101 - In monotonic_multihead_attention.py remove pdb.set_trace() - Move label_smoothed_cross_entropy_latency_augmented.py to fairseq/criterions folder and add missing arguments to parser - In fairseq/data/data_utils.py type cast max_tokens to int to avoid type error. - Update fairseq/convtransformer.py to pyspeech/convtransformer.py # Next steps: - Verify decoding using the model trained - Support everstore handle based decoding in simuleval and integrate it into pyspeech. Reviewed By: jmp84 Differential Revision: D26478861 fbshipit-source-id: 3b02b2aee757e5464b71dbdd7ebdba42659faee5 --- .../modules/fixed_pre_decision.py | 1 - .../modules/monotonic_multihead_attention.py | 2 -- ...moothed_cross_entropy_latency_augmented.py | 22 ++++++++++++++++++- fairseq/data/data_utils.py | 5 ++++- .../models/speech_to_text/convtransformer.py | 6 +---- 5 files changed, 26 insertions(+), 10 deletions(-) rename {examples/simultaneous_translation => fairseq}/criterions/label_smoothed_cross_entropy_latency_augmented.py (86%) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index 2cde55b35e..725be1a983 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -99,7 +99,6 @@ def p_choose( if self.pre_decision_ratio == 1: return super().p_choose( - self, query, key, key_padding_mask=None, diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 5423f26c34..3e25957cd6 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -259,8 +259,6 @@ def expected_alignment_infer( finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - if p_choose_i is None: - import pdb;pdb.set_trace() monotonic_cache["head_step"] = new_monotonic_step # Whether a head is looking for new input diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py similarity index 86% rename from examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py rename to fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index 761cfe61a1..aa3dba31e2 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -53,8 +53,28 @@ def add_args(parser): LatencyAugmentedLabelSmoothedCrossEntropyCriterion, LatencyAugmentedLabelSmoothedCrossEntropyCriterion, ).add_args(parser) - """Add criterion-specific arguments to the parser.""" # fmt: off + + """Add criterion-specific arguments to the parser.""" + parser.add_argument( + "--label-smoothing", + default=0.0, + type=float, + metavar="D", + help="epsilon for label smoothing, 0 means no label smoothing", + ) + parser.add_argument( + "--ignore_prefix_size", + default=0, + type=int, + help="ignore first N tokens", + ) + parser.add_argument( + "--report-accuracy", + default=False, + type=bool, + help="report accuracy metric", + ) parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', help="Average loss weight") parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 47d8492ec9..3042358f2f 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -313,7 +313,10 @@ def batch_by_size( " --editable .` or `python setup.py build_ext --inplace`." ) - max_tokens = max_tokens if max_tokens is not None else -1 + # added int() to avoid TypeError: an integer is required + max_tokens = ( + int(max_tokens) if max_tokens is not None else -1 + ) max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 512ee78be0..06276e636a 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -307,11 +307,7 @@ def forward(self, src_tokens, src_lengths): subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - input_lengths = min( - (src_lengths.float() / subsampling_factor).ceil().long(), - x.size(0) * src_lengths.new_ones([1]), - ) - + input_lengths = (src_lengths.float() / subsampling_factor).ceil().long() encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True ) From 675f608915a216ac32777928a0b1e8210cb66df6 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Fri, 19 Feb 2021 08:59:37 -0800 Subject: [PATCH 471/707] Fix LibriSpeech data prep script Summary: Fix LibriSpeech data prep script * Lowercasing transcript to be consistent with the pre-trained models Reviewed By: jmp84 Differential Revision: D26538845 fbshipit-source-id: 0885f99e2c85f0e722a24f3cb83f2635ce9429bc --- examples/speech_to_text/prep_librispeech_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 6a6f55ded4..7b08447190 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -71,7 +71,7 @@ def process(args): manifest["audio"].append(zip_manifest[sample_id]) duration_ms = int(wav.size(1) / sample_rate * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - manifest["tgt_text"].append(utt) + manifest["tgt_text"].append(utt.lower()) manifest["speaker"].append(spk_id) save_df_to_tsv( pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv" From 2909ee1852cdae7ad4115a1a04520b0522265dd2 Mon Sep 17 00:00:00 2001 From: "joseph.suh" <joseph.suh@netmarble.com> Date: Fri, 19 Feb 2021 10:07:43 -0800 Subject: [PATCH 472/707] Fix bug for issue (#3211) (#3212) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes KeyError mentioned in # (3211). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3212 Reviewed By: alexeib Differential Revision: D26513255 Pulled By: myleott fbshipit-source-id: 5a11cb369c9d4202fab6998d269e7da5f3d3e534 --- fairseq/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e860fb1832..f66dc25e40 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -401,7 +401,7 @@ def load_checkpoint( self.lr_step(epoch) - if itr_state["version"] >= 2 and itr_state["iterations_in_epoch"] == 0: + if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: # reset meters at start of epoch reset_meters = True From 3ef18886d0a802a8c8d90b57d858df3da7e75202 Mon Sep 17 00:00:00 2001 From: Alex Gaziev <alex.gaziev@gmail.com> Date: Fri, 19 Feb 2021 10:10:13 -0800 Subject: [PATCH 473/707] Remove extra arg min_length and fix min_sample_size behavior (#3249) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3178 (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � (I did ;) Pull Request resolved: https://github.com/pytorch/fairseq/pull/3249 Reviewed By: alexeib Differential Revision: D26513275 Pulled By: myleott fbshipit-source-id: 2785098a945404c07eb72c079177654b1739a7a2 --- fairseq/data/audio/raw_audio_dataset.py | 10 +++------- fairseq/tasks/audio_pretraining.py | 5 ++--- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index ac5acd03bb..1d92e4966b 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -23,9 +23,8 @@ def __init__( self, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, ): @@ -37,7 +36,6 @@ def __init__( max_sample_size if max_sample_size is not None else sys.maxsize ) self.min_sample_size = min_sample_size - self.min_length = min_length self.pad = pad self.shuffle = shuffle self.normalize = normalize @@ -136,9 +134,8 @@ def __init__( manifest_path, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, ): @@ -147,7 +144,6 @@ def __init__( max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, - min_length=min_length, pad=pad, normalize=normalize, ) @@ -162,7 +158,7 @@ def __init__( items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) - if min_length is not None and sz < min_length: + if min_sample_size is not None and sz < min_sample_size: skipped += 1 continue self.fnames.append(items[0]) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 92685160d4..b7b5429819 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -57,7 +57,7 @@ class AudioPretrainingConfig(FairseqDataclass): default=None, metadata={"help": "max sample size to crop to for batching"} ) min_sample_size: Optional[int] = field( - default=None, metadata={"help": "min sample size to crop to for batching"} + default=None, metadata={"help": "min sample size to skip small examples"} ) # Options for reporting WER metrics during validation. Only applicable to @@ -135,8 +135,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): manifest, sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, - min_sample_size=self.cfg.max_sample_size, - min_length=self.cfg.min_sample_size, + min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, ) From c6b5c00312dc23f473c66ba3016cc9e3decfd317 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Fri, 19 Feb 2021 10:31:08 -0800 Subject: [PATCH 474/707] fix criterion name check when resuming from checkpoint Summary: I tried resuming a run from a checkpoint in f250883864, but ran into: AssertionError: Criterion does not match; please reset the optimizer (--reset-optimizer). DistributedTimeoutWrapper vs ContrastiveLabelsCriterion Based on this, I believe since D25836853 (https://github.com/pytorch/fairseq/commit/d68a3530dda7f8275e490864b28974ef30fe854b) we are no longer saving the actual criterion's name, but DistributedTimeoutWrapper in the checkpoint. This is kind of weird though, as I would expect more people to run into this issue. Not sure if I am doing something wrong, let me know if so, thanks! Reviewed By: myleott Differential Revision: D26478656 fbshipit-source-id: bc3c7c925f5505140d9df4438af3a73d65d4f531 --- fairseq/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index f66dc25e40..891155f162 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -284,7 +284,7 @@ def save_checkpoint(self, filename, extra_state): filename, self.cfg, self.model.state_dict(), - self.criterion, + self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), @@ -375,10 +375,10 @@ def load_checkpoint( last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ - ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." + ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ - ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) From ae22da652d63bd6e05a9a035f6a9dcabb1a39c73 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Fri, 19 Feb 2021 21:52:31 -0800 Subject: [PATCH 475/707] Correct the estimation of cnn output lengths in convtransformer (#1636) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1636 Reviewed By: xutaima Differential Revision: D26562816 Pulled By: jmp84 fbshipit-source-id: 4e6efd0b4236d7187bd365d790f260bd5297aed5 --- fairseq/models/speech_to_text/convtransformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 06276e636a..622b5e6df8 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -30,7 +30,6 @@ class ConvTransformerModel(FairseqEncoderDecoderModel): Transformer-based Speech translation model from ESPNet-ST https://arxiv.org/abs/2004.10234 """ - def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @@ -307,7 +306,11 @@ def forward(self, src_tokens, src_lengths): subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - input_lengths = (src_lengths.float() / subsampling_factor).ceil().long() + input_lengths = torch.min( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() + ) + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True ) From 61e46bb99758e05bc990e3687c69b507a8ebf185 Mon Sep 17 00:00:00 2001 From: Frankie Robertson <frankie@robertson.name> Date: Sat, 20 Feb 2021 06:21:45 -0800 Subject: [PATCH 476/707] Fix attempt to unlink directory copied into source package (Python 3.9) (#3235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [N/A] Did you make sure to update the docs? - [N/A] Did you write any new necessary tests? ## What does this PR do? Currently when installing the newest source package from PyPI I get an error like so: ``` Collecting fairseq Using cached fairseq-0.10.2.tar.gz (938 kB) Installing build dependencies ... done Getting requirements to build wheel ... error ERROR: Command errored out with exit status 1: command: /home/frankier/sources/datasets/.venv/bin/python3 /tmp/tmp_ujftsgi_in_process.py get_requires_for_build_wheel /tmp/tmpmn0eumq2 cwd: /tmp/pip-install-dg5d6q9y/fairseq Complete output (31 lines): Traceback (most recent call last): File "setup.py", line 214, in <module> do_setup(package_data) File "setup.py", line 136, in do_setup setup( File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/__init__.py", line 152, in setup _install_setup_requires(attrs) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/__init__.py", line 147, in _install_setup_requires dist.fetch_build_eggs(dist.setup_requires) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 60, in fetch_build_eggs raise SetupRequirementsError(specifier_list) setuptools.build_meta.SetupRequirementsError: ['cython', 'numpy', 'setuptools>=18.0'] During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/tmp_ujftsgi_in_process.py", line 280, in <module> main() File "/tmp/tmp_ujftsgi_in_process.py", line 263, in main json_out['return_val'] = hook(**hook_input['kwargs']) File "/tmp/tmp_ujftsgi_in_process.py", line 114, in get_requires_for_build_wheel return hook(config_settings) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 149, in get_requires_for_build_wheel return self._get_build_requires( File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 130, in _get_build_requires self.run_setup() File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 145, in run_setup exec(compile(code, __file__, 'exec'), locals()) File "setup.py", line 217, in <module> os.unlink(fairseq_examples) IsADirectoryError: [Errno 21] Is a directory: 'fairseq/examples' ---------------------------------------- ERROR: Command errored out with exit status 1: /home/frankier/sources/datasets/.venv/bin/python3 /tmp/tmp_ujftsgi_in_process.py get_requires_for_build_wheel /tmp/tmpmn0eumq2 Check the logs for full command output. ``` I believe the reason for this is that the source package contains the examples directory because it was put there during package creation (it seems the symlink because a directory). Now, when setup.py is run again, it seems the setup.py attempts to unlink the directory, which is not possible because only symlinks can be unlinked. This PR therefore only attempts to unlink it if it is a symlink. I have not thoroughly tested whether my proposed cause is the true cause, but this should fix it in any case. Note that the source package is fetched because there is no wheel for Python 3.9, so most users will not see this because they will use the wheel. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3235 Reviewed By: alexeib Differential Revision: D26513259 Pulled By: myleott fbshipit-source-id: 775d6c636a5867b9983bb6419829f13ee414e2fd --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d1a976104e..3670ff3cfc 100644 --- a/setup.py +++ b/setup.py @@ -256,5 +256,5 @@ def get_files(path, relative_to="fairseq"): } do_setup(package_data) finally: - if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): os.unlink(fairseq_examples) From 4cf7d76114d50008cdd98a7fde250d4ef99b66fe Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Sat, 20 Feb 2021 06:23:41 -0800 Subject: [PATCH 477/707] Hydra Integration doc should refer to non legacy task (#1619) Summary: # Before submitting - [NO] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [YES] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [YES] Did you make sure to update the docs? - [NO] Did you write any new necessary tests? ## What does this PR do? This is a typo fix to the Hydra Integration doc where the example with dataclass config should user `FairseqTask` and not `LegacyFairseqTask`. Didn't make an issue for this as it's a trivial doc change for the example to match the actual doc. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1619 Reviewed By: huihuifan Differential Revision: D26448855 Pulled By: Mortimerp9 fbshipit-source-id: 467323101b8425370f6bd7c0532e70abb319b337 --- docs/hydra_integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 04c797fe50..6a15298382 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -120,7 +120,7 @@ class LanguageModelingConfig(FairseqDataclass): ... @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(LegacyFairseqTask): +class LanguageModelingTask(FairseqTask): ... @classmethod def setup_task(cls, cfg: LanguageModelingConfig): From 38258a79a42f3ccfa596cc51bbf269cf13c3d799 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Mon, 22 Feb 2021 13:55:06 -0800 Subject: [PATCH 478/707] Update FairseqSimulSTAgent to make it generic and reusable internally Summary: This diff 1. Updates FairseqSimulSTAgent to make it generic and reusable internally [Touches OSS] 2. Adds FBFairseqSimulSTAgent inheriting FairseqSimulSTAgent 3. Add TARGETS file in examples/speech_to_text 4. Update simuleval TARGETS and add a bento kernel for easy testing Reviewed By: jmp84 Differential Revision: D26573214 fbshipit-source-id: f4b71f90693cc878cc771b46a006bcbc83a50124 --- .../agents/fairseq_simul_st_agent.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index cbe8bc4322..32cd0a1f61 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -27,25 +27,18 @@ class OnlineFeatureExtractor: Extract speech feature on the fly. """ - def __init__( - self, - shift_size=SHIFT_SIZE, - window_size=WINDOW_SIZE, - sample_rate=SAMPLE_RATE, - feature_dim=FEATURE_DIM, - global_cmvn=None, - ): - self.shift_size = shift_size - self.window_size = window_size + def __init__(self, args): + self.shift_size = args.shift_size + self.window_size = args.window_size assert self.window_size >= self.shift_size - self.sample_rate = sample_rate - self.feature_dim = feature_dim - self.num_samples_per_shift = int(SHIFT_SIZE * SAMPLE_RATE / 1000) - self.num_samples_per_window = int(WINDOW_SIZE * SAMPLE_RATE / 1000) + self.sample_rate = args.sample_rate + self.feature_dim = args.feature_dim + self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) + self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 self.previous_residual_samples = [] - self.global_cmvn = global_cmvn + self.global_cmvn = args.global_cmvn def clear_cache(self): self.previous_residual_samples = [] @@ -134,16 +127,15 @@ def __init__(self, args): self.load_model_vocab(args) - config_yaml = os.path.join(args.data_bin, "config.yaml") - with open(config_yaml, "r") as f: + with open(args.config, "r") as f: config = yaml.load(f) if "global_cmvn" in config: - global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) else: - global_cmvn = None + args.global_cmvn = None - self.feature_extractor = OnlineFeatureExtractor(global_cmvn=global_cmvn) + self.feature_extractor = OnlineFeatureExtractor(args) self.max_len = args.max_len @@ -164,6 +156,8 @@ def add_args(parser): help='path to your pretrained model.') parser.add_argument("--data-bin", type=str, required=True, help="Path of data binary") + parser.add_argument("--config", type=str, required=True, + help="Path to config yaml file") parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", help="Subword splitter type for target text") parser.add_argument("--tgt-splitter-path", type=str, default=None, @@ -174,9 +168,21 @@ def add_args(parser): help="Max length of translation") parser.add_argument("--force-finish", default=False, action="store_true", help="") + parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, + help="") + parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, + help="") + parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, + help="") + parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, + help="") + # fmt: on return parser + def set_up_task(self, task_args): + return tasks.setup_task(task_args) + def load_model_vocab(self, args): filename = args.model_path @@ -188,7 +194,7 @@ def load_model_vocab(self, args): task_args = state["cfg"]["task"] task_args.data = args.data_bin - task = tasks.setup_task(task_args) + task = self.set_up_task(task_args) # build model for ensemble self.model = task.build_model(state["cfg"]["model"]) From 808b751597d85c098990080d21fd450877dcb242 Mon Sep 17 00:00:00 2001 From: Miguel Del-Agua <miguel.delagua@nuance.com> Date: Mon, 22 Feb 2021 14:21:36 -0800 Subject: [PATCH 479/707] Improve torchscript compatibility of transfomer and transformer pg (#3247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3246 Fixes https://github.com/pytorch/fairseq/issues/3248 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3247 Reviewed By: myleott Differential Revision: D26513267 Pulled By: lematt1991 fbshipit-source-id: 958de0b3a58a0dd2a56bd6c6d7fb2644a89f6746 --- .../pointer_generator_src/transformer_pg.py | 80 +++++++++++++++---- fairseq/models/fairseq_decoder.py | 13 +++ fairseq/models/transformer.py | 47 +++++++++-- tests/test_export.py | 13 +++ 4 files changed, 133 insertions(+), 20 deletions(-) diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index fb40a80836..e109a8e269 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -4,13 +4,12 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List, Tuple import torch import torch.nn as nn from fairseq import metrics, utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS, @@ -155,7 +154,13 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder): to the decoder. """ - def forward(self, src_tokens, src_lengths, **kwargs): + def forward( + self, + src_tokens, + src_lengths: Optional[Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[Tensor] = None + ): """ Runs the `forward()` method of the parent Transformer class. Then adds the source tokens into the encoder output tuple. @@ -169,6 +174,10 @@ def forward(self, src_tokens, src_lengths, **kwargs): shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings Returns: namedtuple: @@ -184,7 +193,15 @@ def forward(self, src_tokens, src_lengths, **kwargs): - **src_tokens** (Tensor): input token ids of shape `(batch, src_len)` """ - encoder_out = super().forward(src_tokens, src_lengths, **kwargs) + encoder_out = self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. return { "encoder_out": encoder_out["encoder_out"], # T x B x C "encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T @@ -236,7 +253,7 @@ def __init__(self, args, dictionary, embed_tokens): def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, alignment_layer: Optional[int] = 0, @@ -248,8 +265,8 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing - encoder_out (EncoderOut, optional): output from the encoder, used - for encoder-side attention + encoder_out (optional): output from the encoder, used for + encoder-side attention incremental_state (dict, optional): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without @@ -284,10 +301,21 @@ def forward( predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) p_gens = torch.sigmoid(p_gens) - x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens) + # Torchscript complains if encoder_out or attn are None because + # `output_layer()` signature expects tensors instead + attn: Optional[Tensor] = extra["attn"][0] + assert encoder_out is not None + assert attn is not None + x = self.output_layer(x, attn, encoder_out["src_tokens"][0], p_gens) return x, extra - def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): + def output_layer( + self, + features: Tensor, + attn: Tensor, + src_tokens: Tensor, + p_gens: Tensor + ) -> Tensor: """ Project features to the vocabulary size and mix with the attention distributions. @@ -296,7 +324,10 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): p_gens = self.force_p_gen # project back to size of vocabulary - logits = super().output_layer(features, **kwargs) + if self.adaptive_softmax is None: + logits = self.output_projection(features) + else: + logits = features batch_size = logits.shape[0] output_length = logits.shape[1] @@ -306,7 +337,7 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # The final output distribution will be a mixture of the normal output # distribution (softmax of logits) and attention weights. - gen_dists = super().get_normalized_probs( + gen_dists = self.get_normalized_probs_scriptable( (logits, None), log_probs=False, sample=None ) gen_dists = torch.mul(gen_dists, p_gens) @@ -330,7 +361,12 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # Final distributions, [batch_size, output_length, num_types]. return gen_dists + attn_dists - def get_normalized_probs(self, net_output, log_probs, sample): + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): """ Get normalized probabilities (or log probs) from a net's output. Pointer-generator network output is already normalized. @@ -375,8 +411,19 @@ class Embedding(nn.Embedding): """ __constants__ = ["unk_idx"] - def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx): - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + # Torchscript: Inheriting from Embedding class produces an error when exporting to Torchscript + # -> RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details + # It's happening because max_norm attribute from nn.Embedding is None by default and it cannot be + # cast to a C++ type + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int], + unk_idx: int, + max_norm: Optional[float] = float("inf"), + ): + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm) self.unk_idx = unk_idx nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) nn.init.constant_(self.weight[padding_idx], 0) @@ -385,7 +432,10 @@ def forward(self, input): input = torch.where( input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input ) - return super().forward(input) + return nn.functional.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) @register_model_architecture( diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 7eeb5c652f..4f1e8b52a2 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -64,6 +64,19 @@ def get_normalized_probs( sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 605cfa65e8..f2f36baf3e 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -422,6 +422,45 @@ def forward( token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of @@ -787,13 +826,11 @@ def extract_features_scriptable( alignment_layer = self.num_layers - 1 # embed positions - positions = ( - self.embed_positions( + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) - if self.embed_positions is not None - else None - ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] diff --git a/tests/test_export.py b/tests/test_export.py index 87e52bd7c1..b380697b9a 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -103,6 +103,19 @@ def test_export_transformer(self): scripted = torch.jit.script(model) _test_save_and_load(scripted) + @unittest.skipIf( + torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + ) + def test_export_transformer_no_token_pos_emb(self): + task, parser = get_dummy_task_and_parser() + TransformerModel.add_args(parser) + args = parser.parse_args([]) + args.no_token_positional_embeddings = True + model = TransformerModel.build_model(args, task) + scripted = torch.jit.script(model) + _test_save_and_load(scripted) + + if __name__ == "__main__": unittest.main() From 89cd70c0f0c096bdbfcfb2ab339a9c8f23540bc0 Mon Sep 17 00:00:00 2001 From: m_fomicheva <mari.fomicheva@gmail.com> Date: Mon, 22 Feb 2021 14:55:33 -0800 Subject: [PATCH 480/707] Fixed scripts and instructions for reproducing the results. (#3264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [Y] Did you make sure to update the docs? - [N] Did you write any new necessary tests? ## What does this PR do? Small fixes in the script and documentation for correctly reproducing the results in the corresponding paper. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3264 Reviewed By: lematt1991 Differential Revision: D26587397 Pulled By: myleott fbshipit-source-id: 3675ec4d4388cafa224d395e08b53667f142cb27 --- examples/unsupervised_quality_estimation/README.md | 6 +++--- examples/unsupervised_quality_estimation/meteor.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index aeb96a14b1..e86a0d13b8 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -55,7 +55,7 @@ Translate ``` CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out -grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out +grep ^H $TMP/fairseq.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/mt.out ``` Post-process @@ -88,7 +88,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out -grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores +grep ^H $TMP/dropout.scoring.out | cut -d- -f2- | sort -n | cut -f2 > $TMP/dropout.scores ``` @@ -112,7 +112,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out -grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_ +grep ^H $TMP/dropout.generation.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/dropout.hypotheses_ sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl -l $TGT_LANG > $TMP/dropout.hypotheses diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py index 4a214e794d..2ee0448cf1 100644 --- a/examples/unsupervised_quality_estimation/meteor.py +++ b/examples/unsupervised_quality_estimation/meteor.py @@ -85,19 +85,19 @@ def read_output(meteor_output_path, n_repeats): def main(): parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input") + parser.add_argument("-i", "--infile") parser.add_argument("-n", "--repeat_times", type=int) parser.add_argument("-m", "--meteor") parser.add_argument("-o", "--output") args = parser.parse_args() - translations = read_translations(args.infile, args.repetitions) + translations = read_translations(args.infile, args.repeat_times) sys.stderr.write("\nGenerating input for Meteor...") - ref_path, mt_path = generate_input(translations, args.repetitions) + ref_path, mt_path = generate_input(translations, args.repeat_times) sys.stderr.write("\nRunning Meteor...") out_path = run_meteor(ref_path, mt_path, args.meteor) sys.stderr.write("\nReading output...") - scores = read_output(out_path, args.repetitions) + scores = read_output(out_path, args.repeat_times) sys.stderr.write("\nWriting results...") with open(args.output, "w") as o: for scr in scores: From b9778da42643f5b20fa0a555834d49537ce165c0 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Mon, 22 Feb 2021 15:00:15 -0800 Subject: [PATCH 481/707] Small fixes for flow-cli usage Summary: - Use `PathManager.ls` instead of `os.listdir` - Add version.txt to fairseq TARGETS Reviewed By: vishrav Differential Revision: D26579091 fbshipit-source-id: 20d57dc19335a3006cd5fa6d1a3d5e878b105874 --- fairseq/data/data_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 3042358f2f..6f7561afbe 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -17,6 +17,8 @@ import numpy as np import torch +from fairseq.file_io import PathManager + logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def infer_language_pair(path): """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" src, dst = None, None - for filename in os.listdir(path): + for filename in PathManager.ls(path): parts = filename.split(".") if len(parts) >= 3 and len(parts[1].split("-")) == 2: return parts[1].split("-") From ab560669cd9baaa4009e1fd01c970f8ffccd1ee0 Mon Sep 17 00:00:00 2001 From: freewym <freewym@gmail.com> Date: Mon, 22 Feb 2021 15:36:56 -0800 Subject: [PATCH 482/707] Fixes circular import as complained by python (#3257) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? fixes circular import as complained by python ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3257 Reviewed By: jmp84 Differential Revision: D26587382 Pulled By: myleott fbshipit-source-id: a8a6e7bee4dcfa6baf934c257958b7d7592205c8 --- .../models/speech_to_text/convtransformer_simul_trans.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/fairseq/models/speech_to_text/convtransformer_simul_trans.py index e5dd771e03..7e77330a0c 100644 --- a/fairseq/models/speech_to_text/convtransformer_simul_trans.py +++ b/fairseq/models/speech_to_text/convtransformer_simul_trans.py @@ -5,9 +5,6 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from examples.simultaneous_translation.models.transformer_monotonic_attention import ( - TransformerMonotonicDecoder, -) from fairseq import checkpoint_utils from fairseq.models import ( register_model, @@ -33,6 +30,10 @@ def add_args(parser): def build_decoder(cls, args, task, embed_tokens): tgt_dict = task.tgt_dict + from examples.simultaneous_translation.models.transformer_monotonic_attention import ( + TransformerMonotonicDecoder, + ) + decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): From c3d2beec96bd609f87d8da14cc2dffdbbd843b54 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Tue, 23 Feb 2021 23:32:40 -0800 Subject: [PATCH 483/707] efficient batch level sampling Summary: Batch level sampling (each batch comes from a dataset sampled from some distribution) is useful in cases where we have a criterion that makes this assumption or a unique collator per dataset. However, the current implementation in fairseq `MultiCorpusSampledDataset` is inefficient, because it packs batches by assuming the size of item i is `max(dataset.size(i % len(dataset)) for dataset in datasets)`, which often significantly overestimates the actual sampled item's size, especially with many datasets. We can make this more efficient by modifying `MultiCorpusDataset`, which can do efficient batch sampling by: 1. Every epoch, sampling the indices/dataset to train on. 2. When creating batches, create per-dataset batches and merge them together Reviewed By: jay-mahadeokar Differential Revision: D26601515 fbshipit-source-id: a3273f88d86d7922f9ba004e7324e909ecc6ecf7 --- fairseq/data/multi_corpus_dataset.py | 49 +++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 7207174bf3..6563713489 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -35,6 +35,7 @@ class MultiCorpusDataset(FairseqDataset): corresponding dataset seed: random seed for sampling the datsets sort_indices: if true, will sort the ordered indices by size + batch_sample: if true, will ensure each batch is from a single dataset """ def __init__( @@ -43,6 +44,7 @@ def __init__( distribution: List[float], seed: int, sort_indices: bool = False, + batch_sample: bool = False, ): super().__init__() assert isinstance(datasets, OrderedDict) @@ -51,6 +53,7 @@ def __init__( self.distribution = distribution self.seed = seed self.sort_indices = sort_indices + self.batch_sample = batch_sample # Avoid repeated conversions to list later self.dataset_list = list(datasets.values()) @@ -80,6 +83,7 @@ def ordered_indices(self): ] if self.sort_indices: sampled_indices.sort(key=lambda i: self.num_tokens(i)) + return np.array(sampled_indices, dtype=np.int64) def _sample(self, indices, counters): @@ -125,22 +129,26 @@ def __len__(self): return self.total_num_instances def __getitem__(self, index): - index, key = self._map_index(index) + new_index, key = self._map_index(index) try: - return self.datasets[key][index] + item = self.datasets[key][new_index] + item["full_id"] = index + return item except Exception as e: e.args = (f"Error from {key} dataset", *e.args) raise def collater(self, samples): """ - Since we enforce all datsets to be the same, collating is just - picking the first one and doing collate. + If we are doing batch sampling, then pick the right collater to use. + + Otherwise we assume all collaters are the same. """ if len(samples) == 0: return None + _, key = self._map_index(samples[0]["full_id"]) - return list(self.datasets.values())[0].collater(samples) + return self.datasets[key].collater(samples) def num_tokens(self, index: int): index, key = self._map_index(index) @@ -168,3 +176,34 @@ def supports_fetch_outside_dataloader(self): self.datasets[key].supports_fetch_outside_dataloader for key in self.datasets ) + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not self.batch_sample: + return super().batch_by_size( + indices, max_tokens, max_sentences, required_batch_size_multiple + ) + + dataset_indices = {key: [] for key in self.datasets} + for i in indices: + _, key = self._map_index(i) + dataset_indices[key].append(i) + + batches = [] + for key in dataset_indices: + cur_batches = super().batch_by_size( + np.array(dataset_indices[key], dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info(f"Created {len(cur_batches)} batches for dataset {key}") + batches += cur_batches + + # Assume shuffling is handled in fairseq/data/iterators.py + return batches From 55e48f18fee765fc4d528650570b8af0133ac074 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 24 Feb 2021 11:22:27 -0800 Subject: [PATCH 484/707] downcast indices in TokenBlockDataset (#1647) Summary: ### Measurements TLDR: This saves ~8% CPU RAM for training tiny model on medium sized dataset (11GB on disk) Command below: ``` +---------------------+----------------+---------+--------+ | fname | cpu_mem_used | wps | ppl | +=====================+================+=========+========+ +---------------------+----------------+---------+--------+ | branch_nw8_2gpu.log | 25.41 | 54721 | 429.1 | +---------------------+----------------+---------+--------+ +---------------------+----------------+---------+--------+ | master_nw8_2gpu.log | 27.53 | 52833.1 | 429.1 | +---------------------+----------------+---------+--------+ ``` ### Command ``` base_cmd () { dd=$1 shift fairseq-train --fp16 $dd \ --task language_modeling \ --arch transformer_lm_gpt2_tiny \ --sample-break-mode complete --tokens-per-sample 512 \ --optimizer adam --clip-norm 0.0 --lr 0.0005 \ --batch-size 1 \ --max-update 200 --max-epoch 1 \ --log-format simple --log-interval 100 \ --restore-file x.pt --no-save \ --skip-invalid-size-inputs-valid-test --disable-validation $@ } CUDA_VISIBLE_DEVICES=0,1 base_cmd /private/home/sshleifer/data-bin/stories_mmap --num-workers 8 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1647 Reviewed By: myleott Differential Revision: D26628861 Pulled By: sshleifer fbshipit-source-id: 142afe0358d1c4cae448828ba811b211406509d7 --- fairseq/data/indexed_dataset.py | 37 +++++++++++++++++++---------- fairseq/data/token_block_dataset.py | 11 +++++---- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index a821417321..066f4dcd4f 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -15,12 +15,21 @@ from . import FairseqDataset +from typing import Union -def __best_fitting_dtype(vocab_size=None): - if vocab_size is not None and vocab_size < 65500: + +def best_fitting_uint_dtype( + max_int_to_represent, +) -> Union[np.uint16, np.uint32, np.uint64]: + + if max_int_to_represent is None: + return np.uint32 # Safe guess + elif max_int_to_represent < 65500: return np.uint16 + elif max_int_to_represent < 4294967295: + return np.uint32 else: - return np.int32 + return np.uint64 def get_available_dataset_impl(): @@ -48,7 +57,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) + out_file, dtype=best_fitting_uint_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError @@ -92,7 +101,7 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) -dtypes = { +_code_to_dtype = { 1: np.uint8, 2: np.int8, 3: np.int16, @@ -101,12 +110,14 @@ def write_longs(f, a): 6: np.float, 7: np.double, 8: np.uint16, + 9: np.uint32, + 10: np.uint64, } -def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: +def _dtype_header_code(dtype) -> int: + for k in _code_to_dtype.keys(): + if _code_to_dtype[k] == dtype: return k raise ValueError(dtype) @@ -141,7 +152,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack("<Q", version) == (1,) code, self.element_size = struct.unpack("<QQ", f.read(16)) - self.dtype = dtypes[code] + self.dtype = _code_to_dtype[code] self._len, self.s = struct.unpack("<QQ", f.read(16)) self.dim_offsets = read_longs(f, self._len + 1) self.data_offsets = read_longs(f, self._len + 1) @@ -348,7 +359,9 @@ def finalize(self, index_file): index = open(index_file, "wb") index.write(b"TNTIDX\x00\x00") index.write(struct.pack("<Q", 1)) - index.write(struct.pack("<QQ", code(self.dtype), self.element_size)) + index.write( + struct.pack("<QQ", _dtype_header_code(self.dtype), self.element_size) + ) index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes))) write_longs(index, self.dim_offsets) write_longs(index, self.data_offsets) @@ -374,7 +387,7 @@ def __enter__(self): self._file.write(cls._HDR_MAGIC) self._file.write(struct.pack("<Q", 1)) - self._file.write(struct.pack("<B", code(dtype))) + self._file.write(struct.pack("<B", _dtype_header_code(dtype))) return self @@ -419,7 +432,7 @@ def __init__(self, path): assert (1,) == version (dtype_code,) = struct.unpack("<B", stream.read(1)) - self._dtype = dtypes[dtype_code] + self._dtype = _code_to_dtype[dtype_code] self._dtype_size = self._dtype().itemsize self._len = struct.unpack("<Q", stream.read(8))[0] diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index aa33f9d06f..038f1c81d7 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -6,7 +6,7 @@ import numpy as np import torch from fairseq.data import FairseqDataset, plasma_utils - +from fairseq.data.indexed_dataset import best_fitting_uint_dtype class TokenBlockDataset(FairseqDataset): """Break a Dataset of tokens into blocks. @@ -98,9 +98,12 @@ def __init__( sizes, slice_indices, ) - self._slice_indices = plasma_utils.PlasmaArray(slice_indices) - self._sizes = plasma_utils.PlasmaArray(self._sizes) - self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) + size_dtype = np.uint16 if block_size < 65535 else np.uint32 + slice_indices_dtype = best_fitting_uint_dtype(slice_indices[-1].max()) + + self._slice_indices = plasma_utils.PlasmaArray(slice_indices.astype(slice_indices_dtype)) + self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) + self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index.astype(slice_indices_dtype)) @property def slice_indices(self): From 5c008e0c339ba932b551d18b0801c201e8fdf5a9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 24 Feb 2021 11:25:41 -0800 Subject: [PATCH 485/707] make LanguageModelingTask 1% simpler (#1641) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1641 Reviewed By: myleott Differential Revision: D26607648 Pulled By: sshleifer fbshipit-source-id: 9d7f9d7a0825e3124c181b651a126842e5de6109 --- fairseq/tasks/language_modeling.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 4a44d967b3..579bf69785 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -184,7 +184,9 @@ def build_model(self, args): return model - def load_dataset(self, split, epoch=1, combine=False, **kwargs): + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> MonolingualDataset: """Load a given dataset split. Args: @@ -228,7 +230,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): and self.args.sample_break_mode != "none" ) - self.datasets[split] = self._initialize_dataset( + self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, @@ -239,9 +241,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): add_bos_token=self.args.add_bos_token, ) - def _initialize_dataset(self, **kwargs): - return MonolingualDataset(**kwargs) - def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens From 52daa1b29b35c93ffb950e56507c9c1d17aa2369 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 24 Feb 2021 14:21:24 -0800 Subject: [PATCH 486/707] move code to .py files, document usage (#1637) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1637 Test Plan: ```bash python examples/bart/summarize.py --model-dir pytorch/fairseq --model-file bart.large.cnn --src $HOME/data-bin/cnn_dm/test.source --n 12 --out hub_hypo.txt python examples/bart/summarize.py \ --model-dir pytorch/fairseq \ --model-file bart.large.cnn \ --src cnn_dm/test.source \ --out cnn_dm/test.hypo --xsum-kwargs ``` Reviewed By: ngoyal2707 Differential Revision: D26581703 Pulled By: sshleifer fbshipit-source-id: 80eb28012f7770eee01ed50a1163c5a2c5cc6d37 --- examples/bart/README.md | 47 +++++------- examples/bart/README.summarization.md | 55 +++++--------- examples/bart/summarize.py | 100 ++++++++++++++++++++++++++ fairseq/sequence_generator.py | 6 +- 4 files changed, 136 insertions(+), 72 deletions(-) create mode 100644 examples/bart/summarize.py diff --git a/examples/bart/README.md b/examples/bart/README.md index 013a809be6..4050a724ee 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -179,38 +179,23 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin: ``` #### Evaluating the `bart.large.cnn` model: -Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores +- `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search. + In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`. -```python -bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn') -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('test.source') as source, open('test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() -``` - -Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). +In `fairseq`, summaries can be generated using: + +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir pytorch/fairseq \ + --model-file bart.large.cnn \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` + +For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). ```bash export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar diff --git a/examples/bart/README.summarization.md b/examples/bart/README.summarization.md index d7fecc9ce6..8727584f2b 100644 --- a/examples/bart/README.summarization.md +++ b/examples/bart/README.summarization.md @@ -80,42 +80,23 @@ Expected training time is about `5 hours`. Training time can be reduced with dis Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task ### Inference for CNN-DM test data using above trained checkpoint. -After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: +After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example -```python -import torch -from fairseq.models.bart import BARTModel - -bart = BARTModel.from_pretrained( - 'checkpoints/', - checkpoint_file='checkpoint_best.pt', - data_name_or_path='cnn_dm-bin' -) - -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` +For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10: +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo \ + --xsum-kwargs ``` -Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation diff --git a/examples/bart/summarize.py b/examples/bart/summarize.py new file mode 100644 index 0000000000..04435f80e3 --- /dev/null +++ b/examples/bart/summarize.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.models.bart import BARTModel +import argparse + +XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3) +CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) + + +@torch.no_grad() +def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs): + count = 1 + + # if n_obs is not None: bsz = min(bsz, n_obs) + + with open(infile) as source, open(outfile, "w") as fout: + sline = source.readline().strip() + slines = [sline] + for sline in source: + if n_obs is not None and count > n_obs: + break + if count % bsz == 0: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + slines = [] + + slines.append(sline.strip()) + count += 1 + + if slines != []: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + + +def main(): + """ + Usage:: + + python examples/bart/summarize.py \ + --model-dir $HOME/bart.large.cnn \ + --model-file model.pt \ + --src $HOME/data-bin/cnn_dm/test.source + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + required=True, + type=str, + default="bart.large.cnn/", + help="path containing model file and src_dict.txt", + ) + parser.add_argument( + "--model-file", + default="checkpoint_best.pt", + help="where in model_dir are weights saved", + ) + parser.add_argument( + "--src", default="test.source", help="text to summarize", type=str + ) + parser.add_argument( + "--out", default="test.hypo", help="where to save summaries", type=str + ) + parser.add_argument("--bsz", default=32, help="where to save summaries", type=int) + parser.add_argument( + "--n", default=None, help="how many examples to summarize", type=int + ) + parser.add_argument( + "--xsum-kwargs", + action="store_true", + default=False, + help="if true use XSUM_KWARGS else CNN_KWARGS", + ) + args = parser.parse_args() + eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS + if args.model_dir == "pytorch/fairseq": + bart = torch.hub.load("pytorch/fairseq", args.model_file) + else: + bart = BARTModel.from_pretrained( + args.model_dir, + checkpoint_file=args.model_file, + data_name_or_path=args.model_dir, + ) + bart = bart.eval() + if torch.cuda.is_available(): + bart = bart.cuda().half() + generate( + bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs + ) + + +if __name__ == "__main__": + main() diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 117c6116fb..2574ab13f0 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -214,7 +214,7 @@ def _generate( raise Exception("expected src_tokens or source in net input") # bsz: total number of sentences in beam - # Note that src_tokens may have more than 2 dimenions (i.e. audio features) + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) bsz, src_len = src_tokens.size()[:2] beam_size = self.beam_size @@ -376,9 +376,7 @@ def _generate( self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: - lprobs = self.repeat_ngram_blocker( - tokens, lprobs, bsz, beam_size, step - ) + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( From fb3fadbb159d8af6d83a5680674d20f7b7635766 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 24 Feb 2021 15:41:02 -0800 Subject: [PATCH 487/707] Set DynamicLossScaler class defaults to match CLI defaults (#1649) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1649 Reviewed By: stephenroller Differential Revision: D26639303 Pulled By: myleott fbshipit-source-id: 7def925cd7885cfe85d542464316cbc0f2ba6d2c --- fairseq/optim/dynamic_loss_scaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py index c5da604220..43f9be37b9 100644 --- a/fairseq/optim/dynamic_loss_scaler.py +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -10,7 +10,7 @@ def __init__( init_scale=2.0 ** 15, scale_factor=2.0, scale_window=2000, - tolerance=0.05, + tolerance=0.0, threshold=None, min_loss_scale=1e-4, ): From b8651bc984413e7e45f44294dffcc85692ba89c1 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Wed, 24 Feb 2021 15:48:38 -0800 Subject: [PATCH 488/707] actually checking gradnorm consistency Summary: D24849271 (https://github.com/pytorch/fairseq/commit/3c5647cebf454c07b52a0fb899c920789381ebda) fixed finite check, but the 'or' condition means as long as all gradients are finite, the check will pass. This diff adds back the consistency check, the norm can't differ from each other much. Reviewed By: myleott Differential Revision: D26640459 fbshipit-source-id: 3e23e13841372aa04461dcde245b893715480c3c --- fairseq/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 891155f162..680a7ee953 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1113,7 +1113,7 @@ def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( torch.isfinite(tensor).all() - or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) if not is_consistent(self._grad_norm_buf): From d3890e593398c485f6593ab8512ac51d37dedc9c Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Wed, 24 Feb 2021 22:55:37 -0800 Subject: [PATCH 489/707] Add HiveScorer to read data from hive and EverstoreAudioInstance to load audio from everstore Summary: This diff - Refactors utils/agent_finder.py to reduce the complexity of find_agent_cls function - Refactors cli.py and server.py to remove unnecessary argument parser function calls - Adds fb_hive_scorer.py with HiveScorer to read data from hive and process everstore handles - Adds fb_options.py to add necessary arguments for HiveScorer - Updates other parts of the code to include the new scorer Reviewed By: jmp84 Differential Revision: D26575148 fbshipit-source-id: ae6e12d2adf5f393f807d5238f0d78a2f64a77a3 --- .../simultaneous_translation/agents/fairseq_simul_st_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 32cd0a1f61..5793609095 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -128,7 +128,7 @@ def __init__(self, args): self.load_model_vocab(args) with open(args.config, "r") as f: - config = yaml.load(f) + config = yaml.load(f, Loader=yaml.BaseLoader) if "global_cmvn" in config: args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) From f569c024ae6ee3e8c37c3b9dca975a3df50f7a03 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Thu, 25 Feb 2021 22:33:48 -0800 Subject: [PATCH 490/707] Relocate simultaneous translation code (#1639) Summary: Relocate simultaneous translation code from example/simultaneous_translation to fairseq/model/simultaneous_translation, only keep the documents Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1639 Reviewed By: jmp84 Differential Revision: D26599346 Pulled By: xutaima fbshipit-source-id: 4f708d172696a430bd4e7b14871f5c8862a20489 --- examples/simultaneous_translation/__init__.py | 2 +- .../criterions/__init__.py | 15 --------------- .../models}/convtransformer_simul_trans.py | 10 +++++++++- fairseq/models/speech_to_text/__init__.py | 1 - fairseq/models/speech_to_text/convtransformer.py | 8 ++------ 5 files changed, 12 insertions(+), 24 deletions(-) delete mode 100644 examples/simultaneous_translation/criterions/__init__.py rename {fairseq/models/speech_to_text => examples/simultaneous_translation/models}/convtransformer_simul_trans.py (83%) diff --git a/examples/simultaneous_translation/__init__.py b/examples/simultaneous_translation/__init__.py index 446fc86c8a..5835316ba9 100644 --- a/examples/simultaneous_translation/__init__.py +++ b/examples/simultaneous_translation/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import criterions, eval, models # noqa +from . import models # noqa diff --git a/examples/simultaneous_translation/criterions/__init__.py b/examples/simultaneous_translation/criterions/__init__.py deleted file mode 100644 index 08791bfff3..0000000000 --- a/examples/simultaneous_translation/criterions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - criterion_name = file[: file.find(".py")] - importlib.import_module( - "examples.simultaneous_translation.criterions." + criterion_name - ) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py similarity index 83% rename from fairseq/models/speech_to_text/convtransformer_simul_trans.py rename to examples/simultaneous_translation/models/convtransformer_simul_trans.py index 7e77330a0c..84ba4d0d3f 100644 --- a/fairseq/models/speech_to_text/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -11,11 +11,19 @@ register_model_architecture, ) -from .convtransformer import ConvTransformerModel, convtransformer_espnet +from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet @register_model("convtransformer_simul_trans") class SimulConvTransformerModel(ConvTransformerModel): + """ + Implementation of the paper: + + SimulMT to SimulST: Adapting Simultaneous Text Translation to + End-to-End Simultaneous Speech Translation + + https://www.aclweb.org/anthology/2020.aacl-main.58.pdf + """ @staticmethod def add_args(parser): super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 28e3bb720f..c6ae9b17ba 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -5,5 +5,4 @@ from .berard import * # noqa from .convtransformer import * # noqa -from .convtransformer_simul_trans import * # noqa from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 622b5e6df8..a4cbbcdeeb 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -7,9 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from examples.simultaneous_translation.utils.data_utils import ( - lengths_to_encoder_padding_mask, -) +from fairseq.data.data_utils import lengths_to_padding_mask from fairseq import checkpoint_utils, utils from fairseq.models import ( FairseqEncoder, @@ -311,9 +309,7 @@ def forward(self, src_tokens, src_lengths): x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() ) - encoder_padding_mask, _ = lengths_to_encoder_padding_mask( - input_lengths, batch_first=True - ) + encoder_padding_mask = lengths_to_padding_mask(input_lengths) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) x += positions From 4f881a760e1cd7e11ecce2332b6ee9a435f233a5 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Fri, 26 Feb 2021 20:59:22 -0800 Subject: [PATCH 491/707] TokenBlockDataset np type promotion issue (#1658) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1658 Reviewed By: jxmsML Differential Revision: D26701840 Pulled By: sshleifer fbshipit-source-id: 90d631c3cd775ab847366fe7a05136c29d90cd63 --- fairseq/data/indexed_dataset.py | 10 ++++++---- fairseq/data/token_block_dataset.py | 16 ++++++++++------ fairseq/models/transformer_lm.py | 12 ++++++++++++ tests/test_token_block_dataset.py | 13 +++++++++++++ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 066f4dcd4f..802e37a7ff 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -18,9 +18,9 @@ from typing import Union -def best_fitting_uint_dtype( +def best_fitting_int_dtype( max_int_to_represent, -) -> Union[np.uint16, np.uint32, np.uint64]: +) -> Union[np.uint16, np.uint32, np.int64]: if max_int_to_represent is None: return np.uint32 # Safe guess @@ -29,7 +29,9 @@ def best_fitting_uint_dtype( elif max_int_to_represent < 4294967295: return np.uint32 else: - return np.uint64 + return np.int64 + # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly + # https://github.com/numpy/numpy/issues/5745 def get_available_dataset_impl(): @@ -57,7 +59,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( - out_file, dtype=best_fitting_uint_dtype(vocab_size) + out_file, dtype=best_fitting_int_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 038f1c81d7..4617466234 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -6,7 +6,8 @@ import numpy as np import torch from fairseq.data import FairseqDataset, plasma_utils -from fairseq.data.indexed_dataset import best_fitting_uint_dtype +from fairseq.data.indexed_dataset import best_fitting_int_dtype + class TokenBlockDataset(FairseqDataset): """Break a Dataset of tokens into blocks. @@ -95,15 +96,18 @@ def __init__( ) else: block_to_dataset_index = _get_block_to_dataset_index_fast( - sizes, - slice_indices, + sizes, slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 - slice_indices_dtype = best_fitting_uint_dtype(slice_indices[-1].max()) + slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max()) - self._slice_indices = plasma_utils.PlasmaArray(slice_indices.astype(slice_indices_dtype)) + self._slice_indices = plasma_utils.PlasmaArray( + slice_indices.astype(slice_indices_dtype) + ) self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) - self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index.astype(slice_indices_dtype)) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index.astype(slice_indices_dtype) + ) @property def slice_indices(self): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index edf62b12b3..f12470d033 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -394,6 +394,18 @@ def transformer_lm_gpt2_small(args): base_lm_architecture(args) +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") +def transformer_lm_gpt2_tiny(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 1) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + @register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") def transformer_lm_gpt2_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py index ea315b4e67..c4d7b76dcd 100644 --- a/tests/test_token_block_dataset.py +++ b/tests/test_token_block_dataset.py @@ -74,6 +74,19 @@ def test_complete_break_mode(self): self.assertEqual(ds[1].tolist(), [5, 1, 1]) self.assertEqual(ds[2].tolist(), [6, 1]) + def test_4billion_tokens(self): + """Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745""" + data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000 + ds = self._build_dataset( + data, block_size=6, pad=0, eos=1, break_mode="complete" + ) + ds[-1] # __getitem__ works + start, end = ds.slice_indices[-1] + assert end > 4294967295 # data must be sufficiently large to overflow uint32 + assert not isinstance( + end + 1, float + ) # this would also raise, since np.uint64(1) + 1 => 2.0 + if __name__ == "__main__": unittest.main() From 5354aa3a6ec80092cc7bb9aecfad7077bb50b47e Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@fb.com> Date: Sun, 28 Feb 2021 12:44:23 -0800 Subject: [PATCH 492/707] github CI install pyarrow Reviewed By: myleott Differential Revision: D26643358 fbshipit-source-id: 8d7e1082c6e11f9bbab4b34de078cf05197297a5 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29e5254d33..0af8bad95d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,7 @@ jobs: - name: Install optional test requirements run: | - python -m pip install fairscale iopath transformers + python -m pip install fairscale iopath transformers pyarrow - name: Lint with flake8 run: | From e5e8b3fee1e57a7abf35ad1a3ff223a2b7190c65 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Sun, 28 Feb 2021 12:49:20 -0800 Subject: [PATCH 493/707] Fix nearly all unit-test warnings (#1652) Summary: 2 types of warnings fixed: ``` `np.long` is a deprecated alias for `np.compat.long`. Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1652 Reviewed By: myleott Differential Revision: D26643344 Pulled By: sshleifer fbshipit-source-id: 960bccc94f299bd8a8c58a87acd80694e9d5c363 --- fairseq/data/language_pair_dataset.py | 12 +++--------- fairseq/data/token_block_dataset.py | 2 +- fairseq/optim/lr_scheduler/cosine_lr_scheduler.py | 2 +- .../lr_scheduler/inverse_square_root_schedule.py | 2 +- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 8858cec84e..9d36cbd4ce 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -114,10 +114,7 @@ def compute_alignment_weights(alignments): "id": id, "nsentences": len(samples), "ntokens": ntokens, - "net_input": { - "src_tokens": src_tokens, - "src_lengths": src_lengths, - }, + "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,}, "target": target, } if prev_output_tokens is not None: @@ -289,7 +286,7 @@ def __init__( # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) - num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) @@ -470,8 +467,5 @@ def filter_indices_by_size(self, indices, max_sizes): list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( - self.src_sizes, - self.tgt_sizes, - indices, - max_sizes, + self.src_sizes, self.tgt_sizes, indices, max_sizes, ) diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 4617466234..ce0a0d1114 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -88,7 +88,7 @@ def __init__( [ np.arange(len(sizes)), # starting index in dataset np.zeros( - len(sizes), dtype=np.long + len(sizes), dtype=np.compat.long ), # starting offset within starting index np.arange(len(sizes)), # ending index in dataset ], diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 38b57fe54c..51f58359ed 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import math -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d9321577bb..0f87bb5d7e 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List From 39e55139ea05da36e9ab9837c4943f660b79dcbe Mon Sep 17 00:00:00 2001 From: Hiromu Yakura <hiromu1996@gmail.com> Date: Mon, 1 Mar 2021 12:36:18 -0800 Subject: [PATCH 494/707] Fix the order of constraints in LanguagePairDataset (#3280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3279. This change modifies the output of `echo -e "Ja, wer hat, wenn du willst, Götter gebildet, uns zu ihnen erhoben, sie zu uns herniedergebracht, als der Dichter?\tbard\nZu vollenden ist nicht die Sache des Schülers, es ist genug, wenn er sich übt\tstudent" | python normalize.py | python tok.py | fairseq-interactive --constraints -s de -t en --beam 10 --batch-size 2 --buffer-size 2 --bpe fastbpe --bpe-codes ../../../models/ende30k.fastbpe.code --path ../../../models/wmt19.de-en.ffn8192.pt ../../../models/` as follows. Before: ``` S-0 Ja , wer hat , wenn du will@@ st , Gö@@ tter gebildet , uns zu ihnen erhoben , sie zu uns her@@ nieder@@ gebracht , als der Dich@@ ter ? W-0 1.755 seconds C-0 student H-0 -1.1425577402114868 Yes , who , if you will , has formed go@@ ds , raised us up to them , brought them down to us , but the po@@ et student ? D-0 -1.1425577402114868 Yes , who , if you will , has formed gods , raised us up to them , brought them down to us , but the poet student ? P-0 -1.8768 -0.2214 -0.4671 -1.2521 -0.2101 -0.3053 -1.2077 -0.1496 -1.8780 -1.4195 -0.4071 -0.1347 -0.3726 -1.1306 -0.1665 -1.4588 -0.2837 -0.1722 -0.2330 -0.2840 -0.1806 -0.1432 -0.2263 -0.1395 -0.7261 -1.4593 -0.3639 -0.4030 -0.1083 -18.7577 -0.2396 -0.1837 S-1 Zu voll@@ enden ist nicht die Sache des Sch@@ ül@@ ers , es ist genug , wenn er sich übt W-1 1.755 seconds C-1 b@@ ard H-1 -1.9625756740570068 It is not up to the b@@ ard to complete , it is enough if he practi@@ ses D-1 -1.9625756740570068 It is not up to the bard to complete , it is enough if he practises P-1 -1.2630 -0.3364 -0.1634 -2.7070 -0.1734 -0.2815 -17.3978 -6.0238 -0.4888 -1.7563 -0.8708 -0.6773 -0.2027 -0.2456 -1.6366 -0.2911 -2.0235 -0.1961 -0.5538 ``` After: ``` S-0 Ja , wer hat , wenn du will@@ st , Gö@@ tter gebildet , uns zu ihnen erhoben , sie zu uns her@@ nieder@@ gebracht , als der Dich@@ ter ? W-0 1.740 seconds C-0 b@@ ard H-0 -1.2060465812683105 Yes , who , if you will , formed go@@ ds , raised us up to them , brought them down to us , but the b@@ ard ? D-0 -1.2060465812683105 Yes , who , if you will , formed gods , raised us up to them , brought them down to us , but the bard ? P-0 -1.8768 -0.2214 -0.4671 -1.2521 -0.2101 -0.3053 -1.2077 -0.1496 -2.2551 -0.5702 -0.1331 -0.3940 -1.0268 -0.1750 -1.4635 -0.2821 -0.1725 -0.2404 -0.3575 -0.1833 -0.1441 -0.2250 -0.1419 -0.7020 -1.5215 -0.3700 -16.8578 -2.7290 -0.3405 -0.2060 S-1 Zu voll@@ enden ist nicht die Sache des Sch@@ ül@@ ers , es ist genug , wenn er sich übt W-1 1.740 seconds C-1 student H-1 -0.8064212203025818 It is not up to the student to complete , it is enough if he practi@@ ses D-1 -0.8064212203025818 It is not up to the student to complete , it is enough if he practises P-1 -1.2630 -0.3364 -0.1634 -2.7070 -0.1734 -0.2815 -1.5556 -0.2831 -1.3885 -0.7310 -0.6367 -0.1824 -0.2386 -1.5320 -0.2728 -2.0003 -0.2163 -0.5536 ``` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3280 Reviewed By: myleott Differential Revision: D26725013 Pulled By: lematt1991 fbshipit-source-id: 2275fcf146cb8cd9ca21f847e10a4dacdee996f9 --- fairseq/data/language_pair_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 9d36cbd4ce..ff3e14bf14 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -157,7 +157,7 @@ def compute_alignment_weights(alignments): constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): constraints[i, 0 : lens[i]] = samples[i].get("constraints") - batch["constraints"] = constraints + batch["constraints"] = constraints.index_select(0, sort_order) return batch From 1c0439b7dabe62d39c6e7f1c8ebc86311e042b5a Mon Sep 17 00:00:00 2001 From: freewym <freewym@gmail.com> Date: Mon, 1 Mar 2021 16:21:05 -0800 Subject: [PATCH 495/707] fixes circular imports incurred by a recent commit (#3286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes circular imports incurred by a recent commit ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3286 Reviewed By: lematt1991 Differential Revision: D26725255 Pulled By: myleott fbshipit-source-id: 5572f733b83bdfadcce3188c0789fc6d70a3bad3 --- fairseq/models/fairseq_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 244cbc0c66..186f3d2464 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -14,7 +14,6 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -111,6 +110,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) @@ -450,6 +452,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) From 3100d0b8e5bb5e61b4d73b9c058389aa2c06784a Mon Sep 17 00:00:00 2001 From: Eric Lou <ericlou@fb.com> Date: Tue, 2 Mar 2021 09:24:03 -0800 Subject: [PATCH 496/707] ioPath async - opt-in Fairseq integration (#1635) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1635 **Summary:** Integrate ioPath's async writes feature into Fairseq checkpoint writing. **Details:** - Created new checkpoint config param `--write-checkpoints-asynchronously` with default value `False`. Aliased to `--save-async`. - Added to `PathManager` class in `file_io.py` to include `PathManager.opena(...)` and `PathManager.async_close()`. These new methods use ioPath's async `PathManager`. **Usage:** ``` python train.py --save-async ``` --------- NOTE: **QUESTIONS** 1. In the current implementation, we don't save `checkpoint_best` and `checkpoint_latest` since ioPath doesn't yet have a "wait until a file is written and then copy/move it to another path" feature. Is this okay for now? 2. Should I mimic the atomic vs non-atomic save structure that synchronous Fairseq checkpoint writes have? **Note to Eric:** Keep this integration in check with D26375501. Reviewed By: myleott Differential Revision: D26467815 fbshipit-source-id: 50068ef7bf9a6d5cea4d5e0d13d672604dc4a6b0 --- fairseq/checkpoint_utils.py | 42 ++++++++++++++++++++++---------- fairseq/dataclass/configs.py | 10 ++++++++ fairseq/file_io.py | 46 ++++++++++++++++++++++++++++++++++++ fairseq_cli/train.py | 20 ++++++++++++++++ 4 files changed, 105 insertions(+), 13 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 55a546356e..d6618fbb62 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -93,9 +93,17 @@ def is_better(a, b): if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: - assert PathManager.copy( - checkpoints[0], cp, overwrite=True - ), f"Failed to copy {checkpoints[0]} to {cp}" + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( @@ -383,7 +391,23 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(obj, f): +def torch_persistent_save(cfg: CheckpointConfig, obj, filename): + if cfg.write_checkpoints_asynchronously: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): if isinstance(f, str): with PathManager.open(f, "wb") as h: torch_persistent_save(obj, h) @@ -448,15 +472,7 @@ def save_state( # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) - if PathManager.supports_rename(filename): - # do atomic save - with PathManager.open(filename + ".tmp", "wb") as f: - torch_persistent_save(state_dict, f) - PathManager.rename(filename + ".tmp", filename) - else: - # fallback to non-atomic save - with PathManager.open(filename, "wb") as f: - torch_persistent_save(state_dict, f) + torch_persistent_save(cfg.checkpoint, state_dict, filename) def _upgrade_state_dict(state): diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f66e98fe83..39355b1caf 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -607,6 +607,16 @@ class CheckpointConfig(FairseqDataclass): "(default: only load on rank 0 and broadcast to other devices)" }, ) + write_checkpoints_asynchronously: bool = field( + default=False, + metadata={ + "help": ( + "Write checkpoints asynchronously in a separate " + "thread. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--save-async", + }, + ) model_parallel_size: int = II("common.model_parallel_size") distributed_rank: int = II("distributed_training.distributed_rank") diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 7d6c28dccd..731fef3570 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -32,6 +32,8 @@ except ImportError: FVCorePathManager = None +IOPathPathManager = None + class PathManager: """ @@ -148,3 +150,47 @@ def supports_rename(path: str) -> bool: @staticmethod def rename(src: str, dst: str): os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathPathManager + if not IOPathPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath import PathManager + IOPathPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathPathManager + if IOPathPathManager: + return IOPathPathManager.async_close() + return False diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ec4890b9e6..80ad57acd1 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -28,6 +28,7 @@ from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed_utils import is_master +from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer @@ -67,6 +68,16 @@ def main(cfg: FairseqConfig) -> None: # Print args logger.info(cfg) + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) @@ -157,6 +168,15 @@ def main(cfg: FairseqConfig) -> None: train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + # ioPath implementation to wait for all asynchronous file writes to complete. + if cfg.checkpoint.write_checkpoints_asynchronously: + logger.info( + "ioPath PathManager waiting for all asynchronous checkpoint " + "writes to finish." + ) + PathManager.async_close() + logger.info("ioPath PathManager finished waiting.") + def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch From 12e21b9a6e7262fa1af2090e22c301bc0b5d1399 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Tue, 2 Mar 2021 13:28:53 -0800 Subject: [PATCH 497/707] Add global cmvn for mustc data preparation (#1660) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1660 Reviewed By: jmp84, kahne Differential Revision: D26708521 Pulled By: xutaima fbshipit-source-id: c53e9052298c559706ceffeb359dadfede2f1a09 --- examples/speech_to_text/data_utils.py | 32 +++++++++++++++++-- examples/speech_to_text/prep_mustc_data.py | 37 ++++++++++++++++++++-- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 0d7c034419..fa0d459611 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -126,7 +126,9 @@ def gen_config_yaml( specaugment_policy: str = "lb", prepend_tgt_lang_tag: bool = False, sampling_alpha: float = 1.0, - audio_root: str = "" + audio_root: str = "", + cmvn_type: str = "utterance", + gcmvn_path: Optional[Path] = None, ): manifest_root = manifest_root.absolute() writer = S2TDataConfigWriter(manifest_root / yaml_filename) @@ -151,8 +153,19 @@ def gen_config_yaml( if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) writer.set_sampling_alpha(sampling_alpha) - writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"]) - writer.set_feature_transforms("*", ["utterance_cmvn"]) + + if cmvn_type not in ["global", "utterance"]: + raise NotImplementedError + + writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"]) + writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) + + if cmvn_type == "global": + assert gcmvn_path is not None, ( + 'Please provide path of global cmvn file.' + ) + writer.set_global_cmvn(gcmvn_path) + if len(audio_root) > 0: writer.set_audio_root(audio_root) writer.flush() @@ -206,6 +219,16 @@ def filter_manifest_df( return df[valid] +def cal_gcmvn_stats(features_list): + features = np.concatenate(features_list) + square_sums = (features ** 2).sum(axis=0) + mean = features.mean(axis=0) + features = np.subtract(features, mean) + var = square_sums / features.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-8)) + return {"mean": mean.astype("float32"), "std": std.astype("float32")} + + class S2TDataConfigWriter(object): DEFAULT_VOCAB_FILENAME = "dict.txt" DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 @@ -297,6 +320,9 @@ def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer + def set_global_cmvn(self, stats_npz_path: str): + self.config["stats_npz_path"] = stats_npz_path + def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 520968401c..4e410bcb18 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -13,6 +13,7 @@ from tempfile import NamedTemporaryFile from typing import Tuple +import numpy as np import pandas as pd import torchaudio from examples.speech_to_text.data_utils import ( @@ -24,6 +25,7 @@ get_zip_manifest, load_df_from_tsv, save_df_to_tsv, + cal_gcmvn_stats, ) from torch import Tensor from torch.utils.data import Dataset @@ -111,10 +113,28 @@ def process(args): print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split) print("Extracting log mel filter bank features...") + if split == 'train' and args.cmvn_type == "global": + print("And estimating cepstral mean and variance stats...") + gcmvn_feature_list = [] + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features( - waveform, sample_rate, feature_root / f"{utt_id}.npy" + features = extract_fbank_features(waveform, sample_rate) + + np.save( + (feature_root / f"{utt_id}.npy").as_posix(), + features ) + + if split == 'train' and args.cmvn_type == "global": + if len(gcmvn_feature_list) < args.gcmvn_max_num: + gcmvn_feature_list.append(features) + + if split == 'train' and args.cmvn_type == "global": + # Estimate and save cmv + stats = cal_gcmvn_stats(gcmvn_feature_list) + with open(cur_root / "gcmvn.npz", "wb") as f: + np.savez(f, mean=stats["mean"], std=stats["std"]) + # Pack features into ZIP zip_path = cur_root / "fbank80.zip" print("ZIPing features...") @@ -158,6 +178,11 @@ def process(args): spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", + cmvn_type=args.cmvn_type, + gcmvn_cmvn_path=( + cur_root / "gcmvn.npz" if args.cmvn_type == "global" + else None + ), ) # Clean up shutil.rmtree(feature_root) @@ -216,6 +241,14 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") + parser.add_argument("--cmvn-type", default="utterance", + choices=["global", "utterance"], + help="The type of cepstral mean and variance normalization") + parser.add_argument("--gcmvn-max-num", default=150000, type=int, + help=( + "Maximum number of sentences to use to estimate" + "global mean and variance" + )) args = parser.parse_args() if args.joint: From c58af189957eb15b47e507473b4da3e83dfbdf2e Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 498/707] Several update on simultaneous translation inference. (#1655) Summary: Fix some issues in some corner cases. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1655 Reviewed By: jmp84 Differential Revision: D26651362 Pulled By: sravyapopuri388 fbshipit-source-id: 160d75be8d49f8263c14af225c90fe7997171a43 --- .../models/convtransformer_simul_trans.py | 2 +- .../models/transformer_monotonic_attention.py | 90 ++++--------------- .../modules/fixed_pre_decision.py | 38 ++++++-- .../modules/monotonic_multihead_attention.py | 7 +- .../agents/fairseq_simul_st_agent.py | 32 ++++--- 5 files changed, 76 insertions(+), 93 deletions(-) diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py index 84ba4d0d3f..760a48168d 100644 --- a/examples/simultaneous_translation/models/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -10,7 +10,6 @@ register_model, register_model_architecture, ) - from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet @@ -24,6 +23,7 @@ class SimulConvTransformerModel(ConvTransformerModel): https://www.aclweb.org/anthology/2020.aacl-main.58.pdf """ + @staticmethod def add_args(parser): super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index dd3895f0eb..65c12c6f5b 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -65,60 +65,6 @@ def _indices_from_states(self, states): return src_indices, None, tgt_indices - def predict_from_states(self, states): - decoder_states = self.decoder.output_layer(states["decoder_features"]) - lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True) - - index = lprobs.argmax(dim=-1) - - token = self.decoder.dictionary.string(index) - - return token, index[0, 0].item() - - def decision_from_states(self, states): - """ - This funcion take states dictionary as input, and gives the agent - a decision of whether read a token from server. Moreover, the decoder - states are also calculated here so we can directly generate a target - token without recompute every thing - """ - - self.eval() - - if len(states["tokens"]["src"]) == 0: - return 0 - - src_indices, src_lengths, tgt_indices = self._indices_from_states(states) - - # Update encoder states if needed - if ( - "encoder_states" not in states - or states["encoder_states"][0].size(1) <= states["steps"]["src"] - ): - encoder_out_dict = self.encoder(src_indices, src_lengths) - states["encoder_states"] = encoder_out_dict - else: - encoder_out_dict = states["encoder_states"] - - # online means we still need tokens to feed the model - states["model_states"]["online"] = not ( - states["finish_read"] - and len(states["tokens"]["src"]) == states["steps"]["src"] - ) - - states["model_states"]["steps"] = states["steps"] - - x, outputs = self.decoder.forward( - prev_output_tokens=tgt_indices, - encoder_out=encoder_out_dict, - incremental_state=states["model_states"], - features_only=True, - ) - - states["decoder_features"] = x - - return outputs["action"] - class TransformerMonotonicEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): @@ -208,6 +154,18 @@ def post_attention(self, x): return x + def clear_cache(self, incremental_state, end_id=None): + """ + Clear cache in the monotonic layers. + The cache is generated because of a forward pass of decode but no prediction. + end_id is the last idx of the layers + """ + if end_id is None: + end_id = len(self.layers) + + for j in range(end_id): + self.layers[j].prune_incremental_state(incremental_state) + def extract_features( self, prev_output_tokens, encoder_out, incremental_state=None, **unused ): @@ -247,9 +205,13 @@ def extract_features( curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) - if incremental_state.get("online", False): + if incremental_state.get("online", True): + # Online indicates that the encoder states are still changing p_choose = ( - attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) + attn["p_choose"] + .squeeze(0) + .squeeze(1) + .gather(1, curr_steps.t()) ) new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) @@ -258,24 +220,10 @@ def extract_features( # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state - for j in range(i + 1): - self.layers[j].prune_incremental_state(incremental_state) + self.clear_cache(incremental_state, i + 1) return x, {"action": 0} - if incremental_state is not None and not incremental_state.get("online", False): - # Here is for fast evaluation - fastest_step = ( - torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1 - ) - - if "fastest_step" in incremental_state: - incremental_state["fastest_step"] = torch.cat( - [incremental_state["fastest_step"], fastest_step], dim=1 - ) - else: - incremental_state["fastest_step"] = fastest_step - x = self.post_attention(x) return x, { diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index 725be1a983..cc5e7ad532 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -1,6 +1,7 @@ from functools import partial import torch +import math import torch.nn.functional as F from . import register_monotonic_attention @@ -96,6 +97,9 @@ def p_choose( incremental_state=None, **extra_args ): + src_len = key.size(0) + tgt_len = query.size(0) + batch_size = query.size(1) if self.pre_decision_ratio == 1: return super().p_choose( @@ -119,6 +123,16 @@ def p_choose( else: key_padding_mask_pool = None + if incremental_state is not None: + # The floor instead of ceil is used for inference + # But make sure the length key_pool at least 1 + if ( + max(1, math.floor(key.size(0) / self.pre_decision_ratio)) + ) < key_pool.size(0): + key_pool = key_pool[:-1] + if key_padding_mask_pool is not None: + key_padding_mask_pool = key_padding_mask_pool[:-1] + p_choose_pooled = super().p_choose( query, key_pool, @@ -129,13 +143,23 @@ def p_choose( # Upsample, interpolate zeros p_choose = self.insert_zeros(p_choose_pooled) - # can be larger than src_len because we used ceil before - src_len = key.size(0) - p_choose = p_choose[:, :, :src_len] - p_choose[:, :, -1] = p_choose_pooled[:, :, -1] - - tgt_len = query.size(0) - batch_size = query.size(1) + if p_choose.size(-1) < src_len: + # Append zeros if the upsampled p_choose is shorter than src_len + p_choose = torch.cat( + [ + p_choose, + p_choose.new_zeros( + p_choose.size(0), + tgt_len, + src_len - p_choose.size(-1) + ) + ], + dim=2 + ) + else: + # can be larger than src_len because we used ceil before + p_choose = p_choose[:, :, :src_len] + p_choose[:, :, -1] = p_choose_pooled[:, :, -1] assert list(p_choose.size()) == [ batch_size * self.num_heads, diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 3e25957cd6..49882afcd8 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -604,13 +604,14 @@ def p_choose( key_padding_mask: bsz, src_len """ if incremental_state is not None: + # Retrieve target length from incremental states + # For inference the length of query is always 1 tgt_len = int(incremental_state["steps"]["tgt"]) - src_len = int(incremental_state["steps"]["src"]) - bsz = 1 else: - src_len, bsz, _ = key.size() tgt_len, bsz, _ = query.size() + src_len, bsz, _ = key.size() + p_choose = query.new_ones(bsz, tgt_len, src_len) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 5793609095..f944203785 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -127,6 +127,15 @@ def __init__(self, args): self.load_model_vocab(args) + if getattr( + self.model.decoder.layers[0].encoder_attn, + 'pre_decision_ratio', + None + ) is not None: + self.speech_segment_size *= ( + self.model.decoder.layers[0].encoder_attn.pre_decision_ratio + ) + with open(args.config, "r") as f: config = yaml.load(f, Loader=yaml.BaseLoader) @@ -167,15 +176,15 @@ def add_args(parser): parser.add_argument("--max-len", type=int, default=200, help="Max length of translation") parser.add_argument("--force-finish", default=False, action="store_true", - help="") + help="Force the model to finish the hypothsis if the source is not finished") parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, - help="") + help="Shift size of feature extraction window.") parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, - help="") + help="Window size of feature extraction window.") parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, - help="") + help="Sample rate") parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, - help="") + help="Acoustic feature dimension.") # fmt: on return parser @@ -265,11 +274,12 @@ def units_to_segment(self, units, states): def update_model_encoder(self, states): if len(states.units.source) == 0: return - src_indices = self.to_device(states.units.source.value.unsqueeze(0)) + src_indices = self.to_device( + states.units.source.value.unsqueeze(0) + ) src_lengths = self.to_device( torch.LongTensor([states.units.source.value.size(0)]) ) - print(src_lengths) states.encoder_states = self.model.encoder(src_indices, src_lengths) torch.cuda.empty_cache() @@ -294,13 +304,12 @@ def policy(self, states): "tgt": 1 + len(states.units.target), } - states.incremental_states["online"] = True + states.incremental_states["online"] = not states.finish_read() x, outputs = self.model.decoder.forward( prev_output_tokens=tgt_indices, encoder_out=states.encoder_states, incremental_state=states.incremental_states, - # features_only=True, ) states.decoder_out = x @@ -323,8 +332,6 @@ def predict(self, states): index = lprobs.argmax(dim=-1) - torch.cuda.empty_cache() - index = index[0, 0].item() if ( @@ -332,6 +339,9 @@ def predict(self, states): and index == self.model.decoder.dictionary.eos() and not states.finish_read() ): + # If we want to force finish the translation + # (don't stop before finish reading), return a None + # self.model.decoder.clear_cache(states.incremental_states) index = None return index From ddc483ff3d3a70f3abc33fc4d10bb29871c73d73 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 499/707] Streaming models for simul ST (#1552) Summary: `fairseq/models/speech_to_text/modules/emformer.py` mostly contains the code from Yangyang. I did a little modification to make it run on fairseq. `fairseq/models/speech_to_text/modules/augmented_memory_attention.py` contains code for the old streaming models `fairseq/models/speech_to_text/modules/convtransformer_simul_trans.py` contaons three convtransformer based simultaneous translation models. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1552 Reviewed By: jmp84 Differential Revision: D26563864 Pulled By: sravyapopuri388 fbshipit-source-id: a91a6247559861977cbc22db00ba9511f6b21c69 --- .../modules/augmented_memory_attention.py | 486 +++++ .../models/speech_to_text/modules/emformer.py | 1838 +++++++++++++++++ fairseq/models/speech_to_text/utils.py | 564 +++++ 3 files changed, 2888 insertions(+) create mode 100644 fairseq/models/speech_to_text/modules/augmented_memory_attention.py create mode 100644 fairseq/models/speech_to_text/modules/emformer.py create mode 100644 fairseq/models/speech_to_text/utils.py diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py new file mode 100644 index 0000000000..5d31524b76 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -0,0 +1,486 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, List + +import torch +import torch.nn.functional as F +from fairseq.models import FairseqEncoder +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.speech_to_text import ( + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.utils import attention_suppression +from fairseq.models.speech_to_text.utils import ( + lengths_to_encoder_padding_mask, + segments_to_sequence, + sequence_to_segments, +) +from fairseq.modules import MultiheadAttention, TransformerEncoderLayer +from torch import nn, Tensor + +# ------------------------------------------------------------------------------ +# AugmentedMemoryConvTransformerEncoder +# ------------------------------------------------------------------------------ + + +class AugmentedMemoryConvTransformerEncoder(ConvTransformerEncoder): + def __init__(self, args): + super().__init__(args) + + args.encoder_stride = self.stride() + + self.left_context = args.left_context // args.encoder_stride + + self.right_context = args.right_context // args.encoder_stride + + self.left_context_after_stride = args.left_context // args.encoder_stride + self.right_context_after_stride = args.right_context // args.encoder_stride + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [ + AugmentedMemoryTransformerEncoderLayer(args) + for i in range(args.encoder_layers) + ] + ) + + def stride(self): + # Hard coded here. Should infer from convs in future + stride = 4 + return stride + + def forward(self, src_tokens, src_lengths, states=None): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = 1.0 * max_seq_len / output_seq_len + input_lengths = (src_lengths.float() / subsampling_factor).round().long() + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + # TODO: fix positional embedding + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # State to store memory banks etc. + if states is None: + states = [ + {"memory_banks": None, "encoder_states": None} + for i in range(len(self.transformer_layers)) + ] + + for i, layer in enumerate(self.transformer_layers): + # x size: + # (self.left_size + self.segment_size + self.right_size) + # / self.stride, num_heads, dim + # TODO: Consider mask here + x = layer(x, states[i]) + states[i]["encoder_states"] = x[ + self.left_context_after_stride : -self.right_context_after_stride + ] + + lengths = ( + ( + ~encoder_padding_mask[ + :, self.left_context_after_stride : -self.right_context_after_stride + ] + ) + .sum(dim=1, keepdim=True) + .long() + ) + + return states[-1]["encoder_states"], lengths, states + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryTransformerEncoderLayer +# ------------------------------------------------------------------------------ +class AugmentedMemoryTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, args): + super().__init__(args) + + self.left_context = args.left_context // args.encoder_stride + self.right_context = args.right_context // args.encoder_stride + + def forward(self, x, state): + + length, batch_size, x_dim = x.size() + + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + # init_state + if state.get("memory_banks", None) is None: + state["memory_banks"] = [] + + # TODO reseach new sum_query method + seg_start = self.left_context + seg_end = length - self.right_context + if seg_start < seg_end: + summarization_query = torch.mean(x[seg_start:seg_end], keepdim=True, dim=0) + else: + summarization_query = x.new_zeros(1, batch_size, x_dim) + + x = torch.cat([x, summarization_query], dim=0) + + x = self.self_attn(input_and_summary=x, state=state) + + x = self.dropout_module(x) + x = residual + x + + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + + def build_self_attention(self, embed_dim, args): + return AugmentedMemoryMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + tanh_on_mem=True, + max_memory_size=args.max_memory_size, + ) + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryMultiheadAttention +# ------------------------------------------------------------------------------ +class AugmentedMemoryMultiheadAttention(MultiheadAttention): + """ + Augmented Memory Attention from + Streaming Transformer-based Acoustic Models + Using Self-attention with Augmented Memory + https://arxiv.org/abs/2005.08042 + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + tanh_on_mem=False, + memory_dim=None, + std_scale=0.5, # 0.5 based on https://arxiv.org/abs/2005.09137 + max_memory_size=-1, + disable_mem_on_mem_attn=True, + ): + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + q_noise, + qn_block_size, + ) + + self.memory_dim = memory_dim if memory_dim is not None else embed_dim + self.std_scale = std_scale + self.disable_mem_on_mem_attn = disable_mem_on_mem_attn + + # This Operator was used for factorization in PySpeech + self.v2e = lambda x: x + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = lambda x: x + self.nonlinear_squash_mem = False + + self.max_memory_size = max_memory_size + + def forward(self, input_and_summary, state): + """ + input: Encoder states of current segment with left or right context, + plus one summarization query + + """ + + length, batch_size, _ = input_and_summary.shape + length = length - 1 # not include sum_query, last index + + memory = state["memory_banks"] + # TODO: positional embedding on memory + + if self.max_memory_size > -1 and len(memory) > self.max_memory_size: + # TODO: need to fix here + if self.max_memory_size == 0: + memory = memory.new_zeros(1, memory.size(1), self.memory_dim) + else: + memory = memory[-self.max_memory_size :] + + memory_and_input = torch.cat(memory + [input_and_summary[:-1]], dim=0) + input_and_sum_query = input_and_summary + + q = self.q_proj(self.v2e(input_and_sum_query)) + k = self.k_proj(self.v2e(memory_and_input)) + v = self.v_proj(self.v2e(memory_and_input)) + + q = ( + q.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + * self.scaling + ) + k = ( + k.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + v.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + + if self.disable_mem_on_mem_attn: + attention_weights = self.suppress_mem_on_mem_attention( + batch_size, self.num_heads, len(memory), attention_weights + ) + + if self.std_scale is not None: + attention_weights = attention_suppression(attention_weights, self.std_scale) + + assert list(attention_weights.shape) == [ + batch_size * self.num_heads, + length + 1, + length + len(memory), + ] + + attention_weights = torch.nn.functional.softmax( + attention_weights.float(), dim=-1 + ).type_as(attention_weights) + + attention_probs = self.dropout_module(attention_weights) + + # [T, T, B, n_head] + [T, B, n_head, d_head] -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + + assert list(attention.shape) == [ + batch_size * self.num_heads, + length + 1, + self.head_dim, + ] + + attention = ( + attention.transpose(0, 1) + .contiguous() + .view(length + 1, batch_size, self.embed_dim) + ) + + output_and_memory = self.out_proj(attention) + + next_m = output_and_memory[-1:] + next_m = self.squash_mem(next_m) + output = output_and_memory[:-1] + + state["memory_banks"].append(next_m) + + return output + + def suppress_mem_on_mem_attention( + self, B: int, num_heads: int, mem_size: int, attention_weight: Tensor + ): + """ + Arguments: + - B: batch size + - num_heads: number of attention heads + - mem_size: size of memory bank + - attention_weight: a [B*num_heads, T + 1, T + mem_size] vector + + Return: + modified attention_weight with [B*num_heads, -1, :mem_size] = -inf + """ + attention_weight[:, -1, :mem_size] = float("-inf") + return attention_weight + + +# ------------------------------------------------------------------------------ +# SequenceEncoder +# ------------------------------------------------------------------------------ +class SequenceEncoder(FairseqEncoder): + """ + SequenceEncoder encodes sequences. + + More specifically, `src_tokens` and `src_lengths` in `forward()` should + describe a batch of "complete" sequences rather than segments. + + Segment-by-segment inference can be triggered by `segment_size`: + 1) `segment_size` is None: + SequenceEncoder treats the input sequence as one single segment. + 2) `segment_size` is not None (some int instead): + SequenceEncoder does the following: + 1. breaks the input sequence into several segments + 2. inference on each segment and collect the outputs + 3. concatanete segment outputs into the output sequence. + Note that `segment_size` here shouldn't include additional left/right + contexts needed, for example if we wish to infer with LC-BLSTM where the + middle chunk size is 100 and right context is 20, `segment_size` should be + 100. + """ + + def __init__(self, args, module): + super().__init__(None) + + self.module = module + self.input_time_axis = 1 + self.output_time_axis = 0 + self.segment_size = args.segment_size + self.left_context = args.left_context + self.right_context = args.right_context + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + states=None, + ): + + seg_src_tokens_lengths = sequence_to_segments( + sequence=src_tokens, + time_axis=self.input_time_axis, + lengths=src_lengths, + segment_size=self.segment_size, + extra_left_context=self.left_context, + extra_right_context=self.right_context, + ) + + seg_encoder_states_lengths: List[Tuple[Tensor, Tensor]] = [] + + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + + seg_encoder_states_lengths.append((seg_encoder_states, seg_enc_lengths)) + + encoder_out, enc_lengths = segments_to_sequence( + segments=seg_encoder_states_lengths, time_axis=self.output_time_axis + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + enc_lengths, batch_first=True + ) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + return EncoderOut( + encoder_out=encoder_out, + encoder_padding_mask=encoder_padding_mask, + encoder_embedding=None, + encoder_states=states, + src_tokens=None, + src_lengths=None, + ) + + def incremental_encode( + self, + seg_src_tokens: Tensor, + seg_src_lengths: Tensor, + states=None, + ): + """ + Different from forward function, this function takes segmented speech + as input, and append encoder states to previous states + """ + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + return seg_encoder_states, seg_enc_lengths, states + + +# ------------------------------------------------------------------------------ +# Augmented memory model decorator +# ------------------------------------------------------------------------------ +def augmented_memory(klass): + class StreamSeq2SeqModel(klass): + @staticmethod + def add_args(parser): + super(StreamSeq2SeqModel, StreamSeq2SeqModel).add_args(parser) + parser.add_argument( + "--segment-size", type=int, required=True, help="Length of the segment." + ) + parser.add_argument( + "--left-context", + type=int, + default=0, + help="Left context for the segment.", + ) + parser.add_argument( + "--right-context", + type=int, + default=0, + help="Right context for the segment.", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + + StreamSeq2SeqModel.__name__ = klass.__name__ + return StreamSeq2SeqModel diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py new file mode 100644 index 0000000000..42b157b766 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -0,0 +1,1838 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import math +import re +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq.models import ( + FairseqEncoder, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.speech_to_text.utils import ( + NoOp, + lengths_to_padding_mask, + segments_to_sequence, +) +from fairseq.models.speech_to_text.utils import ( + attention_suppression, + layer_norm_backward_hook, +) +from torch import Tensor, device as Device +from torch.quantization.qconfig import ( + default_dynamic_qconfig, + per_channel_dynamic_qconfig, +) + + +class RelativePositionEmbedding(nn.Module): + """ + Implementation according to https://arxiv.org/abs/1803.02155 + """ + + def __init__(self, head_dim, max_position, norm_init=True): + super().__init__() + self.head_dim = head_dim + self.max_position = max_position + self.embeddings = nn.Parameter(torch.Tensor(max_position * 2 + 1, head_dim)) + if norm_init: + nn.init.xavier_normal_(self.embeddings) + else: + nn.init.xavier_uniform_(self.embeddings) + + def forward(self, input: Tensor): + output = nn.functional.embedding(input.long(), self.embeddings) + return output + + +class Fp32LayerNorm(nn.Module): + def __init__( + self, + input_dim, + clamp_grad=True, + max_grad_value=256, + eps=1e-5, + elementwise_affine=True, + ): + super().__init__() + self.torch_module = torch.nn.LayerNorm( + input_dim, eps=eps, elementwise_affine=elementwise_affine + ) + if clamp_grad: + hook = partial(layer_norm_backward_hook, clamp_value=max_grad_value) + self.torch_module.register_backward_hook(hook) + + def forward(self, input): + output = torch.nn.functional.layer_norm( + input.float(), + self.torch_module.normalized_shape, + self.torch_module.weight.float() + if self.torch_module.weight is not None + else None, + self.torch_module.bias.float() + if self.torch_module.bias is not None + else None, + self.torch_module.eps, + ).type_as(input) + return output + + +# ------------------------------------------------------------------------------ +# PositionwiseFF +# ------------------------------------------------------------------------------ + + +class PositionwiseFF(nn.Module): + """ + FFN layer in transformer. + + Args: + input_dim: input embedding dimension + ffn_dim: FFN layer inner dimension + dropout_on_fc1: dropout for first linear layer + dropout_on_fc2: dropout fr second linear layer + activation_fn: activation function used after first linear layer. \ + Only relu or gelu is supported. + + """ + + def __init__( + self, input_dim, ffn_dim, dropout_on_fc1, dropout_on_fc2, activation_fn + ): + super(PositionwiseFF, self).__init__() + + self.input_dim = input_dim + self.ffn_dim = ffn_dim + if activation_fn == "relu": + ac = nn.ReLU() + elif activation_fn == "gelu": + ac = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(activation_fn)) + + # fc1 -> ac -> dropout -> fc2 -> dropout + self.module = nn.Sequential( + nn.Linear(input_dim, ffn_dim), + ac, + nn.Dropout(dropout_on_fc1), + nn.Linear(ffn_dim, input_dim), + nn.Dropout(dropout_on_fc2), + ) + + self.layer_norm = Fp32LayerNorm(input_dim) + + def forward(self, input): + module_out = self.module(self.layer_norm(input)) + output = module_out + input + + return output + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# SummarizationLayer +# ------------------------------------------------------------------------------ + + +class SummarizationLayer(nn.Module): + def __init__(self, method, segment_size, embedding_dim): + super(SummarizationLayer, self).__init__() + self.segment_size = segment_size + self.embedding_dim = embedding_dim + nonlin_match = re.match(r"nonlinear\((?P<act>[a-z]+),(?P<dim>[0-9]+)\)", method) + self.method = method + if method == "mean": + self.module = nn.AvgPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "max": + self.module = nn.MaxPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "linear": + self.module = nn.Linear(segment_size, 1) + elif nonlin_match: + nonlin_args = nonlin_match.groupdict() + act_type = nonlin_args["act"] + hid_dim = int(nonlin_args["dim"]) + if act_type == "relu": + act = nn.ReLU() + elif act_type == "gelu": + act = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(act_type)) + self.module = nn.Sequential( + nn.Linear(segment_size, hid_dim), + act, + nn.Linear(hid_dim, 1), + ) + else: + raise ValueError("Unsupported summarization method = ({})".format(method)) + + def forward(self, input): + # T, B, D -> B, D, T + input = input.permute(1, 2, 0) + + if self.method == "mean" or self.method == "max": + output = self.module(input) + output = output.permute(2, 0, 1) + return output + + full_seg_length = input.size(2) // self.segment_size * self.segment_size + if full_seg_length > 0: + # at least one seg is full + B = input.size(0) + D = input.size(1) + input_todo = ( + input[:, :, :full_seg_length] + .contiguous() + .view(B, -1, self.segment_size) + ) + output = self.module(input_todo) + output = output.view(B, D, -1) + else: + output = input.new_zeros(input.size(0), input.size(1), 0) + left = input.size(2) - full_seg_length + if left > 0: + # when last seg is not full, use zeros as last memory placeholder + zeros = input.new_zeros(input.size(0), input.size(1), 1) + output = torch.cat([output, zeros], dim=2) + output = output.permute(2, 0, 1) + return output + + +# ------------------------------------------------------------------------------ +# NoSegAugmentedMemoryMultiheadAttentionBmm +# ------------------------------------------------------------------------------ + + +class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module): + """ + Whole utterance augmented memory multihead attention using BMM. + + Different with previous augmented memory multihead attention where + the utterance is chunked into segments. Here we use attention mask + achieve so. The input embedding [right_context, utterance, summary] + is a concatenation of right context, utterance and summary. + + Right context block is the concatenation of all the right context for + each segments. [right_context_0, right_context_1, ..., right_context_n] + For example, if we have utterance = [v0, v1, v2, ...., v20]. segment + size 8, right_context size 4. Then the right context blocks = + [v8, v9, v10, v11, v16, v17, v18, v19, 0, 0, 0, 0], where v8, v9, v10, + and v11 are the right context for first segment. v16, v17, v18 and v19 + are the right context for second segment. 0, 0, 0 and 0 are right context + for the last segment. + + utterance is corresponding to input embedding sequence + + summary is concatenation of average of each segments. [summary_0, + summary_1, ..., ]. + + In augmented memory multihead attention, the query is [right_context, + utterance, summary], key is [memory, right_context, utterance]. Different + with AugmentedMemoryMultiheadAttentionBmm, memory here is passed from + previous attention layer. For the first attention layer, memory is average + of each segment. + + Memory is a concatenation of memory from each segments in previous attention + layer. For example, current layer is i, then memory is [m_0, m_1, ..., m_n]. + Each m_k is the output from seg_k in layer i-1. + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + dropout: attention dropout + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + scaled_init: whether to use scaled init for linear weight + tanh_on_mem: whether to use tanh on memory output + use_mem: whether to use memory or not. When max_memory_size is 0, then + we don't have memory anymore. + layer_index: current self-attention layer index that is used in depth + initialization + max_relative_position: max relative position used in relative position + embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + + """ + + def __init__( + self, + input_dim, + num_heads, + dropout=0.0, + std_scale=None, + scaled_init=False, + tanh_on_mem=False, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + max_relative_position=0, + rpe_old_option=True, + ): + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + super().__init__() + + embed_dim = input_dim + self.e2h_kv = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) + self.e2h_q = torch.nn.Linear(input_dim, input_dim, bias=True) + self.rpe_old_option = rpe_old_option + if max_relative_position > 0: + self.use_rpe = True + self.rpe_k = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + self.rpe_v = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + else: + self.use_rpe = False + self.rpe_k = None + self.rpe_v = None + if scaled_init: + if layer_index == -1: + gain = 1.0 / math.sqrt(2) + else: + # https://arxiv.org/abs/2005.09684 depthwise initialization + # stablize the training greatly. Use depthwise initialization to + # replace incremental loss. + gain = 1.0 / math.sqrt(layer_index + 1) + torch.nn.init.xavier_uniform_(self.e2h_kv.weight, gain=gain) + torch.nn.init.xavier_uniform_(self.e2h_q.weight, gain=gain) + + self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True) + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + + self.std_scale = std_scale + self.use_mem = use_mem + self.mini_batches = mini_batches + self.negative_inf = negative_inf + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = NoOp() + self.nonlinear_squash_mem = False + + def prepare_qkv( + self, + input: Tensor, + mems: Tensor, + lengths: Tensor, + summary_length: int, + lc_length: int, + ): + # T: right_context length + utterance_length + summary_length + T, B, D = input.shape + mem_length = mems.size(0) + utterance_length = torch.max(lengths) + + right_context_blocks_length = T - utterance_length - summary_length + rc_block = input[:right_context_blocks_length, :, :] + utterance_block = input[right_context_blocks_length : T - summary_length, :, :] + + if B == 1: + padding_mask = None + else: + klengths = lengths + mem_length + right_context_blocks_length + lc_length + padding_mask = lengths_to_padding_mask(lengths=klengths) + + mem_rc_input = torch.cat([mems, rc_block, utterance_block], dim=0) + + # In training lc_length = 0 + key_length = mem_rc_input.size(0) + lc_length + rc_input_sum = input + q = self.e2h_q(rc_input_sum) + kv = self.e2h_kv(mem_rc_input) + k, v = kv.chunk(chunks=2, dim=2) + result_qkv = (q, k, v) + input_shape = (T, B, D) + result_lengths_info = ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) + if padding_mask is not None: + assert padding_mask.size(0) == B + assert padding_mask.size(1) == key_length + + return result_qkv, input_shape, result_lengths_info, padding_mask + + def prepare_attention_weights( + self, + q: Tensor, + new_k: Tensor, + new_v: Tensor, + input_shape: Tuple[int, int, int], + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + T, B, D = input_shape + q = ( + q.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) + * self.scaling + ) + + k = ( + new_k.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + new_v.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_k = self.rpe_k(rpe) + # [q, B*h, d] * [q, k, d] -> [B*h, q, k] + attention_weights_rpe = torch.matmul( + q.transpose(0, 1), r_k.transpose(1, 2) + ).transpose(0, 1) + attention_weights = attention_weights + attention_weights_rpe + attention_weights_float = attention_weights.float() + + return attention_weights, attention_weights_float, v + + def prepare_attention_output( + self, + attention_weights: Tensor, + attention_weights_float: Tensor, + v: Tensor, + input_shape: Tuple[int, int, int], + key_length: int, + padding_mask: Optional[Tensor], + rpe: Optional[Tensor], + ) -> Tensor: + T, B, D = input_shape + if padding_mask is not None: + attention_weights_float = attention_weights_float.view( + B, self.num_heads, T, key_length + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attention_weights_float = attention_weights_float.view( + B * self.num_heads, T, key_length + ) + + if self.std_scale is not None: + attention_weights_float = attention_suppression( + attention_weights_float, self.std_scale + ) + + attention_weights_float = torch.nn.functional.softmax( + attention_weights_float, dim=-1 + ) + attention_weights = attention_weights_float.type_as(attention_weights) + + attention_probs = torch.nn.functional.dropout( + attention_weights, p=self.dropout, training=self.training + ) + + # [T, key_length, B, n_head]+ [key_length, B, n_head, d_head] + # -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_v = self.rpe_v(rpe) + attention_rpe = torch.matmul( + attention_probs.transpose(0, 1), r_v + ).transpose(0, 1) + + if self.rpe_old_option: + attention += attention + attention_rpe + else: + attention = attention + attention_rpe + + assert list(attention.shape) == [B * self.num_heads, T, self.head_dim] + + attention = attention.transpose(0, 1).contiguous().view(T, B, self.embed_dim) + + rc_output_memory = self.out_proj(attention) + return rc_output_memory + + @torch.jit.unused + def forward( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + attention_mask: Tensor, + pre_mems: Optional[Tensor] = None, + left_context_key: Optional[Tensor] = None, + left_context_val: Optional[Tensor] = None, + rpe: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in training. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + attention_mask: attention mask for query = [right_context, query, summary] + key = [mem, right_context, query]. This is only used for traing. + + """ + if self.use_mem: + mem_length = mems.size(0) + summary_length = mem_length + 1 + if pre_mems is not None: + mems = torch.cat([pre_mems, mems], dim=0) + else: + mem_length = 0 + summary_length = 0 + + # In training, lc_length = 0 + if left_context_key is not None: + lc_length = left_context_key.size(0) + else: + lc_length = 0 + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + if left_context_key is not None: + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + else: + new_k = k + new_v = v + next_k = None + next_v = None + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + + # mask attention + attention_mask = attention_mask.unsqueeze(0) + attention_weights_float = attention_weights_float.masked_fill( + attention_mask, float(self.negative_inf) + ) + + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + if self.use_mem: + # next_m length equals to summary length - 1 + # last memory is ignored + if self.mini_batches: + next_m = rc_output_memory[-summary_length:] + else: + next_m = rc_output_memory[-summary_length:-1] + + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-summary_length] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + next_m = mems + rc_output = rc_output_memory + + return rc_output, next_m, next_k, next_v + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in decoding. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + left_context_key: left_context for key part. This is only used for online + decoding. In training, this is empty tensor + left_context_val: left_context for value part. This is only used for online + decoding. In training, this is empty tensor + + """ + lc_length = left_context_key.size(0) + + # In decoding, summary_length = 1 or 0 + if self.use_mem: + summary_length = 1 + else: + summary_length = 0 + + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + # In online decoding, we don't have attention mask. But we still need + # to disable the attention from summary query to memory + attention_weights_float[:, -1, :mem_length] = float(self.negative_inf) + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + # In decoding, summary length is 1 + if self.use_mem: + next_m = rc_output_memory[-1:] + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-1] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + rc_output = rc_output_memory + # empty tensor as input mems + next_m = mems + + return rc_output, next_m, next_k, next_v + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +class NoSegAugmentedMemoryTransformer(nn.Module): + """ + Whole utterance augmented memory transformer. + + This is not pyspeech nn layer. It is used as a module in a master layer where + multiple transformers is used. + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + dropout_in_attn=0.0, + dropout_on_attn=None, + dropout_on_fc1=None, + dropout_on_fc2=None, + activation_fn="relu", + tanh_on_mem=False, + std_scale=None, + scaled_init=False, + segment_size=128, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super(NoSegAugmentedMemoryTransformer, self).__init__() + + self.attention = NoSegAugmentedMemoryMultiheadAttentionBmm( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout_in_attn, + scaled_init=scaled_init, + tanh_on_mem=tanh_on_mem, + std_scale=std_scale, + use_mem=use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + max_relative_position=max_relative_position, + ) + self.dropout = nn.Dropout(dropout_on_attn) + self.pos_ff = PositionwiseFF( + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + activation_fn=activation_fn, + ) + self.layer_norm_pre = Fp32LayerNorm(input_dim) + self.layer_norm = Fp32LayerNorm(input_dim) + self.segment_size = segment_size + self.use_mem = use_mem + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + def set_mini_batches(self, mini_batches): + self.attention.mini_batches = mini_batches + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def pre_attention_ops(self, input, right_context_blocks): + rc_length = right_context_blocks.size(0) + input_length = input.size(0) + + rc_and_input = torch.cat([right_context_blocks, input], dim=0) + residual_input = rc_and_input + rc_and_input = self.layer_norm_pre(rc_and_input) + + query_input = rc_and_input[-input_length:, :, :] + return rc_length, input_length, residual_input, query_input, rc_and_input + + def after_attention_ops(self, attention_output, residual_input): + output = self.dropout(attention_output) + output = output + residual_input + output = self.pos_ff(output) + output = self.layer_norm(output) + return output + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + right_context_blocks: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + + # In online decoding, the summary query size is always 1 or 0 + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + summary_query = summary_query[0:1, :, :] + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention.forward_jit( + input=rc_qu_su, + lengths=lengths, + mems=mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + return results + + @torch.jit.unused + def forward( + self, + input, + lengths, + mems, + right_context_blocks, + attention_mask, + pre_mems, + left_context_key, + left_context_val, + rpe, + ): + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention( + input=rc_qu_su, + lengths=lengths, + mems=mems, + attention_mask=attention_mask, + pre_mems=pre_mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + + # [TODO] Note memory did not go through pos_ff. What happen if we pass + # memory through the pos_ff as well? + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + + return results + + +class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder): + """ + Whole utterance augmented memory transformer encoder layer. This is a master layer + where we can define multiple augmented memory transformers. There are two reasons + to setup the master layer. + 1. We only need to define once about the attention mask. All the layers in the master + layer share the same mask. + 2. pyspeech nn layer has special input and output format. Defining one master layer is + easier to passing memory between different layes inside the master layer + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + ffn_dim: ffn dimension in FFN layer + num_layers: number of augmented memory transformer layers + dropout_in_attn: dropout used in multi-head self-attention + dropout_on_attn: dropout used for output from te multihead self-attention + dropout_on_fc1: dropout used in FFN layer for the first linear layer + dropout_on_fc2: dropout used in FFN layer for the second linear layer + segment_size: segment size for each segment + context_config: (left_context_size, right_context_size) defines the surround context size + for each segment + max_memory_size: maximum memory size used for each segment + scaled_init: whether use scaled init for weight initialization in attention layer + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + activation_fn: activation function used in FFN layer. [ReLU, GELU] supported + tanh_on_mem: whether use tanh on memory + mini_batches: use mini-btach training + negative_inf: the negative infinity value used in attention masking. default is "-inf". + For some situation, e.g. LM. it is better to use "-1e8" to avoid nan issue. + summarization_method: method to generate segment summrization embedding + max_relative_position: max relatie position for relative position embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + [TODO]: remove the rpe_old_option by the end of 2021 Q1. + + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + num_layers=1, + dropout_in_attn=0.0, + dropout_on_attn=0.0, + dropout_on_fc1=0.0, + dropout_on_fc2=0.0, + segment_size=128, + context_config=(0, 0), + max_memory_size=0, + scaled_init=True, + std_scale=None, + activation_fn="relu", + tanh_on_mem=False, + mini_batches=False, + negative_inf="-inf", + deep_init=True, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super().__init__(None) + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + # we used to support growing memory size. However, it will cause + # cross stream batching failure. Now we need to have exact max memory size + if max_memory_size < 0: + raise ValueError("max_memory_size must be >= 0") + + # Only assign right_context. In decoding, left context will be cached. + # No need to let the online decoder to re-assign the left context + self.left_context, self.right_context = context_config + self.segment_size = segment_size + self.memory_dim = input_dim + self.max_memory_size = max_memory_size + self.mini_batches = mini_batches + if self.max_memory_size != 0: + self.use_mem = True + else: + self.use_mem = False + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + self.layers = torch.nn.ModuleList() + self.num_layers = num_layers + self.max_relative_position = max_relative_position + if self.max_relative_position > 0: + self.use_rpe = True + else: + self.use_rpe = False + for i in range(self.num_layers): + if deep_init: + layer_index = i + else: + layer_index = -1 + + self.layers.append( + NoSegAugmentedMemoryTransformer( + num_heads=num_heads, + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_in_attn=dropout_in_attn, + dropout_on_attn=dropout_on_attn, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + segment_size=segment_size, + std_scale=std_scale, + activation_fn=activation_fn, + tanh_on_mem=tanh_on_mem, + scaled_init=scaled_init, + use_mem=self.use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + summarization_method=summarization_method, + max_relative_position=max_relative_position, + rpe_old_option=rpe_old_option, + ) + ) + + def set_mini_batches(self, mini_batches): + # handy function only used for unit test + self.mini_batches = mini_batches + for layer in self.layers: + layer.set_mini_batches(mini_batches) + + def _get_relative_position( + self, + input: Tensor, + max_relative_position: int, + left_context_length: int, + past_length: int, + is_decoding: bool, + ): + # For training, we copy the right context to the start of the utterance + # First dimension in distance is corresponding to query. + # [right context, utterance, summary vector] + # Second dimension in distance is corresponding to key. + # [Memory bank, right context, utterance] + # For summary vector in query part, the distance with + # all other position is 2*max_position. For memory bank in key, + # the distance with all other positions is 0. + + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + + # utterance + u_st = past_length * self.segment_size + u_ed = u_st + T + utterance_ranges = torch.arange(u_st, u_ed - self.right_context) + + # left context. Only in minibatch or decoding + left_context_ranges = torch.arange(u_st - left_context_length, u_st) + + # Right context block + # right context + utterance + right_context_blocks = [] + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + u_st + ed = st + self.right_context + assert ed < u_ed + temp = torch.arange(st, ed) + right_context_blocks.append(temp) + right_context_blocks.append(torch.arange(u_ed - self.right_context, u_ed)) + right_context_ranges = torch.cat(right_context_blocks) + + if self.use_mem: + # Memory bank + # The position for memory -n, .., -1 + if is_decoding: + memory_size = min(past_length, self.max_memory_size) + else: + memory_size = num_segs + past_length - 1 + memory_bank_ranges = torch.arange( + -max_relative_position - 1, -max_relative_position - 1 - memory_size, -1 + ) + + # summary vector + # The position for summary vector as the T+max_relative_position+1. + # After the clamping, the relative position is max_relative_position + summary_pos_st = u_ed + max_relative_position + 1 + summary_vector_ranges = torch.arange( + summary_pos_st, summary_pos_st + num_segs + ) + + key_ranges = torch.cat( + [ + memory_bank_ranges, + right_context_ranges, + left_context_ranges, + utterance_ranges, + ] + ) + + query_ranges = torch.cat( + [right_context_ranges, utterance_ranges, summary_vector_ranges] + ) + else: + key_ranges = torch.cat( + [right_context_ranges, left_context_ranges, utterance_ranges] + ) + + query_ranges = torch.cat([right_context_ranges, utterance_ranges]) + + distance = key_ranges[None, :] - query_ranges[:, None] + distance_clamp = ( + torch.clamp(distance, -max_relative_position, max_relative_position) + + max_relative_position + ) + distance_clamp = distance_clamp.to(input.device).long().detach() + return distance_clamp + + def _get_attention_mask(self, input, past_length=0, left_context_cache=0): + # attention mask for each query contains three parts: + # 1. memory part + # 2. left_context + segment + # 3. right_context_block + # so for each segment and its correspoinding right context block, + # the attention matrix is formed by 9 parts: + # [0, m, 0, 0, right_context, 0, 0, seg, 0] + # [before memory, memory, after memory, before right context, right_context, + # after right context, before seg, seg, after seg] + # + # Query is formed in the way as [right_context_blocks, utterance, summary] + # + # Note: put m and right_context before segment is convenient + # for padding_mask operation. + # Key lengths = m_length + right_context_block_length + lengths + utterance_length, batch_size, _ = input.shape + summary_length = math.ceil(utterance_length / self.segment_size) + num_segs = summary_length + rc_length = self.right_context * num_segs + rc = self.right_context + lc = self.left_context + + # using mini-batches, there is left context cache available for current + # sequence. + lcc = left_context_cache + + # max_memory_size is 0 then we don't have memory and summary + # past_length is the memory carry from previous sequence + if self.use_mem: + mem_length = num_segs - 1 + past_length + else: + mem_length = 0 + rc_mask = [] + query_mask = [] + summary_mask = [] + for j in range(0, num_segs): + ssize = min(self.segment_size, utterance_length - j * self.segment_size) + + rc_size = rc + rc_mat = [] + q_mat = [] + s_mat = [] + m_start = max(j + past_length - self.max_memory_size, 0) + + # max_memory_size is 0, then we don't use memory + if self.use_mem: + # part 0: before memory + rc_mat.append(input.new_zeros(rc_size, m_start)) + q_mat.append(input.new_zeros(ssize, m_start)) + s_mat.append(input.new_zeros(1, m_start)) + + # part 1: memory + col_1 = j + past_length - m_start + rc_mat.append(torch.ones(rc_size, col_1, device=input.device)) + q_mat.append(torch.ones(ssize, col_1, device=input.device)) + # based on D22875746, disable summary query attention + # on memeory is better for long form utterance + s_mat.append(input.new_zeros(1, col_1)) + + # part 2: after memory + col_2 = mem_length - (j + past_length) + rc_mat.append(input.new_zeros(rc_size, col_2)) + q_mat.append(input.new_zeros(ssize, col_2)) + s_mat.append(input.new_zeros(1, col_2)) + + # part 3: before right context + rc_start = j * rc + rc_mat.append(input.new_zeros(rc_size, rc_start)) + q_mat.append(input.new_zeros(ssize, rc_start)) + s_mat.append(input.new_zeros(1, rc_start)) + + # part 4: right context + rc_end = rc_start + rc + col_4 = rc + rc_mat.append(torch.ones(rc_size, col_4, device=input.device)) + q_mat.append(torch.ones(ssize, col_4, device=input.device)) + s_mat.append(torch.ones(1, col_4, device=input.device)) + + # part 5: after right context + col_5 = rc_length - rc_end + rc_mat.append(input.new_zeros(rc_size, col_5)) + q_mat.append(input.new_zeros(ssize, col_5)) + s_mat.append(input.new_zeros(1, col_5)) + + # part 6: before query segment + seg_start = max(j * self.segment_size + lcc - lc, 0) + rc_mat.append(input.new_zeros(rc_size, seg_start)) + q_mat.append(input.new_zeros(ssize, seg_start)) + s_mat.append(input.new_zeros(1, seg_start)) + + # part 7: query segment + # note: right context is put in right context block + # here we only need to consider about left context + seg_end = min((j + 1) * self.segment_size + lcc, utterance_length + lcc) + col_7 = seg_end - seg_start + rc_mat.append(torch.ones(rc_size, col_7, device=input.device)) + q_mat.append(torch.ones(ssize, col_7, device=input.device)) + s_mat.append(torch.ones(1, col_7, device=input.device)) + + # part 8: after query segment + col_8 = utterance_length + lcc - seg_end + rc_mat.append(input.new_zeros(rc_size, col_8)) + q_mat.append(input.new_zeros(ssize, col_8)) + s_mat.append(input.new_zeros(1, col_8)) + + rc_mask.append(torch.cat(rc_mat, dim=1)) + query_mask.append(torch.cat(q_mat, dim=1)) + summary_mask.append(torch.cat(s_mat, dim=1)) + + # no memory, then we don't need summary either + if self.use_mem: + attention_mask = ( + 1 + - torch.cat( + [ + torch.cat(rc_mask, dim=0), + torch.cat(query_mask, dim=0), + torch.cat(summary_mask, dim=0), + ], + dim=0, + ) + ).to(torch.bool) + else: + attention_mask = ( + 1 + - torch.cat( + [torch.cat(rc_mask, dim=0), torch.cat(query_mask, dim=0)], dim=0 + ) + ).to(torch.bool) + + return attention_mask + + @torch.jit.export + def init_state( + self, batch_size: int, device: Optional[Device] = None + ) -> List[Tensor]: + empty_memory = torch.zeros( + self.num_layers, + self.max_memory_size, + batch_size, + self.memory_dim, + device=device, + ) + left_context_key = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + left_context_val = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + + return [empty_memory, left_context_key, left_context_val, past_length] + + @torch.jit.export + def batch_state(self, states: List[List[Tensor]]) -> List[Tensor]: + if len(states) == 0: + return [] + batched_m = [] + batched_lc_key = [] + batched_lc_val = [] + batched_past_length = [] + for state in states: + if len(state) == 0: + continue + m, lc_key, lc_val, past_length = state + batched_m.append(m) + batched_lc_key.append(lc_key) + batched_lc_val.append(lc_val) + batched_past_length.append(past_length) + + if ( + (len(batched_m) == 0) + or (len(batched_lc_key) == 0) + or (len(batched_lc_val) == 0) + or (len(batched_past_length) == 0) + ): + return [ + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ] + + batched_m = torch.cat(batched_m, dim=2) + batched_lc_key = torch.cat(batched_lc_key, dim=2) + batched_lc_val = torch.cat(batched_lc_val, dim=2) + batched_past_length = torch.cat(batched_past_length, dim=1) + return [batched_m, batched_lc_key, batched_lc_val, batched_past_length] + + @torch.jit.export + def reorder_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + if len(state) == 0: + return [] + m, lc_key, lc_val, past_length = state + indices = indices.to(device=m.device) + reord_m = torch.index_select(m, 2, indices) + reord_lc_key = torch.index_select(lc_key, 2, indices) + reord_lc_val = torch.index_select(lc_val, 2, indices) + reord_past_length = torch.index_select(past_length, 1, indices) + return [reord_m, reord_lc_key, reord_lc_val, reord_past_length] + + @torch.jit.export + def reset_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + m, lc_key, lc_val, past_length = state + m = m.index_fill(dim=2, index=indices, value=0.0) + lc_key = lc_key.index_fill(dim=2, index=indices, value=0.0) + lc_val = lc_val.index_fill(dim=2, index=indices, value=0.0) + past_length = past_length.index_fill(dim=1, index=indices, value=0) + + return [m, lc_key, lc_val, past_length] + + @torch.jit.export + def state_size(self) -> int: + return 4 + + @torch.jit.export + def batch_size_in_state( + self, state: Optional[List[Tensor]], sloppy: bool = True + ) -> Optional[int]: + if state is None: + return None + return state[0].size(2) + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def _gen_right_context_padded_input(self, input): + # This function deals with input that is already + # padded with right context (e.g. minibatch training) + right_context_blocks = [] + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + ed = st + self.right_context + assert ed < T + temp = input[st:ed, :, :] + right_context_blocks.append(temp) + + # last segment right context is already available + right_context_blocks.append(input[T - self.right_context :, :, :]) + return torch.cat(right_context_blocks, dim=0) + + def _gen_segs_right_context(self, input, lengths): + segments = [] + T, B, D = input.size() + nT = T - self.right_context + + # assume input is right context padded + num_segs = math.ceil(nT / self.segment_size) + # pad zeros to the utterance to make sure each + # segment has the same right context. For the + for i in range(0, num_segs - 1): + st = i * self.segment_size + ed = min(T, st + self.segment_size + self.right_context) + temp = input[st:ed, :, :] + rest_lengths = torch.clamp( + lengths - self.segment_size, min=0, max=nT - (i + 1) * self.segment_size + ) + segments.append((temp, lengths - rest_lengths + self.right_context)) + lengths = rest_lengths + + last_seg = input[st + self.segment_size :, :, :] + segments.append((last_seg, rest_lengths + self.right_context)) + + return segments + + @torch.jit.unused + def forward( + self, input: Tensor, padding_masks: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + # Xutai: originally the second argument is lengths. + lengths = (~padding_masks).sum(dim=1).long() + # mini batch training. + if self.mini_batches: + return self.forward_mini_batches(input, lengths, state) + + # regular full sequence training. Note, assume the right context in provided + # in the input. + T, B, D = input.size() + right_context_blocks = self._gen_right_context_padded_input(input) + + # generate the relative positional embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=0, + past_length=0, + is_decoding=False, + ) + else: + rpe = None + input = input[: T - self.right_context, :, :] + + attention_mask = self._get_attention_mask(input) + + # firt layer use each segment mean as memory + # ignore the last one seg average + if self.use_mem: + mems = self.gen_summary_queries(input)[:-1, :, :] + else: + mems = torch.zeros(0, input.size(1), input.size(2), device=input.device) + mems = mems.type_as(input) + + output = input + all_outputs = [] + + for layer in self.layers: + output, mems, right_context_blocks, _, _ = layer( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=None, + left_context_key=None, + left_context_val=None, + rpe=rpe, + ) + all_outputs.append(output) + return output, padding_masks, [], all_outputs + + def forward_jit_mini_batch_init( + self, + seg: Tensor, + state: Optional[List[Tensor]] = None, + is_decoding: bool = False, + ): + # Prepare state. In whole sequence training, state is ignored. + # For minibatch training, we need to prepare state + if state is None: + state = self.init_state(batch_size=seg.size(1), device=seg.device) + if seg.dtype == torch.half: + state = [state[0].half(), state[1].half(), state[2].half(), state[3]] + + if self.use_mem: + # note input average only on seg, not on right context + # first layer use each segmetn mean as memory. the last + # one segment average is used in state + full_mems = self.gen_summary_queries(seg) + if is_decoding: + mems = full_mems[0:1, :, :] + state_mems = torch.cat([state[0][0], mems], dim=0) + else: + mems = full_mems[:-1, :, :] + state_mems = torch.cat([state[0][0], full_mems], dim=0) + else: + mems = state[0][0] + state_mems = mems + + # track processed segment number or memory number + # the same batch as the same bumber of past length + past_length = state[3][0][0].item() + past_left_context = min(past_length * self.segment_size, self.left_context) + past_length = min(self.max_memory_size, past_length) + + return state, mems, state_mems, past_length, past_left_context + + def state_update_before( + self, layer: int, state: List[Tensor], past_length: int, past_left_context: int + ): + pre_mems = state[0][layer][self.max_memory_size - past_length :, :, :] + lc_key = state[1][layer][self.left_context - past_left_context :, :, :] + lc_val = state[2][layer][self.left_context - past_left_context :, :, :] + return pre_mems, lc_key, lc_val + + def state_update_after( + self, + layer: int, + state: List[Tensor], + mems: Tensor, + next_key: Tensor, + next_val: Tensor, + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + ): + # mems is used for next layer + if layer < self.num_layers - 1: + state_mems = torch.cat([state[0][layer + 1], mems], dim=0) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + + # when mems pass to next sequence, we need the last memory. when mems + # use for the next layer, we can ignore the last memory + mems = mems[:-1, :, :] + + # note state[1][i] and state[2][i] original length equals to self.left_context + new_k = torch.cat([state[1][layer], next_key], dim=0) + new_v = torch.cat([state[2][layer], next_val], dim=0) + lc_key_list.append(new_k[-self.left_context :, :, :]) + lc_val_list.append(new_v[-self.left_context :, :, :]) + return mems_list, lc_key_list, lc_val_list, mems + + def state_update_after_loop( + self, + state: List[Tensor], + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + update_length: int, + ): + state[0] = torch.stack(mems_list, dim=0) + state[1] = torch.stack(lc_key_list, dim=0) + state[2] = torch.stack(lc_val_list, dim=0) + state[3] = state[3] + update_length + return state + + @torch.jit.unused + def forward_mini_batches( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + T, B, D = input.size() + + # input without right context + seg = input[: T - self.right_context, :, :] + + # get right context blocks + right_context_blocks = self._gen_right_context_padded_input(input) + + mems_list = [] + lc_key_list = [] + lc_val_list = [] + results = self.forward_jit_mini_batch_init(seg, state, False) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=False, + ) + else: + rpe = None + + # get attention mask based on seg (not include right context) and available + # left context + attention_mask = self._get_attention_mask(seg, past_length, past_left_context) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + all_outputs = [] + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + pre_mems, lc_key, lc_val = self.state_update_before( + i, state, past_length, past_left_context + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + all_outputs.append(output) + mems_list, lc_key_list, lc_val_list, mems = self.state_update_after( + layer=i, + state=state, + mems=mems, + next_key=next_key, + next_val=next_val, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + + i += 1 + + # update state + update_length = math.ceil((T - self.right_context) / self.segment_size) + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=update_length, + ) + + return output, lengths, state, all_outputs + + def forward_jit_test( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + This one simulate sequence encoder forward jit. This is for unit test purpose. + It is not used in training or decoding. Note, extra_right_context is set in + the model. In unit test, input = [utterance, right_context], lengths = + [utterance_length]. + args: + input: input utterance + lengths: utterance input length + state: None here. input is whole utterance + """ + # [TODO] sequence_to_segment has bug in lengths. + seg_src_tokens_lengths = self._gen_segs_right_context(input, lengths) + + seg_enc_tokens_lengths: List[Tuple[Tensor, Tensor]] = [] + state: Optional[List[Tensor]] = None + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + seg_enc_tokens, seg_enc_lengths, state = self.forward_jit( + input=seg_src_tokens, lengths=seg_src_lengths, state=state + ) + seg_enc_tokens_lengths.append((seg_enc_tokens, seg_enc_lengths)) + + enc_tokens, enc_lengths = segments_to_sequence( + segments=seg_enc_tokens_lengths, time_axis=0 + ) + + state = [] # returns trivial state + + return enc_tokens, enc_lengths, state + + @torch.jit.export + def forward_jit( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Forward helper for online decoding. + + args: + input: [seg, right_context]. We assume in online we + always padding the right context to the preset right context size. + For the last segment, we may have short segment size, but right + context size is the same as other segments + lengths: utterance input length is the utterance segment length and + right context size + state: [memory, left_context_key, left_context_val]. To improve throughput, + in addition to memory, we also cache key and value for left_context in + multihead self-attention + """ + # In online decoding, input = [segment, right_context] + # Lengths = [segment_length, right_context_length] + # so we need strip right context in output + T, B, D = input.size() + rc_str = T - self.right_context + rc_end = T + right_context_blocks = input[rc_str:rc_end, :, :] + seg = input[:rc_str, :, :] + lengths = torch.clamp(lengths - self.right_context, min=0) + mems_list = [] + lc_key_list = [] + lc_val_list = [] + + results = self.forward_jit_mini_batch_init(seg, state, True) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=True, + ) + else: + rpe = None + + # memory for first layer. + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + true_mems, lc_key, lc_val = self.state_update_before( + layer=i, + state=state, + past_length=past_length, + past_left_context=past_left_context, + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward_jit( + input=output, + lengths=lengths, + mems=true_mems, + right_context_blocks=right_context_blocks, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + # mems is used for next layer + mems_list, lc_key_list, lc_val_list, _ = self.state_update_after( + layer=i, + state=state, + mems_list=mems_list, + mems=mems, + next_key=next_key, + next_val=next_val, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + i += 1 + + # update state + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=1, + ) + + return output, lengths, state + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# Emformer encoder for seq2seq model +# This is a wrapper over the original emformer +# ------------------------------------------------------------------------------ +def emformer_encoder(klass): + class SpeechEncoder(klass): + def __init__(self, args): + super().__init__(args) + stride = SpeechEncoder.conv_layer_stride(args) + trf_left_context = args.segment_left_context // stride + trf_right_context = args.segment_right_context // stride + context_config = [trf_left_context, trf_right_context] + self.transformer_layers = nn.ModuleList( + [ + NoSegAugmentedMemoryTransformerEncoderLayer( + input_dim=args.encoder_embed_dim, + num_heads=args.encoder_attention_heads, + ffn_dim=args.encoder_ffn_embed_dim, + num_layers=args.encoder_layers, + dropout_in_attn=args.dropout, + dropout_on_attn=args.dropout, + dropout_on_fc1=args.dropout, + dropout_on_fc2=args.dropout, + activation_fn=args.activation_fn, + context_config=context_config, + segment_size=args.segment_length, + max_memory_size=args.max_memory_size, + scaled_init=True, # TODO: use constant for now. + tanh_on_mem=args.amtrf_tanh_on_mem, + ) + ] + ) + + def forward(self, *args, **kwargs): + encoder_out = super().forward(*args, **kwargs) + (output, encoder_padding_masks, [], all_outputs) = encoder_out.encoder_out + + # This is because that in the original implementation + # the output didn't consider the last segment as right context. + encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] + # import pdb;pdb.set_trace() + + return EncoderOut( + encoder_out=output, + encoder_padding_mask=encoder_padding_masks, + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=None, + ) + + @staticmethod + def conv_layer_stride(args): + # TODO: make it configurable from the args + return 4 + + SpeechEncoder.__name__ = klass.__name__ + return SpeechEncoder diff --git a/fairseq/models/speech_to_text/utils.py b/fairseq/models/speech_to_text/utils.py new file mode 100644 index 0000000000..573f8537c9 --- /dev/null +++ b/fairseq/models/speech_to_text/utils.py @@ -0,0 +1,564 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import logging +from collections.abc import Iterable +from itertools import repeat +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + + +# ------------------------------------------------------------------------------ +# assert_equal() +# ------------------------------------------------------------------------------ + + +def assert_equal(value1, value2, name1=None, name2=None): + """Asserts two values are equal otherwise raise an error.""" + + str_name1 = "" if name1 is None else "{} ".format(name1) + str_name2 = "" if name2 is None else "{} ".format(name2) + if value1 != value2: + str_value1 = "{}" if name1 is None else "({})" + str_value1 = str_value1.format(value1) + str_value2 = "{}" if name2 is None else "({})" + str_value2 = str_value2.format(value2) + raise ValueError( + "Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) + ) + + +def fill_config(config, key, value): + if value is not None: + if key not in config or config[key] is None: + config[key] = value + assert_equal(value, config[key], "value", f'config["{key}"]') + + +# ------------------------------------------------------------------------------ +# check_and_return_expected() +# ------------------------------------------------------------------------------ + + +def check_and_return_expected(value, undefined_value, expected_value, name=None): + """ + Return the expected value while checking if the given value is undefined or + equal to the expected value. + """ + if (undefined_value is None and value is None) or (undefined_value == value): + return expected_value + if value != expected_value: + str_name = "" if name is None else "{} ".format(name) + str_value = "{}" if name is None else "({})" + str_value = str_value.format(value) + raise ValueError( + "Expected {}{} == {}".format(str_name, str_value, expected_value) + ) + return expected_value + + +# ------------------------------------------------------------------------------ +# get_time_axis() +# ------------------------------------------------------------------------------ + + +def get_time_axis(layout): + """ + Extract the time axis from the layout, for example for breaking sequence into + segments. + """ + if layout in ["TB", "TBD"]: + return 0 + if layout in ["BT", "BTD"]: + return 1 + if layout in ["BCTD"]: + return 2 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# get_batch_axis() +# ------------------------------------------------------------------------------ + + +def get_batch_axis(layout): + """ + Extract the batch axis from the layout + """ + if layout in ["TB", "TBD"]: + return 1 + if layout in ["BT", "BTD", "BCTD"]: + return 0 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# monotonically_increasing_and_bounded() +# ------------------------------------------------------------------------------ + + +def monotonically_increasing_and_bounded(iterable, min=None, max=None): + """ + Check if the elements in the given iterable are monotonically increasing and + bounded by upper/lower bounds. + """ + if not isinstance(iterable, Iterable): + raise TypeError( + "Expected iterable to be of type Iterable, got ({})".format( + iterable.__class__.__name__ + ) + ) + for i in range(len(iterable)): + if min is not None and iterable[i] < min: + return False + if max is not None and iterable[i] > max: + return False + if i > 0 and iterable[i] <= iterable[i - 1]: + return False + return True + + +# ------------------------------------------------------------------------------ +# to_pair() +# ------------------------------------------------------------------------------ + + +def to_pair(value, name): + """Make a pair (of type tuple) of given value.""" + if isinstance(value, Iterable): + if len(value) != 2: + raise ValueError( + "Expected `{}` to have exactly 2 elements, got: ({})".format( + name, value + ) + ) + return value + return tuple(repeat(value, 2)) + + +# ------------------------------------------------------------------------------ +# infer_conv_output_attrs() +# ------------------------------------------------------------------------------ + + +# TODO(cfyeh): figure out if we can get `output_dim` without calling the module. +def infer_conv_output_attrs( + module, input_channels, input_dim, batch_size=1, max_length=8 +): + """Get output attributes of a module with input.""" + input = torch.randn(batch_size, input_channels, max_length, input_dim) + output = module(input) + output_channels = output.shape[1] + output_dim = output.shape[-1] + return output_channels, output_dim + + +# ------------------------------------------------------------------------------ +# NoOp +# ------------------------------------------------------------------------------ + + +class NoOp(torch.nn.Module): + """ + NoOp simply passes the input as the output. + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +# ------------------------------------------------------------------------------ +# Permute: a torch.nn.Module applies permutation on the input tensor. +# ------------------------------------------------------------------------------ + + +class Permute(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, input: Tensor) -> Tensor: + return input.permute(self.dims).contiguous() + + +# ------------------------------------------------------------------------------ +# lengths_to_padding_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_padding_mask(lengths: Tensor) -> Tensor: + """Convert lengths of shape (B, ) to padding mask.""" + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange( # [0, ..., T-1] + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(batch_size, max_length) >= lengths.unsqueeze(1) + + return padding_mask + + +# ------------------------------------------------------------------------------ +# lengths_to_attention_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_attention_mask( + lengths: Tensor, + left_context: Optional[int] = None, + right_context: Optional[int] = None, +) -> Optional[Tensor]: + """ + Generate attention mask based on (lengths, left_context, right_context). + left_context is None means unlimited left context. + right_context is None means unlimited right context. + """ + + if left_context is None and right_context is None: + return None + + max_length = int(torch.max(lengths).item()) + + # For example, with `max_length` == 5, + # indices = tensor([ + # [ 0, 1, 2, 3, 4, 5], + # [-1, 0, 1, 2, 3, 4], + # [-2, -1, 0, 1, 2, 3], + # [-3, -2, -1, 0, 1, 2], + # [-4, -3, -2, -1, 0, 1], + # [-5, -4, -3, -2, -1, 0], + # ]) + + # In some cases the second torch.arange is created on cpu which causes a + # failure. Adding the device option to guard against it. + indices = torch.arange( + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(max_length, max_length) - torch.arange( + max_length, device=lengths.device + ).view( + max_length, -1 + ) + + # For example, with `max_length` == 5, + # bool_mask = tensor([ + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + bool_mask = ( + torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) + ) + + # For example, with `max_length` == 5, left_context == 2 + # left_mask = tensor([ + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [False, True, True, True, True], + # [False, False, True, True, True], + # ]) + if left_context is not None: + left_mask = indices >= -left_context + bool_mask = bool_mask & left_mask + + # For example, with `max_length` == 5, right_context == 1 + # right_mask = tensor([ + # [True, True, False, False, False], + # [True, True, True, False, False], + # [True, True, True, True, False], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + if right_context is not None: + right_mask = indices <= right_context + bool_mask = bool_mask & right_mask + + bool_mask = (~bool_mask).to(device=lengths.device) + return bool_mask + + +# ------------------------------------------------------------------------------ +# infer_output_norm() +# ------------------------------------------------------------------------------ + + +def infer_output_norm(module, output_norm=None): + """ + Infer the output norm (string and module) needed on the module gvien desired + output normalization. + """ + if output_norm == module.output_norm(): + # output_norm already matches module.output_norm(). + return (None, NoOp()) + + if output_norm is None and module.output_norm() is not None: + logger = logging.getLogger("infer_output_norm()") + logger.warning( + "trying to set output_norm ({}) ".format(output_norm) + + "but got module.output_norm() ({}), ".format(module.output_norm()) + + "the combined output_norm() will be ({})".format(module.output_norm()) + ) + return (None, NoOp()) + + if output_norm == "log_softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) + + if output_norm == "softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("softmax", torch.nn.Softmax(dim=-1)) + + raise ValueError( + "output_norm ({}) not in ".format(output_norm) + + "supported list = [None, softmax, log_softmax]" + ) + + +# ------------------------------------------------------------------------------ +# infer_channels_from_layout() +# ------------------------------------------------------------------------------ + + +def infer_channels_from_layout(layout, channels): + """Extract the number of channels from the layout.""" + if layout in ("TBD", "BTD"): + if channels is not None and channels != 1: + raise ValueError( + "Expected channels ({}) to be 1 for layout = {}".format( + channels, layout + ) + ) + if channels is None: + return 1 + return channels + + +# ------------------------------------------------------------------------------ +# pad_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def pad_sequence( + sequence: Tensor, + time_axis: int, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> Tensor: + """Pad extra left/right contexts to the sequence.""" + + if extra_left_context == 0 and extra_right_context == 0: + return sequence + + tensors_to_concat = [] + + if extra_left_context: + size = (extra_left_context,) + fill_value = 0 + indices = torch.full( + size=size, + fill_value=fill_value, + dtype=torch.long, + device=sequence.device, + ) + left_padding = torch.index_select(sequence, time_axis, indices) + tensors_to_concat.append(left_padding) + + tensors_to_concat.append(sequence) + + # NOTE(cfyeh): for efficiency reason we pad 0 instead of the last frame for + # extra right contexts. + if extra_right_context: + size = list(sequence.shape) + size[time_axis] = extra_right_context + right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) + tensors_to_concat.append(right_padding) + + padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) + return padded_sequence + + +# ------------------------------------------------------------------------------ +# sequence_to_segments() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def sequence_to_segments( + sequence: Tensor, + time_axis: int, + lengths: Tensor, + segment_size: Optional[int] = None, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> List[Tuple[Tensor, Tensor]]: + """Breaks sequence into segments.""" + + sequence = pad_sequence( + sequence=sequence, + time_axis=time_axis, + extra_left_context=extra_left_context, + extra_right_context=extra_right_context, + ) + + lengths = lengths + extra_left_context + extra_right_context + + segments: List[Tuple[Tensor, Tensor]] = [] + + if segment_size is None: + segments.append((sequence, lengths)) + return segments + + offset = 0 + end = sequence.shape[time_axis] + step = segment_size + size = extra_left_context + segment_size + extra_right_context + + while offset + extra_left_context + extra_right_context < end: + clamped_size = min(size, end - offset) + segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) + indices = torch.arange( + start=offset, + end=(offset + clamped_size), + step=1, + dtype=torch.long, + device=sequence.device, + ) + segment_tensor = torch.index_select(sequence, time_axis, indices) + segments.append((segment_tensor, segment_lengths)) + offset = offset + step + + return segments + + +# ------------------------------------------------------------------------------ +# segments_to_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def segments_to_sequence( + segments: List[Tuple[Tensor, Tensor]], time_axis: int +) -> Tuple[Tensor, Tensor]: + """Concatenate segments into a full sequence.""" + if len(segments) == 1: + return segments[0] + + tensors_to_concat: List[Tensor] = [] + lengths_to_stack: List[Tensor] = [] + + for tensor, lengths in segments: + tensors_to_concat.append(tensor) + lengths_to_stack.append(lengths) + + sequence = torch.cat(tensors_to_concat, dim=time_axis) + lengths = torch.stack(lengths_to_stack, dim=0) + lengths = torch.sum(lengths, dim=0) + + return sequence, lengths + + +def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + batch_first: whether to return a (B, T) tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = False for t < lengths[b] and True otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) > lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +# ------------------------------------------------------------------------------ +# attention suppression +# ------------------------------------------------------------------------------ + + +def attention_suppression(attention_weights: Tensor, scale: float): + # B, H, qlen, klen -> B, H, qlen, 1 + attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) + attention_nozeros = attention_prob.to(torch.bool) + nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) + + # For very sparse situation, we need get round about 0s + key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) + + # nozeros_sum should > 1 + key_mean = key_sum / (nozeros_sum + 1e-8) + + # std calculation + dis = (attention_prob - key_mean) * (attention_prob - key_mean) + + # if attention_prob[i] < threshold, then dis_masked[i] = 0; for all i + dis_masked = torch.where( + attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) + ) + + key_var = torch.sum(dis_masked, dim=-1, keepdim=True) + key_var = key_var / (nozeros_sum - 1.0 + 1e-8) + key_std = torch.sqrt(key_var) + key_thread = key_mean - scale * key_std + + # if attention_prob[i] >= key_thread, then attention_prob[i] + # , otherwise "-inf" + inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() + inf_tensor[:] = float("-inf") + attention_weights_float = torch.where( + attention_prob < key_thread, + inf_tensor, + attention_weights.float(), + ) + + return attention_weights_float.type_as(attention_weights) + + +def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): + return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) From b8786dc2aadb56bb549f92ed542875096868bdd5 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 500/707] Integrate Augmented memory transformer and emformer based augmented memory transformer into fbcode Summary: Integrate Augmented memory transformer and emformer based augmented memory transformer into fbcode. This diff - Modifies the way encoder_out_dict variable is accessed in transformer_monotonic_attention.py - Fix dimension issues in augmented_memory_attention.py - Modifies the way encoder_out is accessed in emformer.py Reviewed By: jmp84 Differential Revision: D26567899 fbshipit-source-id: 9b298ad0bdf78de00b1182001813b0513d32a119 --- .../models/convtransformer_simul_trans.py | 99 ++++++++++++++++++- .../modules/augmented_memory_attention.py | 22 +++-- .../models/speech_to_text/modules/emformer.py | 26 +++-- 3 files changed, 122 insertions(+), 25 deletions(-) diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py index 760a48168d..0b15e93fea 100644 --- a/examples/simultaneous_translation/models/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -10,7 +10,17 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet +from fairseq.models.speech_to_text import ( + ConvTransformerModel, + convtransformer_espnet, + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( + augmented_memory, + SequenceEncoder, + AugmentedMemoryConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.emformer import emformer_encoder @register_model("convtransformer_simul_trans") @@ -56,3 +66,90 @@ def build_decoder(cls, args, task, embed_tokens): ) def convtransformer_simul_trans_espnet(args): convtransformer_espnet(args) + + +@register_model("convtransformer_augmented_memory") +@augmented_memory +class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): + @classmethod + def build_encoder(cls, args): + encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) + + if getattr(args, "load_pretrained_encoder_from", None) is not None: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + + return encoder + + +@register_model_architecture( + "convtransformer_augmented_memory", "convtransformer_augmented_memory" +) +def augmented_memory_convtransformer_espnet(args): + convtransformer_espnet(args) + + +# ============================================================================ # +# Convtransformer +# with monotonic attention decoder +# with emformer encoder +# ============================================================================ # + + +@emformer_encoder +class ConvTransformerEmformerEncoder(ConvTransformerEncoder): + pass + + +@register_model("convtransformer_emformer") +class ConvtransformerEmformer(SimulConvTransformerModel): + @staticmethod + def add_args(parser): + super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) + + parser.add_argument( + "--segment-length", + type=int, + metavar="N", + help="length of each segment (not including left context / right context)", + ) + parser.add_argument( + "--segment-left-context", + type=int, + help="length of left context in a segment", + ) + parser.add_argument( + "--segment-right-context", + type=int, + help="length of right context in a segment", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + parser.add_argument( + "--amtrf-tanh-on-mem", + default=False, + action="store_true", + help="whether to use tanh on memory vector", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEmformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + +@register_model_architecture( + "convtransformer_emformer", + "convtransformer_emformer", +) +def convtransformer_emformer_base(args): + convtransformer_espnet(args) diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py index 5d31524b76..e7465bc889 100644 --- a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -8,7 +8,6 @@ import torch import torch.nn.functional as F from fairseq.models import FairseqEncoder -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.speech_to_text import ( ConvTransformerEncoder, ) @@ -72,7 +71,10 @@ def forward(self, src_tokens, src_lengths, states=None): x = self.embed_scale * x subsampling_factor = 1.0 * max_seq_len / output_seq_len - input_lengths = (src_lengths.float() / subsampling_factor).round().long() + input_lengths = torch.max( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long(), + ) encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True @@ -425,14 +427,14 @@ def forward( if not encoder_padding_mask.any(): encoder_padding_mask = None - return EncoderOut( - encoder_out=encoder_out, - encoder_padding_mask=encoder_padding_mask, - encoder_embedding=None, - encoder_states=states, - src_tokens=None, - src_lengths=None, - ) + return { + "encoder_out": [encoder_out], + "encoder_padding_mask": [encoder_padding_mask], + "encoder_embedding": [], + "encoder_states": [states], + "src_tokens": [], + "src_lengths": [], + } def incremental_encode( self, diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py index 42b157b766..e026b86847 100644 --- a/fairseq/models/speech_to_text/modules/emformer.py +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -17,7 +17,6 @@ from fairseq.models import ( FairseqEncoder, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.speech_to_text.utils import ( NoOp, lengths_to_padding_mask, @@ -1811,23 +1810,22 @@ def __init__(self, args): ] ) - def forward(self, *args, **kwargs): - encoder_out = super().forward(*args, **kwargs) - (output, encoder_padding_masks, [], all_outputs) = encoder_out.encoder_out + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + (output, encoder_padding_masks, [], _) = encoder_out["encoder_out"][0] # This is because that in the original implementation # the output didn't consider the last segment as right context. encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] - # import pdb;pdb.set_trace() - - return EncoderOut( - encoder_out=output, - encoder_padding_mask=encoder_padding_masks, - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=None, - ) + + return { + "encoder_out": [output], + "encoder_padding_mask": [encoder_padding_masks], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } @staticmethod def conv_layer_stride(args): From 0c32e251e29dc6f10755addd37c5f9d963693df9 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Wed, 3 Mar 2021 09:59:23 -0800 Subject: [PATCH 501/707] Update Simultaneous Translation doc (#1659) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1659 Reviewed By: jmp84, kahne Differential Revision: D26708524 Pulled By: xutaima fbshipit-source-id: 0f34e5e9e3bec2360e098c9c272105c793bfa7b7 --- .../simultaneous_translation/docs/baseline.md | 178 ------------------ .../docs/evaluation.md | 115 ----------- .../docs/simulst_mustc_example.md | 60 +++++- 3 files changed, 54 insertions(+), 299 deletions(-) delete mode 100644 examples/simultaneous_translation/docs/baseline.md delete mode 100644 examples/simultaneous_translation/docs/evaluation.md diff --git a/examples/simultaneous_translation/docs/baseline.md b/examples/simultaneous_translation/docs/baseline.md deleted file mode 100644 index d9bf1a1117..0000000000 --- a/examples/simultaneous_translation/docs/baseline.md +++ /dev/null @@ -1,178 +0,0 @@ -# **Baseline Simultaneous Translation** ---- - -This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset. - -[STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/) - - -## **Requirements** -Install fairseq (make sure to use the correct branch): -``` -git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git -cd fairseq -pip install -e . -``` - -Assuming that fairseq is installed in a directory called `FAIRSEQ`. - -Install SentencePiece. One easy way is to use anaconda: - -``` -conda install -c powerai sentencepiece -``` - -Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/. -We will assume that the data is downloaded in a directory called `DATA_ROOT`. - - -## **Text-to-text Model** ---- -### Data Preparation -Train a SentencePiece model: -```shell -for lang in en de; do - python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` - -Process the data with the SentencePiece model: -```shell -proc_dir=proc -mkdir -p $proc_dir -for split in train dev tst-COMMON tst-HE; do - for lang in en de; do - spm_encode \ - --model unigram-$lang-10000-3000/spm.model \ - < $DATA_ROOT/data/$split/txt/$split.$lang \ - > $proc_dir/$split.spm.$lang - done -done -``` - -Binarize the data: - -```shell -proc_dir=proc -fairseq-preprocess \ - --source-lang en --target-lang de \ - --trainpref $proc_dir/train.spm \ - --validpref $proc_dir/dev.spm \ - --testpref $proc_dir/tst-COMMON.spm \ - --thresholdtgt 0 \ - --thresholdsrc 0 \ - --workers 20 \ - --destdir ./data-bin/mustc_en_de \ -``` - -### Training - - -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --simul-type waitk \ - --waitk-lagging 2 \ - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## **Speech-to-text Model** ---- -### Data Preparation -First, segment wav files. -```shell -python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \ - --datapath $DATA_ROOT -``` -Similar to text-to-text model, train a Sentencepiecemodel, but only train on German -```Shell -python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` -## Training -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --waitk-lagging 2 \ - --waitk-stride 10 \ - --input-feat-per-channel 40 \ - --encoder-hidden-size 512 \ - --output-layer-dim 128 \ - --decoder-num-layers 3 \ - --task speech_translation \ - --user-dir $FAIRSEQ/examples/simultaneous_translation - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## Evaluation ---- -### Evaluation Server -For text translation models, the server is set up as follow give input file and reference file. - -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --src-file $DATA_ROOT/data/dev/txt/dev.en \ - --ref-file $DATA_ROOT/data/dev/txt/dev.de -``` -For speech translation models, the input is the data direcrory. -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --ref-file $DATA_ROOT \ - --data-type speech -``` - -### Decode and Evaluate with Client -Once the server is set up, run client to evaluate translation quality and latency. -```shell -# TEXT -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --src-spm unigram-en-10000-3000/spm.model\ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt - -# SPEECH -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --data-type speech \ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt -``` diff --git a/examples/simultaneous_translation/docs/evaluation.md b/examples/simultaneous_translation/docs/evaluation.md deleted file mode 100644 index c53407354e..0000000000 --- a/examples/simultaneous_translation/docs/evaluation.md +++ /dev/null @@ -1,115 +0,0 @@ -# Introduction to evaluation interface -The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file. - -## Server-Client Protocol -An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure. - -While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence. - -### Server -The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set, - -```shell - - python fairseq/examples/simultaneous_translation/eval/server.py \ - --hostname local_host \ - --port 1234 \ - --src-file SRC_FILE \ - --ref-file REF_FILE \ - --data-type text \ -``` -The state that server sent to client is has the following format -```json -{ - 'sent_id': Int, - 'segment_id': Int, - 'segment': String -} -``` - -### Client -The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table - -|Action|Content| -|:---:|:---:| -|Request new word / utterence| ```{key: "Get", value: None}```| -|Predict word "W"| ```{key: "SEND", value: "W"}```| - - - -The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function. -```python -class Agent(object): - "an agent needs to follow this pattern" - def __init__(self, *args, **kwargs): - ... - - def init_states(self): - # Initializing states - ... - - def update_states(self, states, new_state): - # Update states with given new state from server - # TODO (describe the states) - ... - - def finish_eval(self, states, new_state): - # Check if evaluation is finished - ... - - def policy(self, state: list) -> dict: - # Provide a action given current states - # The action can only be either - # {key: "GET", value: NONE} - # or - # {key: "SEND", value: W} - ... - - def reset(self): - # Reset agent - ... - - def decode(self, session): - - states = self.init_states() - self.reset() - - # Evaluataion protocol happens here - while True: - # Get action from the current states according to self.policy() - action = self.policy(states) - - if action['key'] == GET: - # Read a new state from server - new_state = session.get_src() - states = self.update_states(states, new_state) - - if self.finish_eval(states, new_state): - # End of document - break - - elif action['key'] == SEND: - # Send a new prediction to server - session.send_hypo(action['value']) - - # Clean the history, wait for next sentence - if action['value'] == DEFAULT_EOS: - states = self.init_states() - self.reset() - else: - raise NotImplementedError - - -``` -Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered. - -## Quality -The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link) - -## Latency -The latency metrics are -* Average Proportion -* Average Lagging -* Differentiable Average Lagging -Again Thery will also be evaluated on detokenized text. - diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 5dea0d8475..0144fcb766 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -1,13 +1,46 @@ # Simultaneous Speech Translation (SimulST) on MuST-C +This is an instruction of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). + [MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. -## Data Preparation & ASR -Please follow the steps in offline [speech-to-text](../mustc_example.md) translation for data preparation and ASR pretraining. +## Data Preparation +[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path +`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with +```bash +# Additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary, +# global cepstral and mean estimation, +# and configuration for each language +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task asr \ + --vocab-type unigram --vocab-size 10000 \ + --cmvn-type global +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task st \ + --vocab-type unigram --vocab-size 10000 + --cmvn-type global +``` + +## ASR Pretraining +We just need a pretrained offline ASR model +``` +fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` -## Training +## Simultaneous Speech Translation Training -#### Wait-K(K=3) with fixed pre-decision module +### Wait-K with fixed pre-decision module +Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks. +Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and +a wait-3 policy model ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ @@ -21,8 +54,9 @@ Please follow the steps in offline [speech-to-text](../mustc_example.md) transla --simul-type waitk_fixed_pre_decision \ --waitk-lagging 3 \ --fixed-pre-decision-ratio 7 + ``` -#### Monotonic multihead attention with fixed pre-decision module +### Monotonic multihead attention with fixed pre-decision module ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ @@ -39,9 +73,13 @@ Please follow the steps in offline [speech-to-text](../mustc_example.md) transla ``` ## Inference & Evaluation [SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. +The source file is a list of paths of audio files, +while target file is the corresponding translations. ``` +pip install simuleval + simuleval \ - --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py --src-file ${SRC_LIST_OF_AUDIO} --tgt-file ${TGT_FILE} --data-bin ${MUSTC_ROOT}/en-de \ @@ -50,3 +88,13 @@ simuleval \ --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ --scores ``` + +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin). + +The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. + +The latency metrics are +* Average Proportion +* Average Lagging +* Differentiable Average Lagging +Again they will also be evaluated on detokenized text. From 7d2394b56f1cbdcdede9c7a8cf6de1df022e0a17 Mon Sep 17 00:00:00 2001 From: Eric Lou <ericlou@fb.com> Date: Wed, 3 Mar 2021 10:48:42 -0800 Subject: [PATCH 502/707] ioPath async - Fairseq unittests (#1669) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1669 Unit tests for async writes integration done in D26467815 (https://github.com/pytorch/fairseq/commit/3100d0b8e5bb5e61b4d73b9c058389aa2c06784a). Ongoing performance tests: https://fb.quip.com/kjM7Atb1kKbO Reviewed By: myleott Differential Revision: D26732660 fbshipit-source-id: faf8cac67b9167af4195358c1a2592804c13562c --- fairseq/file_io.py | 2 +- tests/test_checkpoint_utils.py | 15 +++++++++++++++ tests/test_file_io.py | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 731fef3570..9a78ab505d 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -170,7 +170,7 @@ def opena( if not IOPathPathManager: logging.info("ioPath is initializing PathManager.") try: - from iopath import PathManager + from iopath.common.file_io import PathManager IOPathPathManager = PathManager() except Exception: logging.exception("Failed to initialize ioPath PathManager object.") diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 617a5f7c84..3278de6b9f 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -9,8 +9,10 @@ import tempfile import unittest from io import StringIO +from unittest.mock import patch from fairseq import checkpoint_utils +from omegaconf import OmegaConf from tests.utils import ( create_dummy_data, @@ -87,6 +89,19 @@ def test_prune_state_dict(self): self.assertEqual(len(ensemble[0].encoder.layers), 2) self.assertEqual(len(ensemble[0].decoder.layers), 1) + def test_torch_persistent_save_async(self): + cfg = OmegaConf.create() + cfg.dataset = OmegaConf.create() + cfg.dataset.write_checkpoints_asynchronously = True + state_dict = {} + filename = "async_checkpoint.pt" + + with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: + with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: + checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + mock_opena.assert_called_with(filename, "wb") + mock_save.assert_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_io.py b/tests/test_file_io.py index aef5b80d18..8ebbba4a2e 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -45,3 +45,18 @@ def test_file_io_oss(self): with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents) + + def test_file_io_async(self): + # ioPath `PathManager` is initialized after the first `opena` call. + try: + from fairseq.file_io import IOPathPathManager, PathManager + + self.assertIsNone(IOPathPathManager) + _asyncfile = os.path.join(self._tmpdir, "async.txt") + f = PathManager.opena(_asyncfile, "wb") + f.close() + + from fairseq.file_io import IOPathPathManager + self.assertIsNotNone(IOPathPathManager) + finally: + self.assertTrue(PathManager.async_close()) From 1fed7a8426e8c548196add0d65d77857ab224705 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 3 Mar 2021 19:29:55 -0800 Subject: [PATCH 503/707] add unit test for multi_corpus_dataset Reviewed By: vimalmanohar Differential Revision: D26220694 fbshipit-source-id: ed13f8527a1b203e1a9d004fa8a86e1ad6423d60 --- tests/test_multi_corpus_dataset.py | 69 ++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/test_multi_corpus_dataset.py diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py new file mode 100644 index 0000000000..a1fafe489b --- /dev/null +++ b/tests/test_multi_corpus_dataset.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from collections import OrderedDict + +import torch +from fairseq.data import LanguagePairDataset, TokenBlockDataset +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from tests.test_train import mock_dict + + +class TestMultiCorpusDataset(unittest.TestCase): + def setUp(self): + d = mock_dict() + tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1) + tokens_ds1 = TokenBlockDataset( + tokens_1, + sizes=[tokens_1.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_1 = LanguagePairDataset( + tokens_ds1, tokens_ds1.sizes, d, shuffle=False + ) + tokens_2 = torch.LongTensor([i for i in range(2, 5000, 2)]).view(1, -1) + tokens_ds2 = TokenBlockDataset( + tokens_2, + sizes=[tokens_2.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_2 = LanguagePairDataset( + tokens_ds2, tokens_ds2.sizes, d, shuffle=False + ) + + def _test_sample_helper( + self, + distribution, + ): + m = MultiCorpusDataset( + OrderedDict({0: self.dataset_1, 1: self.dataset_2}), + distribution=distribution, + seed=0, + sort_indices=True, + ) + m.set_epoch(1) + indices = m.ordered_indices() + count_sample_from_first_dataset = 0 + for i in indices: + if m[i]["source"].item() % 2 == 1: + count_sample_from_first_dataset += 1 + sample_from_first_ds_percentage = ( + 1.0 * count_sample_from_first_dataset / len(indices) + ) + self.assertLess( + abs(sample_from_first_ds_percentage - distribution[0]), + 0.01, + ) + + def test_multi_corpus_dataset(self): + for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: + self._test_sample_helper(distribution=distribution) From fc2840de58b06f381626332153203fb32588c23d Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 3 Mar 2021 19:29:55 -0800 Subject: [PATCH 504/707] optimize sampling process of multi_corpus_dataset Summary: The sampling process in multi_corpus_dataset is very inefficient. Turns out we can signficantly optimize it by sampling in batches rather than one by one. this allows: 1. fast local development and iteration with corpus sampling, as the turnaround time was long before 2. makes it take less time for our jobs can start training, enabling earlier signal if for example there is a configuration issue Reviewed By: zhengwy888 Differential Revision: D26187821 fbshipit-source-id: b4f7f6b7c187b3785499308226e2af671a6c354f --- fairseq/data/multi_corpus_dataset.py | 85 +++++++++++++++++----------- tests/test_multi_corpus_dataset.py | 14 ++++- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 6563713489..00e464ed31 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import time from collections import OrderedDict from typing import Dict, List @@ -12,7 +13,6 @@ from . import FairseqDataset - logger = logging.getLogger(__name__) @@ -49,6 +49,7 @@ def __init__( super().__init__() assert isinstance(datasets, OrderedDict) assert len(datasets) == len(distribution) + assert sum(distribution) == 1 self.datasets = datasets self.distribution = distribution self.seed = seed @@ -69,43 +70,61 @@ def __init__( self.total_num_instances += len(dataset) def ordered_indices(self): + start = time.time() with data_utils.numpy_seed(self.seed, self.epoch): - # Used to store the order of indices of each dataset to use - indices = [ - np.random.permutation(len(dataset)) - for dataset in self.datasets.values() - ] - # Keep track of which samples we've used for each dataset - counters = [0 for _ in self.datasets] - - sampled_indices = [ - self._sample(indices, counters) for _ in range(self.total_num_instances) - ] + sampled_indices = [] + num_selected_instances = 0 + + # For each dataset i, sample self.distribution[i] * self.total_num_instances + for i, key in enumerate(self.datasets): + + if i < len(self.datasets) - 1: + num_instances = int(self.distribution[i] * self.total_num_instances) + high = self.dataset_offsets[i + 1] + else: + num_instances = self.total_num_instances - num_selected_instances + high = self.total_num_instances + + logger.info(f"sampling {num_instances} from {key} dataset") + num_selected_instances += num_instances + + # First, add k copies of the dataset where k = num_instances // len(dataset). + # This ensures an equal distribution of the data points as much as possible. + # For the remaining entries randomly sample them + dataset_size = len(self.datasets[key]) + num_copies = num_instances // dataset_size + dataset_indices = ( + np.random.permutation(high - self.dataset_offsets[i]) + + self.dataset_offsets[i] + )[: num_instances - num_copies * dataset_size] + if num_copies > 0: + sampled_indices += list( + np.concatenate( + ( + np.repeat( + np.arange(self.dataset_offsets[i], high), num_copies + ), + dataset_indices, + ) + ) + ) + else: + sampled_indices += list(dataset_indices) + + assert ( + len(sampled_indices) == self.total_num_instances + ), f"{len(sampled_indices)} vs {self.total_num_instances}" + + np.random.shuffle(sampled_indices) if self.sort_indices: sampled_indices.sort(key=lambda i: self.num_tokens(i)) - return np.array(sampled_indices, dtype=np.int64) - - def _sample(self, indices, counters): - # First pick dataset - dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) - - # Then get dataset internal index - idx = indices[dataset_idx][counters[dataset_idx]] - - # Convert to multi-datasets index - idx += self.dataset_offsets[dataset_idx] - - counters[dataset_idx] += 1 - - # Reset if we reach end - if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): - counters[dataset_idx] = 0 - indices[dataset_idx] = np.random.permutation( - len(self.dataset_list[dataset_idx]) + logger.info( + "multi_corpus_dataset ordered_indices took {}s".format( + time.time() - start + ) ) - - return idx + return np.array(sampled_indices, dtype=np.int64) def _map_index(self, index: int): """ diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py index a1fafe489b..5a79f4b680 100644 --- a/tests/test_multi_corpus_dataset.py +++ b/tests/test_multi_corpus_dataset.py @@ -27,7 +27,7 @@ def setUp(self): self.dataset_1 = LanguagePairDataset( tokens_ds1, tokens_ds1.sizes, d, shuffle=False ) - tokens_2 = torch.LongTensor([i for i in range(2, 5000, 2)]).view(1, -1) + tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1) tokens_ds2 = TokenBlockDataset( tokens_2, sizes=[tokens_2.size(-1)], @@ -53,9 +53,13 @@ def _test_sample_helper( m.set_epoch(1) indices = m.ordered_indices() count_sample_from_first_dataset = 0 + items = set() for i in indices: - if m[i]["source"].item() % 2 == 1: + item = m[i]["source"].item() + if item % 2 == 1: count_sample_from_first_dataset += 1 + + items.add(item) sample_from_first_ds_percentage = ( 1.0 * count_sample_from_first_dataset / len(indices) ) @@ -63,6 +67,12 @@ def _test_sample_helper( abs(sample_from_first_ds_percentage - distribution[0]), 0.01, ) + self.assertEqual( + len(items), + int(min(len(self.dataset_1), len(indices) * distribution[0]) + + min(len(self.dataset_1), len(indices) * distribution[1])) + ) + print(distribution) def test_multi_corpus_dataset(self): for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: From f6d60e2fee9fe8982e3c9de1e6bb77680978e749 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Wed, 3 Mar 2021 21:15:01 -0800 Subject: [PATCH 505/707] minor fixes and improvements (#1671) Summary: there are a few changes here: - convert config persisted in checkpoints into a plain dict when saving and back to omegaconf config when loading: this helps avoid compatibility issues between different versions of python, omegaconf, etc - update checkpoints that have old print_alignment saved - add lr_float to composite optimizer to enable sweeping on lr with auto sweepers like ax - fixing some edge cases for config loading Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1671 Reviewed By: myleott Differential Revision: D26791583 Pulled By: alexeib fbshipit-source-id: 124dec74932052925c43b6a93130f4428803cb46 --- fairseq/checkpoint_utils.py | 54 +++++++++++++++++++++++++++++++------ fairseq/dataclass/utils.py | 16 ++++++----- fairseq/optim/composite.py | 13 ++++++--- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index d6618fbb62..97f22041bc 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -21,7 +21,7 @@ ) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig, open_dict +from omegaconf import Container, DictConfig, open_dict, OmegaConf logger = logging.getLogger(__name__) @@ -275,8 +275,22 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) - if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: - overwrite_args_by_name(state["cfg"], arg_overrides) + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state @@ -440,7 +454,7 @@ def save_state( if extra_state is None: extra_state = {} state_dict = { - "cfg": cfg, + "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history @@ -453,7 +467,7 @@ def save_state( } ], "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {} + "task_state": task.state_dict() if task is not None else {}, } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() @@ -568,15 +582,39 @@ def _upgrade_state_dict(state): if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] # convert task data arg to a string instead of List[string] - if hasattr(state["args"], "data") and isinstance(state["args"].data, list) and len(state["args"].data) > 0: + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): state["args"].data = state["args"].data[0] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: - with open_dict(state["cfg"]): + cfg = state["cfg"] + with open_dict(cfg): # any upgrades for Hydra-based configs - pass + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = "hard" + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" return state diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index a4d4a412dd..27c9006fdb 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -43,7 +43,9 @@ def interpret_dc_type(field_type): return str typestring = str(field_type) - if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring) or typestring.startswith("typing.Optional"): + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): return field_type.__args__[0] return field_type @@ -235,15 +237,17 @@ def get_default(f): and not (isinstance(val, str) and val.startswith("${")) ): # if type is int but val is float, then we will crash later - try to convert here - if hasattr(v.type, '__args__'): + if hasattr(v.type, "__args__"): t_args = v.type.__args__ - if len(t_args) == 1: + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): val = list(map(t_args[0], val)) - elif val is not None and (field_type is int or field_type is bool or field_type is float): + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): try: val = field_type(val) except: - pass # ignore errors here, they are often from interpolation args + pass # ignore errors here, they are often from interpolation args if val is None: overrides.append("{}.{}=null".format(sub_node, k)) @@ -430,7 +434,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): if k in cfg and isinstance(cfg[k], DictConfig): if k in overrides and isinstance(overrides[k], dict): for ok, ov in overrides[k].items(): - if isinstance(ov, dict): + if isinstance(ov, dict) and cfg[k][ok] is not None: overwrite_args_by_name(cfg[k][ok], ov) else: cfg[k][ok] = ov diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index 1a581bc010..a5366d6243 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -22,12 +22,13 @@ class OptimizerAndSchedulerConfig(FairseqDataclass): optimizer: Any = None lr_scheduler: Optional[Any] = None - lr: List[float] = II("optimization.lr") + lr: List = II("optimization.lr") + lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers @dataclass class CompositeOptimizerConfig(FairseqDataclass): - groups: Dict[str, OptimizerAndSchedulerConfig] = field( + groups: Dict[str, Any] = field( default_factory=lambda: {}, metadata={ "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " @@ -64,8 +65,12 @@ def __init__(self, cfg: CompositeOptimizerConfig, params): for group, group_params in groupped_params.items(): group_cfg = cfg.groups[group] with open_dict(group_cfg): - group_cfg.optimizer.lr = group_cfg.lr - group_cfg.lr_scheduler.lr = group_cfg.lr + if group_cfg.lr_float is not None: + group_cfg.optimizer.lr = [group_cfg.lr_float] + group_cfg.lr_scheduler.lr = [group_cfg.lr_float] + else: + group_cfg.optimizer.lr = group_cfg.lr + group_cfg.lr_scheduler.lr = group_cfg.lr self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) if group_cfg.lr_scheduler is not None: self.lr_schedulers[group] = build_lr_scheduler( From f1c595beb8acd2a6dc8c9fa9f7fb60ca23c61899 Mon Sep 17 00:00:00 2001 From: Kaushik Rangadurai <krangadu@fb.com> Date: Thu, 4 Mar 2021 11:48:27 -0800 Subject: [PATCH 506/707] Ability to pass attn_mask to TransformerSentenceEncoder Summary: Provide an ability to pass attn_mask to TransformerSentenceEncoder. The default is None and hence this is backwards compatible. The attention mask can either be a 2D tensor (of shape [tgt_seq_len, src_seq_len]) or a 3D tensor of shape (bcz * num_heads, tgt_seq_len, src_seq_len). In case of self attention, tgt_seq_len = src_seq_len. Reviewed By: myleott Differential Revision: D26790767 fbshipit-source-id: 937d6c6cf08790c7d43d33fda97a30425f31ea06 --- fairseq/modules/transformer_sentence_encoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 6e9c32f467..a7fb198779 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -226,6 +226,7 @@ def forward( last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: is_tpu = tokens.device.type == "xla" @@ -268,7 +269,7 @@ def forward( inner_states.append(x) for layer in self.layers: - x, _ = layer(x, self_attn_padding_mask=padding_mask) + x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) if not last_state_only: inner_states.append(x) From 6d23cc7e7c32d1a6aa1d2d4a4c94abe50c980126 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH 507/707] Move checkpoint state_dict creation into Trainer (#1666) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666 Context: the checkpoint saving call stack has become a bit convoluted: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.save_state + checkpoint_utils.torch_persistent_save ``` This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.torch_persistent_save ``` This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards. Test Plan: - unit tests - trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results - `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/ Reviewed By: zhengwy888 Differential Revision: D26771146 Pulled By: myleott fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93 --- fairseq/checkpoint_utils.py | 71 ++++------------------------------ fairseq/dataclass/configs.py | 1 - fairseq/trainer.py | 67 +++++++++++++++++++++++++------- tests/test_checkpoint_utils.py | 7 ++-- tests/test_train.py | 1 + 5 files changed, 64 insertions(+), 83 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 97f22041bc..5a98dad2aa 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir - if cfg.distributed_rank == 0: + if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) @@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() - if not trainer.is_data_parallel_master: + if not trainer.should_save_checkpoint_on_current_rank: return write_timer = meters.StopwatchMeter() @@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - suffix = cfg.checkpoint_suffix or "" + suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 @@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = cfg.checkpoint_suffix + suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' @@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif cfg.model_parallel_size > 1: + elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file @@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(cfg: CheckpointConfig, obj, filename): - if cfg.write_checkpoints_asynchronously: +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: with PathManager.opena(filename, "wb") as f: _torch_persistent_save(obj, f) else: @@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f): logger.error(traceback.format_exc()) -def save_state( - filename, - cfg: FairseqConfig, - model_state_dict, - criterion, - optimizer, - lr_scheduler, - num_updates, - optim_history=None, - extra_state=None, - task=None, - **kwargs, -): - from fairseq import utils - - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, - "args": kwargs.get("args", None), - "model": model_state_dict or {}, - "optimizer_history": optim_history - + [ - { - "criterion_name": criterion.__class__.__name__, - "optimizer_name": optimizer.__class__.__name__, - "lr_scheduler_state": lr_scheduler.state_dict(), - "num_updates": num_updates, - } - ], - "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {}, - } - if utils.has_parameters(criterion): - state_dict["criterion"] = criterion.state_dict() - - if cfg is None: - cfg = state_dict["args"] - assert cfg is not None, "must provide cfg or args" - - if isinstance(cfg, DictConfig): - no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state - else: - no_save_optimizer_state = cfg.no_save_optimizer_state - if not no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - - # keep everything on CPU - state_dict = utils.move_to_cpu(state_dict) - - torch_persistent_save(cfg.checkpoint, state_dict, filename) - - def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks @@ -529,7 +474,7 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( + if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 39355b1caf..4d3c60bfd6 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass): }, ) model_parallel_size: int = II("common.model_parallel_size") - distributed_rank: int = II("distributed_training.distributed_rank") @dataclass diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 680a7ee953..45d9591d7c 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -25,6 +25,8 @@ from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import OmegaConf + logger = logging.getLogger(__name__) @@ -171,6 +173,16 @@ def use_distributed_wrapper(self) -> bool: and not self.cfg.optimization.use_bmuf ) + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + return self.cfg.checkpoint.checkpoint_suffix or "" + @property def criterion(self): if self._wrapped_criterion is None: @@ -274,25 +286,50 @@ def consolidate_optimizer(self): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg) + if OmegaConf.is_config(self.cfg) else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + } + } + if not self.cfg.checkpoint.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + return state_dict + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - if self.is_data_parallel_master: # only save one checkpoint - logger.info(f"Saving checkpoint to {filename}") - extra_state["metrics"] = metrics.state_dict() - extra_state["previous_training_time"] = self.cumulative_training_time() - checkpoint_utils.save_state( + logger.info(f"Saving checkpoint to {filename}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, filename, - self.cfg, - self.model.state_dict(), - self.get_criterion(), - self.optimizer, - self.lr_scheduler, - self.get_num_updates(), - optim_history=self._optim_history, - extra_state=extra_state, - task=self.task, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {filename}") + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 3278de6b9f..0f28222633 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -90,15 +90,14 @@ def test_prune_state_dict(self): self.assertEqual(len(ensemble[0].decoder.layers), 1) def test_torch_persistent_save_async(self): - cfg = OmegaConf.create() - cfg.dataset = OmegaConf.create() - cfg.dataset.write_checkpoints_asynchronously = True state_dict = {} filename = "async_checkpoint.pt" with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: - checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) mock_opena.assert_called_with(filename, "wb") mock_save.assert_called() diff --git a/tests/test_train.py b/tests/test_train.py index 57daa194b2..65f4683bc6 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model): "reset_lr_scheduler": False, "finetune_from_model": finetune_from_model, "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", }, "common": { "model_parallel_size": 1, From 656d7e5779a9ec4ccf0ad45d86a4ce589c597588 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH 508/707] Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667 Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) This enables fully parameter + optimizer state sharding by using FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide `--ddp-backend=fully_sharded` to enable. Other common options work out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`, etc.). This should be a drop-in replacement for the "c10d" backend. This yields pretty big speedups for small models and enables training ~13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model parallelism. This also adds a new option `--cpu-offload` that offloads the optimizer state and FP32 model copy to CPU, which is particularly useful when combined with `--optimizer=cpu_adam`. Note: after enabling this, each GPU will save a checkpoint file, since the optimizer state is sharded. Each checkpoint will contain a single shard of the optimizer state and the rank 0 checkpoint will contain the full model weights. Note: a known limitation of the current implementation is that you cannot resume training on a different world_size. This constraint will be relaxed in future iterations. Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26771144 Pulled By: myleott fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46 --- fairseq/dataclass/configs.py | 15 +++ fairseq/dataclass/constants.py | 1 + fairseq/distributed/__init__.py | 4 + .../fully_sharded_data_parallel.py | 122 ++++++++++++++++++ fairseq/models/distributed_fairseq_model.py | 21 ++- fairseq/models/fairseq_model.py | 20 ++- fairseq/models/transformer.py | 6 + fairseq/optim/cpu_adam.py | 4 + fairseq/optim/fp16_optimizer.py | 14 +- fairseq/trainer.py | 84 ++++++++++-- fairseq_cli/train.py | 15 ++- tests/test_binaries.py | 10 +- tests/test_dataset.py | 7 + 13 files changed, 292 insertions(+), 31 deletions(-) create mode 100644 fairseq/distributed/fully_sharded_data_parallel.py diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 4d3c60bfd6..5d6aee157a 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, + metadata={"help": "offload FP32 params to CPU"} + ) @dataclass diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 93bc6d03cb..faba0862fa 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum([ "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale "legacy_ddp", "no_c10d", # alias for legacy_ddp "pytorch_ddp", diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index 7f4016e38c..d0b96b734c 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .module_proxy_wrapper import ModuleProxyWrapper from .tpu_distributed_data_parallel import TPUDistributedDataParallel @@ -11,6 +12,9 @@ __all__ = [ "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", "LegacyDistributedDataParallel", "ModuleProxyWrapper", "TPUDistributedDataParallel", diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000000..9d74398325 --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch + +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + from fairscale.utils.testing import DummyProcessGroup + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": True, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + } + with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + cls = FullyShardedDataParallel + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, cls=cls, **kwargs) + else: + return module + else: + return wrap(module, cls=cls, **kwargs) + except ImportError: + return module diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ca157f06e9..3422faea74 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) # kill hung distributed jobs after a timeout - wrapped_model = DistributedTimeoutWrapper( - wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) - ) + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) return wrapped_model diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 186f3d2464..d393c02ae6 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance(module.unwrapped_module, expected_type) + else: + assert isinstance(module, expected_type) + + class BaseFairseqModel(nn.Module): """Base class for fairseq models.""" @@ -284,8 +291,9 @@ def __init__(self, encoder, decoder): self.encoder = encoder self.decoder = decoder - assert isinstance(self.encoder, FairseqEncoder) - assert isinstance(self.decoder, FairseqDecoder) + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ @@ -365,8 +373,8 @@ def __init__(self, encoders, decoders): assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) for key in self.keys: - assert isinstance(encoders[key], FairseqEncoder) - assert isinstance(decoders[key], FairseqDecoder) + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) self.models = nn.ModuleDict( { @@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel): def __init__(self, decoder): super().__init__() self.decoder = decoder - assert isinstance(self.decoder, FairseqDecoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, **kwargs): """ @@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel): def __init__(self, encoder): super().__init__() self.encoder = encoder - assert isinstance(self.encoder, FairseqEncoder) + check_type(self.encoder, FairseqEncoder) def forward(self, src_tokens, src_lengths, **kwargs): """ diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f2f36baf3e..a0a0b8dcd5 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from fairseq import utils +from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -240,6 +241,9 @@ def build_model(cls, args, task): args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: + encoder = fsdp_wrap(encoder, min_num_params=1e8) + decoder = fsdp_wrap(decoder, min_num_params=1e8) return cls(args, encoder, decoder) @classmethod @@ -386,6 +390,7 @@ def build_encoder_layer(self, args): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward_embedding( @@ -726,6 +731,7 @@ def build_decoder_layer(self, args, no_encoder_attn=False): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward( diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index fad5a64ecb..5e935df1a5 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -107,6 +107,10 @@ def __init__( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_flat_params(self): + return True + @torch.no_grad() def step(self, closure=None): loss = None diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index e0b069f172..00ea1bbb76 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -322,6 +322,10 @@ def set_lr(self, lr): def all_reduce_grads(self, module): self.fp32_optimizer.all_reduce_grads(module) + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -442,6 +446,10 @@ def zero_grad(self): else: self._multiply_factor = 1.0 + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + class MemoryEfficientFP16Optimizer( _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer @@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): - if not optimizer.supports_memory_efficient_fp16: + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 45d9591d7c..4d47d39897 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -63,15 +63,31 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): else: self.device = torch.device("cpu") + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + else: + if self.cfg.distributed_training.cpu_offload: + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.common.fp16: - self._criterion = self._criterion.half() - self._model = self._model.half() - elif cfg.common.bf16: - self._criterion = self._criterion.to(dtype=torch.bfloat16) - self._model = self._model.to(dtype=torch.bfloat16) + if cfg.distributed_training.ddp_backend != "fully_sharded": + if cfg.common.fp16: + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, @@ -171,17 +187,26 @@ def use_distributed_wrapper(self) -> bool: return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.cpu_offload ) @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - return self.is_data_parallel_master + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return True + else: + return self.is_data_parallel_master @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - return self.cfg.checkpoint.checkpoint_suffix or "" + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" @property def criterion(self): @@ -234,7 +259,20 @@ def _build_optimizer(self): ) ) - if self.cfg.common.fp16 or self.cfg.common.bf16: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.common.fp16 + ): + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " @@ -254,6 +292,16 @@ def _build_optimizer(self): logger.info("NOTE: your device may support faster training with --fp16") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + assert not self.cfg.optimization.use_bmuf, \ + "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + if self.cfg.optimization.use_bmuf: self._optimizer = optim.FairseqBMUF( self.cfg.bmuf, @@ -355,6 +403,8 @@ def load_checkpoint( # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or self.cfg.distributed_training.ddp_backend == "fully_sharded" ) if load_on_all_ranks or self.data_parallel_rank == 0: @@ -965,7 +1015,21 @@ def set_num_updates(self, num_updates): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + + def agg_norm_fn(total_norm): + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + total_norm = total_norm ** 2 + if ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ): + total_norm = distributed_utils.all_reduce( + total_norm.cuda(), group=self.data_parallel_process_group + ) + total_norm = total_norm ** 0.5 + return total_norm + + return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) def cumulative_training_time(self): if self._cumulative_training_time is None: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 80ad57acd1..d770e4e4ec 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -18,7 +18,6 @@ import torch from fairseq import ( checkpoint_utils, - distributed_utils, options, quantization_utils, tasks, @@ -27,7 +26,7 @@ from fairseq.data import iterators from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.distributed_utils import is_master +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer @@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None: utils.import_user_module(cfg.common) - if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) @@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None: assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion - model = task.build_model(cfg.model) + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) @@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None: logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {:,} (num. trained: {:,})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), ) ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 3cb98897bf..e10cc767b8 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1697,8 +1697,9 @@ def test_activation_offloading_does_not_change_metrics(self): """Neither ----checkpoint-activations nor --offload-activations should change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) offload_logs = self._train(data_dir, ["--offload-activations"]) baseline_logs = self._train(data_dir, []) @@ -1720,8 +1721,9 @@ def test_activation_checkpointing_does_not_change_metrics(self): """--checkpoint-activations should not change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) baseline_logs = self._train(data_dir, []) assert len(baseline_logs) == len(ckpt_logs) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9fb69a5f77..a3e3970028 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from typing import Sequence @@ -20,6 +21,12 @@ def sample(id: int, length: int): class TestDataset(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + def test_round_robin_zip_datasets(self): long_dataset = lang_pair_dataset([10, 9, 8, 11]) short_dataset = lang_pair_dataset([11, 9]) From 73886ac228f8f0368871237f7498ec8b07444322 Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Thu, 4 Mar 2021 14:20:00 -0800 Subject: [PATCH 509/707] Refactor FairseqSimulSTAgent Summary: 1. In fblearner flow we are dumping cmvn stats into json file (e.g. f253830726) Previously there's only --config option taking .npz path from a yaml file, and it's the only usage for the config. This diff adds an option --global-stats to import from json. 2. Inherit FairseqSimulSTAgent from nn.Module instead of SpeechAgent whose root class is object to prepare for scripting methods. Copy over / simplify all the necessary methods from SpeechAgent/Agent. Reviewed By: jmp84 Differential Revision: D26800957 fbshipit-source-id: 74be527f8473c13405a60bb16ce6da5a7dc0b888 --- .../agents/fairseq_simul_st_agent.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index f944203785..2b5fdc2d3f 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -1,19 +1,20 @@ import math import os - +import json import numpy as np import torch import torchaudio.compliance.kaldi as kaldi import yaml from fairseq import checkpoint_utils, tasks +from fairseq.file_io import PathManager try: from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS - from simuleval.agents import SpeechAgent - from simuleval.states import ListEntry + from simuleval.states import ListEntry, SpeechStates except ImportError: print("Please install simuleval 'pip install simuleval'") +from torch import nn SHIFT_SIZE = 10 WINDOW_SIZE = 25 @@ -112,12 +113,12 @@ def info(self): } -class FairseqSimulSTAgent(SpeechAgent): +class FairseqSimulSTAgent(nn.Module): speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size def __init__(self, args): - super().__init__(args) + super().__init__() self.eos = DEFAULT_EOS @@ -136,13 +137,18 @@ def __init__(self, args): self.model.decoder.layers[0].encoder_attn.pre_decision_ratio ) - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.BaseLoader) + args.global_cmvn = None + if args.config: + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.BaseLoader) - if "global_cmvn" in config: - args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) - else: - args.global_cmvn = None + if "global_cmvn" in config: + args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + + if args.global_stats: + with PathManager.open(args.global_stats, "r") as f: + global_cmvn = json.loads(f.read()) + self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]} self.feature_extractor = OnlineFeatureExtractor(args) @@ -152,6 +158,13 @@ def __init__(self, args): torch.set_grad_enabled(False) + def build_states(self, args, client, sentence_id): + # Initialize states here, for example add customized entry to states + # This function will be called at beginning of every new sentence + states = SpeechStates(args, client, sentence_id, self) + self.initialize_states(states) + return states + def to_device(self, tensor): if self.gpu: return tensor.cuda() @@ -165,8 +178,10 @@ def add_args(parser): help='path to your pretrained model.') parser.add_argument("--data-bin", type=str, required=True, help="Path of data binary") - parser.add_argument("--config", type=str, required=True, + parser.add_argument("--config", type=str, default=None, help="Path to config yaml file") + parser.add_argument("--global-stats", type=str, default=None, + help="Path to json file containing cmvn stats") parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", help="Subword splitter type for target text") parser.add_argument("--tgt-splitter-path", type=str, default=None, From 7c95746a7e5e4a087399d186590815e45ae775c8 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Thu, 4 Mar 2021 17:17:11 -0800 Subject: [PATCH 510/707] fix bug on converting stereo audio in audio_utils.py Summary: Fix bug on converting stereo audio in audio_utils.py - Github issue: https://github.com/pytorch/fairseq/issues/3303 Reviewed By: jmp84 Differential Revision: D26825964 fbshipit-source-id: 26905e71540bc52e98d76996b199ac0fbe78357b --- fairseq/data/audio/audio_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index f0e75b1d65..f8cc80f5e2 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -63,8 +63,8 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr # Mono channel: D -> 1 x D waveform = waveform.unsqueeze(0) else: - # Merge multiple channels to one: C x D -> 1 x D - waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) + # Merge multiple channels to one: D x C -> 1 x D + waveform, _ = ta_sox.apply_effects_tensor(waveform.T, sample_rate, [['channels', '1']]) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate From 16c1a200f87a2adb6395e353345c19bbe990d1dd Mon Sep 17 00:00:00 2001 From: sarapapi <57095209+sarapapi@users.noreply.github.com> Date: Mon, 8 Mar 2021 14:10:29 -0800 Subject: [PATCH 511/707] Fix Global CMVN path of MustC data preprocessing (#3307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fix a typo in gcmv_path given for config yaml generation (actual: gcvmn_cvmn_path, correct: gcmvn_path) ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3307 Reviewed By: jmp84 Differential Revision: D26826231 Pulled By: kahne fbshipit-source-id: 6b60f2a8a8b4ba1c0c088299a08ef04fdfe870a8 --- examples/speech_to_text/prep_mustc_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 4e410bcb18..45fd43533d 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -179,7 +179,7 @@ def process(args): yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", cmvn_type=args.cmvn_type, - gcmvn_cmvn_path=( + gcmvn_path=( cur_root / "gcmvn.npz" if args.cmvn_type == "global" else None ), From 00d5b7adbeaf64e02c53a591d637efe4c8cad923 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 9 Mar 2021 06:28:23 -0800 Subject: [PATCH 512/707] Add README/tutorial for Fully Sharded Data Parallel (#3327) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3327 Reviewed By: sshleifer Differential Revision: D26899416 Pulled By: myleott fbshipit-source-id: bbb493a5c4e0a51f3b26fe8f94e3962b6206d6f6 --- .github/workflows/build.yml | 3 +- README.md | 11 +- .../fully_sharded_data_parallel/README.md | 164 ++++++++++++++++++ .../fully_sharded_data_parallel.py | 18 +- fairseq/models/fairseq_model.py | 5 +- fairseq/models/roberta/model.py | 26 ++- fairseq/models/transformer.py | 40 ++++- fairseq/models/transformer_lm.py | 121 +++++++++++-- fairseq/trainer.py | 16 +- 9 files changed, 363 insertions(+), 41 deletions(-) create mode 100644 examples/fully_sharded_data_parallel/README.md diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0af8bad95d..105c42a503 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,8 @@ jobs: - name: Install optional test requirements run: | - python -m pip install fairscale iopath transformers pyarrow + python -m pip install iopath transformers pyarrow + python -m pip install git+https://github.com/facebookresearch/fairscale.git@master - name: Lint with flake8 run: | diff --git a/README.md b/README.md index 5fedac7eec..839dd8e1de 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,9 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) +* February 2021 [Added LASER training code](examples/laser/README.md) +* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) * December 2020: [GottBERT model and code released](examples/gottbert/README.md) * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) @@ -68,14 +71,14 @@ We provide reference implementations of various sequence modeling papers: * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) * October 2020: [Added CRISS models and code](examples/criss/README.md) + +<details><summary>Previous updates</summary><p> + * September 2020: [Added Linformer code](examples/linformer/README.md) * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - -<details><summary>Previous updates</summary><p> - * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) * April 2020: [Quant-Noise code released](examples/quant_noise/README.md) @@ -108,6 +111,8 @@ We provide reference implementations of various sequence modeling papers: * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration +* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) +* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) with a convenient `torch.hub` interface: diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md new file mode 100644 index 0000000000..bc98670968 --- /dev/null +++ b/examples/fully_sharded_data_parallel/README.md @@ -0,0 +1,164 @@ +# Fully Sharded Data Parallel (FSDP) + +## Overview +Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and +[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel +training can be made significantly more efficient by sharding the model +parameters and optimizer state across data parallel workers. These ideas are +encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided +by [fairscale](https://github.com/facebookresearch/fairscale/). + +Compared to PyTorch DDP: +* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training) +* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs +* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass +* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs + +FSDP is fully supported in fairseq via the following new arguments: +* `--ddp-backend=fully_sharded`: enables full sharding via FSDP +* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`) +* `--no-reshard-after-forward`: increases training speed for some models and is similar to ZeRO stage 2 +* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal + +<details><summary>Limitations</summary><p> + +FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP): +* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.) +* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of these and other limitations. + +</p></details> + +<details><summary>How it works</summary><p> + +<img width="800" alt="Fully Sharded Data Parallel" src="https://user-images.githubusercontent.com/231798/110406775-c2de0000-8050-11eb-9718-fbfc4510a76a.png"> + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of how FSDP works. + +</p></details> + +## Example usage + +The following examples illustrate how to train a very large language model with +13 billion parameters on 1 GPU by offloading parameters and optimizer states to +CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs. + +These examples use the WikiText-103 dataset for demonstration purposes, but +in practice a much larger dataset will be needed to achieve good results. +Follow the [instructions here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data) +to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary. + +### 13B params on 1 V100 GPU (with CPU offloading) + +The following command trains a 13B parameter GPT-3 model on a single V100 GPU +using the `--cpu-offload` feature to offload parameters and optimizer states to +CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the +`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)), +which further saves memory in exchange for a small increase in computation. + +Requirements: +- You'll need 32GB of GPU memory and 256GB of system memory. +- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command. + +Some notes: +- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow. +- The `--cpu-offload` feature requires training in mixed precision (`--fp16`). +- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading. +- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`). + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 + +# Example output: +# (...) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +# (...) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +# (...) +# Adam Optimizer #0 is created with AVX2 arithmetic capability. +# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +# (...) +# 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +# 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +# 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +# 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +# 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +# 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +# 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +# 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +# 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +# 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +# 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +# 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +# 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +# 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding) + +FSDP can also shard the parameters and optimizer states across multiple GPUs, +reducing memory requirements significantly. On 8 GPUs, sharding enables +training the same 13B parameter model *without offloading the parameters to +CPU*. However, without CPU offloading we'd only be able to fit a batch size of +1 per GPU, which would cause training speed to suffer. + +We obtain the best performance on 8 GPUs by combining full sharding and CPU +offloading. The following command trains the same 13B parameter GPT-3 model as +before on 8 GPUs; training speed increases from ~310 -> ~3200 words per second. + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 + +# Example output: +# (...) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +# (...) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +# (...) +# Adam Optimizer #0 is created with AVX2 arithmetic capability. +# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +# (...) +# 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +# 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +# 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +# 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +# 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +# 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +# 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +# 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +# 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +# 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +# 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +# 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +# 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +# 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 9d74398325..9c290b3fda 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -43,6 +43,13 @@ def __init__(self, *args, use_sharded_state: bool = False, **kwargs): super().__init__(*args, **kwargs) self.use_sharded_state = use_sharded_state + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + def state_dict(self, destination=None, prefix='', keep_vars=False): if self.use_sharded_state: return super().local_state_dict( @@ -94,7 +101,11 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, } - with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=use_sharded_state, + **fsdp_config, + ): yield @@ -109,14 +120,13 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): """ try: from fairscale.nn import wrap - cls = FullyShardedDataParallel if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) if num_params >= min_num_params: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) else: return module else: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) except ImportError: return module diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index d393c02ae6..171a8a40f1 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -29,9 +29,10 @@ def check_type(module, expected_type): if hasattr(module, "unwrapped_module"): - assert isinstance(module.unwrapped_module, expected_type) + assert isinstance(module.unwrapped_module, expected_type), \ + f"{type(module.unwrapped_module)} != {expected_type}" else: - assert isinstance(module, expected_type) + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" class BaseFairseqModel(nn.Module): diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index a2a40ba6e2..c79d4faf79 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -18,7 +18,7 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import TransformerEncoder +from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder from fairseq.modules import LayerNorm from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -122,6 +122,11 @@ def add_args(parser): action="store_true", help="(re-)register and load heads when loading checkpoints", ) + parser.add_argument( + "--untie-weights-roberta", + action="store_true", + help="Untie weights between embeddings and classifiers in RoBERTa", + ) # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument( "--encoder-layerdrop", @@ -157,17 +162,26 @@ def add_args(parser): default=0, help="scalar quantization noise and scalar quantization at training time", ) - parser.add_argument( - "--untie-weights-roberta", - action="store_true", - help="Untie weights between embeddings and classifiers in RoBERTa", - ) + # args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020) parser.add_argument( "--spectral-norm-classification-head", action="store_true", default=False, help="Apply spectral normalization on the classification head", ) + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + "--min-params-to-wrap", + type=int, + metavar="D", + default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes." + ) + ) @classmethod def build_model(cls, args, task): diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index a0a0b8dcd5..d39e9ec7ed 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -36,6 +36,9 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + @register_model("transformer") class TransformerModel(FairseqEncoderDecoderModel): """ @@ -191,6 +194,16 @@ def add_args(parser): help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + 'minimum number of params for a layer to be wrapped with FSDP() when ' + 'training with --ddp-backend=fully_sharded. Smaller values will ' + 'improve memory efficiency, but may make torch.distributed ' + 'communication less efficient due to smaller input sizes.' + ) + ) # fmt: on @classmethod @@ -242,8 +255,11 @@ def build_model(cls, args, task): encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) if not args.share_all_embeddings: - encoder = fsdp_wrap(encoder, min_num_params=1e8) - decoder = fsdp_wrap(decoder, min_num_params=1e8) + min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod @@ -387,10 +403,16 @@ def __init__(self, args, dictionary, embed_tokens): def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # checkpointing requires alignment to FSDP wrap boundaries + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( @@ -728,10 +750,16 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # checkpointing requires alignment to FSDP wrap boundaries + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index f12470d033..09c99b96f6 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -14,7 +14,9 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.transformer import ( + DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder +) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from omegaconf import II @@ -126,15 +128,6 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) - decoder_layerdrop: float = field( - default=0.0, metadata={"help": "LayerDrop probability for decoder"} - ) - decoder_layers_to_keep: Optional[str] = field( - default=None, - metadata={ - "help": "which layers to *keep* when pruning as a comma-separated list" - }, - ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) @@ -148,6 +141,17 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) + # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + # config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -156,13 +160,25 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=8, metadata={"help": "block size of quantization noise at training time"}, ) - # TODO common var add to parent quant_noise_scalar: float = field( default=0.0, metadata={ "help": "scalar quantization noise and scalar quantization at training time" }, ) + # config for Fully Sharded Data Parallel (FSDP) training + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": ( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes." + ) + } + ) + # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") @@ -289,7 +305,7 @@ def base_lm_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) @@ -428,3 +444,84 @@ def transformer_lm_gpt2_big(args): args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) + + +def base_gpt3_architecture(args): + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") +def transformer_lm_gpt3_small(args): + # 125M params + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") +def transformer_lm_gpt3_medium(args): + # 350M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") +def transformer_lm_gpt3_large(args): + # 760M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") +def transformer_lm_gpt3_xl(args): + # 1.3B params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 24) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") +def transformer_lm_gpt3_2_7(args): + # 2.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") +def transformer_lm_gpt3_6_7(args): + # 6.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") +def transformer_lm_gpt3_13(args): + # 13B params + args.decoder_layers = getattr(args, "decoder_layers", 40) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 40) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") +def transformer_lm_gpt3_175(args): + # 175B params + args.decoder_layers = getattr(args, "decoder_layers", 96) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 12288) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 96) + base_gpt3_architecture(args) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4d47d39897..9435558157 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1017,15 +1017,17 @@ def set_num_updates(self, num_updates): def clip_grad_norm(self, clip_norm): def agg_norm_fn(total_norm): - if self.cfg.distributed_training.ddp_backend == "fully_sharded": - total_norm = total_norm ** 2 - if ( + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and ( self.data_parallel_process_group is not None or torch.distributed.is_initialized() - ): - total_norm = distributed_utils.all_reduce( - total_norm.cuda(), group=self.data_parallel_process_group - ) + ) + ): + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) total_norm = total_norm ** 0.5 return total_norm From c6006678261bf5d52e2c744508b5ddd306cafebd Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 9 Mar 2021 09:38:01 -0800 Subject: [PATCH 513/707] Update README for Fully Sharded Data Parallel (#3331) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3331 Reviewed By: sshleifer Differential Revision: D26912554 Pulled By: myleott fbshipit-source-id: b45a161fbd52a12da13d7e011d562d35a5b5a1a7 --- .../fully_sharded_data_parallel/README.md | 137 ++++++++++-------- fairseq/models/roberta/model.py | 4 +- fairseq/models/transformer.py | 11 +- fairseq/models/transformer_lm.py | 4 +- fairseq/trainer.py | 29 ++-- 5 files changed, 104 insertions(+), 81 deletions(-) diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md index bc98670968..d620f0e4f1 100644 --- a/examples/fully_sharded_data_parallel/README.md +++ b/examples/fully_sharded_data_parallel/README.md @@ -17,7 +17,7 @@ Compared to PyTorch DDP: FSDP is fully supported in fairseq via the following new arguments: * `--ddp-backend=fully_sharded`: enables full sharding via FSDP * `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`) -* `--no-reshard-after-forward`: increases training speed for some models and is similar to ZeRO stage 2 +* `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2 * other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal <details><summary>Limitations</summary><p> @@ -59,11 +59,13 @@ CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the `--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)), which further saves memory in exchange for a small increase in computation. -Requirements: -- You'll need 32GB of GPU memory and 256GB of system memory. +**Requirements:** +- Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master` +- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model. +- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7` - We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command. -Some notes: +**Notes:** - The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow. - The `--cpu-offload` feature requires training in mixed precision (`--fp16`). - Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading. @@ -79,48 +81,54 @@ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ --max-update 10 --no-save --log-format json --log-interval 1 +``` + +<details><summary>Example output</summary><p> -# Example output: -# (...) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) -# (...) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 -# (...) -# Adam Optimizer #0 is created with AVX2 arithmetic capability. -# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 -# (...) -# 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} -# 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} -# 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 -# 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 -# 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} -# 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} -# 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} -# 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} -# 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} -# 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} -# 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} -# 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} -# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 -# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset -# 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} -# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) -# 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} -# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds ``` +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +</p></details> ### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding) FSDP can also shard the parameters and optimizer states across multiple GPUs, -reducing memory requirements significantly. On 8 GPUs, sharding enables +reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables training the same 13B parameter model *without offloading the parameters to CPU*. However, without CPU offloading we'd only be able to fit a batch size of 1 per GPU, which would cause training speed to suffer. We obtain the best performance on 8 GPUs by combining full sharding and CPU offloading. The following command trains the same 13B parameter GPT-3 model as -before on 8 GPUs; training speed increases from ~310 -> ~3200 words per second. +before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310 +words per second to ~3200 words per second. ```bash OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ @@ -132,33 +140,38 @@ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ --max-update 10 --no-save --log-format json --log-interval 1 +``` + +<details><summary>Example output</summary><p> -# Example output: -# (...) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) -# (...) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 -# (...) -# Adam Optimizer #0 is created with AVX2 arithmetic capability. -# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 -# (...) -# 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} -# 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} -# 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 -# 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 -# 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} -# 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} -# 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} -# 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} -# 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} -# 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} -# 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} -# 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} -# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 -# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset -# 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} -# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) -# 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} -# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds ``` +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` + +</p></details> diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index c79d4faf79..5d2ed4902d 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -179,7 +179,9 @@ def add_args(parser): "minimum number of params for a layer to be wrapped with FSDP() when " "training with --ddp-backend=fully_sharded. Smaller values will " "improve memory efficiency, but may make torch.distributed " - "communication less efficient due to smaller input sizes." + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." ) ) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index d39e9ec7ed..297807c31a 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -201,7 +201,9 @@ def add_args(parser): 'minimum number of params for a layer to be wrapped with FSDP() when ' 'training with --ddp-backend=fully_sharded. Smaller values will ' 'improve memory efficiency, but may make torch.distributed ' - 'communication less efficient due to smaller input sizes.' + 'communication less efficient due to smaller input sizes. This option ' + 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' + '--offload-activations are passed.' ) ) # fmt: on @@ -258,6 +260,7 @@ def build_model(cls, args, task): min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @@ -407,7 +410,8 @@ def build_encoder_layer(self, args): if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # checkpointing requires alignment to FSDP wrap boundaries + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 @@ -754,7 +758,8 @@ def build_decoder_layer(self, args, no_encoder_attn=False): if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # checkpointing requires alignment to FSDP wrap boundaries + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 09c99b96f6..fca9470e5e 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -174,7 +174,9 @@ class TransformerLanguageModelConfig(FairseqDataclass): "minimum number of params for a layer to be wrapped with FSDP() when " "training with --ddp-backend=fully_sharded. Smaller values will " "improve memory efficiency, but may make torch.distributed " - "communication less efficient due to smaller input sizes." + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." ) } ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 9435558157..dcf5305455 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1017,21 +1017,22 @@ def set_num_updates(self, num_updates): def clip_grad_norm(self, clip_norm): def agg_norm_fn(total_norm): - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and ( - self.data_parallel_process_group is not None - or torch.distributed.is_initialized() - ) - ): - total_norm = total_norm.cuda().float() ** 2 - total_norm = distributed_utils.all_reduce( - total_norm, group=self.data_parallel_process_group - ) - total_norm = total_norm ** 0.5 - return total_norm + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) + return total_norm ** 0.5 - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) + should_agg_norm = ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ) + ) + return self.optimizer.clip_grad_norm( + clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None + ) def cumulative_training_time(self): if self._cumulative_training_time is None: From 05255f96410e5b1eaf3bf59b767d5b4b7e2c3a35 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Tue, 9 Mar 2021 16:26:05 -0800 Subject: [PATCH 514/707] update audio_utils and fix mTEDx example Summary: update audio_utils and fix mTEDx example - Updated `audio_utils` - Added support for OGG Vorbis (the only supported lossy compressed format) - Added a separate `convert_to_mono()` helper function - Updated `get_waveform()` - added new arguments `frames` and `start` for reading part of audios - added new argument `mono` for auto conversion to mono-channel audio - unified returned waveform shape to channels x length (same as torchaudio default) - Updated mTEDx and MUST-C data prep scripts - Replaced `torchaudio.info()` with `soundfile.info()` (the latter is faster and the former has incompatible interface between <0.8 and the latest 0.8) - Replaced `torchaudio.load()` with `get_waveform` for auto conversion to mono channel Reviewed By: jmp84 Differential Revision: D26901114 fbshipit-source-id: fa9560c9714d51a91157d5141564574d4eee454d --- examples/speech_to_text/data_utils.py | 12 ++- examples/speech_to_text/docs/mtedx_example.md | 2 +- examples/speech_to_text/docs/mustc_example.md | 2 +- examples/speech_to_text/prep_mtedx_data.py | 13 +-- examples/speech_to_text/prep_mustc_data.py | 13 +-- fairseq/data/audio/audio_utils.py | 83 ++++++++++++++----- fairseq/data/audio/speech_to_text_dataset.py | 5 +- 7 files changed, 89 insertions(+), 41 deletions(-) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index fa0d459611..3f96ffc427 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -14,7 +14,10 @@ import numpy as np import pandas as pd import sentencepiece as sp -from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank +from fairseq.data.audio.audio_utils import ( + _convert_to_mono, _get_kaldi_fbank, _get_torchaudio_fbank +) +import torch from tqdm import tqdm @@ -66,7 +69,7 @@ def gen_vocab( def extract_fbank_features( - waveform, + waveform: torch.FloatTensor, sample_rate: int, output_path: Optional[Path] = None, n_mel_bins: int = 80, @@ -75,8 +78,9 @@ def extract_fbank_features( if output_path is not None and output_path.is_file() and not overwrite: return - _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers - _waveform = _waveform.squeeze().numpy() + _waveform = _convert_to_mono(waveform, sample_rate) + _waveform = _waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + _waveform = _waveform.numpy() features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) if features is None: diff --git a/examples/speech_to_text/docs/mtedx_example.md b/examples/speech_to_text/docs/mtedx_example.md index c0e17db9a2..25b4556aff 100644 --- a/examples/speech_to_text/docs/mtedx_example.md +++ b/examples/speech_to_text/docs/mtedx_example.md @@ -11,7 +11,7 @@ with translations to a subset of 5 target languages. `${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with ```bash # additional Python packages for S2T data processing/model training -pip install pandas torchaudio sentencepiece +pip install pandas torchaudio soundfile sentencepiece # Generate TSV manifests, features, vocabulary # and configuration for each language diff --git a/examples/speech_to_text/docs/mustc_example.md b/examples/speech_to_text/docs/mustc_example.md index 7628dc77ef..79df0aafdc 100644 --- a/examples/speech_to_text/docs/mustc_example.md +++ b/examples/speech_to_text/docs/mustc_example.md @@ -11,7 +11,7 @@ `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with ```bash # additional Python packages for S2T data processing/model training -pip install pandas torchaudio sentencepiece +pip install pandas torchaudio soundfile sentencepiece # Generate TSV manifests, features, vocabulary # and configuration for each language diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py index 6c37398fcc..34b1c398c8 100644 --- a/examples/speech_to_text/prep_mtedx_data.py +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -14,7 +14,7 @@ from typing import Tuple import pandas as pd -import torchaudio +import soundfile as sf from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, @@ -25,10 +25,12 @@ load_df_from_tsv, save_df_to_tsv, ) -from torch import Tensor +import torch from torch.utils.data import Dataset from tqdm import tqdm +from fairseq.data.audio.audio_utils import get_waveform + log = logging.getLogger(__name__) @@ -73,7 +75,7 @@ def __init__(self, root: str, lang: str, split: str) -> None: for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_filename = wav_filename.replace(".wav", ".flac") wav_path = wav_root / wav_filename - sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + sample_rate = sf.info(wav_path.as_posix()).samplerate seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) @@ -93,9 +95,10 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str, str]: + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str, str]: wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id def __len__(self) -> int: diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 45fd43533d..0ee204e651 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -15,7 +15,7 @@ import numpy as np import pandas as pd -import torchaudio +import soundfile as sf from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, @@ -27,10 +27,12 @@ save_df_to_tsv, cal_gcmvn_stats, ) -from torch import Tensor +import torch from torch.utils.data import Dataset from tqdm import tqdm +from fairseq.data.audio.audio_utils import get_waveform + log = logging.getLogger(__name__) @@ -71,7 +73,7 @@ def __init__(self, root: str, lang: str, split: str) -> None: self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = wav_root / wav_filename - sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + sample_rate = sf.info(wav_path.as_posix()).samplerate seg_group = sorted(_seg_group, key=lambda x: x["offset"]) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) @@ -90,9 +92,10 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str]: wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, utt_id def __len__(self) -> int: diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index f8cc80f5e2..ddd5642c7e 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,35 +1,80 @@ -import os.path as op +from pathlib import Path from typing import BinaryIO, Optional, Tuple, Union import numpy as np +import torch + + +SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} + + +def _convert_to_mono( + waveform: torch.FloatTensor, sample_rate: int +) -> torch.FloatTensor: + if waveform.shape[0] > 1: + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError( + "Please install torchaudio to convert multi-channel audios" + ) + effects = [['channels', '1']] + return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] + return waveform + + +def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray: + if waveform.shape[0] > 1: + _waveform = torch.from_numpy(waveform) + return _convert_to_mono(_waveform, sample_rate).numpy() + return waveform def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization=True + path_or_fp: Union[str, BinaryIO], normalization=True, mono=True, + frames=-1, start=0, always_2d=True ) -> Tuple[np.ndarray, int]: - """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. + """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. Args: path_or_fp (str or BinaryIO): the path or file-like object normalization (bool): Normalize values to [-1, 1] (Default: True) + mono (bool): convert multi-channel audio to mono-channel one + frames (int): the number of frames to read. (-1 for reading all) + start (int): Where to start reading. A negative value counts from the end. + always_2d (bool): always return 2D array even for mono-channel audios + Returns: + waveform (numpy.ndarray): 1D or 2D waveform (channels x length) + sample_rate (float): sample rate """ if isinstance(path_or_fp, str): - ext = op.splitext(op.basename(path_or_fp))[1] - if ext not in {".flac", ".wav"}: + ext = Path(path_or_fp).suffix + if ext not in SF_AUDIO_FILE_EXTENSIONS: raise ValueError(f"Unsupported audio format: {ext}") try: import soundfile as sf except ImportError: - raise ImportError("Please install soundfile to load WAV/FLAC file") + raise ImportError( + "Please install soundfile to load WAV/FLAC/OGG Vorbis audios" + ) - waveform, sample_rate = sf.read(path_or_fp, dtype="float32") + waveform, sample_rate = sf.read( + path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start + ) + waveform = waveform.T # T x C -> C x T + if mono and waveform.shape[0] > 1: + waveform = convert_to_mono(waveform, sample_rate) if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers + if not always_2d: + waveform = waveform.squeeze(axis=0) return waveform, sample_rate -def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_kaldi_fbank( + waveform: np.ndarray, sample_rate: int, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: from kaldi.feat.mel import MelBanksOptions @@ -45,27 +90,19 @@ def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: opts.mel_opts = mel_opts opts.frame_opts = frame_opts fbank = Fbank(opts=opts) - features = fbank.compute(Vector(waveform), 1.0).numpy() + features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() return features except ImportError: return None -def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_torchaudio_fbank( + waveform: np.ndarray, sample_rate, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: - import torch import torchaudio.compliance.kaldi as ta_kaldi - import torchaudio.sox_effects as ta_sox - waveform = torch.from_numpy(waveform) - if len(waveform.shape) == 1: - # Mono channel: D -> 1 x D - waveform = waveform.unsqueeze(0) - else: - # Merge multiple channels to one: D x C -> 1 x D - waveform, _ = ta_sox.apply_effects_tensor(waveform.T, sample_rate, [['channels', '1']]) - features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) @@ -79,11 +116,11 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: (faster CPP implementation) to TorchAudio (Python implementation). Note that Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the waveform should not be normalized.""" - sound, sample_rate = get_waveform(path_or_fp, normalization=False) + waveform, sample_rate = get_waveform(path_or_fp, normalization=False) - features = _get_kaldi_fbank(sound, sample_rate, n_bins) + features = _get_kaldi_fbank(waveform, sample_rate, n_bins) if features is None: - features = _get_torchaudio_fbank(sound, sample_rate, n_bins) + features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) if features is None: raise ImportError( "Please install pyKaldi or torchaudio to enable " diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 39d22c7a5e..c6c64db084 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -153,7 +153,8 @@ def get_features_or_waveform_from_uncompressed_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_flac_or_wav_data(data): - features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f) + features_or_waveform = \ + get_waveform(f, always_2d=False)[0] if need_waveform else get_fbank(f) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform @@ -178,7 +179,7 @@ def get_features_or_waveform(path: str, need_waveform=False): if len(extra) == 0: if need_waveform: - return get_waveform(_path) + return get_waveform(_path, always_2d=False) return get_features_from_npy_or_audio(_path) elif len(extra) == 2: extra = [int(i) for i in extra] From d031611ce49cb231653cf9246667ac237cbbdaff Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Wed, 10 Mar 2021 20:32:49 -0800 Subject: [PATCH 515/707] Update simul trans doc (#1683) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1683 Reviewed By: jmp84 Differential Revision: D26914869 Pulled By: xutaima fbshipit-source-id: a5d2efdcff1852e56304e77838840b3aad5124b0 --- .../docs/simulst_mustc_example.md | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 0144fcb766..3452806a1c 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -1,6 +1,6 @@ # Simultaneous Speech Translation (SimulST) on MuST-C -This is an instruction of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). +This is a tutorial of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). [MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. @@ -14,18 +14,21 @@ pip install pandas torchaudio sentencepiece # Generate TSV manifests, features, vocabulary, # global cepstral and mean estimation, # and configuration for each language +cd fairseq + python examples/speech_to_text/prep_mustc_data.py \ --data-root ${MUSTC_ROOT} --task asr \ --vocab-type unigram --vocab-size 10000 \ --cmvn-type global + python examples/speech_to_text/prep_mustc_data.py \ --data-root ${MUSTC_ROOT} --task st \ - --vocab-type unigram --vocab-size 10000 + --vocab-type unigram --vocab-size 10000 \ --cmvn-type global ``` ## ASR Pretraining -We just need a pretrained offline ASR model +We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}` ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ @@ -34,21 +37,22 @@ fairseq-train ${MUSTC_ROOT}/en-de \ --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 ``` +A pretrained ASR checkpoint can be downloaded [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1_en_de_pretrained_asr) ## Simultaneous Speech Translation Training ### Wait-K with fixed pre-decision module Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks. Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and -a wait-3 policy model -``` +a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` +```bash fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ --save-dir ${ST_SAVE_DIR} --num-workers 8 \ --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ --criterion label_smoothed_cross_entropy \ --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ - --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \ --task speech_to_text \ --arch convtransformer_simul_trans_espnet \ --simul-type waitk_fixed_pre_decision \ @@ -76,7 +80,9 @@ a wait-3 policy model The source file is a list of paths of audio files, while target file is the corresponding translations. ``` -pip install simuleval +git clone https://github.com/facebookresearch/SimulEval.git +cd SimulEval +pip install -e . simuleval \ --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -89,7 +95,24 @@ simuleval \ --scores ``` -A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin). +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). + +The output should be similar as follow: +```bash +{ + "Quality": { + "BLEU": 12.79214535384013 + }, + "Latency": { + "AL": 1669.5778120018108, + "AL_CA": 2077.9027656104813, + "AP": 0.7652936521983029, + "AP_CA": 0.8891561507382866, + "DAL": 2028.1566141735727, + "DAL_CA": 2497.336430059716 + } +} +``` The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. From 2235f86b40da5915cd801c4f2f29de4c17c9804b Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Fri, 12 Mar 2021 12:29:40 -0800 Subject: [PATCH 516/707] PlasmaView: don't materialize array in memory (#1645) Summary: ### Changes: - `PlasmaArray` saves the underlying data to `self.array`, `PlasmaView` never does that, instead it fetches the data from `plasma_store` shared memory when it is needed. - `PlasmaArray` starts a new, ephemeral plasma_store and puts a new array in it when it is pickled. If `--use-plasma-view`, there is one server started before `spawn` and arrays are only put into it once, in `PlasmaArray.__init__` to accommodate this. - user can now pass `--plasma-path` to explicitly control where server is started. - We now make plasma keys based on `(split_path, (block_size, document_sep_len, str(break_mode), len(dataset)))`, so two jobs sharing plasma server but with different datasets, or same dataset but different clargs, will read each the other's array. ### Results [pre March 1] This saves some CPU memory (5-15%), according to both `psutil` and `psrecord`: here we run base_cmd (below) with num_workers=0,2,8, 2 GPUS and collect the logs. `branch` refers to `--use-plasma-view`, `master` uses `PlasmaArray` ``` +-------------------------+----------------+---------+-------+ | setting | cpu_mem_used | wps | ppl | +=========================+================+=========+=======+ | branch_nw0_gpu2_ddm.log | 12 | 55143.2 | 429.1 | +-------------------------+----------------+---------+-------+ | branch_nw2_gpu2_ddm.log | 13.67 | 43377.6 | 429.1 | +-------------------------+----------------+---------+-------+ | branch_nw8_gpu2_ddm.log | 18.36 | 53019.9 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw0_gpu2_ddm.log | 12.26 | 56733 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw2_gpu2_ddm.log | 14.58 | 53337.9 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw8_gpu2_ddm.log | 21.1 | 53217.2 | 429.1 | +-------------------------+----------------+---------+-------+ ``` ### Replication 1) get this branch ```bash git fetch && git checkout share-plasma-server ``` 2) Train tiny model and save logs ```bash base_cmd () { fairseq-train --fp16 /private/home/sshleifer/data-bin/stories_mmap \ --task language_modeling \ --arch transformer_lm_gpt2_tiny \ --sample-break-mode complete --tokens-per-sample 512 \ --optimizer adam --clip-norm 0.0 --lr 0.0005 \ --batch-size 1 \ --max-update 200 --max-epoch 1 \ --log-format simple --log-interval 100 \ --restore-file x.pt --no-save \ --skip-invalid-size-inputs-valid-test --disable-validation $@ } USE_LOCK=1 CUDA_VISIBLE_DEVICES=0,1 base_cmd --num-workers 0 --use-plasma-view | tee branch_nw0_gpu2_ddm.log ``` ### TODO: - [x] test larger dataset - [x] make it optional, cleanup - [x] 1 GPU - [x] unit-tests - [x] ask hashing Q on stackoverflow https://stackoverflow.com/questions/66354598/deterministic-method-to-hash-np-array-int - [ ] measure whether `PlasmaArray` disable for small array's logic helps - [ x] test with fb_sweep - [ x] measure 4 GPU savings Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1645 Test Plan: Read github PR description: https://github.com/fairinternal/fairseq-py/pull/1645 Reviewed By: myleott Differential Revision: D26630365 Pulled By: sshleifer fbshipit-source-id: b0c4163fbc97a7aefb116de70265fba11f6d7b42 --- fairseq/data/plasma_utils.py | 134 +++++++++++++++++++++++++--- fairseq/data/token_block_dataset.py | 67 +++++++++----- fairseq/dataclass/configs.py | 83 +++++++---------- fairseq/tasks/language_modeling.py | 11 ++- fairseq/trainer.py | 5 +- fairseq_cli/train.py | 9 +- tests/test_plasma_utils.py | 127 ++++++++++++++++++++++++++ 7 files changed, 343 insertions(+), 93 deletions(-) create mode 100644 tests/test_plasma_utils.py diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index f4bb6472d7..b9fab3b739 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -3,11 +3,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + import subprocess +import json import tempfile +import hashlib +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False -class PlasmaArray(object): +class PlasmaArray: """ Wrapper around numpy arrays that automatically moves the data to shared memory upon serialization. This is particularly helpful when passing numpy @@ -31,12 +43,7 @@ def __init__(self, array): @property def plasma(self): if self._plasma is None and not self.disable: - try: - import pyarrow.plasma as plasma - - self._plasma = plasma - except ImportError: - self._plasma = None + self._plasma = plasma return self._plasma def start_server(self): @@ -47,13 +54,7 @@ def start_server(self): self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name self._server = subprocess.Popen( - [ - "plasma_store", - "-m", - str(int(1.05 * self.array.nbytes)), - "-s", - self.path, - ] + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] ) @property @@ -64,6 +65,7 @@ def client(self): return self._client def __getstate__(self): + """Called on pickle load""" if self.plasma is None: return self.__dict__ if self.object_id is None: @@ -78,6 +80,7 @@ def __getstate__(self): return state def __setstate__(self, state): + """Called on pickle save""" self.__dict__.update(state) if self.plasma is None: return @@ -89,3 +92,106 @@ def __del__(self): self._server = None self._server_tmp.close() self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024 ** 3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index ce0a0d1114..d2c65fd7e0 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -7,6 +7,7 @@ import torch from fairseq.data import FairseqDataset, plasma_utils from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple class TokenBlockDataset(FairseqDataset): @@ -42,7 +43,46 @@ def __init__( break_mode=None, include_targets=False, document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" try: from fairseq.data.token_block_utils_fast import ( _get_slice_indices_fast, @@ -54,15 +94,6 @@ def __init__( "or `python setup.py build_ext --inplace`" ) - super().__init__() - self.dataset = dataset - self.pad = pad - self.eos = eos - self.include_targets = include_targets - - assert len(dataset) == len(sizes) - assert len(dataset) > 0 - if isinstance(sizes, list): sizes = np.array(sizes, dtype=np.int64) else: @@ -79,7 +110,7 @@ def __init__( slice_indices = _get_slice_indices_fast( sizes, str(break_mode), block_size, document_sep_len ) - self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + _sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices if break_mode == "eos": @@ -99,15 +130,12 @@ def __init__( sizes, slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 - slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max()) - - self._slice_indices = plasma_utils.PlasmaArray( - slice_indices.astype(slice_indices_dtype) - ) - self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) - self._block_to_dataset_index = plasma_utils.PlasmaArray( - block_to_dataset_index.astype(slice_indices_dtype) - ) + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices @property def slice_indices(self): @@ -131,7 +159,6 @@ def __getitem__(self, index): buffer = torch.cat( [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] ) - slice_s, slice_e = self.slice_indices[index] length = slice_e - slice_s s, e = start_offset, start_offset + length diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 5d6aee157a..3c29be9197 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -104,15 +104,10 @@ class CommonConfig(FairseqDataclass): ) wandb_project: Optional[str] = field( default=None, - metadata={ - "help": "Weights and Biases project name to use for logging" - }, + metadata={"help": "Weights and Biases project name to use for logging"}, ) azureml_logging: Optional[bool] = field( - default=False, - metadata={ - "help": "Log scalars to AzureML context" - }, + default=False, metadata={"help": "Log scalars to AzureML context"}, ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} @@ -192,6 +187,15 @@ class CommonConfig(FairseqDataclass): "main method can return a value (useful for sweeps)" }, ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) @dataclass @@ -263,7 +267,7 @@ class DistributedTrainingConfig(FairseqDataclass): metadata={ "help": "kill the job if no progress is made in N seconds; " "set to -1 to disable" - } + }, ) broadcast_buffers: bool = field( default=False, @@ -360,16 +364,13 @@ class DistributedTrainingConfig(FairseqDataclass): tpu: bool = II("common.tpu") # configuration for --ddp-backend=fully_sharded no_reshard_after_forward: bool = field( - default=False, - metadata={"help": "don't reshard parameters after forward pass"}, + default=False, metadata={"help": "don't reshard parameters after forward pass"}, ) fp32_reduce_scatter: bool = field( - default=False, - metadata={"help": "reduce-scatter grads in FP32"}, + default=False, metadata={"help": "reduce-scatter grads in FP32"}, ) cpu_offload: bool = field( - default=False, - metadata={"help": "offload FP32 params to CPU"} + default=False, metadata={"help": "offload FP32 params to CPU"} ) @@ -665,12 +666,10 @@ class FairseqBMUFConfig(FairseqDataclass): @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( - default=5, - metadata={"help": "beam size"}, + default=5, metadata={"help": "beam size"}, ) nbest: int = field( - default=1, - metadata={"help": "number of hypotheses to output"}, + default=1, metadata={"help": "number of hypotheses to output"}, ) max_len_a: float = field( default=0, @@ -685,24 +684,19 @@ class GenerationConfig(FairseqDataclass): }, ) min_len: int = field( - default=1, - metadata={"help": "minimum generation length"}, + default=1, metadata={"help": "minimum generation length"}, ) match_source_len: bool = field( - default=False, - metadata={"help": "generations should match the source length"}, + default=False, metadata={"help": "generations should match the source length"}, ) unnormalized: bool = field( - default=False, - metadata={"help": "compare unnormalized hypothesis scores"}, + default=False, metadata={"help": "compare unnormalized hypothesis scores"}, ) no_early_stop: bool = field( - default=False, - metadata={"help": "deprecated"}, + default=False, metadata={"help": "deprecated"}, ) no_beamable_mm: bool = field( - default=False, - metadata={"help": "don't use BeamableMM in attention layers"}, + default=False, metadata={"help": "don't use BeamableMM in attention layers"}, ) lenpen: float = field( default=1, @@ -724,12 +718,10 @@ class GenerationConfig(FairseqDataclass): }, ) sacrebleu: bool = field( - default=False, - metadata={"help": "score with sacrebleu"}, + default=False, metadata={"help": "score with sacrebleu"}, ) score_reference: bool = field( - default=False, - metadata={"help": "just score the reference translation"}, + default=False, metadata={"help": "just score the reference translation"}, ) prefix_size: int = field( default=0, @@ -763,12 +755,10 @@ class GenerationConfig(FairseqDataclass): }, ) temperature: float = field( - default=1.0, - metadata={"help": "temperature for generation"}, + default=1.0, metadata={"help": "temperature for generation"}, ) diverse_beam_groups: int = field( - default=-1, - metadata={"help": "number of groups for Diverse Beam Search"}, + default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, ) diverse_beam_strength: float = field( default=0.5, @@ -787,16 +777,13 @@ class GenerationConfig(FairseqDataclass): }, ) print_step: bool = field( - default=False, - metadata={"help": "print steps"}, + default=False, metadata={"help": "print steps"}, ) lm_path: Optional[str] = field( - default=None, - metadata={"help": "path to lm checkpoint for lm fusion"}, + default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, ) lm_weight: float = field( - default=0.0, - metadata={"help": "weight for lm probs for lm fusion"}, + default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, ) # arguments for iterative refinement generator @@ -805,8 +792,7 @@ class GenerationConfig(FairseqDataclass): metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, ) iter_decode_max_iter: int = field( - default=10, - metadata={"help": "maximum iterations for iterative refinement."}, + default=10, metadata={"help": "maximum iterations for iterative refinement."}, ) iter_decode_force_max_iter: bool = field( default=False, @@ -833,8 +819,7 @@ class GenerationConfig(FairseqDataclass): }, ) retain_dropout: bool = field( - default=False, - metadata={"help": "Use dropout at inference time"}, + default=False, metadata={"help": "Use dropout at inference time"}, ) # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # retain_dropout_modules: Optional[List[str]] = field( @@ -859,8 +844,7 @@ class GenerationConfig(FairseqDataclass): @dataclass class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, - metadata={"help": "path(s) to model file(s), colon separated"}, + default=None, metadata={"help": "path(s) to model file(s), colon separated"}, ) post_process: Optional[str] = field( default=None, @@ -922,8 +906,7 @@ class InteractiveConfig(FairseqDataclass): }, ) input: str = field( - default="-", - metadata={"help": "file to read from; use - for stdin"}, + default="-", metadata={"help": "file to read from; use - for stdin"}, ) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 579bf69785..a3847733a1 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -91,6 +91,8 @@ class LanguageModelingConfig(FairseqDataclass): ) data_buffer_size: int = II("dataset.data_buffer_size") tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") @register_task("language_modeling", dataclass=LanguageModelingConfig) @@ -198,13 +200,12 @@ def load_dataset( data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) + # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: - raise FileNotFoundError( - "Dataset not found: {} ({})".format(split, split_path) - ) + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, @@ -214,7 +215,6 @@ def load_dataset( self.args.tokens_per_sample, self.args.seed, ) - dataset = TokenBlockDataset( dataset, dataset.sizes, @@ -223,6 +223,9 @@ def load_dataset( eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index dcf5305455..ee29ed65a8 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1165,10 +1165,7 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index d770e4e4ec..d618817e46 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -24,6 +24,7 @@ utils, ) from fairseq.data import iterators +from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils @@ -118,7 +119,6 @@ def main(cfg: FairseqConfig) -> None: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size @@ -465,6 +465,10 @@ def cli_main( cfg = convert_namespace_to_omegaconf(args) + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -472,6 +476,9 @@ def cli_main( else: distributed_utils.call_main(cfg, main) + # if cfg.common.use_plasma_view: + # server.server.kill() + if __name__ == "__main__": cli_main() diff --git a/tests/test_plasma_utils.py b/tests/test_plasma_utils.py new file mode 100644 index 0000000000..5737530e3d --- /dev/null +++ b/tests/test_plasma_utils.py @@ -0,0 +1,127 @@ +import contextlib +import unittest +import tempfile +from io import StringIO + +import numpy as np + +from tests.test_binaries import train_language_model +from tests.utils import create_dummy_data, preprocess_lm_data + +try: + from pyarrow import plasma + from fairseq.data.plasma_utils import PlasmaView, PlasmaStore + + PYARROW_AVAILABLE = True +except ImportError: + PYARROW_AVAILABLE = False + +dummy_path = 'dummy' + + +@unittest.skipUnless(PYARROW_AVAILABLE, "") +class TestPlasmaView(unittest.TestCase): + def setUp(self) -> None: + self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201 + self.path = self.tmp_file.name + self.server = PlasmaStore.start(path=self.path) + self.client = plasma.connect(self.path, num_retries=10) + + def tearDown(self) -> None: + self.client.disconnect() + self.tmp_file.close() + self.server.kill() + + def test_two_servers_do_not_share_object_id_space(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + server_2_path = self.path + with tempfile.NamedTemporaryFile() as server_1_path: + server = PlasmaStore.start(path=server_1_path.name, nbytes=10000) + arr1 = PlasmaView( + data_server_1, dummy_path, 1, plasma_path=server_1_path.name + ) + assert len(arr1.client.list()) == 1 + assert (arr1.array == data_server_1).all() + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path) + assert (arr2.array == data_server_2).all() + assert (arr1.array == data_server_1).all() + server.kill() + + def test_hash_collision(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + assert len(arr2.client.list()) == 1 + assert (arr2.array == data_server_1).all() + # New hash key based on tuples + arr3 = PlasmaView( + data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path + ) + assert ( + len(arr2.client.list()) == 2 + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr2.client.list() + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr3.client.list() + ), "No new object was created by using a novel hash key" + del arr3, arr2, arr1 + + @staticmethod + def _assert_view_equal(pv1, pv2): + np.testing.assert_array_equal(pv1.array, pv2.array) + + def test_putting_same_array_twice(self): + data = np.array([4, 4, 4]) + arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path) + assert len(self.client.list()) == 1 + arr1b = PlasmaView( + data, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + arr1c = PlasmaView( + None, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + + assert len(self.client.list()) == 1 + self._assert_view_equal(arr1, arr1b) + self._assert_view_equal(arr1, arr1c) + PlasmaView( + data, dummy_path, 2, plasma_path=self.path + ) # new object id, adds new entry + assert len(self.client.list()) == 2 + + new_client = plasma.connect(self.path) + assert len(new_client.list()) == 2 # new client can access same objects + assert isinstance(arr1.object_id, plasma.ObjectID) + del arr1b + del arr1c + + def test_plasma_store_full_raises(self): + with tempfile.NamedTemporaryFile() as new_path: + server = PlasmaStore.start(path=new_path.name, nbytes=10000) + with self.assertRaises(plasma.PlasmaStoreFull): + # 2000 floats is more than 2000 bytes + PlasmaView( + np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name + ) + server.kill() + + def test_object_id_overflow(self): + PlasmaView.get_object_id("", 2 ** 21) + + def test_training_lm_plasma(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + ["--use-plasma-view", "--plasma-path", self.path], + run_validation=True, + ) From 252d5a9ae93e68254cfb1896fb5624cf11cda15e Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Fri, 12 Mar 2021 16:45:40 -0800 Subject: [PATCH 517/707] Fix a bug that FairseqSimulSTAgent is not an agent (#1690) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1690 Reviewed By: jmp84 Differential Revision: D27025669 Pulled By: xutaima fbshipit-source-id: 8125365adedfdc938813d08e911e1f6ebe4f584b --- .../agents/fairseq_simul_st_agent.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 2b5fdc2d3f..8b8003e1d5 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -10,12 +10,11 @@ try: from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent from simuleval.states import ListEntry, SpeechStates except ImportError: print("Please install simuleval 'pip install simuleval'") -from torch import nn - SHIFT_SIZE = 10 WINDOW_SIZE = 25 SAMPLE_RATE = 16000 @@ -65,7 +64,7 @@ def __call__(self, new_samples): input_samples = samples[:effective_num_samples] self.previous_residual_samples = samples[ - num_frames * self.num_samples_per_shift : + num_frames * self.num_samples_per_shift: ] torch.manual_seed(1) @@ -113,12 +112,12 @@ def info(self): } -class FairseqSimulSTAgent(nn.Module): +class FairseqSimulSTAgent(SpeechAgent): speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size def __init__(self, args): - super().__init__() + super().__init__(args) self.eos = DEFAULT_EOS @@ -218,6 +217,9 @@ def load_model_vocab(self, args): task_args = state["cfg"]["task"] task_args.data = args.data_bin + if args.config is not None: + task_args.config_yaml = args.config + task = self.set_up_task(task_args) # build model for ensemble From 965240c784910895b05e66d7ef7e15321050b414 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Sun, 14 Mar 2021 20:55:46 -0700 Subject: [PATCH 518/707] optimize memory when loading large checkpoints by deleting state dict early Summary: I had some issues with loading checkpoints from 5B parameter models (60 GB checkpoint files) due to OOM. Reviewed By: myleott Differential Revision: D27027616 fbshipit-source-id: 2b816e8e46ec80f0ec721aa7a6702cee531b94eb --- fairseq/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index ee29ed65a8..1c4c532dd0 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -442,10 +442,14 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) + # save memory for later steps + del state["model"] if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) + del state["criterion"] + except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " From dd74992d0d143155998e9ed4076826bcea80fb06 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Mon, 15 Mar 2021 23:44:19 -0700 Subject: [PATCH 519/707] Several updates for simul speech transition example (#1703) Summary: Fix sever issues in simul speech transition example, including - Load pretrained encoder with when loading model. - Generating broken config.yaml when using gcvm. - Fix the preprocessed databin. - Fix some errors in the instructions. - Add detailed instructions on evaluation a pretrained model. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1703 Reviewed By: jmp84 Differential Revision: D27071600 Pulled By: xutaima fbshipit-source-id: bfe72005190d7936caeef4f805bd99c8d2cf9c37 --- examples/speech_to_text/data_utils.py | 4 +- .../docs/simulst_mustc_example.md | 72 ++++++++++++++++--- .../agents/fairseq_simul_st_agent.py | 9 ++- 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 3f96ffc427..2bcff046f7 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -168,7 +168,7 @@ def gen_config_yaml( assert gcmvn_path is not None, ( 'Please provide path of global cmvn file.' ) - writer.set_global_cmvn(gcmvn_path) + writer.set_global_cmvn(str(gcmvn_path)) if len(audio_root) > 0: writer.set_audio_root(audio_root) @@ -325,7 +325,7 @@ def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer def set_global_cmvn(self, stats_npz_path: str): - self.config["stats_npz_path"] = stats_npz_path + self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path} def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 3452806a1c..22f359abe3 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -5,6 +5,9 @@ This is a tutorial of training and evaluating a transformer *wait-k* simultaneou [MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. ## Data Preparation +This section introduces the data preparation for training and evaluation. +If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation) + [Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with ```bash @@ -77,25 +80,74 @@ a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` ``` ## Inference & Evaluation [SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. -The source file is a list of paths of audio files, -while target file is the corresponding translations. +The following command is for evaluation. + ``` git clone https://github.com/facebookresearch/SimulEval.git cd SimulEval pip install -e . simuleval \ - --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py - --src-file ${SRC_LIST_OF_AUDIO} - --tgt-file ${TGT_FILE} + --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --source ${SRC_LIST_OF_AUDIO} + --target ${TGT_FILE} --data-bin ${MUSTC_ROOT}/en-de \ + --config config_st.yaml \ --model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ - --tgt-splitter-type SentencePieceModel \ - --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ + --output ${OUTPUT} \ --scores ``` -A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). +The source file `${SRC_LIST_OF_AUDIO}` is a list of paths of audio files. Assuming your audio files stored at `/home/user/data`, +it should look like this + +```bash +/home/user/data/audio-1.wav +/home/user/data/audio-2.wav +``` + +Each line of target file `${TGT_FILE}` is the translation for each audio file input. +```bash +Translation_1 +Translation_2 +``` + +The `--data-bin` and `--config` should be the same in previous section if you prepare the data from the scratch. +If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). It contains +- `spm_unigram10000_st.model`: a sentencepiece model binary. +- `spm_unigram10000_st.txt`: the dictionary file generated by the sentencepiece model. +- `gcmvn.npz`: the binary for global cepstral mean and variance. +- `config_st.yaml`: the config yaml file. It looks like this. +You will need to set the absolute paths for `sentencepiece_model` and `stats_npz_path` if the data directory is downloaded. +```yaml +bpe_tokenizer: + bpe: sentencepiece + sentencepiece_model: ABS_PATH_TO_SENTENCEPIECE_MODEL +global_cmvn: + stats_npz_path: ABS_PATH_TO_GCMVN_FILE +input_channels: 1 +input_feat_per_channel: 80 +sampling_alpha: 1.0 +specaugment: + freq_mask_F: 27 + freq_mask_N: 1 + time_mask_N: 1 + time_mask_T: 100 + time_mask_p: 1.0 + time_wrap_W: 0 +transforms: + '*': + - global_cmvn + _train: + - global_cmvn + - specaugment +vocab_filename: spm_unigram10000_st.txt +``` + +Notice that once a `--data-bin` is set, the `--config` is the base name of the config yaml, not the full path. + +Set `--model-path` to the model checkpoint. +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The output should be similar as follow: ```bash @@ -114,10 +166,14 @@ The output should be similar as follow: } ``` +If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory. + + The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. The latency metrics are * Average Proportion * Average Lagging * Differentiable Average Lagging + Again they will also be evaluated on detokenized text. diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 8b8003e1d5..9ff07775da 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -138,7 +138,7 @@ def __init__(self, args): args.global_cmvn = None if args.config: - with open(args.config, "r") as f: + with open(os.path.join(args.data_bin, args.config), "r") as f: config = yaml.load(f, Loader=yaml.BaseLoader) if "global_cmvn" in config: @@ -203,9 +203,6 @@ def add_args(parser): # fmt: on return parser - def set_up_task(self, task_args): - return tasks.setup_task(task_args) - def load_model_vocab(self, args): filename = args.model_path @@ -220,9 +217,11 @@ def load_model_vocab(self, args): if args.config is not None: task_args.config_yaml = args.config - task = self.set_up_task(task_args) + task = tasks.setup_task(task_args) # build model for ensemble + state["cfg"]["model"].load_pretrained_encoder_from = None + state["cfg"]["model"].load_pretrained_decoder_from = None self.model = task.build_model(state["cfg"]["model"]) self.model.load_state_dict(state["model"], strict=True) self.model.eval() From edcef1306b48e7fa9bf84dcbec25171a1e57a5dc Mon Sep 17 00:00:00 2001 From: Jongsoo Park <jongsoo@fb.com> Date: Wed, 17 Mar 2021 20:02:30 -0700 Subject: [PATCH 520/707] make deepspeed cpu_adam works in fbcode Summary: To test cpu-offload + fsdp in fairseq Differential Revision: D26873232 fbshipit-source-id: 8d4dee874713a055bb6b6541cddcbfd722eef9f8 --- fairseq/optim/cpu_adam.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index 5e935df1a5..9637a78666 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -12,12 +12,14 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from omegaconf import II, DictConfig +import logging try: - from deepspeed.ops.op_builder import CPUAdamBuilder + import deepspeed.op_extensions.cpu_adam as ds_opt_adam has_deepspeed_cpu_adam = True -except ImportError: +except ImportError as e: + logging.warning(e) has_deepspeed_cpu_adam = False @@ -101,7 +103,7 @@ def __init__( self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 - self.ds_opt_adam = CPUAdamBuilder().load() + self.ds_opt_adam = ds_opt_adam adamw_mode = True self.ds_opt_adam.create_adam( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode From 53b781caad3d58cd7fefed80e17faf147d15da66 Mon Sep 17 00:00:00 2001 From: Juan Miguel Pino <juancarabina@fb.com> Date: Fri, 19 Mar 2021 09:35:09 -0700 Subject: [PATCH 521/707] Add --update-freq 8 to simulst tutorial (#3374) Summary: Clarify that training is done on 1 GPU. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3374 Reviewed By: kahne Differential Revision: D27183474 Pulled By: jmp84 fbshipit-source-id: 330ec9b6510dcbd1f38a7c1c954d7504c6de3dda --- examples/speech_to_text/docs/simulst_mustc_example.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 22f359abe3..d83ec086f9 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -31,7 +31,8 @@ python examples/speech_to_text/prep_mustc_data.py \ ``` ## ASR Pretraining -We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}` +We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}`. +The following command (and the subsequent training commands in this tutorial) assume training on 1 GPU (you can also train on 8 GPUs and remove the `--update-freq 8` option). ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ @@ -60,7 +61,8 @@ a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` --arch convtransformer_simul_trans_espnet \ --simul-type waitk_fixed_pre_decision \ --waitk-lagging 3 \ - --fixed-pre-decision-ratio 7 + --fixed-pre-decision-ratio 7 \ + --update-freq 8 ``` ### Monotonic multihead attention with fixed pre-decision module @@ -76,7 +78,8 @@ a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` --latency-weight-avg 0.1 \ --arch convtransformer_simul_trans_espnet \ --simul-type infinite_lookback_fixed_pre_decision \ - --fixed-pre-decision-ratio 7 + --fixed-pre-decision-ratio 7 \ + --update-freq 8 ``` ## Inference & Evaluation [SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. From 5c87bb5ce81dbc051a37e50bca3da40633149f26 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Sat, 20 Mar 2021 15:00:56 -0700 Subject: [PATCH 522/707] Fix RoBERTa + FSDP (also minor fix for GPT-3 configs) (#1724) Summary: There will also be a fix on the fairscale side to fix the segfault with FSDP. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1724 Reviewed By: sshleifer Differential Revision: D27189594 Pulled By: myleott fbshipit-source-id: 7a0ccadf8e2104cc782faccb55756deddb2dd346 --- fairseq/models/roberta/model.py | 3 --- fairseq/models/transformer.py | 2 +- fairseq/models/transformer_lm.py | 5 ++++- fairseq/modules/fairseq_dropout.py | 2 +- fairseq/modules/transformer_sentence_encoder.py | 17 ++++++++++++----- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 5d2ed4902d..5b9ba8105f 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -458,9 +458,6 @@ def build_encoder(self, args, dictionary, embed_tokens): def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) - def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): - return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) - def forward( self, src_tokens, diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 297807c31a..1e47d102f9 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -511,7 +511,7 @@ def forward_scriptable( x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation - if encoder_padding_mask is not None: + if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index fca9470e5e..d2c0cff493 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -449,11 +449,14 @@ def transformer_lm_gpt2_big(args): def base_gpt3_architecture(args): + args.decoder_input_dim = args.decoder_embed_dim + args.decoder_output_dim = args.decoder_embed_dim args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.dropout = getattr(args, "dropout", 0.0) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "gelu") + args.share_decoder_input_output_embed = True base_lm_architecture(args) @@ -489,7 +492,7 @@ def transformer_lm_gpt3_xl(args): # 1.3B params args.decoder_layers = getattr(args, "decoder_layers", 24) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 24) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) base_gpt3_architecture(args) diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py index f070a804e6..3cddca7718 100644 --- a/fairseq/modules/fairseq_dropout.py +++ b/fairseq/modules/fairseq_dropout.py @@ -21,7 +21,7 @@ def __init__(self, p, module_name=None): self.apply_during_inference = False def forward(self, x, inplace: bool = False): - if self.training or self.apply_during_inference: + if self.p > 0 and (self.training or self.apply_during_inference): return F.dropout(x, p=self.p, training=True, inplace=inplace) else: return x diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index a7fb198779..d0540d6922 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -32,18 +32,25 @@ def init_bert_params(module): the normal distribution (to be validated). """ + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=0.02) + normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=0.02) + normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, MultiheadAttention): - module.q_proj.weight.data.normal_(mean=0.0, std=0.02) - module.k_proj.weight.data.normal_(mean=0.0, std=0.02) - module.v_proj.weight.data.normal_(mean=0.0, std=0.02) + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) class TransformerSentenceEncoder(nn.Module): From 8f77e24cf184c9762ed48acb43e9bd5daba550b1 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Sat, 20 Mar 2021 16:40:08 -0700 Subject: [PATCH 523/707] Deepspeed can be used outside of fbcode (#1727) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1727 Reviewed By: myleott Differential Revision: D27213955 Pulled By: sshleifer fbshipit-source-id: be84e7f7c1c55c407ee7445fad9b3026a79763fb --- fairseq/optim/cpu_adam.py | 20 ++++++++++++++------ scripts/test_fsdp.sh | 13 +++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) create mode 100755 scripts/test_fsdp.sh diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index 9637a78666..21336ef59b 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -16,13 +16,21 @@ try: - import deepspeed.op_extensions.cpu_adam as ds_opt_adam - has_deepspeed_cpu_adam = True + import deepspeed + has_deepspeed = True except ImportError as e: - logging.warning(e) - has_deepspeed_cpu_adam = False + has_deepspeed = False +def _get_cpu_adam(): + try: + from deepspeed.ops.op_builder import CPUAdamBuilder + return CPUAdamBuilder().load() + except ImportError: + # fbcode + from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam + return ds_opt_adam + @dataclass class FairseqCPUAdamConfig(FairseqDataclass): adam_betas: str = field( @@ -97,13 +105,13 @@ def __init__( self.use_fp16_stats = use_fp16_stats self.FLOAT16_MAX = 65504.0 - if not has_deepspeed_cpu_adam: + if not has_deepspeed: raise ImportError("Please install DeepSpeed: pip install deepspeed") self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 - self.ds_opt_adam = ds_opt_adam + self.ds_opt_adam = _get_cpu_adam() adamw_mode = True self.ds_opt_adam.create_adam( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode diff --git a/scripts/test_fsdp.sh b/scripts/test_fsdp.sh new file mode 100755 index 0000000000..0f4d6c420b --- /dev/null +++ b/scripts/test_fsdp.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +rm -rf fsdp_dummy +mkdir -p fsdp_dummy +fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 256 --batch-size 8 \ + --arch transformer_lm_gpt2_tiny \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --log-format json --log-interval 1 \ + --save-interval-updates 10 --save-dir fsdp_dummy \ + --restore-file x.pt "$@" From 5273bbb7c18a9b147e3f0cfc97121cc945a962bd Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Sat, 20 Mar 2021 19:25:54 -0700 Subject: [PATCH 524/707] Fix transformer LM arg upgrade logic (#1717) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1717 Reviewed By: alexeib Differential Revision: D27156919 Pulled By: myleott fbshipit-source-id: af3c2e41464c04a7808f40894e7b0106798e4822 --- fairseq/models/transformer_lm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index d2c0cff493..70354a228d 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -230,9 +230,6 @@ def __init__(self, decoder): def build_model(cls, args, task): """Build a new model instance.""" - # make sure all arguments are present in older models - base_lm_architecture(args) - if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) From bd0109cdc66edbd01e7362d41e5997f85afbde7d Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Mon, 22 Mar 2021 13:45:48 -0700 Subject: [PATCH 525/707] Revert change in defaults for LMs to learned pos embeddings (#1734) Summary: This was premature. Leaving it to ``True`` for GPT-3 configs, but reverting back to ``False`` in general. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1734 Reviewed By: shruti-bh Differential Revision: D27233834 Pulled By: myleott fbshipit-source-id: 597b36f94f28d59834f1d68ab1dd2991e82c1e32 --- fairseq/models/transformer_lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 70354a228d..b616a923d4 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -304,7 +304,7 @@ def base_lm_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) @@ -449,6 +449,7 @@ def base_gpt3_architecture(args): args.decoder_input_dim = args.decoder_embed_dim args.decoder_output_dim = args.decoder_embed_dim args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + # GPT-3 used learned positional embeddings, rather than sinusoidal args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.dropout = getattr(args, "dropout", 0.0) args.attention_dropout = getattr(args, "attention_dropout", 0.0) From 8c14a8f7dfcd18abd983491fc2207ac634d25759 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Mon, 22 Mar 2021 20:51:32 -0700 Subject: [PATCH 526/707] --nval only validate for a few steps (#1735) Summary: Afaict, it's easy to set --max-update 4 to run training quickly, but it's hard to control validation without changing the data. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1735 Reviewed By: myleott Differential Revision: D27246062 Pulled By: sshleifer fbshipit-source-id: 30a210cbbb45791647a050f49e6f38fbacd0d988 --- fairseq/dataclass/configs.py | 2 ++ fairseq_cli/train.py | 4 +++- tests/test_binaries.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3c29be9197..be9f7c5af3 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -448,6 +448,8 @@ class DatasetConfig(FairseqDataclass): "argparse_alias": "--max-sentences-valid", }, ) + max_valid_steps: Optional[int] = field(default=None, metadata={'help': 'How many batches to evaluate', + "argparse_alias": "--nval"}) curriculum: int = field( default=0, metadata={"help": "don't shuffle batches for first N epochs"} ) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index d618817e46..8b5ca89cee 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -432,7 +432,9 @@ def validate( # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: - for sample in progress: + for i, sample in enumerate(progress): + if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: + break trainer.valid_step(sample) # log validation stats diff --git a/tests/test_binaries.py b/tests/test_binaries.py index e10cc767b8..49e6dcd9f8 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1160,7 +1160,7 @@ def test_transformer_lm(self): train_language_model( data_dir, "transformer_lm", - ["--add-bos-token"], + ["--add-bos-token", '--nval', '1'], run_validation=True, ) eval_lm_main(data_dir) From 1bba712622b8ae4efb3eb793a8a40da386fe11d0 Mon Sep 17 00:00:00 2001 From: Taylan Bilal <taylanbil@gmail.com> Date: Mon, 22 Mar 2021 21:29:42 -0700 Subject: [PATCH 527/707] Enable w2v2 tpu (#3328) Summary: This enables training wav2vec 2.0 models on TPUs courtesy of taylanbil Pull Request resolved: https://github.com/pytorch/fairseq/pull/3328 Reviewed By: myleott Differential Revision: D27127542 Pulled By: alexeib fbshipit-source-id: b402c58f812c3c36edaaa88fdbe20e37fae3d4f3 --- examples/wav2vec/README.md | 52 +++++++ .../wav2vec2_large_librivox_tpu-pod.yaml | 71 +++++++++ .../wav2vec2_large_librivox_tpu.yaml | 71 +++++++++ fairseq/criterions/wav2vec_criterion.py | 65 ++++++--- fairseq/data/audio/raw_audio_dataset.py | 123 +++++++++++++++- fairseq/data/bucket_pad_length_dataset.py | 46 +++--- fairseq/data/data_utils.py | 22 +++ fairseq/distributed/utils.py | 6 +- fairseq/logging/metrics.py | 8 + fairseq/models/wav2vec/wav2vec2.py | 137 +++++++++++------- fairseq/models/wav2vec/wav2vec2_asr.py | 24 ++- fairseq/tasks/audio_pretraining.py | 79 +++++++++- fairseq/trainer.py | 22 ++- fairseq/utils.py | 36 +++++ fairseq_cli/train.py | 5 +- 15 files changed, 658 insertions(+), 109 deletions(-) create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e95f292b51..bfed3913cf 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -222,6 +222,58 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test ``` +### Run wav2vec2 pre-training on Google Cloud TPUs: + +Wav2Vec2 is now supported on TPUs! It's currently pre-training only. + +#### Using hydra on a v3-8: + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu.yaml +``` + +#### Using command line arguments on a v3-8: + +``` +$ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size 8 --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + +#### Using hydra on a pod slice (v3-N with N > 8): + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu-pod.yaml # edit distributed-world-size accordingly +``` + +#### Using command line arguments on a pod slice (v3-N with N > 8): + + +``` +$ python -m torch_xla.distributed.xla_dist \ + --tpu ${TPUNAME} --conda-env=torch-xla-${TORCH_XLA_VERSION} --env OMP_NUM_THREADS=1 \ + -- \ +python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size ${WORLD_SIZE} --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + ### Extract embeddings from the downstream task data: ``` diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml new file mode 100644 index 0000000000..676c9fe339 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml @@ -0,0 +1,71 @@ +# @package _group_ + +common: + tpu: true + fp16: false + log_format: json + log_interval: 10 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + enable_padding: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 256 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + mask_channel_prob: 0.1 + mask_prob: 0.1 + + feature_grad_mult: 1.0 + diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml new file mode 100644 index 0000000000..c45c4d9117 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -0,0 +1,71 @@ +# @package _group_ + +common: + tpu: true + fp16: false + log_format: json + log_interval: 10 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + enable_padding: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 8 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 256 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + mask_channel_prob: 0.1 + mask_prob: 0.1 + + feature_grad_mult: 1.0 + diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 859177f2b6..f682508cb1 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -31,7 +31,7 @@ class Wav2VecCriterionConfig(FairseqDataclass): default_factory=lambda: [], metadata={"help": "output keys to log"}, ) - +from fairseq.utils import index_put, is_xla_tensor @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): @@ -52,7 +52,9 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) + self.xla = is_xla_tensor(logits) + # XXX: handle weights on xla. weights = None if hasattr(model, "get_target_weights") and not self.infonce: weights = model.get_target_weights(target, net_output) @@ -61,21 +63,31 @@ def forward(self, model, sample, reduce=True): losses = [] + reduction = "none" if ((not reduce) or self.xla) else "sum" if self.infonce: - loss = F.cross_entropy( - logits, - target, - reduction="sum" if reduce else "none", - ) + loss = F.cross_entropy(logits, target, reduction=reduction) else: loss = F.binary_cross_entropy_with_logits( - logits, - target.float(), - weights, - reduction="sum" if reduce else "none", + logits, target.float(), weights, reduction=reduction + ) + + if self.xla: + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we use mask indices to adjust loss. + mi = ( + sample['net_input']['mask_indices'] + .transpose(0, 1) # logits are transposed in `model.get_logits` + .reshape(logits.size(0)) ) + loss = (loss * mi).sum() if reduce else (loss * mi) - sample_size = target.numel() if self.infonce else target.long().sum().item() + if 'sample_size' in sample and self.infonce: + sample_size = sample['sample_size'] + elif 'mask_indices' in sample['net_input']: + sample_size = sample['net_input']['mask_indices'].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss.detach().clone()) if self.loss_weights is not None: @@ -95,7 +107,7 @@ def forward(self, model, sample, reduce=True): losses.append(p) logging_output = { - "loss": loss.item() if reduce else loss, + "loss": loss.item() if (reduce and not self.xla) else loss.detach(), "ntokens": sample_size, "nsentences": sample["id"].numel(), "sample_size": sample_size, @@ -111,11 +123,14 @@ def forward(self, model, sample, reduce=True): if not self.training: logging_output["target"] = target.cpu().numpy() elif lk in net_output: - logging_output[lk] = float(net_output[lk]) + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f"loss_{i}"] = l.item() + logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach() if self.infonce: with torch.no_grad(): @@ -126,9 +141,15 @@ def forward(self, model, sample, reduce=True): assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 - both = max & min - corr = max.long().sum().item() - both.long().sum().item() - count = max.numel() + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count @@ -188,11 +209,15 @@ def reduce_metrics(logging_outputs) -> None: else: metrics.log_scalar(k, val / len(logging_outputs), round=3) - @staticmethod - def logging_outputs_can_be_summed() -> bool: + # FIXME: revert when gather based xla reduction is implemented + #@staticmethod + #def logging_outputs_can_be_summed() -> bool: + def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return False + # XXX: Gather based reduction not implemented for xla yet. + # So we fall to sum based reduction for xla. + return self.xla diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 1d92e4966b..d0ff604e2b 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset +from .. import FairseqDataset, BaseWrapperDataset +from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes logger = logging.getLogger(__name__) @@ -27,6 +28,8 @@ def __init__( shuffle=True, pad=False, normalize=False, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__() @@ -39,6 +42,14 @@ def __init__( self.pad = pad self.shuffle = shuffle self.normalize = normalize + self.compute_mask_indices = compute_mask_indices + if self.compute_mask_indices: + self.mask_compute_kwargs = mask_compute_kwargs + self._features_size_map = {} + self._C = mask_compute_kwargs['encoder_embed_dim'] + self._conv_feature_layers = eval( + mask_compute_kwargs['conv_feature_layers'] + ) def __getitem__(self, index): raise NotImplementedError() @@ -70,6 +81,45 @@ def crop_to_max_size(self, wav, target_size): end = size - diff + start return wav[start:end] + def _compute_mask_indices(self, dims, padding_mask): + B, T, C = dims + mask_indices, mask_channel_indices = None, None + if self.mask_compute_kwargs['mask_prob'] > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_compute_kwargs['mask_prob'], + self.mask_compute_kwargs['mask_length'], + self.mask_compute_kwargs['mask_selection'], + self.mask_compute_kwargs['mask_other'], + min_masks=2, + no_overlap=self.mask_compute_kwargs['no_mask_overlap'], + min_space=self.mask_compute_kwargs['mask_min_space'], + ) + mask_indices = torch.from_numpy(mask_indices) + if self.mask_compute_kwargs['mask_channel_prob'] > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_compute_kwargs['mask_channel_prob'], + self.mask_compute_kwargs['mask_channel_length'], + self.mask_compute_kwargs['mask_channel_selection'], + self.mask_compute_kwargs['mask_channel_other'], + no_overlap=self.mask_compute_kwargs['no_mask_channel_overlap'], + min_space=self.mask_compute_kwargs['mask_channel_min_space'], + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .unsqueeze(1) + .expand(-1, T, -1) + ) + + return mask_indices, mask_channel_indices + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + def collater(self, samples): samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: @@ -101,9 +151,55 @@ def collater(self, samples): collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} + out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask - return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + if hasattr(self, 'num_buckets') and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s['id']] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input['source'] = self._bucket_tensor( + collated_sources, num_pad, 0 + ) + input['padding_mask'] = self._bucket_tensor( + padding_mask, num_pad, True + ) + + if self.compute_mask_indices: + B = input['source'].size(0) + T = self._get_mask_indices_dims(input['source'].size(-1)) + padding_mask_reshaped = input['padding_mask'].clone() + extra = padding_mask_reshaped.size(1) % T + if extra > 0: + padding_mask_reshaped = padding_mask_reshaped[:, :-extra] + padding_mask_reshaped = padding_mask_reshaped.view( + padding_mask_reshaped.size(0), T, -1 + ) + padding_mask_reshaped = padding_mask_reshaped.all(-1) + input['padding_count'] = ( + padding_mask_reshaped.sum(-1).max().item() + ) + mask_indices, mask_channel_indices = self._compute_mask_indices( + (B, T, self._C), padding_mask_reshaped, + ) + input["mask_indices"] = mask_indices + input["mask_channel_indices"] = mask_channel_indices + out['sample_size'] = mask_indices.sum().item() + + out["net_input"] = input + return out + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in self._conv_feature_layers: + L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] def num_tokens(self, index): return self.size(index) @@ -138,6 +234,9 @@ def __init__( shuffle=True, pad=False, normalize=False, + num_buckets=0, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, @@ -146,6 +245,8 @@ def __init__( shuffle=shuffle, pad=pad, normalize=normalize, + compute_mask_indices=compute_mask_indices, + **mask_compute_kwargs, ) self.fnames = [] @@ -164,8 +265,26 @@ def __init__( self.fnames.append(items[0]) self.line_inds.add(i) self.sizes.append(sz) + self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index cda8834ac8..0f94100148 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -6,6 +6,7 @@ import numpy as np import torch.nn.functional as F from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes class BucketPadLengthDataset(BaseWrapperDataset): @@ -29,42 +30,43 @@ def __init__( num_buckets, pad_idx, left_pad, + tensor_key=None, ): super().__init__(dataset) self.pad_idx = pad_idx self.left_pad = left_pad assert num_buckets > 0 - self.buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation="lower", - )[1:] - ) + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key - def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item - self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] - def __getitem__(self, index): - item = self.dataset[index] - bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - item.size(-1) + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) return F.pad( - item, + tensor, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + @property def sizes(self): return self._bucketed_sizes diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 6f7561afbe..01c743c3e8 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -524,3 +524,25 @@ def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: return ~lengths_to_padding_mask(lens) + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 710ca18628..970b784915 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -281,7 +281,6 @@ def distributed_init(cfg: FairseqConfig): cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers - xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) @@ -357,7 +356,10 @@ def call_main(cfg: FairseqConfig, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, cfg, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), ) else: # single GPU main diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 7b56e31592..2bb1da086f 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -286,3 +286,11 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 644add7b17..6999dca2d9 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -26,7 +26,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, index_put, is_xla_tensor EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) @@ -330,47 +330,52 @@ def build_model(cls, cfg: Wav2Vec2Config, task=None): return cls(cfg) - def apply_mask(self, x, padding_mask): + def apply_mask( + self, x, padding_mask, + mask_indices=None, mask_channel_indices=None, + ): B, T, C = x.shape if self.mask_prob > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - x[mask_indices] = self.mask_emb + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) return x, mask_indices - def sample_negatives(self, y, num): + def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) @@ -378,8 +383,9 @@ def sample_negatives(self, y, num): bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C + # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz - high = tsz + high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" @@ -436,10 +442,17 @@ def compute_preds(self, x, y, negatives): logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) - logits /= self.logit_temp + logits = logits / self.logit_temp - if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + if is_xla_tensor(logits) or neg_is_pos.any(): + fillval = -float(2**30) + if not hasattr(self, '_inftensor'): + self._inftensor = ( + torch.tensor(fillval).to(x.device) + if is_xla_tensor(logits) else + float("-inf") + ) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits @@ -458,7 +471,11 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths.to(torch.long) - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def forward( + self, source, padding_mask=None, mask=True, features_only=False, + mask_indices=None, mask_channel_indices=None, + padding_count=None, + ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) @@ -509,8 +526,14 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.project_inp(features) if mask: - x, mask_indices = self.apply_mask(features, padding_mask) - if mask_indices is not None: + x, mask_indices = self.apply_mask( + features, padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if not is_xla_tensor(x) and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1) ) @@ -537,12 +560,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) - negs, _ = self.sample_negatives(neg_cands, y.size(1)) + neg_cands = self.quantizer( + unmasked_features, produce_targets=False + )["x"] + negs, _ = self.sample_negatives( + neg_cands, y.size(1), padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( @@ -557,12 +586,20 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs, _ = self.sample_negatives( + unmasked_features, y.size(1), + padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + if not is_xla_tensor(x): + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) @@ -571,7 +608,9 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): x = self.final_proj(x) x = self.compute_preds(x, y, negs) - result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} + result = { + "x": x, "padding_mask": padding_mask, "features_pen": features_pen, + } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl @@ -759,11 +798,11 @@ def forward(self, x, padding_mask=None): def extract_features(self, x, padding_mask=None): if padding_mask is not None: - x[padding_mask] = 0 + x = index_put(x, padding_mask, 0) x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index afa51299b6..e8a1d03eb2 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field from omegaconf import MISSING, II, open_dict -from typing import Any +from typing import Optional, Any from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass import FairseqDataclass @@ -127,7 +127,27 @@ class Wav2Vec2AsrConfig(FairseqDataclass): @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): - pass + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + conv_feature_layers: Optional[str] = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": ( + "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + ), + }, + ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) @register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index b7b5429819..df073a1814 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,6 +5,7 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import logging import os import sys import torch @@ -12,7 +13,7 @@ from argparse import Namespace from dataclasses import dataclass, field from typing import Optional, Any -from omegaconf import MISSING +from omegaconf import MISSING, II from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders from fairseq.dataclass import FairseqDataclass @@ -23,6 +24,9 @@ from ..logging import metrics +logger = logging.getLogger(__name__) + + class LabelEncoder(object): def __init__(self, dictionary): self.dictionary = dictionary @@ -86,6 +90,37 @@ class AudioPretrainingConfig(FairseqDataclass): "adds 'prev_output_tokens' to input and appends eos to target" }, ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "number of buckets" + }, + ) + precompute_mask_indices: bool = field( + default=False, + metadata={ + "help": "flag to compute mask indices in data preparation.", + }, + ) + # The following are needed to precompute mask and mask channel indices + # before model's forward. + mask_length: Optional[int] = II("model.mask_length") + mask_prob: Optional[float] = II("model.mask_prob") + mask_selection: Optional[str] = II("model.mask_selection") + mask_other: Optional[float] = II("model.mask_other") + no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") + mask_min_space: Optional[int] = II("model.mask_min_space") + mask_channel_length: Optional[int] = II("model.mask_channel_length") + mask_channel_prob: Optional[float] = II("model.mask_channel_prob") + mask_channel_selection: Optional[str] = II("model.mask_channel_selection") + mask_channel_other: Optional[float] = II("model.mask_channel_other") + no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") + mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") + + conv_feature_layers: Optional[str] = II("model.conv_feature_layers") + encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") + + tpu: bool = II("common.tpu") @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) @@ -117,11 +152,37 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): def load_target_dictionary(self): if self.cfg.labels: - dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + dict_path = os.path.join( + self.cfg.data, f"dict.{self.cfg.labels}.txt" + ) return Dictionary.load(dict_path) return None - def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + def _get_mask_precompute_kwargs(self, cfg): + if self.cfg.precompute_mask_indices or self.cfg.tpu: + args = [ + 'mask_length', + 'mask_prob', + 'mask_selection', + 'mask_other', + 'no_mask_overlap', + 'mask_min_space', + 'mask_channel_length', + 'mask_channel_prob', + 'mask_channel_selection', + 'mask_channel_other', + 'no_mask_channel_overlap', + 'mask_channel_min_space', + 'encoder_embed_dim', + 'conv_feature_layers', + ] + return {arg: cfg[arg] for arg in args} + else: + return {} + + def load_dataset( + self, split: str, task_cfg: FairseqDataclass = None, **kwargs + ): data_path = self.cfg.data task_cfg = task_cfg or self.cfg @@ -138,8 +199,20 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask_indices=( + self.cfg.precompute_mask_indices or self.cfg.tpu + ), + **self._get_mask_precompute_kwargs(task_cfg), ) + if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0: + logger.info( + "Pretraining on TPUs may suffer convergence " + "issues when training with `mask_channel_prob` value of " + "0. You may want to set this to a low value close to 0." + ) + if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") with open(label_path, "r") as f: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 1c4c532dd0..f7897070c7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -532,6 +532,7 @@ def get_train_iterator( epoch=epoch, combine=combine, data_selector=data_selector, + tpu=self.tpu, ) batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(self.cfg.dataset.train_subset), @@ -684,9 +685,7 @@ def maybe_no_sync(): # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass - import torch_xla.core.xla_model as xm - - xm.mark_step() + self._xla_markstep_and_send_to_cpu() if is_dummy_batch: if torch.is_tensor(sample_size): @@ -806,10 +805,10 @@ def maybe_no_sync(): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: - # mark step on TPUs import torch_xla.core.xla_model as xm - xm.mark_step() + # mark step on TPUs + self._xla_markstep_and_send_to_cpu() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 @@ -825,7 +824,7 @@ def maybe_no_sync(): metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0 ) - + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm ) @@ -878,9 +877,7 @@ def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous("valid_step") # wait for all workers - xm.mark_step() with torch.no_grad(): self.model.eval() @@ -923,6 +920,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output @@ -1300,6 +1299,13 @@ def _check_xla_compilation(self): ) self._num_xla_compiles = num_xla_compiles + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) + def _catalog_shared_params(module, memo=None, prefix=""): if memo is None: diff --git a/fairseq/utils.py b/fairseq/utils.py index d4bf73648b..90bb8369f2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -110,6 +110,7 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): + def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -120,6 +121,17 @@ def _move_to_cpu(tensor): return apply_to_sample(_move_to_cpu, sample) +def move_to_tpu(sample): + + import torch_xla.core.xla_model as xm + device = xm.xla_device() + + def _move_to_tpu(tensor): + return tensor.to(device) + + return apply_to_sample(_move_to_tpu, sample) + + def get_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], @@ -289,6 +301,9 @@ def convert_padding_direction( def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == 'xla': + return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): @@ -679,6 +694,27 @@ def tpu_data_loader(itr): ) +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == 'xla' + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 8b5ca89cee..6924dfe5c8 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -125,7 +125,7 @@ def main(cfg: FairseqConfig) -> None: ) ) logger.info( - "max tokens per GPU = {} and batch size per GPU = {}".format( + "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, ) @@ -139,6 +139,9 @@ def main(cfg: FairseqConfig) -> None: # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) + if cfg.common.tpu: + import torch_xla.core.xla_model as xm + xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() From 25dd9266809429c41833c151c5c09398f9f9ca9e Mon Sep 17 00:00:00 2001 From: Liang Tan <liangtan@fb.com> Date: Wed, 24 Mar 2021 20:28:29 -0700 Subject: [PATCH 528/707] Fix the issue that manifold raises error when reading non-existing file Summary: 1. Fairseq's expected reading behavior is "return None when accessing non-existing file" 2. However, when reading from manifold, manifold will raises error when accessing non-existing file 3. Here I add try-except block to bypass the manifold error Reviewed By: myleott Differential Revision: D27300619 fbshipit-source-id: 252606c82e9810516ccfb0705e08297f646b3708 --- fairseq/data/data_utils.py | 8 +++++++- fairseq/tasks/sentence_prediction.py | 19 +++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 01c743c3e8..097ea09f76 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -88,7 +88,13 @@ def load_indexed_dataset( datasets = [] for k in itertools.count(): path_k = path + (str(k) if k > 0 else "") - path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) + try: + path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"path_k: {e} not found") + else: + raise e dataset_impl_k = dataset_impl if dataset_impl_k is None: diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 67acf7d377..6732728de9 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -140,12 +140,19 @@ def get_path(key, split): def make_dataset(key, dictionary): split_path = get_path(key, split) - dataset = data_utils.load_indexed_dataset( - split_path, - dictionary, - self.args.dataset_impl, - combine=combine, - ) + try: + dataset = data_utils.load_indexed_dataset( + split_path, + dictionary, + self.args.dataset_impl, + combine=combine, + ) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"dataset {e} not found") + dataset = None + else: + raise e return dataset input0 = make_dataset("input0", self.source_dictionary) From 06c9cefed73cfd43f9453c616e1b9d3ef63f58cf Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Thu, 25 Mar 2021 15:25:07 -0700 Subject: [PATCH 529/707] update CoVoST2 recipe; pretrained S2T encoder loading Summary: - Update CoVoST2 recipe - Add back label smoothing and increase dropout for En ASR - Update pretrained S2T encoder loading - Omit non-existing pre-training checkpoint (this occurs when we distribute only the ST checkpoint without the ASR checkpoint it leverages for pre-training) Related github issue: https://github.com/pytorch/fairseq/issues/3364 Reviewed By: jmp84 Differential Revision: D27159799 fbshipit-source-id: 3759926d68bd3f41f9f8e5fe5004f508eb24c2c0 --- .../speech_to_text/docs/covost_example.md | 18 ++++---- .../models/speech_to_text/s2t_transformer.py | 46 +++++++++++++++---- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/examples/speech_to_text/docs/covost_example.md b/examples/speech_to_text/docs/covost_example.md index 55cd134c16..16447f041e 100644 --- a/examples/speech_to_text/docs/covost_example.md +++ b/examples/speech_to_text/docs/covost_example.md @@ -32,10 +32,10 @@ We train an En ASR model for encoder pre-training of all ST models: ```bash fairseq-train ${COVOST_ROOT}/en \ --config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \ - --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 60000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ - --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ - --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 50000 --max-update 60000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --report-accuracy --arch s2t_transformer_s --dropout 0.15 --optimizer adam --lr 2e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 ``` where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. @@ -62,10 +62,10 @@ Fr-En as example: ```bash fairseq-train ${COVOST_ROOT}/fr \ --config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \ - --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 60000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ - --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ - --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-update 30000 --max-tokens 40000 \ # --max-tokens 50000 for en-* + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ + --arch s2t_transformer_s --encoder-freezing-updates 1000 --optimizer adam --lr 2e-3 \ + --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} ``` where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better @@ -97,6 +97,6 @@ Type in WAV/FLAC/OGG audio paths (one per line) after the prompt. #### Results | --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model | |---|---|---|---|---|---|---|---|---|---|---| -| s2t_transformer_s | 31M | [26.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.0](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [18.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [13.0](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [13.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) | +| s2t_transformer_s | 31M | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [19.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.6](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [12.9](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [12.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) | [[Back]](..) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 814924ec97..7480dc7967 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -3,7 +3,9 @@ import logging import math from typing import Dict, List, Optional, Tuple +from pathlib import Path +import torch import torch.nn as nn from fairseq import checkpoint_utils, utils from fairseq.data.data_utils import lengths_to_padding_mask @@ -199,18 +201,28 @@ def add_args(parser): metavar="STR", help="model to take encoder weights from (for initialization)", ) + parser.add_argument( + '--encoder-freezing-updates', + default=None, + type=int, + metavar='N', + help='freeze encoder for first N updates' + ) @classmethod def build_encoder(cls, args): encoder = S2TTransformerEncoder(args) - if getattr(args, "load_pretrained_encoder_from", None): - encoder = checkpoint_utils.load_pretrained_component_from_model( - component=encoder, checkpoint=args.load_pretrained_encoder_from - ) - logger.info( - f"loaded pretrained encoder from: " - f"{args.load_pretrained_encoder_from}" - ) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") return encoder @classmethod @@ -267,6 +279,9 @@ class S2TTransformerEncoder(FairseqEncoder): def __init__(self, args): super().__init__(None) + self.encoder_freezing_updates = args.encoder_freezing_updates + self.num_updates = 0 + self.dropout_module = FairseqDropout( p=args.dropout, module_name=self.__class__.__name__ ) @@ -294,7 +309,7 @@ def __init__(self, args): else: self.layer_norm = None - def forward(self, src_tokens, src_lengths): + def _forward(self, src_tokens, src_lengths): x, input_lengths = self.subsample(src_tokens, src_lengths) x = self.embed_scale * x @@ -318,6 +333,14 @@ def forward(self, src_tokens, src_lengths): "src_lengths": [], } + def forward(self, src_tokens, src_lengths): + if self.num_updates < self.encoder_freezing_updates: + with torch.no_grad(): + x = self._forward(src_tokens, src_lengths) + else: + x = self._forward(src_tokens, src_lengths) + return x + def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = ( [] if len(encoder_out["encoder_out"]) == 0 @@ -348,6 +371,10 @@ def reorder_encoder_out(self, encoder_out, new_order): "src_lengths": [], # B x 1 } + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.num_updates = num_updates + class TransformerDecoderScriptable(TransformerDecoder): def extract_features( @@ -373,6 +400,7 @@ def extract_features( @register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") def base_architecture(args): + args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) # Convolutional subsampler args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_channels = getattr(args, "conv_channels", 1024) From 2d5eaa0e7b795633b7efd2645f26ab40b9b520d5 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Thu, 25 Mar 2021 16:04:59 -0700 Subject: [PATCH 530/707] restore original module forward() during jit Summary: checkpoint_wrapper replaces m.forward with functools.partial(), but functools.partial is not compatible with jit. so we need to find a way to restore the original forward function for jit to pick it up correctly. I decided to add one extra attribute 'precheckpoint_forward' to the module. and then it's up to the jit workflow to restore this 'precheckpoint_forward' function to the original forward location for jitting. Also added a check that checkpoint_wrapper can't be called twice on the same module. Reviewed By: myleott Differential Revision: D27303905 fbshipit-source-id: 75bf3b5858ed130598dcc21d85a48b43c661818c --- fairseq/modules/checkpoint_activations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index ae07dcfaa0..f4a277f349 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -8,7 +8,6 @@ import torch import torch.utils.checkpoint as checkpoint - from fairseq import utils @@ -26,9 +25,14 @@ def checkpoint_wrapper(m, offload_to_cpu=False): checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) """ + # should I check whether original_forward has already been set? + assert not hasattr( + m, "precheckpoint_forward" + ), "checkpoint function has already been applied?" + m.precheckpoint_forward = m.forward m.forward = functools.partial( _checkpointed_forward, - m.forward, # original_forward + m.precheckpoint_forward, # original_forward offload_to_cpu, ) return m From 6e91e226441fc3c68adf91bdef9f39d9d9dc2c9c Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Thu, 25 Mar 2021 23:29:35 -0700 Subject: [PATCH 531/707] Script monotonic attention and related modules Summary: Add types and rewrite some part of the model so JIT compiler likes it Reviewed By: jmp84, sravyapopuri388 Differential Revision: D27194261 fbshipit-source-id: 594532212c907ed97fc711e4f6a2a211d7e2b67e --- .../models/transformer_monotonic_attention.py | 84 ++-- .../modules/fixed_pre_decision.py | 68 +++- .../modules/monotonic_multihead_attention.py | 374 ++++++++++++++++-- 3 files changed, 463 insertions(+), 63 deletions(-) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index 65c12c6f5b..d7aeca5ea5 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, List, NamedTuple, Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -23,10 +25,22 @@ transformer_vaswani_wmt_en_de_big, transformer_vaswani_wmt_en_fr_big, ) +from torch import Tensor DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +TransformerMonotonicDecoderOut = NamedTuple( + "TransformerMonotonicDecoderOut", + [ + ("action", int), + ("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]), + ("step_list", Optional[List[Optional[Tensor]]]), + ("encoder_out", Optional[Dict[str, List[Tensor]]]), + ("encoder_padding_mask", Optional[Tensor]), + ], +) + @register_model("transformer_unidirectional") class TransformerUnidirectionalModel(TransformerModel): @@ -103,7 +117,10 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ) def pre_attention( - self, prev_output_tokens, encoder_out_dict, incremental_state=None + self, + prev_output_tokens, + encoder_out_dict: Dict[str, List[Tensor]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): positions = ( self.embed_positions( @@ -118,7 +135,6 @@ def pre_attention( prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] - # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) @@ -143,7 +159,7 @@ def pre_attention( return x, encoder_out, encoder_padding_mask def post_attention(self, x): - if self.layer_norm: + if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C @@ -154,7 +170,11 @@ def post_attention(self, x): return x - def clear_cache(self, incremental_state, end_id=None): + def clear_cache( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + end_id: Optional[int] = None, + ): """ Clear cache in the monotonic layers. The cache is generated because of a forward pass of decode but no prediction. @@ -163,11 +183,18 @@ def clear_cache(self, incremental_state, end_id=None): if end_id is None: end_id = len(self.layers) - for j in range(end_id): - self.layers[j].prune_incremental_state(incremental_state) + for index, layer in enumerate(self.layers): + if index < end_id: + layer.prune_incremental_state(incremental_state) def extract_features( - self, prev_output_tokens, encoder_out, incremental_state=None, **unused + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, # unused + alignment_layer: Optional[int] = None, # unused + alignment_heads: Optional[int] = None, # unsed ): """ Similar to *forward* but only return features. @@ -178,13 +205,14 @@ def extract_features( - a dictionary with any model-specific outputs """ # incremental_state = None + assert encoder_out is not None (x, encoder_outs, encoder_padding_mask) = self.pre_attention( prev_output_tokens, encoder_out, incremental_state ) attn = None inner_states = [x] - attn_list = [] - step_list = [] + attn_list: List[Optional[Dict[str, Tensor]]] = [] + step_list: List[Optional[Tensor]] = [] for i, layer in enumerate(self.layers): @@ -204,35 +232,43 @@ def extract_features( if incremental_state is not None: curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) - - if incremental_state.get("online", True): + if_online = incremental_state["online"]["only"] + assert if_online is not None + if if_online.to(torch.bool): # Online indicates that the encoder states are still changing + assert attn is not None + assert curr_steps is not None p_choose = ( - attn["p_choose"] - .squeeze(0) - .squeeze(1) - .gather(1, curr_steps.t()) + attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) ) new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) + src = incremental_state["steps"]["src"] + assert src is not None - if (new_steps >= incremental_state["steps"]["src"]).any(): + if (new_steps >= src).any(): # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state self.clear_cache(incremental_state, i + 1) - return x, {"action": 0} + return x, TransformerMonotonicDecoderOut( + action=0, + attn_list=None, + step_list=None, + encoder_out=None, + encoder_padding_mask=None, + ) x = self.post_attention(x) - return x, { - "action": 1, - "attn_list": attn_list, - "step_list": step_list, - "encoder_out": encoder_out, - "encoder_padding_mask": encoder_padding_mask, - } + return x, TransformerMonotonicDecoderOut( + action=1, + attn_list=attn_list, + step_list=step_list, + encoder_out=encoder_out, + encoder_padding_mask=encoder_padding_mask, + ) def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index cc5e7ad532..0e9dfb6dfd 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -1,6 +1,7 @@ from functools import partial import torch +from torch import Tensor import math import torch.nn.functional as F @@ -10,7 +11,7 @@ MonotonicMultiheadAttentionHardAligned, MonotonicMultiheadAttentionInfiniteLookback, ) - +from typing import Dict, Optional def fixed_pooling_monotonic_attention(monotonic_attention): def create_model(monotonic_attention, klass): @@ -80,7 +81,7 @@ def add_args(parser): def insert_zeros(self, x): bsz_num_heads, tgt_len, src_len = x.size() stride = self.pre_decision_ratio - weight = F.pad(x.new_ones(1, 1, 1), (stride - 1, 0)) + weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0)) x_upsample = F.conv_transpose1d( x.view(-1, src_len).unsqueeze(1), weight, @@ -89,25 +90,64 @@ def insert_zeros(self, x): ) return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) + def p_choose_waitk( + self, query, key, key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None + ): + """ + query: bsz, tgt_len + key: bsz, src_len + key_padding_mask: bsz, src_len + """ + if incremental_state is not None: + # Retrieve target length from incremental states + # For inference the length of query is always 1 + tgt = incremental_state["steps"]["tgt"] + assert tgt is not None + tgt_len = int(tgt) + else: + tgt_len, bsz, _ = query.size() + + src_len, bsz, _ = key.size() + + p_choose = torch.ones(bsz, tgt_len, src_len).to(query) + p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) + p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) + + if incremental_state is not None: + p_choose = p_choose[:, -1:] + tgt_len = 1 + + # Extend to each head + p_choose = ( + p_choose.contiguous() + .unsqueeze(1) + .expand(-1, self.num_heads, -1, -1) + .contiguous() + .view(-1, tgt_len, src_len) + ) + + return p_choose + def p_choose( self, - query, - key, - key_padding_mask=None, - incremental_state=None, - **extra_args + query: Optional[Tensor], + key: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): + assert key is not None + assert query is not None src_len = key.size(0) tgt_len = query.size(0) batch_size = query.size(1) if self.pre_decision_ratio == 1: - return super().p_choose( + return self.p_choose_waitk( query, key, - key_padding_mask=None, - incremental_state=None, - **extra_args + key_padding_mask, + incremental_state=incremental_state, ) key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) @@ -133,7 +173,7 @@ def p_choose( if key_padding_mask_pool is not None: key_padding_mask_pool = key_padding_mask_pool[:-1] - p_choose_pooled = super().p_choose( + p_choose_pooled = self.p_choose_waitk( query, key_pool, key_padding_mask_pool, @@ -148,11 +188,11 @@ def p_choose( p_choose = torch.cat( [ p_choose, - p_choose.new_zeros( + torch.zeros( p_choose.size(0), tgt_len, src_len - p_choose.size(-1) - ) + ).to(p_choose) ], dim=2 ) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 49882afcd8..2e3ce8742f 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -6,6 +6,7 @@ import math import torch +from torch import Tensor import torch.nn as nn import torch.nn.functional as F @@ -19,6 +20,7 @@ from fairseq.utils import convert_padding_direction from . import register_monotonic_attention +from typing import Dict, Optional @with_incremental_state @@ -101,13 +103,13 @@ def attn_energy( if key_padding_mask is not None: attn_energy = attn_energy.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) return attn_energy - def expected_alignment_train(self, p_choose, key_padding_mask): + def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]): """ Calculating expected alignment for MMA Mask is not need because p_choose will be 0 if masked @@ -175,7 +177,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask): return alpha def expected_alignment_infer( - self, p_choose, encoder_padding_mask, incremental_state + self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ): # TODO modify this function """ @@ -201,6 +203,7 @@ def expected_alignment_infer( "head_step", p_choose.new_zeros([bsz, self.num_heads]).long() ) + assert prev_monotonic_step is not None bsz, num_heads = prev_monotonic_step.size() assert num_heads == self.num_heads assert bsz * num_heads == bsz_num_heads @@ -292,16 +295,14 @@ def expected_alignment_infer( return alpha - def _get_monotonic_buffer(self, incremental_state): - return utils.get_incremental_state( - self, + def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + return self.get_incremental_state( incremental_state, 'monotonic', ) or {} - def _set_monotonic_buffer(self, incremental_state, buffer): - utils.set_incremental_state( - self, + def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]): + self.set_incremental_state( incremental_state, 'monotonic', buffer, @@ -312,8 +313,8 @@ def v_proj_output(self, value): def forward( self, query, key, value, - key_padding_mask=None, attn_mask=None, incremental_state=None, - need_weights=True, static_kv=False, *args, **kwargs + key_padding_mask=None, attn_mask=None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights=True, static_kv=False ): tgt_len, bsz, embed_dim = query.size() @@ -384,7 +385,328 @@ def __init__(self, args): self.q_in_proj = {"monotonic": self.q_proj} self.v_in_proj = {"output": self.v_proj} - def input_projections(self, query, key, value, name): + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--no-mass-preservation', action="store_false", + dest="mass_preservation", + help='Do not stay on the last token when decoding') + parser.add_argument('--mass-preservation', action="store_true", + dest="mass_preservation", + help='Stay on the last token when decoding') + parser.set_defaults(mass_preservation=True) + parser.add_argument('--noise-var', type=float, default=1.0, + help='Variance of discretness noise') + parser.add_argument('--noise-mean', type=float, default=0.0, + help='Mean of discretness noise') + parser.add_argument('--noise-type', type=str, default="flat", + help='Type of discretness noise') + parser.add_argument('--energy-bias', action="store_true", + default=False, + help='Bias for energy') + parser.add_argument('--energy-bias-init', type=float, default=-2.0, + help='Initial value of the bias for energy') + parser.add_argument('--attention-eps', type=float, default=1e-6, + help='Epsilon when calculating expected attention') + + def p_choose(self, *args): + raise NotImplementedError + + def attn_energy( + self, q_proj: Optional[Tensor], k_proj: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None + ): + """ + Calculating monotonic energies + + ============================================================ + Expected input size + q_proj: bsz * num_heads, tgt_len, self.head_dim + k_proj: bsz * num_heads, src_len, self.head_dim + key_padding_mask: bsz, src_len + attn_mask: tgt_len, src_len + """ + assert q_proj is not None # Optional[Tensor] annotations in the signature above are to make the JIT compiler happy + assert k_proj is not None + bsz, tgt_len, embed_dim = q_proj.size() + bsz = bsz // self.num_heads + src_len = k_proj.size(1) + + attn_energy = ( + torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + ) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_energy += attn_mask + + attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + attn_energy = attn_energy.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + return attn_energy + + def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]): + """ + Calculating expected alignment for MMA + Mask is not need because p_choose will be 0 if masked + + q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} + a_ij = p_ij q_ij + + Parallel solution: + ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) + + ============================================================ + Expected input size + p_choose: bsz * num_heads, tgt_len, src_len + """ + + # p_choose: bsz * num_heads, tgt_len, src_len + bsz_num_heads, tgt_len, src_len = p_choose.size() + + # cumprod_1mp : bsz * num_heads, tgt_len, src_len + cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) + cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) + + init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) + init_attention[:, :, 0] = 1.0 + + previous_attn = [init_attention] + + for i in range(tgt_len): + # p_choose: bsz * num_heads, tgt_len, src_len + # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len + # previous_attn[i]: bsz * num_heads, 1, src_len + # alpha_i: bsz * num_heads, src_len + alpha_i = ( + p_choose[:, i] + * cumprod_1mp[:, i] + * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) + ).clamp(0, 1.0) + previous_attn.append(alpha_i.unsqueeze(1)) + + # alpha: bsz * num_heads, tgt_len, src_len + alpha = torch.cat(previous_attn[1:], dim=1) + + if self.mass_preservation: + # Last token has the residual probabilities + if key_padding_mask is not None and key_padding_mask[:, -1].any(): + # right padding + batch_size = key_padding_mask.size(0) + residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) + src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) + src_lens = src_lens.expand( + batch_size, self.num_heads + ).contiguous().view(-1, 1) + src_lens = src_lens.expand(-1, tgt_len).contiguous() + # add back the last value + residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) + alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) + else: + residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) + alpha[:, :, -1] = residuals + + if torch.isnan(alpha).any(): + # Something is wrong + raise RuntimeError("NaN in alpha.") + + return alpha + + def expected_alignment_infer( + self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ): + # TODO modify this function + """ + Calculating mo alignment for MMA during inference time + + ============================================================ + Expected input size + p_choose: bsz * num_heads, tgt_len, src_len + incremental_state: dict + encodencoder_padding_mask: bsz * src_len + """ + # p_choose: bsz * self.num_heads, src_len + bsz_num_heads, tgt_len, src_len = p_choose.size() + # One token at a time + assert tgt_len == 1 + p_choose = p_choose[:, 0, :] + + monotonic_cache = self._get_monotonic_buffer(incremental_state) + + # prev_monotonic_step: bsz, num_heads + bsz = bsz_num_heads // self.num_heads + prev_monotonic_step = monotonic_cache.get( + "head_step", + p_choose.new_zeros([bsz, self.num_heads]).long() + ) + assert prev_monotonic_step is not None + bsz, num_heads = prev_monotonic_step.size() + assert num_heads == self.num_heads + assert bsz * num_heads == bsz_num_heads + + # p_choose: bsz, num_heads, src_len + p_choose = p_choose.view(bsz, num_heads, src_len) + + if encoder_padding_mask is not None: + src_lengths = src_len - \ + encoder_padding_mask.sum(dim=1, keepdim=True).long() + else: + src_lengths = torch.ones(bsz, 1).to(prev_monotonic_step) * src_len + + # src_lengths: bsz, num_heads + src_lengths = src_lengths.expand_as(prev_monotonic_step) + # new_monotonic_step: bsz, num_heads + new_monotonic_step = prev_monotonic_step + + step_offset = torch.tensor(0) + if encoder_padding_mask is not None: + if encoder_padding_mask[:, 0].any(): + # left_pad_source = True: + step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) + + max_steps = src_lengths - 1 if self.mass_preservation else src_lengths + + # finish_read: bsz, num_heads + finish_read = new_monotonic_step.eq(max_steps) + p_choose_i = torch.tensor(1) + while finish_read.sum().item() < bsz * self.num_heads: + # p_choose: bsz * self.num_heads, src_len + # only choose the p at monotonic steps + # p_choose_i: bsz , self.num_heads + p_choose_i = ( + p_choose.gather( + 2, + (step_offset + new_monotonic_step) + .unsqueeze(2) + .clamp(0, src_len - 1), + ) + ).squeeze(2) + + action = ( + (p_choose_i < 0.5) + .type_as(prev_monotonic_step) + .masked_fill(finish_read, 0) + ) + # 1 x bsz + # sample actions on unfinished seq + # 1 means stay, finish reading + # 0 means leave, continue reading + # dist = torch.distributions.bernoulli.Bernoulli(p_choose) + # action = dist.sample().type_as(finish_read) * (1 - finish_read) + + new_monotonic_step += action + + finish_read = new_monotonic_step.eq(max_steps) | (action == 0) + + monotonic_cache["head_step"] = new_monotonic_step + # Whether a head is looking for new input + monotonic_cache["head_read"] = ( + new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + ) + + # alpha: bsz * num_heads, 1, src_len + # new_monotonic_step: bsz, num_heads + alpha = ( + p_choose + .new_zeros([bsz * self.num_heads, src_len]) + .scatter( + 1, + (step_offset + new_monotonic_step) + .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), + 1 + ) + ) + + if not self.mass_preservation: + alpha = alpha.masked_fill( + (new_monotonic_step == max_steps) + .view(bsz * self.num_heads, 1), + 0 + ) + + alpha = alpha.unsqueeze(1) + + self._set_monotonic_buffer(incremental_state, monotonic_cache) + + return alpha + + def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + maybe_incremental_state = self.get_incremental_state( + incremental_state, + 'monotonic', + ) + if maybe_incremental_state is None: + typed_empty_dict: Dict[str, Optional[Tensor]] = {} + return typed_empty_dict + else: + return maybe_incremental_state + + def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]): + self.set_incremental_state( + incremental_state, + 'monotonic', + buffer, + ) + + def forward( + self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False, + ): + assert query is not None + assert value is not None + tgt_len, bsz, embed_dim = query.size() + src_len = value.size(0) + + # stepwise prob + # p_choose: bsz * self.num_heads, tgt_len, src_len + p_choose = self.p_choose( + query, key, key_padding_mask, incremental_state, + ) + + # expected alignment alpha + # bsz * self.num_heads, tgt_len, src_len + if incremental_state is not None: + alpha = self.expected_alignment_infer( + p_choose, key_padding_mask, incremental_state) + else: + alpha = self.expected_alignment_train( + p_choose, key_padding_mask) + + # expected attention beta + # bsz * self.num_heads, tgt_len, src_len + beta = self.expected_attention( + alpha, query, key, value, + key_padding_mask, attn_mask, + incremental_state + ) + + attn_weights = beta + + v_proj = self.v_proj_output(value) + assert v_proj is not None + + attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) + + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + + attn = self.out_proj(attn) + + beta = beta.view(bsz, self.num_heads, tgt_len, src_len) + alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) + p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) + + return attn, { + "alpha": alpha, + "beta": beta, + "p_choose": p_choose, + } + + def input_projections(self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], name: str): """ Prepare inputs for multihead attention @@ -398,7 +720,7 @@ def input_projections(self, query, key, value, name): if query is not None: bsz = query.size(1) - q = self.q_in_proj[name](query) + q = self.q_proj(query) q *= self.scaling q = q.contiguous().view( -1, bsz * self.num_heads, self.head_dim @@ -408,7 +730,7 @@ def input_projections(self, query, key, value, name): if key is not None: bsz = key.size(1) - k = self.k_in_proj[name](key) + k = self.k_proj(key) k = k.contiguous().view( -1, bsz * self.num_heads, self.head_dim ).transpose(0, 1) @@ -417,7 +739,7 @@ def input_projections(self, query, key, value, name): if value is not None: bsz = value.size(1) - v = self.v_in_proj[name](value) + v = self.v_proj(value) v = v.contiguous().view( -1, bsz * self.num_heads, self.head_dim ).transpose(0, 1) @@ -427,8 +749,8 @@ def input_projections(self, query, key, value, name): return q, k, v def p_choose( - self, query, key, key_padding_mask=None, - incremental_state=None, *extra_args + self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None ): """ Calculating step wise prob for reading and writing @@ -507,8 +829,8 @@ def init_soft_attention(self): nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) def expected_attention( - self, alpha, query, key, value, - key_padding_mask, attn_mask, incremental_state + self, alpha, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], + key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ): # monotonic attention, we will calculate milk here bsz_x_num_heads, tgt_len, src_len = alpha.size() @@ -524,7 +846,9 @@ def expected_attention( if incremental_state is not None: monotonic_cache = self._get_monotonic_buffer(incremental_state) - monotonic_length = monotonic_cache["head_step"] + 1 + head_step = monotonic_cache["head_step"] + assert head_step is not None + monotonic_length = head_step + 1 step_offset = 0 if key_padding_mask is not None: if key_padding_mask[:, 0].any(): @@ -536,7 +860,7 @@ def expected_attention( soft_energy.size(2), 1 ).unsqueeze(1) - soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) + soft_energy = soft_energy.masked_fill(~mask.to(torch.bool), float("-inf")) soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] exp_soft_energy = torch.exp(soft_energy) exp_soft_energy_sum = exp_soft_energy.sum(dim=2) @@ -557,7 +881,7 @@ def expected_attention( if key_padding_mask is not None: beta = beta.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), 0) + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0) beta = beta / beta.sum(dim=3, keepdim=True) beta = beta.view(bsz * self.num_heads, tgt_len, src_len) @@ -595,8 +919,8 @@ def add_args(parser): ) def p_choose( - self, query, key, key_padding_mask=None, - incremental_state=None, *extra_args + self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): """ query: bsz, tgt_len @@ -612,7 +936,7 @@ def p_choose( src_len, bsz, _ = key.size() - p_choose = query.new_ones(bsz, tgt_len, src_len) + p_choose = torch.ones(bsz, tgt_len, src_len).to(query) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) From 93bb1905ac56a687d9a550537704cb339b226a9e Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Thu, 25 Mar 2021 23:29:35 -0700 Subject: [PATCH 532/707] Script SimulConvTransformerModel Summary: Copy & paste forward function from parent classes to ConvTransformerEmformerEncoder / TransformerMonotonicDecoderLayer, and modify slightly to make JIT compiler happy. TS in general doesn't work well with polymorphism+inheritance. Reviewed By: jmp84, sravyapopuri388 Differential Revision: D27194275 fbshipit-source-id: 5248f017ac86adcb09f398038c10a0d20bc03453 --- .../models/convtransformer_simul_trans.py | 55 +++++- .../modules/monotonic_transformer_layer.py | 168 ++++++++++++++++-- 2 files changed, 207 insertions(+), 16 deletions(-) diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py index 0b15e93fea..4a26422f65 100644 --- a/examples/simultaneous_translation/models/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -20,8 +20,10 @@ SequenceEncoder, AugmentedMemoryConvTransformerEncoder, ) -from fairseq.models.speech_to_text.modules.emformer import emformer_encoder +from torch import nn, Tensor +from typing import Dict, List +from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer @register_model("convtransformer_simul_trans") class SimulConvTransformerModel(ConvTransformerModel): @@ -97,9 +99,56 @@ def augmented_memory_convtransformer_espnet(args): # ============================================================================ # -@emformer_encoder class ConvTransformerEmformerEncoder(ConvTransformerEncoder): - pass + def __init__(self, args): + super().__init__(args) + stride = self.conv_layer_stride(args) + trf_left_context = args.segment_left_context // stride + trf_right_context = args.segment_right_context // stride + context_config = [trf_left_context, trf_right_context] + self.transformer_layers = nn.ModuleList( + [ + NoSegAugmentedMemoryTransformerEncoderLayer( + input_dim=args.encoder_embed_dim, + num_heads=args.encoder_attention_heads, + ffn_dim=args.encoder_ffn_embed_dim, + num_layers=args.encoder_layers, + dropout_in_attn=args.dropout, + dropout_on_attn=args.dropout, + dropout_on_fc1=args.dropout, + dropout_on_fc2=args.dropout, + activation_fn=args.activation_fn, + context_config=context_config, + segment_size=args.segment_length, + max_memory_size=args.max_memory_size, + scaled_init=True, # TODO: use constant for now. + tanh_on_mem=args.amtrf_tanh_on_mem, + ) + ] + ) + self.conv_transformer_encoder = ConvTransformerEncoder(args) + + def forward(self, src_tokens, src_lengths): + encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device)) + output = encoder_out["encoder_out"][0] + encoder_padding_masks = encoder_out["encoder_padding_mask"] + + return { + "encoder_out": [output], + # This is because that in the original implementation + # the output didn't consider the last segment as right context. + "encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0 + else [], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @staticmethod + def conv_layer_stride(args): + # TODO: make it configurable from the args + return 4 @register_model("convtransformer_emformer") diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index e6e1850a18..bcd45aa8a6 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -7,6 +7,11 @@ from . import build_monotonic_attention +from typing import Dict, List, Optional + +import torch +from torch import Tensor + class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): def forward(self, x, encoder_padding_mask): @@ -34,23 +39,160 @@ def __init__( self.embed_dim, export=getattr(args, "char_inputs", False) ) - def get_head_steps(self, incremental_state): + def get_head_steps(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): return self.encoder_attn._get_monotonic_buffer(incremental_state).get( "head_step" ) - def prune_incremental_state(self, incremental_state): - def prune(module): - input_buffer = module._get_input_buffer(incremental_state) - for key in ["prev_key", "prev_value"]: - if input_buffer[key].size(2) > 1: - input_buffer[key] = input_buffer[key][:, :, :-1, :] - else: - input_buffer = {} - break - module._set_input_buffer(incremental_state, input_buffer) - - prune(self.self_attn) + def prune_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + input_buffer = self.self_attn._get_input_buffer(incremental_state) + for key in ["prev_key", "prev_value"]: + input_buffer_key = input_buffer[key] + assert input_buffer_key is not None + if input_buffer_key.size(2) > 1: + input_buffer[key] = input_buffer_key[:, :, :-1, :] + else: + typed_empty_dict: Dict[str, Optional[Tensor]] = {} + input_buffer = typed_empty_dict + break + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, input_buffer) def get_steps(self, incremental_state): return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0) + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + assert self.encoder_attn is not None + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, self_attn_state + return x, attn, None From a28511d43d2310226086c5fee82042fcdcf4cff0 Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Thu, 25 Mar 2021 23:29:35 -0700 Subject: [PATCH 533/707] Ad-hoc changes Summary: Some ad-hoc changes to address errors when torchscripting simul ST models. Not sure if all of them are necessary and what the best practice is (yet) but having this diff to temporarily unblock. Reviewed By: sravyapopuri388 Differential Revision: D27194569 fbshipit-source-id: 8936f3edb408df7e1a3fd97c0e07b2356ff0d9b4 --- .../utils/data_utils.py | 2 +- .../agents/fairseq_simul_st_agent.py | 2 +- fairseq/data/data_utils.py | 20 +++++++++++-------- .../models/speech_to_text/convtransformer.py | 7 ++++--- .../models/speech_to_text/modules/emformer.py | 3 ++- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/simultaneous_translation/utils/data_utils.py b/examples/simultaneous_translation/utils/data_utils.py index cc4729e63c..a763ea6686 100644 --- a/examples/simultaneous_translation/utils/data_utils.py +++ b/examples/simultaneous_translation/utils/data_utils.py @@ -28,7 +28,7 @@ def apply_mv_norm(features): return res -def lengths_to_encoder_padding_mask(lengths, batch_first=False): +def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): """ convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 9ff07775da..58e38963b5 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -320,7 +320,7 @@ def policy(self, states): "tgt": 1 + len(states.units.target), } - states.incremental_states["online"] = not states.finish_read() + states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())} x, outputs = self.model.decoder.forward( prev_output_tokens=tgt_indices, diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 097ea09f76..de1d2edf11 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -512,23 +512,27 @@ def arrange(s, e, length, keep_length): def get_mem_usage(): - try: - import psutil + # try: + import psutil - mb = 1024 * 1024 - return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" - except ImportError: - return "N/A" + mb = 1024 * 1024 + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" + # except ImportError: + # return "N/A" -def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_padding_mask(lens): bsz, max_lens = lens.size(0), torch.max(lens).item() mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) return mask -def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_mask(lens): return ~lengths_to_padding_mask(lens) diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index a4cbbcdeeb..40e6dd3f4e 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -303,10 +303,10 @@ def forward(self, src_tokens, src_lengths): x = self.embed_scale * x subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - + input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() + input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to(input_len_0.device) input_lengths = torch.min( - (src_lengths.float() / subsampling_factor).ceil().long(), - x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() + input_len_0, input_len_1 ) encoder_padding_mask = lengths_to_padding_mask(input_lengths) @@ -323,6 +323,7 @@ def forward(self, src_tokens, src_lengths): else: maybe_encoder_padding_mask = encoder_padding_mask + return { "encoder_out": [x], "encoder_padding_mask": [maybe_encoder_padding_mask] diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py index e026b86847..6ef76bd012 100644 --- a/fairseq/models/speech_to_text/modules/emformer.py +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -1812,7 +1812,8 @@ def __init__(self, args): def forward(self, src_tokens, src_lengths): encoder_out = super().forward(src_tokens, src_lengths) - (output, encoder_padding_masks, [], _) = encoder_out["encoder_out"][0] + output = encoder_out["encoder_out"][0] + encoder_padding_masks = encoder_out["encoder_padding_mask"][0] # This is because that in the original implementation # the output didn't consider the last segment as right context. From be1d186fa59aa19d7a0735a32af88b5a2bacc5ae Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Fri, 26 Mar 2021 07:18:08 -0700 Subject: [PATCH 534/707] FSDP uses new optimizer gathering to save optimizer state (#1744) Summary: - Full unflattened optimizer state dict is in `checkpoints/shard_0.pt`, other checkpoint files do not have the `last_optimizer_state` key. - requires master version of fairscale (eventually fairscale>=0.3.3) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1744 Reviewed By: myleott Differential Revision: D27342305 Pulled By: sshleifer fbshipit-source-id: 7442b8c6ed01599d8ab0050213e84051f4e98acd --- fairseq/checkpoint_utils.py | 2 +- fairseq/data/data_utils.py | 7 +++---- fairseq/trainer.py | 15 ++++++++++++++- scripts/test_fsdp.sh | 17 ++++++++++++++--- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 5a98dad2aa..86e00a7714 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -42,7 +42,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if cfg.no_save: return - trainer.consolidate_optimizer() + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: return diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index de1d2edf11..4424d1dc53 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -312,13 +312,12 @@ def batch_by_size( ) except ImportError: raise ImportError( - "Please build Cython components with: `pip install --editable .` " - "or `python setup.py build_ext --inplace`" + "Please build Cython components with: " + "`python setup.py build_ext --inplace`" ) except ValueError: raise ValueError( - "Please build (or rebuild) Cython components with: `pip install " - " --editable .` or `python setup.py build_ext --inplace`." + "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`." ) # added int() to avoid TypeError: an integer is required diff --git a/fairseq/trainer.py b/fairseq/trainer.py index f7897070c7..4535e9bda7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -331,9 +331,15 @@ def _build_optimizer(self): def consolidate_optimizer(self): """For OSS, we need to consolidate the state dict.""" + self._gathered_optim_state = None if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + elif self.cfg.distributed_training.ddp_backend == 'fully_sharded': + self._gathered_optim_state = self.model.gather_full_optim_state_dict(self.optimizer, + recipient_rank=0) + + def state_dict(self): state_dict = { "args": None, # legacy @@ -362,7 +368,11 @@ def state_dict(self): } } if not self.cfg.checkpoint.no_save_optimizer_state: - state_dict["last_optimizer_state"] = self.optimizer.state_dict() + if self._gathered_optim_state is not None: + state_dict["last_optimizer_state"] = self._gathered_optim_state + self._gathered_optim_state = None + else: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() return state_dict def save_checkpoint(self, filename, extra_state): @@ -478,6 +488,9 @@ def load_checkpoint( last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) + elif self.cfg.distributed_training.ddp_backend == 'fully_sharded': + last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state) + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) diff --git a/scripts/test_fsdp.sh b/scripts/test_fsdp.sh index 0f4d6c420b..1f428a035e 100755 --- a/scripts/test_fsdp.sh +++ b/scripts/test_fsdp.sh @@ -1,13 +1,24 @@ #!/usr/bin/env bash rm -rf fsdp_dummy mkdir -p fsdp_dummy -fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ --cpu-offload --checkpoint-activations \ --task language_modeling --tokens-per-sample 256 --batch-size 8 \ --arch transformer_lm_gpt2_tiny \ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ - --max-update 10 --log-format json --log-interval 1 \ - --save-interval-updates 10 --save-dir fsdp_dummy \ + --max-update 5 --log-format json --log-interval 1 \ + --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \ --restore-file x.pt "$@" + +# Now we try to load the checkpoint +CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 256 --batch-size 8 \ + --arch transformer_lm_gpt2_tiny \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 2 --log-format json --log-interval 1 \ + --save-interval-updates 2 --save-dir fsdp_dummy From 0975816a853eed81059db9ca02ddd3c73ea64926 Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Fri, 26 Mar 2021 21:59:51 -0700 Subject: [PATCH 535/707] Add back try/except in data_utils Summary: Accidentally removed try/except block in D27194569 (https://github.com/pytorch/fairseq/commit/a28511d43d2310226086c5fee82042fcdcf4cff0) and caused OSS unit test failure. Reviewed By: myleott Differential Revision: D27374072 fbshipit-source-id: 154c92eae106aa3bf6b406d52a8647db140dc6b3 --- fairseq/data/data_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 4424d1dc53..79df6f3769 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -511,13 +511,13 @@ def arrange(s, e, length, keep_length): def get_mem_usage(): - # try: - import psutil + try: + import psutil - mb = 1024 * 1024 - return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" - # except ImportError: - # return "N/A" + mb = 1024 * 1024 + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" + except ImportError: + return "N/A" # lens: torch.LongTensor From 1c9738c6e9ae2e6612e3f8e4f841d22fbfbfa68c Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <31931787+wnhsu@users.noreply.github.com> Date: Mon, 29 Mar 2021 12:09:59 -0700 Subject: [PATCH 536/707] fix speech_recognition hydra decoder bugs (#1742) Summary: ## What does this PR do? Bug 1: generated hypotheses and references transcripts are incomplete when using data_parallel_world_size > 1 Reason: lack of barrier to ensure all workers completes dumping transcripts before starting merging transcripts Bug 2: program failed when using data_parallel_world_size == 1 Reason: unnecessary reduce operation is introduced / transcripts do not need to be merged Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1742 Reviewed By: alexeib Differential Revision: D27403362 Pulled By: wnhsu fbshipit-source-id: b74889660c7253264b986ea35c248d80e0e32358 --- examples/speech_recognition/hydra/decoder.py | 2 +- examples/speech_recognition/hydra/infer.py | 52 +++++++++++--------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/examples/speech_recognition/hydra/decoder.py b/examples/speech_recognition/hydra/decoder.py index 41fcbd7087..d182b95a32 100644 --- a/examples/speech_recognition/hydra/decoder.py +++ b/examples/speech_recognition/hydra/decoder.py @@ -250,7 +250,7 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ - f"{spelling} {spelling_idxs}" + f"{word} {spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/hydra/infer.py index b1c985bc0d..1b49823553 100644 --- a/examples/speech_recognition/hydra/infer.py +++ b/examples/speech_recognition/hydra/infer.py @@ -155,27 +155,29 @@ def merge_shards(self) -> None: shard_id = self.data_parallel_rank num_shards = self.data_parallel_world_size - def merge_shards_with_root(fname: str) -> None: - logger.info("Merging %s on shard %d", fname, shard_id) - base_fpath = Path(f"{fname}.0") - with open(base_fpath, "a") as out_file: - for s in range(1, num_shards): - shard_fpath = Path(f"{fname}.{s}") - with open(shard_fpath, "r") as in_file: - for line in in_file: - out_file.write(line) - shard_fpath.unlink() - shutil.move(f"{fname}.0", fname) - - if shard_id == (0 % num_shards): - merge_shards_with_root("hypo.word") - if shard_id == (1 % num_shards): - merge_shards_with_root("hypo.units") - if shard_id == (2 % num_shards): - merge_shards_with_root("ref.word") - if shard_id == (3 % num_shards): - merge_shards_with_root("ref.units") - dist.barrier() + if self.data_parallel_world_size > 1: + def merge_shards_with_root(fname: str) -> None: + logger.info("Merging %s on shard %d", fname, shard_id) + base_fpath = Path(f"{fname}.0") + with open(base_fpath, "a") as out_file: + for s in range(1, num_shards): + shard_fpath = Path(f"{fname}.{s}") + with open(shard_fpath, "r") as in_file: + for line in in_file: + out_file.write(line) + shard_fpath.unlink() + shutil.move(f"{fname}.0", fname) + + dist.barrier() # ensure all shards finished writing + if shard_id == (0 % num_shards): + merge_shards_with_root("hypo.word") + if shard_id == (1 % num_shards): + merge_shards_with_root("hypo.units") + if shard_id == (2 % num_shards): + merge_shards_with_root("ref.word") + if shard_id == (3 % num_shards): + merge_shards_with_root("ref.units") + dist.barrier() def optimize_model(self, model: FairseqModel) -> None: gcfg = self.cfg.generation @@ -370,7 +372,7 @@ def main(cfg: InferConfig) -> float: if cfg.common.cpu: logger.warning("Merging WER requires CUDA.") - else: + elif processor.data_parallel_world_size > 1: stats = torch.LongTensor([errs_t, leng_t]).cuda() dist.all_reduce(stats, op=dist.ReduceOp.SUM) errs_t, leng_t = stats[0].item(), stats[1].item() @@ -379,7 +381,11 @@ def main(cfg: InferConfig) -> float: if distributed_utils.is_master(cfg.distributed_training): with open(wer_file, "w") as f: - f.write(f"WER: {wer}\n\n{yaml_str}") + f.write(( + f"WER: {wer}\n" + f"err / num_ref_words = {errs_t} / {leng_t}\n\n" + f"{yaml_str}" + )) return wer From 7dafb05754fe268bb5f76a1c97cf3a14062f44e5 Mon Sep 17 00:00:00 2001 From: Michael Lewis <mikelewis@fb.com> Date: Mon, 29 Mar 2021 18:02:07 -0700 Subject: [PATCH 537/707] BASE layers (#1654) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1654 Reviewed By: myleott Differential Revision: D27128074 Pulled By: shruti-bh fbshipit-source-id: ac86d383cd53c9c9bdd946fea839a37b719d95e3 --- fairseq/checkpoint_utils.py | 8 +- fairseq/clib/libbase/balanced_assignment.cpp | 95 ++++++++++++ fairseq/data/data_utils.py | 5 +- fairseq/data/monolingual_dataset.py | 27 +++- .../legacy_distributed_data_parallel.py | 5 + fairseq/distributed/utils.py | 3 + fairseq/models/transformer.py | 5 + fairseq/models/transformer_lm.py | 14 ++ fairseq/modules/__init__.py | 2 + fairseq/modules/base_layer.py | 135 ++++++++++++++++++ fairseq/optim/fp16_optimizer.py | 2 + fairseq/tasks/language_modeling.py | 18 +++ fairseq/trainer.py | 3 +- fairseq/utils.py | 8 +- fairseq_cli/train.py | 13 +- setup.py | 11 ++ 16 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 fairseq/clib/libbase/balanced_assignment.cpp create mode 100644 fairseq/modules/base_layer.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 86e00a7714..7e1b8479d1 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -115,7 +115,7 @@ def is_better(a, b): if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) ) for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): @@ -123,7 +123,7 @@ def is_better(a, b): if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)) for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -132,8 +132,8 @@ def is_better(a, b): # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, - pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - cfg.best_checkpoint_metric + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix ), ) if not cfg.maximize_best_checkpoint_metric: diff --git a/fairseq/clib/libbase/balanced_assignment.cpp b/fairseq/clib/libbase/balanced_assignment.cpp new file mode 100644 index 0000000000..296f03b6ae --- /dev/null +++ b/fairseq/clib/libbase/balanced_assignment.cpp @@ -0,0 +1,95 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* +C++ code for solving the linear assignment problem. +Based on the Auction Algorithm from https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the implementation from: +https://github.com/bkj/auction-lap +Adapted to be more efficient when each worker is looking for k jobs instead of 1. +*/ +#include <torch/extension.h> +#include <iostream> +using namespace torch::indexing; +torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) { + int max_iterations = 100; + torch::Tensor epsilon = (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; + epsilon.clamp_min_(1e-04); + torch::Tensor worker_and_job_to_score = job_and_worker_to_score.detach().transpose(0,1).contiguous(); + int num_workers = worker_and_job_to_score.size(0); + int num_jobs = worker_and_job_to_score.size(1); + auto device = worker_and_job_to_score.device(); + int jobs_per_worker = num_jobs / num_workers; + torch::Tensor value = worker_and_job_to_score.clone(); + int counter = 0; + torch::Tensor max_value = worker_and_job_to_score.max(); + + torch::Tensor bid_indices; + torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); + torch::Tensor bids = worker_and_job_to_score.new_empty({num_workers, num_jobs}); + torch::Tensor bid_increments = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); + torch::Tensor top_values = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); + torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); + + torch::Tensor top_index = top_values.to(torch::kLong); + torch::Tensor high_bidders = top_index.new_empty({num_jobs}); + torch::Tensor have_bids = high_bidders.to(torch::kBool); + torch::Tensor jobs_indices = torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); + torch::Tensor true_tensor = torch::ones({1}, torch::dtype(torch::kBool).device(device)); + + while (true) { + bids.zero_(); + torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); + + // Each worker bids the difference in value between that job and the k+1th job + torch::sub_out(bid_increments, + top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), + top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); + + bid_increments.add_(epsilon); + bids.scatter_(1, + top_index.index({Slice(None, None),Slice(0, jobs_per_worker)}), + bid_increments); + + if (counter < max_iterations && counter > 0) { + // Put in a minimal bid to retain items from the last round if no-one else bids for them this round + bids.view(-1).index_put_({bid_indices}, epsilon); + } + + // Find the highest bidding worker per job + torch::max_out(high_bids, high_bidders, bids, 0); + torch::gt_out(have_bids, high_bids, 0); + + if (have_bids.all().item<bool>()) { + // All jobs were bid for + break; + } + + // Make popular items more expensive + cost.add_(high_bids); + torch::sub_out(value, worker_and_job_to_score, cost); + + bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); + + if (counter < max_iterations) { + // Make sure that this item will be in the winning worker's top-k next time. + value.view(-1).index_put_({bid_indices}, max_value); + } + else { + // Suboptimal approximation that converges quickly from current solution + value.view(-1).index_put_({bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); + } + + counter += 1; + } + + return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}).reshape(-1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment"); +} diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 79df6f3769..63c7fcd118 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -41,13 +41,16 @@ def collate_tokens( move_eos_to_beginning=False, pad_to_length=None, pad_to_multiple=1, + pad_to_bsz=None, ): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) - res = values[0].new(len(values), size).fill_(pad_idx) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index bf7aa86f6c..54fd583b64 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -9,7 +9,7 @@ from . import FairseqDataset, data_utils -def collate(samples, pad_idx, eos_idx): +def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): if len(samples) == 0: return {} @@ -23,6 +23,8 @@ def merge(key, is_list=False): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) ) return res @@ -32,6 +34,8 @@ def merge(key, is_list=False): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) src_tokens = merge("source") @@ -75,6 +79,10 @@ def __init__( shuffle=False, targets=None, add_bos_token=False, + fixed_pad_length=None, + pad_to_bsz=None, + src_lang_idx=None, + tgt_lang_idx=None, ): self.dataset = dataset self.sizes = np.array(sizes) @@ -83,6 +91,10 @@ def __init__( self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token + self.fixed_pad_length = fixed_pad_length + self.pad_to_bsz = pad_to_bsz + self.src_lang_idx = src_lang_idx + self.tgt_lang_idx = tgt_lang_idx assert targets is None or all( t in {"self", "future", "past"} for t in targets @@ -165,6 +177,11 @@ def _maybe_add_bos(self, source, target): target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) return source, target + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + return self.sizes[indices] + def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): @@ -200,7 +217,13 @@ def collater(self, samples): target sentence of shape `(bsz, tgt_len)`. Padding will appear on the right. """ - return collate(samples, self.vocab.pad(), self.vocab.eos()) + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.fixed_pad_length, + self.pad_to_bsz, + ) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index b586e76b7f..f2308f87c5 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -136,6 +136,11 @@ def reduction_fn(): continue if param.grad is None: param.grad = torch.zeros_like(param) + + if hasattr(param, 'expert'): + # Skip gradient sync for unshared parameters + continue + if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 970b784915..b09e87fe09 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -306,6 +306,9 @@ def distributed_init(cfg: FairseqConfig): model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + if getattr(cfg.model, "base_layers", 0) > 0: + cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" + return cfg.distributed_training.distributed_rank diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 1e47d102f9..8da5beb3aa 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -19,6 +19,7 @@ ) from fairseq.modules import ( AdaptiveSoftmax, + BaseLayer, FairseqDropout, LayerDropModuleList, LayerNorm, @@ -751,6 +752,10 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) + num_base_layers = getattr(args, "base_layers", 0) + for i in range(num_base_layers): + self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) + def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index b616a923d4..a546776912 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -180,6 +180,16 @@ class TransformerLanguageModelConfig(FairseqDataclass): ) } ) + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} + ) # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") @@ -313,6 +323,10 @@ def base_lm_architecture(args): args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.base_layers = getattr(args, "base_layers", 0) + args.base_sublayers = getattr(args, "base_sublayers", 1) + args.base_shuffle = getattr(args, "base_shuffle", False) + args.add_bos_token = getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index e2326ac6e3..81930aa71c 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -6,6 +6,7 @@ from .adaptive_input import AdaptiveInput from .adaptive_softmax import AdaptiveSoftmax +from .base_layer import BaseLayer from .beamable_mm import BeamableMM from .character_token_embedder import CharacterTokenEmbedder from .conv_tbc import ConvTBC @@ -39,6 +40,7 @@ __all__ = [ "AdaptiveInput", "AdaptiveSoftmax", + "BaseLayer", "BeamableMM", "CharacterTokenEmbedder", "ConvTBC", diff --git a/fairseq/modules/base_layer.py b/fairseq/modules/base_layer.py new file mode 100644 index 0000000000..e7ef155b25 --- /dev/null +++ b/fairseq/modules/base_layer.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch +import sys +from fairseq import utils +from fairseq.distributed import utils as distributed_utils +from fairseq.modules.layer_norm import LayerNorm + + +class BaseLayer(nn.Module): + + def __init__(self, args): + super().__init__() + self.num_workers = distributed_utils.get_data_parallel_world_size() + expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) + torch.nn.init.orthogonal_(expert_centroids, gain=0.1) + self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids)) + self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)])) + self.expert_id = distributed_utils.get_data_parallel_rank() + self.shuffle = args.base_shuffle + self.cpp = self.load_assignment() + + # Add a special attribute to the expert parameters, so we know not to sync their gradients + for param in self.expert_network.parameters(): + param.expert = True + + def forward(self, input_features, *args, **kwargs): + features = input_features.reshape(-1, input_features.size(-1)) + is_training = input_features.requires_grad + + if self.shuffle and is_training: + # Send each token to a random worker, to break correlations within the batch + shuffle_sort = torch.randperm(features.size(0), device=features.device) + features = All2All.apply(features[shuffle_sort]) + + with torch.no_grad(): + # Compute similarity of each token to each expert, for routing + token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1)) + + # Compute which token goes to which expert + sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \ + if is_training else self.greedy_assignment(token_expert_affinities) + # Swap these tokens for the right ones for our expert + routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits) + + if routed_features.size(0) > 0: + # Mix in the expert network based on how appropriate it is for these tokens + alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1) + routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features + # Return to original worker and ordering + result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)] + + if self.shuffle and is_training: + # Undo shuffling + result = All2All.apply(result)[self.inverse_sort(shuffle_sort)] + + # Return additional Nones for compatibility with TransformerDecoderLayer + return result.view(input_features.size()), None, None + + def inverse_sort(self, order): + # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] + return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device)) + + def balanced_assignment(self, scores): + ok = scores.isfinite() + if not ok.all(): + # NaNs here can break the assignment algorithm + scores[~ok] = scores[ok].min() + return self.cpp.balanced_assignment(scores), None, None + + # Assigns each token to the top k experts + def greedy_assignment(self, scores, k=1): + token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) + token_to_workers, sort_ordering = torch.sort(token_to_workers) + worker2token = sort_ordering // k + + # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) + output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device) + workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) + output_splits[workers] = counts + # Tell other workers how many tokens to expect from us + input_splits = All2All.apply(output_splits) + return worker2token, input_splits.tolist(), output_splits.tolist() + + def load_assignment(self): + try: + from fairseq import libbase + + return libbase + + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbase. run `python setup.py build_ext --inplace`\n" + ) + raise e + + +class BaseSublayer(nn.Module): + def __init__(self, args): + super().__init__() + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') or "relu" + ) + self.norm = LayerNorm(args.decoder_embed_dim, export=False) + self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) + self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim) + self.ff2.weight.data.zero_() + + def forward(self, xs): + return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs)))) + + +# Wraps torch.distributed.all_to_all_single as a function that supports autograd +class All2All(torch.autograd.Function): + @staticmethod + def forward(ctx, xs, input_splits=None, output_splits=None): + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + ys = torch.empty_like(xs) if output_splits is None else \ + xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) + torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits) + return ys + + @staticmethod + def backward(ctx, grad_output): + result = torch.empty_like(grad_output) if ctx.input_splits is None else \ + grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])) + torch.distributed.all_to_all_single(result, grad_output, + output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits) + return result, None, None diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 00ea1bbb76..370a910102 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -64,6 +64,8 @@ def build_fp32_params(cls, args, params, flatten=True): fp32_params = [] for p in params: p32 = torch.nn.Parameter(p.data.float()) + if hasattr(p, 'expert'): + p32.expert = True p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): p32.param_group = p.param_group diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index a3847733a1..3069490fdc 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -84,8 +84,17 @@ class LanguageModelingConfig(FairseqDataclass): 'e.g., "train,valid" (default: all dataset splits)' }, ) + pad_to_fixed_length: Optional[bool] = field( + default=False, metadata={"help": "pad to fixed length"}, + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, metadata={"help": "boolean to pad to fixed batch size"}, + ) + # TODO common vars below add to parent seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( "dataset.dataset_impl" ) @@ -232,6 +241,13 @@ def load_dataset( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = self.args.batch_size_valid if 'valid' in split else self.args.batch_size self.datasets[split] = MonolingualDataset( dataset=dataset, @@ -242,6 +258,8 @@ def load_dataset( shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4535e9bda7..6195afb4a6 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -195,7 +195,7 @@ def use_distributed_wrapper(self) -> bool: @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.distributed_training.ddp_backend == "fully_sharded" or getattr(self.cfg.model, "base_layers", 0) > 0: return True else: return self.is_data_parallel_master @@ -415,6 +415,7 @@ def load_checkpoint( or self.tpu # FSDP requires loading checkpoint shards on all ranks or self.cfg.distributed_training.ddp_backend == "fully_sharded" + or getattr(self.cfg.model, "base_layers", 0) > 0 ) if load_on_all_ranks or self.data_parallel_rank == 0: diff --git a/fairseq/utils.py b/fairseq/utils.py index 90bb8369f2..03826d18d0 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -339,10 +339,14 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: @torch.no_grad() def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None if isinstance(params, torch.Tensor): params = [params] params = list(params) - grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] + grads = [p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, 'expert')] + expert_grads = [p.grad.detach() for p in params if grad_exists(p) and hasattr(p, 'expert')] + if len(grads) == 0: if len(params) > 0: return params[0].new_tensor(0.0) @@ -377,7 +381,7 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if max_norm > 0: max_norm = float(max_norm) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) - for g in grads: + for g in grads + expert_grads: g.mul_(clip_coef) return total_norm diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 6924dfe5c8..1cca64d988 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -98,9 +98,16 @@ def main(cfg: FairseqConfig) -> None: logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( - "num. model params: {:,} (num. trained: {:,})".format( - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), + "num. shared model params: {:,} (num. trained: {:,})".format( + sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), + sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad) + ) + ) + + logger.info( + "num. expert model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), ) ) diff --git a/setup.py b/setup.py index 3670ff3cfc..51e555229c 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,17 @@ def include_dirs(self, dirs): # torch is not available when generating docs from torch.utils import cpp_extension + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libbase", + sources=[ + "fairseq/clib/libbase/balanced_assignment.cpp", + ], + ) + ] + ) + extensions.extend( [ cpp_extension.CppExtension( From c2e8904b6072d8eddab362ac50b324e374b5951d Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek <guw@fb.com> Date: Tue, 30 Mar 2021 09:54:22 -0700 Subject: [PATCH 538/707] Obt 2 (#1614) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? too many of them actually ^^ ## What does this PR do? This is a rewrite of https://github.com/fairinternal/fairseq-py/issues/1538 following the discussion there, and taking into account the proposed https://github.com/fairinternal/fairseq-py/issues/1560 from Myle. it brings online backtranslation to fairseq. It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1614 Reviewed By: myleott Differential Revision: D27157296 Pulled By: gwenzek fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21 --- fairseq/data/noising.py | 2 + fairseq/data/round_robin_zip_datasets.py | 2 +- .../data/transform_eos_lang_pair_dataset.py | 3 + fairseq/models/roberta/__init__.py | 1 + fairseq/models/roberta/enc_dec.py | 192 +++++ fairseq/models/roberta/model.py | 6 +- fairseq/models/transformer.py | 40 +- fairseq/modules/multihead_attention.py | 11 +- fairseq/options.py | 18 +- fairseq/sequence_generator.py | 14 +- fairseq/tasks/online_backtranslation.py | 677 ++++++++++++++++++ fairseq/tasks/translation.py | 4 + .../tasks/translation_from_pretrained_bart.py | 2 +- fairseq_cli/train.py | 8 +- tests/test_online_backtranslation.py | 206 ++++++ tests/test_roberta.py | 314 ++++++++ 16 files changed, 1472 insertions(+), 28 deletions(-) create mode 100644 fairseq/models/roberta/enc_dec.py create mode 100644 fairseq/tasks/online_backtranslation.py create mode 100644 tests/test_online_backtranslation.py create mode 100644 tests/test_roberta.py diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py index 9643d1aa6a..2b1cc34720 100644 --- a/fairseq/data/noising.py +++ b/fairseq/data/noising.py @@ -296,6 +296,8 @@ def __init__( **kwargs, ) ) + self.sizes = src_dataset.sizes + def __getitem__(self, index): """ diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py index d710335b81..2cb7447ea9 100644 --- a/fairseq/data/round_robin_zip_datasets.py +++ b/fairseq/data/round_robin_zip_datasets.py @@ -141,7 +141,7 @@ def _deep_until_language_pair(dataset): f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" ) - # Since we are modifiying in place the _ordered_indices, + # Since we are modifying in place the _ordered_indices, # it's not possible anymore to return valid ignored indices. # Hopefully the extra debug information print above should be enough to debug. # Ideally we would receive ignore_invalid_inputs so that we could have diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index 1dd3d93d2b..07ebdd5f38 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -50,6 +50,9 @@ def __len__(self): def collater(self, samples, **extra_args): samples = self.dataset.collater(samples, **extra_args) + if 'net_input' not in samples: + return samples + if self.new_src_eos is not None: if self.dataset.left_pad_source: assert ( diff --git a/fairseq/models/roberta/__init__.py b/fairseq/models/roberta/__init__.py index cf16914fbc..4cd723ae96 100644 --- a/fairseq/models/roberta/__init__.py +++ b/fairseq/models/roberta/__init__.py @@ -5,6 +5,7 @@ from .hub_interface import * # noqa from .model import * # noqa +from .enc_dec import * # noqa from .model_camembert import * # noqa from .model_gottbert import * # noqa from .model_xlmr import * # noqa diff --git a/fairseq/models/roberta/enc_dec.py b/fairseq/models/roberta/enc_dec.py new file mode 100644 index 0000000000..e538dee0aa --- /dev/null +++ b/fairseq/models/roberta/enc_dec.py @@ -0,0 +1,192 @@ +import argparse +import logging + +import torch.nn as nn +import fairseq.checkpoint_utils +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import TransformerDecoder +from fairseq.models.roberta import model as roberta + +logger = logging.getLogger(__name__) + + +@register_model("roberta_enc_dec") +class RobertaEncDecModel(FairseqEncoderDecoderModel): + @staticmethod + def add_args(parser): + parser.add_argument( + "--pretrained-mlm-checkpoint", + default=None, + type=str, + metavar="PRETRAINED", + help="path to pretrained mlm checkpoint", + ) + parser.add_argument( + "--pretrained-decoder", action="store_true", help="reload decoder" + ) + parser.add_argument( + "--hack-layernorm-embedding", + action="store_true", + help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--share-all-embeddings", + action="store_true", + help="share encoder, decoder and output embeddings" + " (requires shared dictionary and embed dim)", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_enc_dec_architecture(args) + if args.pretrained_mlm_checkpoint: + arg_overrides = None + if args.hack_layernorm_embedding: + arg_overrides = {"layernorm_embedding": False} + loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides + ) + ([roberta_enc], _cfg, _task) = loaded + else: + # Do we need to edit untie_weights here ? + share_in_out = ( + args.share_decoder_input_output_embed or args.share_all_embeddings + ) + args.untie_weights_roberta = not share_in_out + if args.hack_layernorm_embedding: + args.layernorm_embedding = False + args.encoder_normalize_before = False + roberta_enc = roberta.RobertaModel.build_model(args, task) + + return cls.from_roberta(roberta_enc, args, task.source_dictionary) + + @staticmethod + def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary): + encoder = roberta_enc.encoder.sentence_encoder + vocab_size, embed_dim = encoder.embed_tokens.weight.shape + + if args.share_all_embeddings: + lm_head = roberta_enc.encoder.lm_head + assert encoder.embed_tokens.weight is lm_head.weight, ( + "Can't use --share-all-embeddings with a model " + "that was pretraiend with --untie-weights-roberta_enc" + ) + else: + lm_head = roberta.RobertaLMHead( + embed_dim, vocab_size, roberta_enc.args.activation_fn + ) + + dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad()) + if args.share_all_embeddings or args.share_decoder_input_output_embed: + # Note: I wasn't able to use Embedding _weight parameter to achive this sharing. + dec_embs.weight = lm_head.weight + + decoder = TransformerDecoder( + RobertaEncDecModel.read_args_from_roberta(roberta_enc.args), + dictionary, + dec_embs, + no_encoder_attn=False, + output_projection=lm_head, + ) + if getattr(args, "pretrained_decoder", False): + decoder_dict = encoder.state_dict() + + # TODO: hide setting "encoder_attn" layers behind a flag. + for k, w in list(decoder_dict.items()): + if ".self_attn" in k: + k_enc_attn = k.replace(".self_attn", ".encoder_attn") + decoder_dict[k_enc_attn] = w.detach().clone() + + for k, w in lm_head.state_dict().items(): + decoder_dict["output_projection." + k] = w + + missing_keys, unexpected_keys = decoder.load_state_dict( + decoder_dict, strict=False + ) + # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m] + assert not missing_keys and not unexpected_keys, ( + "Failed to load state dict. " + f"Missing keys: {missing_keys}. " + f"Unexpected keys: {unexpected_keys}." + ) + + if args.share_all_embeddings: + assert decoder.output_projection.weight is decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is decoder.embed_tokens.weight + elif args.share_decoder_input_output_embed: + assert decoder.output_projection.weight is decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight + else: + assert decoder.output_projection.weight is not decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight + + return RobertaEncDecModel(encoder, decoder) + + @staticmethod + def read_args_from_roberta(roberta_args: argparse.Namespace): + # TODO: this would become easier if encoder/decoder where using a similar + # TransformerConfig object + args = argparse.Namespace(**vars(roberta_args)) + attr_map = [ + ("encoder_attention_heads", "decoder_attention_heads"), + ("encoder_embed_dim", "decoder_embed_dim"), + ("encoder_embed_dim", "decoder_output_dim"), + ("encoder_normalize_before", "decoder_normalize_before"), + ("encoder_layers_to_keep", "decoder_layers_to_keep"), + ("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"), + ("encoder_layerdrop", "decoder_layerdrop"), + ("encoder_layers", "decoder_layers"), + ("encoder_learned_pos", "decoder_learned_pos"), + # should this be set from here ? + ("max_positions", "max_target_positions"), + ] + for k1, k2 in attr_map: + setattr(args, k2, getattr(roberta_args, k1)) + + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta + return args + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + super().upgrade_state_dict_named(state_dict, name) + old_keys = list(state_dict.keys()) + + # rename decoder -> encoder before upgrading children modules + for k in old_keys: + if k.startswith(prefix + "encoder.lm_head"): + state_dict.pop(k) + continue + new_k = k + new_k = new_k.replace(".sentence_encoder.", ".") + new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.") + if k == new_k: + continue + # print(k, "->", new_k) + state_dict[new_k] = state_dict.pop(k) + + +@register_model_architecture("roberta_enc_dec", "roberta_enc_dec") +def base_enc_dec_architecture(args): + args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False) + args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None) + args.pretrained_decoder = getattr(args, "pretrained_decoder", None) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + + roberta.base_architecture(args) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 5b9ba8105f..d9d0f324cf 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -204,7 +204,7 @@ def forward( features_only=False, return_all_hiddens=False, classification_head_name=None, - **kwargs + **kwargs, ): if classification_head_name is not None: features_only = True @@ -259,7 +259,7 @@ def from_pretrained( checkpoint_file="model.pt", data_name_or_path=".", bpe="gpt2", - **kwargs + **kwargs, ): from fairseq import hub_utils @@ -464,7 +464,7 @@ def forward( features_only=False, return_all_hiddens=False, masked_tokens=None, - **unused + **unused, ): """ Args: diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 8da5beb3aa..eff5ba7b8f 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -645,7 +645,14 @@ class TransformerDecoder(FairseqIncrementalDecoder): (default: False). """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) @@ -727,7 +734,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ) self.adaptive_softmax = None - self.output_projection = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(args, dictionary, embed_tokens) + + def build_output_projection(self, args, dictionary, embed_tokens): if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), @@ -789,7 +800,7 @@ def forward( prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for - encoder-side attention + encoder-side attention, should be of size T x B x C incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without @@ -802,6 +813,7 @@ def forward( - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ + x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, @@ -810,6 +822,7 @@ def forward( alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) + if not features_only: x = self.output_layer(x) return x, extra @@ -866,9 +879,19 @@ def extract_features_scriptable( - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ + bs, slen = prev_output_tokens.size() if alignment_layer is None: alignment_layer = self.num_layers - 1 + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None: + enc = encoder_out["encoder_out"][0] + padding_mask = encoder_out["encoder_padding_mask"][0] + assert ( + enc.size()[1] == bs + ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + # embed positions positions = None if self.embed_positions is not None: @@ -916,15 +939,8 @@ def extract_features_scriptable( x, layer_attn, _ = layer( x, - encoder_out["encoder_out"][0] - if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) - else None, - encoder_out["encoder_padding_mask"][0] - if ( - encoder_out is not None - and len(encoder_out["encoder_padding_mask"]) > 0 - ) - else None, + enc, + padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 6ab86245d2..d84c7e078d 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -147,8 +147,16 @@ def forward( is_tpu = query.device.type == "xla" tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, key_embed_dim = key.size() + if not torch.jit.is_scripting(): + assert (key_bsz, key_embed_dim) == (bsz, embed_dim) + assert value is not None + assert (src_len, bsz, embed_dim) == value.shape + if ( not self.onnx_trace @@ -262,6 +270,7 @@ def forward( else: assert k is not None k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None @@ -290,7 +299,7 @@ def forward( assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None - src_len = k.size(1) + assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. diff --git a/fairseq/options.py b/fairseq/options.py index b79443a177..7558264fce 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import argparse -from typing import Callable, List, Optional +from pathlib import Path +from typing import Callable, List, Optional, Union import torch from fairseq import utils @@ -361,3 +362,18 @@ def add_model_args(parser): help='model architecture') # fmt: on return group + + +def get_args( + data: Union[str, Path], + task: str = "translation", + arch: str = "transformer", + **overrides +): + parser = get_training_parser(task) + args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch]) + + for k, v in overrides.items(): + setattr(args, k, v) + + return args diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 2574ab13f0..ddef3d58d2 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -23,6 +23,7 @@ def __init__( beam_size=1, max_len_a=0, max_len_b=200, + max_len=0, min_len=1, normalize_scores=True, len_penalty=1.0, @@ -44,6 +45,8 @@ def __init__( beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) normalize_scores (bool, optional): normalize scores by the length @@ -79,6 +82,7 @@ def __init__( self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() self.normalize_scores = normalize_scores self.len_penalty = len_penalty @@ -166,7 +170,7 @@ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None) yield id, src, ref, hypos[i] @torch.no_grad() - def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: """Generate translations. Match the api of other fairseq generators. Args: @@ -232,8 +236,7 @@ def _generate( else: max_len = min( int(self.max_len_a * src_len + self.max_len_b), - # exclude the EOS marker - self.model.max_decoder_positions() - 1, + self.max_len - 1, ) assert ( self.min_len <= max_len @@ -275,9 +278,8 @@ def _generate( [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step - finished = [ - False for i in range(bsz) - ] # a boolean array indicating if the sentence at the index is finished or not + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of sentences remaining # number of candidate hypos per step diff --git a/fairseq/tasks/online_backtranslation.py b/fairseq/tasks/online_backtranslation.py new file mode 100644 index 0000000000..2545624cd4 --- /dev/null +++ b/fairseq/tasks/online_backtranslation.py @@ -0,0 +1,677 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import json +import logging +import math +import os +from argparse import Namespace +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Dict, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import fairseq +from fairseq import metrics, options, utils +from fairseq.data import ( + FairseqDataset, + LanguagePairDataset, + NoisingDataset, + PrependTokenDataset, + RoundRobinZipDatasets, + TransformEosLangPairDataset, + data_utils, + encoders, +) +from fairseq.sequence_generator import SequenceGenerator +from fairseq.tasks import register_task +from fairseq.tasks.translation import TranslationTask, load_langpair_dataset + +logger = logging.getLogger(__name__) + + +class PiecewiseLinearFn: + """Piecewise linear function. Can be configured with a string.""" + + def __init__(self, pieces: Sequence[Tuple[int, float]]): + assert pieces == sorted( + pieces + ), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}" + + self.pieces = pieces + + def __call__(self, x: int) -> float: + for i, (x_a, y_a) in enumerate(self.pieces[:-1]): + x_b, y_b = self.pieces[i + 1] + if x_a <= x <= x_b: + return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a) + + return self.pieces[-1][1] + + @staticmethod + def from_string(configuration: str) -> "PiecewiseLinearFn": + """ + Parse the configuration of lambda coefficient (for scheduling). + x = "3" # lambda will be a constant equal to x + x = "0:1,1000:0" # lambda will start from 1 and linearly decrease + # to 0 during the first 1000 iterations + x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 + # iterations, then will linearly increase to 1 until iteration 2000 + """ + if isinstance(configuration, float): + return PiecewiseLinearFn([(0, configuration)]) + + try: + parts = configuration.split(",") + if len(parts) == 1: + v = float(configuration) + return PiecewiseLinearFn([(0, v)]) + + split = [s.split(":") for s in parts] + pieces = [(int(t), float(v)) for t, v in split] + return PiecewiseLinearFn(pieces) + except Exception: + raise ValueError( + f"Invalid PiecewiseLinearFn configuration: {configuration!r}" + ) + + @staticmethod + def one() -> "PiecewiseLinearFn": + return PiecewiseLinearFn([(0, 1.0)]) + + +@register_task("online_backtranslation") +class OnlineBackTranslationTask(TranslationTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + # Generic translation args + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner; \ + however, valid and test data are always in the first directory to \ + avoid the need for repeating them in all directories') + parser.add_argument('--mono-langs', metavar='MONO_LANGS', + help='monolingual languages for training') + parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS', + help='language pairs for validation') + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') + parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', + help='pad the source on the left') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left') + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + parser.add_argument('--truncate-source', action='store_true', default=False, + help='truncate source to max-source-positions') + parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', + help='if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations') + + # Denoising args + parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', + help='maximum word shuffle distance for denoising autoencoding data generation') + parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', + help='word dropout probability for denoising autoencoding data generation') + parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', + help='word blanking probability for denoising autoencoding data generation') + + # Backtranslation args + parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N', + help='back-translation weight') + parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N', + help='denoising auto-encoder weight') + + # Evaluation args + parser.add_argument('--generate-one-by-one', action='store_true', + help='generate one sentence at a time for backtranslation') + + parser.add_argument('--eval-bleu', action='store_true', + help='evaluation with BLEU scores') + parser.add_argument('--eval-bleu-detok', type=str, default="space", + help='detokenize before computing BLEU (e.g., "moses"); ' + 'required if using --eval-bleu; use "space" to ' + 'disable detokenization; see fairseq.data.encoders ' + 'for other options') + parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', + help='args for building the tokenizer, if needed') + parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, + help='compute tokenized BLEU instead of sacrebleu') + parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, + help='remove BPE before computing BLEU') + parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', + help='generation args for BLUE scoring, ' + 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') + parser.add_argument('--eval-bleu-print-samples', action='store_true', + help='print sample generations during validation') + # fmt: on + + def __init__(self, args, common_dict, mono_langs, valid_lang_pairs): + super().__init__(args, common_dict, common_dict) + self.common_dict = common_dict + self.mono_langs = mono_langs + self.valid_lang_pairs = valid_lang_pairs + + self.SHOW_SAMPLES_INTERVAL = 1000 + # Start by showing samples + self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL + self.SHOW_SAMPLES_NUMBER = 5 + self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt) + self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae) + + self.args = args + self.data = utils.split_paths(self.args.data) + if len(self.data) == 1: + shards = list(Path(self.data[0]).glob("shard*")) + if len(shards) > 0: + # keep this as strings, since it can also be a manifold path + old_data = self.data + self.data = [str(shard) for shard in shards] + logging.warning(f"Expanded data directory {old_data} to {self.data}") + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + paths = utils.split_paths(args.data) + assert len(paths) > 0 + assert args.mono_langs is not None + + mono_langs = args.mono_langs.split(",") + valid_lang_pairs = args.valid_lang_pairs.split(",") + + # load dictionary + dict_path = os.path.join(paths[0], "dict.txt") + common_dict = cls.load_dictionary(dict_path) + + return cls(args, common_dict, mono_langs, valid_lang_pairs) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split == "train": + data_path = self.data[(epoch - 1) % len(self.data)] + dataset = self.load_train_dataset(data_path) + else: + # valid/test should always be the same. + dataset = self.load_translation_dataset(split, self.data[0]) + + self.datasets[split] = dataset + return dataset + + def load_train_dataset(self, data_path: str) -> FairseqDataset: + """The training dataset is made of backtranslation dataset and denoising dataset.""" + data = [] + for lang in self.mono_langs: + train_path = os.path.join(data_path, lang, "train") + # TODO: could we do the BT using denoise sample ? + # this would half the data loading work + data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang))) + data.append( + (f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang)) + ) + + return RoundRobinZipDatasets(OrderedDict(data)) + + def _langpair_dataset( + self, src: FairseqDataset, tgt: FairseqDataset + ) -> LanguagePairDataset: + return LanguagePairDataset( + src, + src.sizes, + self.dictionary, + tgt=tgt, + tgt_sizes=tgt.sizes, + tgt_dict=self.dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + # TODO: should we shuffle ? we are already sorting batch by sizes so ? + # shuffle=True, + ) + + def _prepend_lang_bos_to_target( + self, dataset: LanguagePairDataset, lang: str + ) -> LanguagePairDataset: + bos = _lang_token_index(self.dictionary, lang) + return TransformEosLangPairDataset( + dataset, + src_eos=self.dictionary.eos(), + new_src_eos=self.dictionary.eos(), + tgt_bos=self.dictionary.eos(), + new_tgt_bos=bos, + ) + + def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """The BT dataset is generated with (tgt, tgt) pairs. + The actual translation to a (generated_src, tgt) pair + is done on the fly during training. + """ + mono_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + assert mono_dataset is not None, f"No dataset found for {lang}" + + mono_dataset_src = PrependTokenDataset( + mono_dataset, _lang_token_index(self.dictionary, lang) + ) + + mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) + logger.info( + f"mono_lang = {lang} " + f"lang token index = {_lang_token_index(self.dictionary, lang)} " + f"lang token = {_lang_token(lang)}" + ) + + mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang) + return mono_dataset_bt + + def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """Classic denoising dataset""" + dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + noisy_dataset = NoisingDataset( + dataset, + self.dictionary, + seed=1, + max_word_shuffle_distance=self.args.max_word_shuffle_distance, + word_dropout_prob=self.args.word_dropout_prob, + word_blanking_prob=self.args.word_blanking_prob, + ) + noisy_dataset = PrependTokenDataset( + noisy_dataset, _lang_token_index(self.dictionary, lang) + ) + + clean_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) + denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang) + return denoising_dataset + + def load_translation_dataset( + self, split: str, data_path: str, combine: bool = False + ): + # only judging with one language pair for the moment, + # since ConcatDataset doesn't work as expected + assert len(self.valid_lang_pairs) == 1, "For now..." + valid_lang_pair = self.valid_lang_pairs[0] + src, tgt = valid_lang_pair.split("-") + + # use the same function than TranslationTask + src_tgt_dt = load_langpair_dataset( + data_path, + split, + src, + self.common_dict, + tgt, + self.common_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, + truncate_source=self.args.truncate_source, + num_buckets=self.args.num_batch_buckets, + shuffle=(split != "test"), + prepend_bos_src=_lang_token_index(self.dictionary, src), + ) + + src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt) + src_tgt_eos_dt.args = self.args + return src_tgt_eos_dt + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + raise NotImplementedError + + def build_model(self, args): + # torch.autograd.set_detect_anomaly(True) + model = super().build_model(args) + + add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs) + + self.sequence_generators = {} + for mono_lang in self.mono_langs: + self.sequence_generators[mono_lang] = SequenceGenerator( + [model], + tgt_dict=self.dictionary, + beam_size=1, + max_len_a=1.3, + max_len_b=5, + min_len=5, + # keep 1 to be able to prepend bos + max_len=model.max_decoder_positions() - 1, + ) + + if getattr(args, "eval_bleu", False): + assert getattr(args, "eval_bleu_detok", None) is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") + self.tokenizer = encoders.build_tokenizer( + Namespace( + tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args + ) + ) + + gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") + self.bleu_sequence_generator = self.build_generator( + [model], Namespace(**gen_args) + ) + + return model + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.common_dict + + def display_samples_once_in_a_while(self, smp, mono_lang, other_lang): + self._show_samples_ctr += 1 + if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL: + return + self._show_samples_ctr = 0 + + ln = smp["net_input"]["src_tokens"].shape[0] + + logger.info( + f"(r:{self.args.distributed_rank}) : " + f"{other_lang} ---> {mono_lang} " + f"({other_lang} was generated by back-translation.) {ln} samples" + ) + + for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)): + src_tokens = smp["net_input"]["src_tokens"][i] + tgt_tokens = smp["target"][i] + + src_str = self.dictionary.string(src_tokens, "sentencepiece") + tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece") + logger.info( + f"\n{i}\t\t[{other_lang} generated] {src_str}\n" + f"\t\t[{mono_lang} original ] {tgt_str}\n" + f"\t\t[ src tokens] {src_tokens}\n" + ) + + def backtranslate_sample(self, smp, orig_lang, other_lang) -> None: + """ + * WARNING: smp is modified in place. + * At the start of this function, `smp` has the same input and target: + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (from data) __en__ hello world | __en__ hello world | + |--------------------------------------------------------| + + * We call generator.generate(smp, bos_token = token("ro")), + and copy the result as input + * At the end, `smp` has the translation to other language. + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (generated) __ro__ salut lume | __en__ hello world | + |--------------------------------------------------------| + + """ + bos_token = _lang_token_index(self.dictionary, other_lang) + generated = self.sequence_generators[orig_lang].generate( + models=[], sample=smp, bos_token=bos_token + ) + + max_lngth = max([gn[0]["tokens"].size(0) for gn in generated]) + net_input = smp["net_input"] + n_src_tokens = torch.empty( + size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype + ) + n_src_lengths = torch.empty( + len(generated), dtype=net_input["src_lengths"].dtype + ) + + for i, gn in enumerate(generated): + tokens = gn[0]["tokens"] + tokens_size = tokens.size(0) + padding_needed = max_lngth - tokens_size + tokens = torch.cat([tokens.new([bos_token]), tokens]) + tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad()) + n_src_tokens[i] = tokens + n_src_lengths[i] = tokens_size + 1 + + device = net_input["src_tokens"].device + # This seems to be important + del net_input["src_tokens"] + del net_input["src_lengths"] + net_input["src_tokens"] = n_src_tokens.to(device) + net_input["src_lengths"] = n_src_lengths.to(device) + + def generate(self, smp, model): + model.eval() + orig_lang = ( + self.dictionary[smp["net_input"]["src_tokens"][0][0]] + .replace(" ", "") + .replace("_", "") + ) + bos_token = smp["net_input"]["prev_output_tokens"][0][0] + with torch.no_grad(): + generated = self.sequence_generators[orig_lang].generate( + models=[model], sample=smp, bos_token=bos_token + ) + return generated + + def get_other_lang(self, lang): + # TODO: allow more complex mapping + if lang != self.mono_langs[0]: + return self.mono_langs[0] + if len(self.mono_langs) == 2: + return self.mono_langs[1] + return self.mono_langs[np.random.randint(1, len(self.mono_langs))] + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + + model.train() + model.set_num_updates(update_num) + + agg_loss, agg_sample_size = 0.0, 0.0 + agg_logging_output: Dict[str, float] = defaultdict(float) + + dataset_keys = self.datasets["train"].datasets.keys() + + weights = { + "BT": self.lambda_bt(update_num), + "DENOISE": self.lambda_dae(update_num), + } + log_keys = {"BT": "bt_", "DENOISE": "dae_"} + + for dataset_key in dataset_keys: + smp = sample[dataset_key] + mono_lang, task_subtype = dataset_key.split("-") + if weights[task_subtype] == 0: + continue + + if task_subtype == "BT": + with torch.autograd.profiler.record_function("backtranslation"): + model.eval() + # TODO: Could we translate to several language at once ? + # this would allow to share encoder_out and maximize GPU usage. + other_lang = self.get_other_lang(mono_lang) + self.backtranslate_sample(smp, mono_lang, other_lang) + self.display_samples_once_in_a_while(smp, mono_lang, other_lang) + model.train() + + # Like in FairseqTask.train_step + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, smp) + loss *= weights[task_subtype] + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + + agg_loss += loss.item() + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[log_keys[task_subtype] + k] += logging_output[k] + agg_logging_output[k] += logging_output[k] + + return agg_loss, agg_sample_size, agg_logging_output + + def get_bos_token_from_sample(self, sample): + net_input = sample["net_input"] + source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item() + source_lang_token = self.dictionary[source_lang_token_id].replace("_", "") + target_lang_token_id = _lang_token_index( + self.dictionary, self.get_other_lang(source_lang_token) + ) + + return target_lang_token_id + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs) + if bt_sample_size: + bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs) + bt_loss_sum *= 1 / bt_sample_size / math.log(2) + metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3) + + bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs) + bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs) + bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2) + metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3) + metrics.log_derived( + "bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg) + ) + + dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs) + if dae_sample_size: + dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs) + dae_loss_sum *= 1 / dae_sample_size / math.log(2) + metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3) + + dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs) + dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs) + dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2) + metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3) + metrics.log_derived( + "dae_ppl", + lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg), + ) + + +@torch.no_grad() +def extend_embedding( + emb: nn.Module, new_vocab_size: int, copy_from_token_id: int +) -> None: + old_emb_data = emb.weight.data + (old_vocab_size, dim) = old_emb_data.shape + assert new_vocab_size >= old_vocab_size + + if new_vocab_size > old_vocab_size: + emb.weight.data = torch.zeros((new_vocab_size, dim)) + emb.weight.data[:old_vocab_size, :] = old_emb_data + # initialize new embeddings + emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id] + if hasattr(emb, "num_embeddings"): + emb.num_embeddings = new_vocab_size + if hasattr(emb, "out_features"): + emb.out_features = new_vocab_size + + if getattr(emb, "bias", None) is None: + return + + # Fix the bias. + # Bias shape can be different from the previous vocab size + # if the weight matrix was shared and alread extended but not the bias. + (old_vocab_size,) = emb.bias.shape + assert new_vocab_size >= old_vocab_size + if new_vocab_size > old_vocab_size: + old_bias = emb.bias.data + new_bias = torch.zeros( + (new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device + ) + new_bias[:old_vocab_size] = old_bias + emb.bias.data = new_bias + + +def add_secial_tokens_to_dict_and_model( + dictionary: "fairseq.data.Dictionary", + model: nn.Module, + mono_langs: Sequence[str], +) -> None: + embs = model.encoder.embed_tokens + vocab_size, embedding_dim = embs.weight.shape + + # The model may or may not have a '<mask>' embedding yet + assert ( + len(dictionary) <= vocab_size <= len(dictionary) + 1 + ), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})" + # TODO: we should reuse the pretrained model dict which already has <mask> + dictionary.add_symbol("<mask>") + + for lang in mono_langs: + lang_token = _lang_token(lang) + dictionary.add_symbol(lang_token) + logger.info( + f"dictionary: {len(dictionary)} -> {vocab_size} tokens " + f"after adding {len(mono_langs)} lang tokens." + ) + + if len(dictionary) <= vocab_size: + return + + extend_embedding(embs, len(dictionary), dictionary.bos()) + dec_embs = model.decoder.embed_tokens + extend_embedding(dec_embs, len(dictionary), dictionary.bos()) + lm_head = model.decoder.output_projection + extend_embedding(lm_head, len(dictionary), dictionary.bos()) + assert lm_head.weight.shape == (len(dictionary), embedding_dim) + + +def _lang_token(lang: str) -> str: + return f"__{lang}__" + + +def _lang_token_index(dictionary, lang: str) -> int: + return dictionary.index(_lang_token(lang)) + + +@contextlib.contextmanager +def assert_weights_have_changed(model: nn.Module): + def checksum(model: nn.Module) -> float: + return sum(p.sum().item() for p in model.parameters()) + + initial_checksum = checksum(model) + yield model + final_checksum = checksum(model) + logger.info( + f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}" + ) + assert initial_checksum != final_checksum, "Model hasn't changed !" diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 331f685495..ea80fa2e73 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -57,6 +57,7 @@ def load_langpair_dataset( num_buckets=0, shuffle=True, pad_to_multiple=1, + prepend_bos_src=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) @@ -128,6 +129,9 @@ def split_exists(split, src, tgt, lang, data_path): src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + elif prepend_bos_src is not None: + logger.info(f"prepending src bos: {prepend_bos_src}") + src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 8710b7fe7d..0fd7a5b29f 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -38,7 +38,7 @@ def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off TranslationTask.add_args(parser) - parser.add_argument('--langs', required=True, metavar='LANG', + parser.add_argument('--langs', type=str, metavar='LANG', help='comma-separated list of monolingual language, ' 'for example, "en,de,fr". These should match the ' 'langs from pretraining (and be in the same order). ' diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 1cca64d988..f736e67d0d 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -80,9 +80,6 @@ def main(cfg: FairseqConfig) -> None: # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) - # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in cfg.dataset.valid_subset.split(","): - task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" @@ -111,6 +108,11 @@ def main(cfg: FairseqConfig) -> None: ) ) + # Load valid dataset (we load training data below, based on the latest checkpoint) + # We load the valid dataset AFTER building the model + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( diff --git a/tests/test_online_backtranslation.py b/tests/test_online_backtranslation.py new file mode 100644 index 0000000000..0ae7e773da --- /dev/null +++ b/tests/test_online_backtranslation.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, Sequence + +import fairseq.data.indexed_dataset as indexed_dataset +import fairseq.options +import fairseq.tasks.online_backtranslation as obt +import torch +from tests import utils + + +def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]: + batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size) + sample = { + "net_input": { + "src_tokens": batch, + "prev_output_tokens": batch, + "src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long), + }, + "target": batch[:, 1:], + } + return sample + + +def mk_dataset(num_samples: int, max_len: int, output: Path): + output.parent.mkdir(exist_ok=True) + idx = indexed_dataset.IndexedDatasetBuilder(str(output)) + data = torch.randint(5, 100, (num_samples, max_len)) + lengths = torch.randint(3, max_len, (num_samples,)) + for d, l in zip(data, lengths): + d[0] = 0 + idx.add_item(d[:l]) + idx.finalize(output.with_suffix(".idx")) + assert output.exists() + assert output.with_suffix(".idx").exists() + + +class OnlineBacktranslationTest(unittest.TestCase): + + tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest")) + + @classmethod + def obt_task( + cls, languages: Sequence[str], data: Path = None, language_mapping: str = None + ): + dict_path = cls.tmp_dir / "dict.txt" + if not dict_path.exists(): + dictionary = utils.dummy_dictionary(100) + dictionary.save(str(dict_path)) + + if data is not None: + (data / "dict.txt").write_text(dict_path.read_text()) + else: + data = cls.tmp_dir + assert len(languages) >= 2 + + kwargs = { + "arch": "transformer", + # --max-sentences=1 for better predictability of batches + "max_sentences": 1, + # Use characteristics dimensions + "encoder_layers": 3, + "encoder_embed_dim": 12, + "encoder_ffn_embed_dim": 14, + "encoder_attention_heads": 4, + "decoder_layers": 3, + "decoder_embed_dim": 12, + "decoder_output_dim": 12, + "decoder_ffn_embed_dim": 14, + "decoder_attention_heads": 4, + # Disable dropout so we have comparable tests. + "dropout": 0, + "attention_dropout": 0, + "activation_dropout": 0, + "encoder_layerdrop": 0, + } + + args = fairseq.options.get_args( + data, + task="online_backtranslation", + mono_langs=",".join(languages), + valid_lang_pairs=f"{languages[0]}-{languages[1]}", + tokens_per_sample=256, + language_mapping=language_mapping, + **kwargs, + ) + task = obt.OnlineBackTranslationTask.setup_task(args) + # we need to build the model to have the correct dictionary + model = task.build_model(task.args) + return task, model + + def tmp_path(self, test_case: str) -> Path: + return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir)) + + def test_lang_tokens(self): + task, model = self.obt_task(["en", "ro", "zh"]) + assert obt._lang_token("en") in task.dictionary + assert obt._lang_token("ro") in task.dictionary + assert obt._lang_token("zh") in task.dictionary + + en_bos = obt._lang_token_index(task.common_dict, "en") + assert "en" == task.common_dict[en_bos].strip("_") + zh_bos = obt._lang_token_index(task.common_dict, "zh") + assert "zh" == task.common_dict[zh_bos].strip("_") + zh_sample = mk_sample([zh_bos, 16, 14, 12, 10]) + + # we expect to receive the bos token for translation + assert task.get_bos_token_from_sample(zh_sample) == en_bos + + def test_backtranslate_sample(self): + task, model = self.obt_task(["en", "ro", "zh"]) + + en_bos = obt._lang_token_index(task.common_dict, "en") + zh_bos = obt._lang_token_index(task.common_dict, "zh") + sample = mk_sample([zh_bos, 16, 14, 12, 10]) + + task.backtranslate_sample(sample, "zh", "en") + target_zh = list(sample["target"][0]) + assert target_zh == [16, 14, 12, 10] # original zh sentence + generated_en = sample["net_input"]["src_tokens"][0] + assert generated_en[0] == en_bos + + def test_train_dataset(self): + data = self.tmp_path("test_train_dataset") + mk_dataset(20, 10, data / "en" / "train.bin") + mk_dataset(10, 10, data / "zh" / "train.bin") + task, model = self.obt_task(["en", "zh"], data) + task.load_dataset("train") + + en_bos = obt._lang_token_index(task.common_dict, "en") + zh_bos = obt._lang_token_index(task.common_dict, "zh") + + train = task.datasets["train"] + train.ordered_indices() + train.prefetch([0, 19]) + sample_0 = train[0] + sample_19 = train[19] + self.assertEqual( + set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"} + ) + for sample in (sample_0, sample_19): + self.assertEqual(sample["en-BT"]["source"][0], en_bos) + # bt target isn't ready to look at. + self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos) + # TODO What could we check on the target side ? + + for i in range(10): + # Zh dataset is shorter, and is wrapped around En dataset. + train.prefetch([i, i + 10]) + self.assertEqual( + list(train[i]["zh-DENOISE"]["source"]), + list(train[i + 10]["zh-DENOISE"]["source"]), + ) + self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos) + + # Sorted by increasing len + self.assertLess( + len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"]) + ) + + def test_valid_dataset(self): + data = self.tmp_path("test_valid_dataset") + mk_dataset(10, 21, data / "valid.en-zh.en.bin") + mk_dataset(10, 21, data / "valid.en-zh.zh.bin") + + task, model = self.obt_task(["en", "zh"], data) + valid = task.load_dataset("valid") + en_bos = obt._lang_token_index(task.common_dict, "en") + + assert valid is not None + valid.prefetch(range(10)) + sample_0 = valid[0] + sample_9 = valid[9] + self.assertEqual(sample_0["id"], 0) + self.assertEqual(sample_9["id"], 9) + self.assertEqual(sample_0["source"][0], en_bos) + self.assertEqual(sample_9["source"][0], en_bos) + # TODO: could we test the target side ? + + def assertFnMatch(self, fn, values): + for x, y in values.items(): + fn_x = fn(x) + self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}") + + def test_piecewise_linear_fn(self): + self.assertFnMatch( + obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1} + ) + self.assertFnMatch( + obt.PiecewiseLinearFn.from_string("0:1,1000:0"), + {0: 1, 500: 0.5, 1000: 0, 2000: 0}, + ) + self.assertFnMatch( + obt.PiecewiseLinearFn.from_string("0:0,1000:1"), + {0: 0, 500: 0.5, 1000: 1, 2000: 1}, + ) + self.assertFnMatch( + obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"), + {0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0}, + ) diff --git a/tests/test_roberta.py b/tests/test_roberta.py new file mode 100644 index 0000000000..b0b9cfd31e --- /dev/null +++ b/tests/test_roberta.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import unittest +from typing import Any, Dict, Sequence + +import fairseq +import fairseq.options +import fairseq.tasks +import torch +from tests.utils import dummy_dictionary + +VOCAB_SIZE = 100 + + +@fairseq.tasks.register_task("fake_task") +class FakeTask(fairseq.tasks.LegacyFairseqTask): + def __init__(self, args): + super().__init__(args) + self.dictionary = dummy_dictionary(VOCAB_SIZE - 4) + assert len(self.dictionary) == VOCAB_SIZE + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +@functools.lru_cache() +def get_toy_model( + device: str, + architecture: str = "roberta_enc_dec", + **extra_args: Any, +): + assert device in ("gpu", "cpu") + kwargs = { + "arch": architecture, + # Use characteristics dimensions + "encoder_layers": 3, + "encoder_embed_dim": 12, + "encoder_ffn_embed_dim": 14, + "encoder_attention_heads": 4, + "decoder_layers": 3, + "decoder_embed_dim": 12, + "decoder_ffn_embed_dim": 14, + "decoder_attention_heads": 4, + # Disable dropout so we have comparable tests. + "dropout": 0, + "attention_dropout": 0, + "activation_dropout": 0, + "encoder_layerdrop": 0, + # required args + "tokens_per_sample": 256, + "data": "/tmp/test_roberta", + } + kwargs.update(extra_args) + fake_task = FakeTask(kwargs) + args = fairseq.options.get_args( + task="online_backtranslation", + mono_langs="en,ro", + valid_lang_pairs="en-ro", + **kwargs, + ) + torch.manual_seed(0) + model = fake_task.build_model(args) + if device == "gpu": + model.cuda() + return fake_task, model + + +def mk_sample( + lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2 +) -> Dict[str, Any]: + assert device in ("gpu", "cpu") + if not tok: + if lang == "en": + tok = [10, 11, 12, 13, 14, 15, 2] + else: + tok = [20, 21, 22, 23, 24, 25, 26, 27, 2] + + batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size) + if device == "gpu": + batch = batch.cuda() + sample = { + "net_input": { + "src_tokens": batch, + "prev_output_tokens": batch, + "src_lengths": torch.tensor( + [len(tok)] * batch_size, dtype=torch.long, device=batch.device + ), + }, + "target": batch[:, 1:], + } + return sample + + +def cpu_gpu(fn): + def helper(self): + fn(self, "cpu") + if torch.cuda.is_available(): + fn(self, "gpu") + + return helper + + +def architectures(fn): + def helper(self): + for arch in ["roberta_enc_dec", "transformer"]: + fn(self, arch) + + return helper + + +class RobertaTest(unittest.TestCase): + def assertTensorEqual(self, t1, t2, delta: float = 1e-6): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + if delta == 0.0: + self.assertEqual(t1.ne(t2).long().sum(), 0) + else: + self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0) + + def assertSharing(self, model, link_groups: Sequence[Sequence[str]]): + ids = {} + for group in link_groups: + group_ids = {name: id(params(model, name)) for name in group} + shared_id = group_ids[group[0]] + self.assertEqual(group_ids, {name: shared_id for name in group}) + self.assertNotIn(shared_id, ids) + ids[shared_id] = group + + def test_roberta_shared_params(self): + _, roberta = get_toy_model("cpu", architecture="roberta") + self.assertSharing( + roberta, + [ + [ + "encoder.sentence_encoder.embed_tokens.weight", + "encoder.lm_head.weight", + ] + ], + ) + + _, roberta = get_toy_model( + "cpu", architecture="roberta", untie_weights_roberta=True + ) + self.assertSharing( + roberta, + [ + ["encoder.sentence_encoder.embed_tokens.weight"], + ["encoder.lm_head.weight"], + ], + ) + + def test_roberta_enc_dec_shared_params(self): + # 3 distinct embeddings + _, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec") + self.assertSharing( + enc_dec, + [ + ["encoder.embed_tokens.weight"], + ["decoder.embed_tokens.weight"], + ["decoder.output_projection.weight"], + ], + ) + + # 2 distinct embeddings, one for encoder, one for decoder + _, enc_dec = get_toy_model( + "cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True + ) + self.assertSharing( + enc_dec, + [ + ["encoder.embed_tokens.weight"], + [ + "decoder.embed_tokens.weight", + "decoder.output_projection.weight", + ], + ], + ) + + # shared embeddings + _, enc_dec = get_toy_model( + "cpu", architecture="roberta_enc_dec", share_all_embeddings=True + ) + self.assertSharing( + enc_dec, + [ + [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "decoder.output_projection.weight", + ] + ], + ) + + def test_roberta_max_positions_is_correctly_set(self): + device = "cpu" + task, model = get_toy_model(device) + max_pos = model.max_decoder_positions() + self.assertEqual(max_pos, 256) + self.assertEqual(max_pos, model.decoder.max_positions()) + self.assertEqual(max_pos, model.encoder.max_positions()) + self.assertEqual(max_pos, model.encoder.embed_positions.max_positions) + + sentence = [31 for _ in range(max_pos)] + sample = mk_sample("en", device, sentence, batch_size=1) + self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos]) + self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos) + x, _ = model.forward(**sample["net_input"]) + self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE)) + + @cpu_gpu + def test_roberta_forward_backward(self, device: str): + _, model = get_toy_model(device) + sample = mk_sample("en", device) + en_tokens = sample["net_input"]["src_tokens"] + (bs, l) = en_tokens.shape + # Forward + logits, _ = model(**sample["net_input"]) + self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE)) + + # Backward + loss = logits.sum() + loss.backward() + + @cpu_gpu + def test_roberta_forward_backward_bs1(self, device: str): + _, model = get_toy_model(device) + sample = mk_sample("en", device, batch_size=1) + o, _ = model.forward(**sample["net_input"]) + loss = o.sum() + sample2 = mk_sample("ro", device, batch_size=1) + o, _ = model.forward(**sample2["net_input"]) + loss += o.sum() + loss.backward() + + @cpu_gpu + def test_roberta_batching(self, device: str): + """ + Checks that the batch of size 2 give twice the same results than the batch of size 1. + """ + _, model = get_toy_model(device) + sample = mk_sample("en", device, batch_size=1) + slen = sample["net_input"]["src_lengths"][0] + sample2 = mk_sample("en", device, batch_size=2) + with torch.no_grad(): + z = model.encoder.forward( + sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"] + ) + z = z["encoder_out"][-1] + logits, _ = model.forward(**sample["net_input"]) + + z2 = model.encoder.forward( + sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"] + ) + z2 = z2["encoder_out"][-1] + logits2, _ = model.forward(**sample2["net_input"]) + + self.assertEqual(z.shape, (slen, 1, 12)) + self.assertEqual(z2.shape, (slen, 2, 12)) + self.assertTensorEqual(logits2[0], logits2[1]) + self.assertTensorEqual(logits[0], logits2[0]) + + @cpu_gpu + def test_roberta_incremental_decoder(self, device: str): + """ + Checks that incremental decoding yields the same result than non incremental one. + """ + task, model = get_toy_model(device) + + en_sample = mk_sample("en", device) + en_tokens = en_sample["net_input"]["src_tokens"] + ro_sample = mk_sample("ro", device) + ro_tokens = ro_sample["net_input"]["src_tokens"] + + en_enc = model.encoder.forward( + en_tokens, src_lengths=en_sample["net_input"]["src_lengths"] + ) + (bs, tgt_len) = ro_tokens.shape + + # Decode without incremental state + ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc) + self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE)) + self.assertTensorEqual(ro_dec[0], ro_dec[1]) + + # Decode with incremental state + inc_state = {} + ro_dec_inc = [] + for l in range(tgt_len): + ro, _ = model.decoder.forward( + ro_tokens[:, : l + 1], encoder_out=en_enc, incremental_state=inc_state + ) + self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE)) + ro_dec_inc.append(ro) + + for l in range(tgt_len): + # Intra-batch + self.assertTensorEqual(ro_dec_inc[l][0], ro_dec_inc[l][1]) + # Incremental vs non-incremental + self.assertTensorEqual(ro_dec_inc[l][:, 0], ro_dec[:, l]) + + +def params(model, name): + if "." not in name: + return getattr(model, name) + + prefix, name = name.split(".", 1) + return params(getattr(model, prefix), name) From 229de0087f599e31986c85c8106456f0adf44812 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Tue, 30 Mar 2021 10:09:30 -0700 Subject: [PATCH 539/707] Fix an issue for waitk when left_pad_source is set (#1752) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1752 Reviewed By: jmp84 Differential Revision: D27370799 Pulled By: xutaima fbshipit-source-id: 1ba1ef529af5dbf6608d3029b7545392458bd827 --- .../modules/monotonic_multihead_attention.py | 79 ++++++++++++++++--- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 2e3ce8742f..b487f14a98 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -9,15 +9,12 @@ from torch import Tensor import torch.nn as nn -import torch.nn.functional as F from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, lengths_to_mask, ) -from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules import MultiheadAttention -from fairseq.utils import convert_padding_direction from . import register_monotonic_attention from typing import Dict, Optional @@ -262,7 +259,6 @@ def expected_alignment_infer( finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - monotonic_cache["head_step"] = new_monotonic_step # Whether a head is looking for new input monotonic_cache["head_read"] = ( @@ -409,9 +405,6 @@ def add_args(parser): parser.add_argument('--attention-eps', type=float, default=1e-6, help='Epsilon when calculating expected attention') - def p_choose(self, *args): - raise NotImplementedError - def attn_energy( self, q_proj: Optional[Tensor], k_proj: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None ): @@ -934,11 +927,73 @@ def p_choose( else: tgt_len, bsz, _ = query.size() - src_len, bsz, _ = key.size() + max_src_len, bsz, _ = key.size() + + if max_src_len < self.waitk_lagging: + return query.new_zeros( + bsz * self.num_heads, tgt_len, max_src_len + ) + + # Assuming the p_choose looks like this for wait k=3 + # src_len = 6, tgt_len = 5 + # [0, 0, 1, 0, 0, 0, 0] + # [0, 0, 0, 1, 0, 0, 0] + # [0, 0, 0, 0, 1, 0, 0] + # [0, 0, 0, 0, 0, 1, 0] + # [0, 0, 0, 0, 0, 0, 1] + # linearize the p_choose matrix: + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...] + # The indices of linearized matrix that equals 1 is + # 2 + 6 * 0 + # 3 + 6 * 1 + # ... + # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 + # n from 0 to tgt_len - 1 + # + # First, generate the indices (activate_indices_offset: bsz, tgt_len) + # Second, scatter a zeros tensor (bsz, tgt_len * src_len) + # with activate_indices_offset + # Third, resize the tensor to (bsz, tgt_len, src_len) + + activate_indices_offset = ( + ( + torch.arange(tgt_len) * (max_src_len + 1) + + self.waitk_lagging - 1 + ) + .unsqueeze(0) + .expand(bsz, tgt_len) + .to(query) + .long() + ) + + if key_padding_mask is not None: + if key_padding_mask[:, 0].any(): + # Left padding + activate_indices_offset += ( + key_padding_mask.sum(dim=1, keepdim=True) + ) + + # Need to clamp the indices that are too large + activate_indices_offset = ( + activate_indices_offset + .clamp( + 0, + min( + [ + tgt_len, + max_src_len - self.waitk_lagging + 1 + ] + ) * max_src_len - 1 + ) + ) + + p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) - p_choose = torch.ones(bsz, tgt_len, src_len).to(query) - p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) - p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) + p_choose = p_choose.scatter( + 1, + activate_indices_offset, + 1.0 + ).view(bsz, tgt_len, max_src_len) if incremental_state is not None: p_choose = p_choose[:, -1:] @@ -950,7 +1005,7 @@ def p_choose( .unsqueeze(1) .expand(-1, self.num_heads, -1, -1) .contiguous() - .view(-1, tgt_len, src_len) + .view(-1, tgt_len, max_src_len) ) return p_choose From 579a48f4be3876082ea646880061a98c94357af1 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Tue, 30 Mar 2021 10:34:40 -0700 Subject: [PATCH 540/707] pull unwrap_checkpoint() into library for reuse Summary: there was not a central place that model.jit() calls into, so we had to pull the logic out of pyspeech, and make it available in fairseq library. Reviewed By: myleott Differential Revision: D27349919 fbshipit-source-id: f486c11fc840d4d13a4b9265ec2e7f5cc770216c --- fairseq/modules/checkpoint_activations.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index f4a277f349..b44fc346ce 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -38,6 +38,17 @@ def checkpoint_wrapper(m, offload_to_cpu=False): return m +def unwrap_checkpoint(m: torch.nn.Module): + """ + unwrap a module and its children from checkpoint_wrapper + """ + for module in m.modules(): + if hasattr(module, "precheckpoint_forward"): + module.forward = module.precheckpoint_forward + del module.precheckpoint_forward + return m + + def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. From 14807a361202ba34dbbd3a533899db57a0ebda19 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Wed, 31 Mar 2021 17:08:08 -0700 Subject: [PATCH 541/707] Update simultaneous translation docs (#1767) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1767 Test Plan: Imported from GitHub, without a `Test Plan:` line. This pull request contains - An example of EN_JA simul t2t model - Reorganizing simul trans docs - Removal of out-of-date files Reviewed By: jmp84 Differential Revision: D27467907 Pulled By: xutaima fbshipit-source-id: 137165b007cf5301bdc51a0a277ba91cbf733092 --- examples/simultaneous_translation/README.md | 111 +-------- .../simultaneous_translation/docs/ende-mma.md | 74 ++++++ .../docs/enja-waitk.md | 106 ++++++++ .../simultaneous_translation/eval/__init__.py | 4 - .../eval/agents/__init__.py | 24 -- .../eval/agents/agent.py | 67 ------ .../eval/agents/simul_t2t_enja.py | 226 ++++++++++++++++++ .../eval/agents/simul_trans_agent.py | 167 ------------- .../eval/agents/simul_trans_text_agent.py | 81 ------- .../eval/agents/word_splitter.py | 91 ------- .../simultaneous_translation/eval/client.py | 100 -------- .../eval/eval_latency.py | 78 ------ .../simultaneous_translation/eval/evaluate.py | 81 ------- .../eval/scorers/__init__.py | 19 -- .../eval/scorers/scorer.py | 175 -------------- .../eval/scorers/text_scorer.py | 41 ---- .../simultaneous_translation/eval/server.py | 89 ------- .../agents/simul_trans_agent.py | 200 ---------------- ...moothed_cross_entropy_latency_augmented.py | 2 +- fairseq/tasks/simultaneous_translation.py | 42 ++++ 20 files changed, 454 insertions(+), 1324 deletions(-) create mode 100644 examples/simultaneous_translation/docs/ende-mma.md create mode 100644 examples/simultaneous_translation/docs/enja-waitk.md delete mode 100644 examples/simultaneous_translation/eval/__init__.py delete mode 100644 examples/simultaneous_translation/eval/agents/__init__.py delete mode 100644 examples/simultaneous_translation/eval/agents/agent.py create mode 100644 examples/simultaneous_translation/eval/agents/simul_t2t_enja.py delete mode 100644 examples/simultaneous_translation/eval/agents/simul_trans_agent.py delete mode 100644 examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py delete mode 100644 examples/simultaneous_translation/eval/agents/word_splitter.py delete mode 100644 examples/simultaneous_translation/eval/client.py delete mode 100644 examples/simultaneous_translation/eval/eval_latency.py delete mode 100644 examples/simultaneous_translation/eval/evaluate.py delete mode 100644 examples/simultaneous_translation/eval/scorers/__init__.py delete mode 100644 examples/simultaneous_translation/eval/scorers/scorer.py delete mode 100644 examples/simultaneous_translation/eval/scorers/text_scorer.py delete mode 100644 examples/simultaneous_translation/eval/server.py delete mode 100644 examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py create mode 100644 fairseq/tasks/simultaneous_translation.py diff --git a/examples/simultaneous_translation/README.md b/examples/simultaneous_translation/README.md index bbc6dacdda..62a005e0ec 100644 --- a/examples/simultaneous_translation/README.md +++ b/examples/simultaneous_translation/README.md @@ -1,106 +1,5 @@ -# Simultaneous Machine Translation - -This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS) - -## Prepare Data - -[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh) - -## Training - -- MMA-IL - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type infinite_lookback \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --latency-weight-avg 0.1 \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - -- MMA-H - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type hard_aligned \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --latency-weight-var 0.1 \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - -- wait-k - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type wait-k \ - --waitk-lagging 3 \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - - -## Evaluation - -More details on evaluation can be found [here](https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/simultaneous_translation/docs/evaluation.md) - -### Start the server - -```shell -python ./eval/server.py \ - --src-file $SRC_FILE \ - --ref-file $TGT_FILE -``` - -### Run the client - -```shell -python ./evaluate.py \ - --data-bin data-bin/wmt15_en_de_32k \ - --model-path ./checkpoints/checkpoint_best.pt - --scores --output $RESULT_DIR -``` - -### Run evaluation locally without server - -```shell -python ./eval/evaluate.py - --local \ - --src-file $SRC_FILE \ - --tgt-file $TGT_FILE \ - --data-bin data-bin/wmt15_en_de_32k \ - --model-path ./checkpoints/checkpoint_best.pt \ - --scores --output $RESULT_DIR -``` +# Simultaneous Translation +Examples of simultaneous translation in fairseq +- [English-to-Japanese text-to-text wait-k model](docs/enja-waitk.md) +- [English-to-Germen text-to-text monotonic multihead attention model](docs/ende-mma.md) +- [English-to-Germen speech-to-text simultaneous translation model](../speech_to_text/docs/simulst_mustc_example.md) diff --git a/examples/simultaneous_translation/docs/ende-mma.md b/examples/simultaneous_translation/docs/ende-mma.md new file mode 100644 index 0000000000..241d604a3b --- /dev/null +++ b/examples/simultaneous_translation/docs/ende-mma.md @@ -0,0 +1,74 @@ +# Simultaneous Machine Translation + +This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS) + +## Prepare Data + +[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh) + +Another example of training an English to Japanese model can be found [here](docs/enja.md) + +## Training + +- MMA-IL + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type infinite_lookback \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-avg 0.1 \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` + +- MMA-H + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type hard_aligned \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-var 0.1 \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` + +- wait-k + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type wait-k \ + --waitk-lagging 3 \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` diff --git a/examples/simultaneous_translation/docs/enja-waitk.md b/examples/simultaneous_translation/docs/enja-waitk.md new file mode 100644 index 0000000000..fb9d82576f --- /dev/null +++ b/examples/simultaneous_translation/docs/enja-waitk.md @@ -0,0 +1,106 @@ +# An example of English to Japaneses Simultaneous Translation System + +This is an example of training and evaluating a transformer *wait-k* English to Japanese simultaneous text-to-text translation model. + +## Data Preparation +This section introduces the data preparation for training and evaluation. +If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation) + +For illustration, we only use the following subsets of the available data from [WMT20 news translation task](http://www.statmt.org/wmt20/translation-task.html), which results in 7,815,391 sentence pairs. +- News Commentary v16 +- Wiki Titles v3 +- WikiMatrix V1 +- Japanese-English Subtitle Corpus +- The Kyoto Free Translation Task Corpus + +We use WMT20 development data as development set. Training `transformer_vaswani_wmt_en_de_big` model on such amount of data will result in 17.3 BLEU with greedy search and 19.7 with beam (10) search. Notice that a better performance can be achieved with the full WMT training data. + +We use [sentencepiece](https://github.com/google/sentencepiece) toolkit to tokenize the data with a vocabulary size of 32000. +Additionally, we filtered out the sentences longer than 200 words after tokenization. +Assuming the tokenized text data is saved at `${DATA_DIR}`, +we prepare the data binary with the following command. + +```bash +fairseq-preprocess \ + --source-lang en --target-lang ja \ + --trainpref ${DATA_DIR}/train \ + --validpref ${DATA_DIR}/dev \ + --testpref ${DATA_DIR}/test \ + --destdir ${WMT20_ENJA_DATA_BIN} \ + --nwordstgt 32000 --nwordssrc 32000 \ + --workers 20 +``` + +## Simultaneous Translation Model Training +To train a wait-k `(k=10)` model. +```bash +fairseq-train ${WMT20_ENJA_DATA_BIN} \ + --save-dir ${SAVEDIR} + --simul-type waitk \ + --waitk-lagging 10 \ + --max-epoch 70 \ + --arch transformer_monotonic_vaswani_wmt_en_de_big \ + --optimizer adam \ + --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt \ + --warmup-init-lr 1e-07 \ + --warmup-updates 4000 \ + --lr 0.0005 \ + --stop-min-lr 1e-09 \ + --clip-norm 10.0 \ + --dropout 0.3 \ + --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy \ + --label-smoothing 0.1 \ + --max-tokens 3584 +``` +This command is for training on 8 GPUs. Equivalently, the model can be trained on one GPU with `--update-freq 8`. + +## Inference & Evaluation +First of all, install [SimulEval](https://github.com/facebookresearch/SimulEval) for evaluation. + +```bash +git clone https://github.com/facebookresearch/SimulEval.git +cd SimulEval +pip install -e . +``` + +The following command is for the evaluation. +Assuming the source and reference files are `${SRC_FILE}` and `${REF_FILE}`, the sentencepiece model file for English is saved at `${SRC_SPM_PATH}` + + +```bash +simuleval \ + --source ${SRC_FILE} \ + --target ${TGT_FILE} \ + --data-bin ${WMT20_ENJA_DATA_BIN} \ + --sacrebleu-tokenizer ja-mecab \ + --eval-latency-unit char \ + --no-space \ + --src-splitter-type sentencepiecemodel \ + --src-splitter-path ${SRC_SPM_PATH} \ + --agent ${FAIRSEQ}/examples/simultaneous_translation/agents/simul_trans_text_agent_enja.py \ + --model-path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --output ${OUTPUT} \ + --scores +``` + +The `--data-bin` should be the same in previous sections if you prepare the data from the scratch. +If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_databin.tgz) and a pretrained checkpoint (wait-k=10 model) can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_wait10_ckpt.pt). + +The output should look like this: +```bash +{ + "Quality": { + "BLEU": 11.442253287568398 + }, + "Latency": { + "AL": 8.6587861866951, + "AP": 0.7863304776251316, + "DAL": 9.477850951194764 + } +} +``` +The latency is evaluated by characters (`--eval-latency-unit`) on the target side. The latency is evaluated with `sacrebleu` with `MeCab` tokenizer `--sacrebleu-tokenizer ja-mecab`. `--no-space` indicates that do not add space when merging the predicted words. + +If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory. diff --git a/examples/simultaneous_translation/eval/__init__.py b/examples/simultaneous_translation/eval/__init__.py deleted file mode 100644 index 6264236915..0000000000 --- a/examples/simultaneous_translation/eval/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/examples/simultaneous_translation/eval/agents/__init__.py b/examples/simultaneous_translation/eval/agents/__init__.py deleted file mode 100644 index 511e7b2474..0000000000 --- a/examples/simultaneous_translation/eval/agents/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - -from fairseq import registry - - -build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry( - "--agent-type" -) - - -DEFAULT_EOS = "</s>" -GET = 0 -SEND = 1 - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - module = file[: file.find(".py")] - importlib.import_module("agents." + module) diff --git a/examples/simultaneous_translation/eval/agents/agent.py b/examples/simultaneous_translation/eval/agents/agent.py deleted file mode 100644 index 997392cf9b..0000000000 --- a/examples/simultaneous_translation/eval/agents/agent.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import time -from functools import partial -from multiprocessing.pool import ThreadPool as Pool - -from . import DEFAULT_EOS, GET, SEND - - -class Agent(object): - "an agent needs to follow this pattern" - - def __init__(self, *args, **kwargs): - pass - - def init_states(self, *args, **kwargs): - raise NotImplementedError - - def update_states(self, states, new_state): - raise NotImplementedError - - def finish_eval(self, states, new_state): - raise NotImplementedError - - def policy(self, state): - raise NotImplementedError - - def reset(self): - raise NotImplementedError - - def decode(self, session, low=0, high=100000, num_thread=10): - corpus_info = session.corpus_info() - high = min(corpus_info["num_sentences"] - 1, high) - if low >= high: - return - - t0 = time.time() - if num_thread > 1: - with Pool(10) as p: - p.map( - partial(self._decode_one, session), - [sent_id for sent_id in range(low, high + 1)], - ) - else: - for sent_id in range(low, high + 1): - self._decode_one(session, sent_id) - - print(f"Finished {low} to {high} in {time.time() - t0}s") - - def _decode_one(self, session, sent_id): - action = {} - self.reset() - states = self.init_states() - while action.get("value", None) != DEFAULT_EOS: - # take an action - action = self.policy(states) - - if action["key"] == GET: - new_states = session.get_src(sent_id, action["value"]) - states = self.update_states(states, new_states) - - elif action["key"] == SEND: - session.send_hypo(sent_id, action["value"]) - print(" ".join(states["tokens"]["tgt"])) diff --git a/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py b/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py new file mode 100644 index 0000000000..8f3c8703ca --- /dev/null +++ b/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from fairseq import checkpoint_utils, tasks +import sentencepiece as spm +import torch + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import TextAgent +except ImportError: + print("Please install simuleval 'pip install simuleval'") + + +BOS_PREFIX = "\u2581" + + +class SimulTransTextAgentJA(TextAgent): + """ + Simultaneous Translation + Text agent for Japanese + """ + def __init__(self, args): + + # Whether use gpu + self.gpu = getattr(args, "gpu", False) + + # Max len + self.max_len = args.max_len + + # Load Model + self.load_model_vocab(args) + + # build word splitter + self.build_word_splitter(args) + + self.eos = DEFAULT_EOS + + def initialize_states(self, states): + states.incremental_states = dict() + states.incremental_states["online"] = dict() + + def to_device(self, tensor): + if self.gpu: + return tensor.cuda() + else: + return tensor.cpu() + + def load_model_vocab(self, args): + + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + task_args = state["cfg"]["task"] + task_args.data = args.data_bin + + task = tasks.setup_task(task_args) + + # build model for ensemble + state["cfg"]["model"].load_pretrained_encoder_from = None + state["cfg"]["model"].load_pretrained_decoder_from = None + + self.model = task.build_model(state["cfg"]["model"]) + self.model.load_state_dict(state["model"], strict=True) + self.model.eval() + self.model.share_memory() + + if self.gpu: + self.model.cuda() + + # Set dictionary + self.dict = {} + self.dict["tgt"] = task.target_dictionary + self.dict["src"] = task.source_dictionary + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--model-path', type=str, required=True, + help='path to your pretrained model.') + parser.add_argument("--data-bin", type=str, required=True, + help="Path of data binary") + parser.add_argument("--max-len", type=int, default=100, + help="Max length of translation") + parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for target text.") + parser.add_argument("--tgt-splitter-path", type=str, default=None, + help="Subword splitter model path for target text.") + parser.add_argument("--src-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for source text.") + parser.add_argument("--src-splitter-path", type=str, default=None, + help="Subword splitter model path for source text.") + # fmt: on + return parser + + def build_word_splitter(self, args): + self.spm = {} + for lang in ['src', 'tgt']: + if getattr(args, f'{lang}_splitter_type', None): + path = getattr(args, f'{lang}_splitter_path', None) + if path: + self.spm[lang] = spm.SentencePieceProcessor() + self.spm[lang].Load(path) + + def segment_to_units(self, segment, states): + # Split a full word (segment) into subwords (units) + return self.spm['src'].EncodeAsPieces(segment) + + def update_model_encoder(self, states): + if len(states.units.source) == 0: + return + + src_indices = [ + self.dict['src'].index(x) + for x in states.units.source.value + ] + + if states.finish_read(): + # Append the eos index when the prediction is over + src_indices += [self.dict["tgt"].eos_index] + + src_indices = self.to_device( + torch.LongTensor(src_indices).unsqueeze(0) + ) + src_lengths = self.to_device( + torch.LongTensor([src_indices.size(1)]) + ) + + states.encoder_states = self.model.encoder(src_indices, src_lengths) + + torch.cuda.empty_cache() + + def update_states_read(self, states): + # Happens after a read action. + self.update_model_encoder(states) + + def units_to_segment(self, units, states): + # Merge sub words (units) to full word (segment). + # For Japanese, we can directly send + # the untokenized token to server except the BOS token + # with following option + # --sacrebleu-tokenizer MeCab + # --eval-latency-unit char + # --no-space + token = units.value.pop() + + if ( + token == self.dict["tgt"].eos_word + or len(states.segments.target) > self.max_len + ): + return DEFAULT_EOS + + if BOS_PREFIX == token: + return None + if token[0] == BOS_PREFIX: + return token[1:] + else: + return token + + def policy(self, states): + + if not getattr(states, "encoder_states", None): + # No encoder states, read a token first + return READ_ACTION + + # encode previous predicted target tokens + tgt_indices = self.to_device( + torch.LongTensor( + [self.model.decoder.dictionary.eos()] + + [ + self.dict['tgt'].index(x) + for x in states.units.target.value + if x is not None + ] + ).unsqueeze(0) + ) + + # Current steps + states.incremental_states["steps"] = { + "src": states.encoder_states["encoder_out"][0].size(0), + "tgt": 1 + len(states.units.target), + } + + # Online only means the reading is not finished + states.incremental_states["online"]["only"] = ( + torch.BoolTensor([not states.finish_read()]) + ) + + x, outputs = self.model.decoder.forward( + prev_output_tokens=tgt_indices, + encoder_out=states.encoder_states, + incremental_state=states.incremental_states, + ) + + states.decoder_out = x + + torch.cuda.empty_cache() + + if outputs.action == 0: + return READ_ACTION + else: + return WRITE_ACTION + + def predict(self, states): + # Predict target token from decoder states + decoder_states = states.decoder_out + + lprobs = self.model.get_normalized_probs( + [decoder_states[:, -1:]], log_probs=True + ) + + index = lprobs.argmax(dim=-1)[0, 0].item() + + if index != self.dict['tgt'].eos_index: + token = self.dict['tgt'].string([index]) + else: + token = self.dict['tgt'].eos_word + + return token diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py deleted file mode 100644 index 071b9e89ce..0000000000 --- a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os - -from fairseq import checkpoint_utils, tasks, utils - -from . import DEFAULT_EOS, GET, SEND -from .agent import Agent - - -class SimulTransAgent(Agent): - def __init__(self, args): - # Load Model - self.load_model(args) - - # build word spliter - self.build_word_splitter(args) - - self.max_len = args.max_len - - self.eos = DEFAULT_EOS - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--model-path', type=str, required=True, - help='path to your pretrained model.') - parser.add_argument("--data-bin", type=str, required=True, - help="Path of data binary") - parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation", - help="User directory for simultaneous translation") - parser.add_argument("--src-splitter-type", type=str, default=None, - help="Subword splitter type for source text") - parser.add_argument("--tgt-splitter-type", type=str, default=None, - help="Subword splitter type for target text") - parser.add_argument("--src-splitter-path", type=str, default=None, - help="Subword splitter model path for source text") - parser.add_argument("--tgt-splitter-path", type=str, default=None, - help="Subword splitter model path for target text") - parser.add_argument("--max-len", type=int, default=150, - help="Maximum length difference between source and target prediction") - parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', - help='A dictionary used to override model args at generation ' - 'that were used during model training') - # fmt: on - return parser - - def load_dictionary(self, task): - raise NotImplementedError - - def load_model(self, args): - args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") - utils.import_user_module(args) - filename = args.model_path - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - - state = checkpoint_utils.load_checkpoint_to_cpu( - filename, json.loads(args.model_overrides) - ) - - saved_args = state["args"] - saved_args.data = args.data_bin - - task = tasks.setup_task(saved_args) - - # build model for ensemble - self.model = task.build_model(saved_args) - self.model.load_state_dict(state["model"], strict=True) - - # Set dictionary - self.load_dictionary(task) - - def init_states(self): - return { - "indices": {"src": [], "tgt": []}, - "tokens": {"src": [], "tgt": []}, - "segments": {"src": [], "tgt": []}, - "steps": {"src": 0, "tgt": 0}, - "finished": False, - "finish_read": False, - "model_states": {}, - } - - def update_states(self, states, new_state): - raise NotImplementedError - - def policy(self, states): - # Read and Write policy - action = None - - while action is None: - if states["finished"]: - # Finish the hypo by sending eos to server - return self.finish_action() - - # Model make decision given current states - decision = self.model.decision_from_states(states) - - if decision == 0 and not self.finish_read(states): - # READ - action = self.read_action(states) - else: - # WRITE - action = self.write_action(states) - - # None means we make decision again but not sending server anything - # This happened when read a bufffered token - # Or predict a subword - return action - - def finish_read(self, states): - raise NotImplementedError - - def write_action(self, states): - token, index = self.model.predict_from_states(states) - - if ( - index == self.dict["tgt"].eos() - or len(states["tokens"]["tgt"]) > self.max_len - ): - # Finish this sentence is predict EOS - states["finished"] = True - end_idx_last_full_word = self._target_length(states) - - else: - states["tokens"]["tgt"] += [token] - end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( - states["tokens"]["tgt"] - ) - self._append_indices(states, [index], "tgt") - - if end_idx_last_full_word > states["steps"]["tgt"]: - # Only sent detokenized full words to the server - word = self.word_splitter["tgt"].merge( - states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] - ) - states["steps"]["tgt"] = end_idx_last_full_word - states["segments"]["tgt"] += [word] - - return {"key": SEND, "value": word} - else: - return None - - def read_action(self, states): - return {"key": GET, "value": None} - - def finish_action(self): - return {"key": SEND, "value": DEFAULT_EOS} - - def reset(self): - pass - - def finish_eval(self, states, new_state): - if len(new_state) == 0 and len(states["indices"]["src"]) == 0: - return True - return False - - def _append_indices(self, states, new_indices, key): - states["indices"][key] += new_indices - - def _target_length(self, states): - return len(states["tokens"]["tgt"]) diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py deleted file mode 100644 index 7c34817bf6..0000000000 --- a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from . import DEFAULT_EOS, GET, register_agent -from .simul_trans_agent import SimulTransAgent -from .word_splitter import SPLITTER_DICT - - -@register_agent("simul_trans_text") -class SimulTransTextAgent(SimulTransAgent): - def build_word_splitter(self, args): - self.word_splitter = {} - - self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type]( - getattr(args, f"src_splitter_path") - ) - self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type]( - getattr(args, f"tgt_splitter_path") - ) - - def load_dictionary(self, task): - self.dict = {} - self.dict["tgt"] = task.target_dictionary - self.dict["src"] = task.source_dictionary - - def update_states(self, states, new_state): - if states["finish_read"]: - return states - - new_word = new_state["segment"] - - # Split words and index the token - if new_word not in [DEFAULT_EOS]: - tokens = self.word_splitter["src"].split(new_word) - # Get indices from dictionary - # You can change to you own dictionary - indices = ( - self.dict["src"] - .encode_line( - tokens, - line_tokenizer=lambda x: x, - add_if_not_exist=False, - append_eos=False, - ) - .tolist() - ) - else: - tokens = [new_word] - indices = [self.dict["src"].eos()] - states["finish_read"] = True - - # Update states - states["segments"]["src"] += [new_word] - states["tokens"]["src"] += tokens - self._append_indices(states, indices, "src") - - return states - - def read_action(self, states): - # Increase source step by one - states["steps"]["src"] += 1 - - # At leat one word is read - if len(states["tokens"]["src"]) == 0: - return {"key": GET, "value": None} - - # Only request new word if there is no buffered tokens - if len(states["tokens"]["src"]) <= states["steps"]["src"]: - return {"key": GET, "value": None} - - return None - - def finish_read(self, states): - # The first means all segments (full words) has been read from server - # The second means all tokens (subwords) has been read locally - return ( - states["finish_read"] - and len(states["tokens"]["src"]) == states["steps"]["src"] - ) diff --git a/examples/simultaneous_translation/eval/agents/word_splitter.py b/examples/simultaneous_translation/eval/agents/word_splitter.py deleted file mode 100644 index c3f71200a5..0000000000 --- a/examples/simultaneous_translation/eval/agents/word_splitter.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -class SubwordSplitter(object): - def process_line(self, string): - raise NotImplementedError - - def split(self, string): - raise NotImplementedError - - -class NoneWordSplitter(object): - def __init__(self, model): - pass - - def split(self, string): - return [string] - - def process_line(self, string): - return [string] - - def finished_word(self, string): - return True - - def merge(self, list_of_string): - return "".join(list_of_string) - - def last_full_word_step(self, tokens, step): - return len(tokens) - - def end_idx_last_full_word(self, tokens): - return len(tokens) - - -class BPEWordSplitter(object): - # TODO: lock back here - def __init__(self, model_path): - super().__init__() - from subword_nmt.apply_bpe import BPE - - with open(model_path) as f: - self.model = BPE(f) - - def split(self, string): - return self.model.process_line(string).split() - - def end_idx_last_full_word(self, tokens): - # Begin of word indices - bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"] - - if len(bow_indices) < 2: - return 0 - else: - return bow_indices[-1] - - def merge(self, list_of_string): - return " ".join([item.replace("@@", "") for item in list_of_string]) - - -class SentencePieceModelWordSplitter(object): - def __init__(self, model_path): - super().__init__() - import sentencepiece as spm - - self.model = spm.SentencePieceProcessor() - self.model.Load(model_path) - - def split(self, string): - return self.model.EncodeAsPieces(string) - - def end_idx_last_full_word(self, tokens): - # Begin of word indices - bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"] - - if len(bow_indices) < 2: - return 0 - else: - return bow_indices[-1] - - def merge(self, list_of_string): - return self.model.DecodePieces(list_of_string) - - -SPLITTER_DICT = { - None: NoneWordSplitter, - "BPE": BPEWordSplitter, - "SentencePieceModel": SentencePieceModelWordSplitter, -} diff --git a/examples/simultaneous_translation/eval/client.py b/examples/simultaneous_translation/eval/client.py deleted file mode 100644 index 3ca4ea73b8..0000000000 --- a/examples/simultaneous_translation/eval/client.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import requests -from scorers import build_scorer - - -class SimulSTEvaluationService(object): - DEFAULT_HOSTNAME = "localhost" - DEFAULT_PORT = 12321 - - def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT): - self.hostname = hostname - self.port = port - self.base_url = f"http://{self.hostname}:{self.port}" - - def __enter__(self): - self.new_session() - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def new_session(self): - # start eval session - url = f"{self.base_url}" - - try: - _ = requests.post(url) - except Exception as e: - print(f"Failed to start an evaluation session: {e}") - - print("Evaluation session started.") - return self - - def get_scores(self): - # end eval session - url = f"{self.base_url}/result" - try: - r = requests.get(url) - print("Scores: {}".format(r.json())) - print("Evaluation session finished.") - except Exception as e: - print(f"Failed to end an evaluation session: {e}") - - def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: - url = f"{self.base_url}/src" - params = {"sent_id": sent_id} - if extra_params is not None: - for key in extra_params.keys(): - params[key] = extra_params[key] - try: - r = requests.get(url, params=params) - except Exception as e: - print(f"Failed to request a source segment: {e}") - return r.json() - - def send_hypo(self, sent_id: int, hypo: str) -> None: - url = f"{self.base_url}/hypo" - params = {"sent_id": sent_id} - - try: - requests.put(url, params=params, data=hypo.encode("utf-8")) - except Exception as e: - print(f"Failed to send a translated segment: {e}") - - def corpus_info(self): - url = f"{self.base_url}" - try: - r = requests.get(url) - except Exception as e: - print(f"Failed to request corpus information: {e}") - - return r.json() - - -class SimulSTLocalEvaluationService(object): - def __init__(self, args): - self.scorer = build_scorer(args) - - def get_scores(self): - return self.scorer.score() - - def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: - if extra_params is not None: - segment_size = extra_params.get("segment_size", None) - else: - segment_size = None - - return self.scorer.send_src(int(sent_id), segment_size) - - def send_hypo(self, sent_id: int, hypo: str) -> None: - list_of_tokens = hypo.strip().split() - self.scorer.recv_hyp(sent_id, list_of_tokens) - - def corpus_info(self): - return self.scorer.get_info() diff --git a/examples/simultaneous_translation/eval/eval_latency.py b/examples/simultaneous_translation/eval/eval_latency.py deleted file mode 100644 index 50021de47c..0000000000 --- a/examples/simultaneous_translation/eval/eval_latency.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import json - -import torch -from examples.simultaneous_translation.utils.latency import LatencyInference - - -LATENCY_METRICS = [ - "differentiable_average_lagging", - "average_lagging", - "average_proportion", -] - - -class LatencyScorer: - def __init__(self, start_from_zero=True): - self.recorder = [] - self.scores = {} - self.scorer = LatencyInference() - self.start_from_zero = start_from_zero - - def update_reorder(self, list_of_dict): - self.recorder = [] - for info in list_of_dict: - delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]] - delays = torch.LongTensor(delays).unsqueeze(0) - src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) - - self.recorder.append(self.scorer(delays, src_len)) - - def cal_latency(self): - self.scores = {} - for metric in LATENCY_METRICS: - self.scores[metric] = sum( - [x[metric][0, 0].item() for x in self.recorder] - ) / len(self.recorder) - return self.scores - - @classmethod - def score(cls, list_of_dict, start_from_zero=True): - scorer_to_return = cls(start_from_zero) - scorer_to_return.update_reorder(list_of_dict) - scorer_to_return.cal_latency() - return scorer_to_return.scores - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--input", required=True) - parser.add_argument("--start-from-zero", action="store_true") - args = parser.parse_args() - - scorer = LatencyInference() - recorder = [] - with open(args.input, "r") as f: - for line in f: - info = json.loads(line) - - delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]] - - delays = torch.LongTensor(delays).unsqueeze(0) - - src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) - - recorder.append(scorer(delays, src_len)) - - average_results = {} - - for metric in LATENCY_METRICS: - average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len( - recorder - ) - print(f"{metric}: {average_results[metric]}") diff --git a/examples/simultaneous_translation/eval/evaluate.py b/examples/simultaneous_translation/eval/evaluate.py deleted file mode 100644 index 2f7474621a..0000000000 --- a/examples/simultaneous_translation/eval/evaluate.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import argparse - -from agents import build_agent -from client import SimulSTEvaluationService, SimulSTLocalEvaluationService -from fairseq.registry import REGISTRIES - - -DEFAULT_HOSTNAME = "localhost" -DEFAULT_PORT = 12321 - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" - ) - parser.add_argument( - "--port", type=int, default=DEFAULT_PORT, help="server port number" - ) - parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") - parser.add_argument("--scorer-type", default="text", help="Scorer type") - parser.add_argument( - "--start-idx", - type=int, - default=0, - help="Start index of the sentence to evaluate", - ) - parser.add_argument( - "--end-idx", - type=int, - default=float("inf"), - help="End index of the sentence to evaluate", - ) - parser.add_argument( - "--scores", action="store_true", help="Request scores from server" - ) - parser.add_argument("--reset-server", action="store_true", help="Reset the server") - parser.add_argument( - "--num-threads", type=int, default=10, help="Number of threads used by agent" - ) - parser.add_argument( - "--local", action="store_true", default=False, help="Local evaluation" - ) - - args, _ = parser.parse_known_args() - - for registry_name, REGISTRY in REGISTRIES.items(): - choice = getattr(args, registry_name, None) - if choice is not None: - cls = REGISTRY["registry"][choice] - if hasattr(cls, "add_args"): - cls.add_args(parser) - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = get_args() - - if args.local: - session = SimulSTLocalEvaluationService(args) - else: - session = SimulSTEvaluationService(args.hostname, args.port) - - if args.reset_server: - session.new_session() - - if args.agent_type is not None: - agent = build_agent(args) - agent.decode(session, args.start_idx, args.end_idx, args.num_threads) - - if args.scores: - session.get_scores() - print(session.get_scores()) diff --git a/examples/simultaneous_translation/eval/scorers/__init__.py b/examples/simultaneous_translation/eval/scorers/__init__.py deleted file mode 100644 index 0a0e0a0518..0000000000 --- a/examples/simultaneous_translation/eval/scorers/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - -from fairseq import registry - - -(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry( - "--scorer-type" -) - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - module = file[: file.find(".py")] - importlib.import_module("scorers." + module) diff --git a/examples/simultaneous_translation/eval/scorers/scorer.py b/examples/simultaneous_translation/eval/scorers/scorer.py deleted file mode 100644 index d6d3e30aef..0000000000 --- a/examples/simultaneous_translation/eval/scorers/scorer.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -from collections import defaultdict - -from examples.simultaneous_translation.eval.eval_latency import LatencyScorer -from vizseq.scorers.bleu import BLEUScorer -from vizseq.scorers.meteor import METEORScorer -from vizseq.scorers.ter import TERScorer - - -DEFAULT_EOS = "</s>" - - -class SimulScorer(object): - def __init__(self, args): - self.tokenizer = args.tokenizer - self.output_dir = args.output - if args.output is not None: - self.output_files = { - "text": os.path.join(args.output, "text"), - "delay": os.path.join(args.output, "delay"), - "scores": os.path.join(args.output, "scores"), - } - else: - self.output_files = None - self.eos = DEFAULT_EOS - self.data = {"tgt": []} - self.reset() - - def get_info(self): - return {"num_sentences": len(self)} - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--src-file', type=str, required=True, - help='Source input file') - parser.add_argument('--tgt-file', type=str, required=True, - help='Target reference file') - parser.add_argument('--tokenizer', default="13a", choices=["none", "13a"], - help='Tokenizer used for sacrebleu') - parser.add_argument('--output', type=str, default=None, - help='Path for output directory') - # fmt: on - - def send_src(self, sent_id, *args): - raise NotImplementedError - - def recv_hyp(self, sent_id, list_of_tokens): - for token in list_of_tokens: - self.translations[sent_id].append((token, self.steps[sent_id])) - - def reset(self): - self.steps = defaultdict(int) - self.translations = defaultdict(list) - - def src_lengths(self): - raise NotImplementedError - - def score(self): - translations = [] - delays = [] - for i in range(1 + max(self.translations.keys())): - translations += [" ".join(t[0] for t in self.translations[i][:-1])] - delays += [[t[1] for t in self.translations[i]]] - - bleu_score = BLEUScorer( - sent_level=False, - corpus_level=True, - extra_args={"bleu_tokenizer": self.tokenizer}, - ).score(translations, [self.data["tgt"]]) - - ter_score = TERScorer(sent_level=False, corpus_level=True).score( - translations, [self.data["tgt"]] - ) - meteor_score = METEORScorer(sent_level=False, corpus_level=True).score( - translations, [self.data["tgt"]] - ) - - latency_score = LatencyScorer().score( - [ - {"src_len": src_len, "delays": delay} - for src_len, delay in zip(self.src_lengths(), delays) - ], - start_from_zero=False, - ) - - scores = { - "BLEU": bleu_score[0], - "TER": ter_score[0], - "METEOR": meteor_score[0], - "DAL": latency_score["differentiable_average_lagging"], - "AL": latency_score["average_lagging"], - "AP": latency_score["average_proportion"], - } - - if self.output_files is not None: - try: - os.makedirs(self.output_dir, exist_ok=True) - self.write_results_to_file(translations, delays, scores) - except BaseException as be: - print(f"Failed to write results to {self.output_dir}.") - print(be) - print("Skip writing predictions") - - return scores - - def write_results_to_file(self, translations, delays, scores): - if self.output_files["text"] is not None: - with open(self.output_files["text"], "w") as f: - for line in translations: - f.write(line + "\n") - - if self.output_files["delay"] is not None: - with open(self.output_files["delay"], "w") as f: - for i, delay in enumerate(delays): - f.write( - json.dumps({"src_len": self.src_lengths()[i], "delays": delay}) - + "\n" - ) - - with open(self.output_files["scores"], "w") as f: - for key, value in scores.items(): - f.write(f"{key}, {value}\n") - - @classmethod - def _load_text_file(cls, file, split=False): - with open(file) as f: - if split: - return [r.strip().split() for r in f] - else: - return [r.strip() for r in f] - - @classmethod - def _load_text_from_json(cls, file): - list_to_return = [] - with open(file) as f: - content = json.load(f) - for item in content["utts"].values(): - list_to_return.append(item["output"]["text"].strip()) - return list_to_return - - @classmethod - def _load_wav_info_from_json(cls, file): - list_to_return = [] - with open(file) as f: - content = json.load(f) - for item in content["utts"].values(): - list_to_return.append( - { - "path": item["input"]["path"].strip(), - "length": item["input"]["length_ms"], - } - ) - return list_to_return - - @classmethod - def _load_wav_info_from_list(cls, file): - list_to_return = [] - with open(file) as f: - for line in f: - list_to_return.append( - { - "path": line.strip(), - } - ) - return list_to_return - - def __len__(self): - return len(self.data["tgt"]) diff --git a/examples/simultaneous_translation/eval/scorers/text_scorer.py b/examples/simultaneous_translation/eval/scorers/text_scorer.py deleted file mode 100644 index 649a2c7e5c..0000000000 --- a/examples/simultaneous_translation/eval/scorers/text_scorer.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from . import register_scorer -from .scorer import SimulScorer - - -@register_scorer("text") -class SimulTextScorer(SimulScorer): - def __init__(self, args): - super().__init__(args) - self.data = { - "src": self._load_text_file(args.src_file, split=True), - "tgt": self._load_text_file(args.tgt_file, split=False), - } - - def send_src(self, sent_id, *args): - if self.steps[sent_id] >= len(self.data["src"][sent_id]): - dict_to_return = { - "sent_id": sent_id, - "segment_id": self.steps[sent_id], - "segment": self.eos, - } - # Consider EOS - self.steps[sent_id] = len(self.data["src"][sent_id]) + 1 - else: - dict_to_return = { - "sent_id": sent_id, - "segment_id": self.steps[sent_id], - "segment": self.data["src"][sent_id][self.steps[sent_id]], - } - - self.steps[sent_id] += 1 - - return dict_to_return - - def src_lengths(self): - # +1 for eos - return [len(sent) + 1 for sent in self.data["src"]] diff --git a/examples/simultaneous_translation/eval/server.py b/examples/simultaneous_translation/eval/server.py deleted file mode 100644 index e44ceaff85..0000000000 --- a/examples/simultaneous_translation/eval/server.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import argparse -import json -import sys - -from scorers import build_scorer -from tornado import ioloop, web - - -DEFAULT_HOSTNAME = "localhost" -DEFAULT_PORT = 12321 - - -class ScorerHandler(web.RequestHandler): - def initialize(self, scorer): - self.scorer = scorer - - -class EvalSessionHandler(ScorerHandler): - def post(self): - self.scorer.reset() - - def get(self): - r = json.dumps(self.scorer.get_info()) - self.write(r) - - -class ResultHandler(ScorerHandler): - def get(self): - r = json.dumps(self.scorer.score()) - self.write(r) - - -class SourceHandler(ScorerHandler): - def get(self): - sent_id = int(self.get_argument("sent_id")) - segment_size = None - if "segment_size" in self.request.arguments: - string = self.get_argument("segment_size") - if len(string) > 0: - segment_size = int(string) - - r = json.dumps(self.scorer.send_src(int(sent_id), segment_size)) - - self.write(r) - - -class HypothesisHandler(ScorerHandler): - def put(self): - sent_id = int(self.get_argument("sent_id")) - list_of_tokens = self.request.body.decode("utf-8").strip().split() - self.scorer.recv_hyp(sent_id, list_of_tokens) - - -def add_args(): - parser = argparse.ArgumentParser() - # fmt: off - parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME, - help='Server hostname') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, - help='Server port number') - - args, _ = parser.parse_known_args() - # fmt: on - return args - - -def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False): - app = web.Application( - [ - (r"/result", ResultHandler, dict(scorer=scorer)), - (r"/src", SourceHandler, dict(scorer=scorer)), - (r"/hypo", HypothesisHandler, dict(scorer=scorer)), - (r"/", EvalSessionHandler, dict(scorer=scorer)), - ], - debug=debug, - ) - app.listen(port, max_buffer_size=1024 ** 3) - sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n") - ioloop.IOLoop.current().start() - - -if __name__ == "__main__": - args = add_args() - scorer = build_scorer(args) - start_server(scorer, args.hostname, args.port, args.debug) diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py deleted file mode 100644 index 45df5fa227..0000000000 --- a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os - -from fairseq import checkpoint_utils, utils, tasks - -from . import DEFAULT_EOS, GET, SEND -from .agent import Agent - - -class SimulTransAgent(Agent): - def __init__(self, args): - # Load Model - self.load_model(args) - - # build word spliter - self.build_word_splitter(args) - - self.max_len = args.max_len - - self.eos = DEFAULT_EOS - - @staticmethod - def add_args(parser): - parser.add_argument( - "--model-path", - type=str, - required=True, - help="path to your pretrained model.", - ) - parser.add_argument( - "--data-bin", type=str, required=True, help="Path of data binary" - ) - parser.add_argument( - "--user-dir", - type=str, - default="example/simultaneous_translation", - help="User directory for simultaneous translation", - ) - parser.add_argument( - "--src-splitter-type", - type=str, - default=None, - help="Subword splitter type for source text", - ) - parser.add_argument( - "--tgt-splitter-type", - type=str, - default=None, - help="Subword splitter type for target text", - ) - parser.add_argument( - "--src-splitter-path", - type=str, - default=None, - help="Subword splitter model path for source text", - ) - parser.add_argument( - "--tgt-splitter-path", - type=str, - default=None, - help="Subword splitter model path for target text", - ) - parser.add_argument( - "--max-len", - type=int, - default=150, - help="Maximum length difference between source and target prediction", - ) - parser.add_argument( - "--model-overrides", - default="{}", - type=str, - metavar="DICT", - help="A dictionary used to override model args at generation " - "that were used during model training", - ) - # fmt: on - return parser - - def load_dictionary(self, task): - raise NotImplementedError - - def load_model(self, args): - args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") - utils.import_user_module(args) - filename = args.model_path - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - - state = checkpoint_utils.load_checkpoint_to_cpu( - filename, json.loads(args.model_overrides) - ) - - saved_args = state["args"] - saved_args.data = args.data_bin - - task = tasks.setup_task(saved_args) - - # build model for ensemble - self.model = task.build_model(saved_args) - self.model.load_state_dict(state["model"], strict=True) - - # Set dictionary - self.load_dictionary(task) - - def init_states(self): - return { - "indices": {"src": [], "tgt": []}, - "tokens": {"src": [], "tgt": []}, - "segments": {"src": [], "tgt": []}, - "steps": {"src": 0, "tgt": 0}, - "finished": False, - "finish_read": False, - "model_states": {}, - } - - def update_states(self, states, new_state): - raise NotImplementedError - - def policy(self, states): - # Read and Write policy - action = None - - while action is None: - if states["finished"]: - # Finish the hypo by sending eos to server - return self.finish_action() - - # Model make decision given current states - decision = self.model.decision_from_states(states) - - if decision == 0 and not self.finish_read(states): - # READ - action = self.read_action(states) - else: - # WRITE - action = self.write_action(states) - - # None means we make decision again but not sending server anything - # This happened when read a buffered token - # Or predict a subword - return action - - def finish_read(self, states): - raise NotImplementedError - - def write_action(self, states): - token, index = self.model.predict_from_states(states) - - if ( - index == self.dict["tgt"].eos() - or len(states["tokens"]["tgt"]) > self.max_len - ): - # Finish this sentence is predict EOS - states["finished"] = True - end_idx_last_full_word = self._target_length(states) - - else: - states["tokens"]["tgt"] += [token] - end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( - states["tokens"]["tgt"] - ) - self._append_indices(states, [index], "tgt") - - if end_idx_last_full_word > states["steps"]["tgt"]: - # Only sent detokenized full words to the server - word = self.word_splitter["tgt"].merge( - states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] - ) - states["steps"]["tgt"] = end_idx_last_full_word - states["segments"]["tgt"] += [word] - - return {"key": SEND, "value": word} - else: - return None - - def read_action(self, states): - return {"key": GET, "value": None} - - def finish_action(self): - return {"key": SEND, "value": DEFAULT_EOS} - - def reset(self): - pass - - def finish_eval(self, states, new_state): - if len(new_state) == 0 and len(states["indices"]["src"]) == 0: - return True - return False - - def _append_indices(self, states, new_indices, key): - states["indices"][key] += new_indices - - def _target_length(self, states): - return len(states["tokens"]["tgt"]) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index aa3dba31e2..051785238f 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from examples.simultaneous_translation.utils.latency import LatencyTraining from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, @@ -31,6 +30,7 @@ def __init__( super().__init__( task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy ) + from examples.simultaneous_translation.utils.latency import LatencyTraining self.eps = label_smoothing self.latency_weight_avg = latency_weight_avg self.latency_weight_avg_type = latency_weight_avg_type diff --git a/fairseq/tasks/simultaneous_translation.py b/fairseq/tasks/simultaneous_translation.py new file mode 100644 index 0000000000..11c7dc1ea9 --- /dev/null +++ b/fairseq/tasks/simultaneous_translation.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.tasks.translation import ( + TranslationTask, TranslationConfig +) + +try: + import examples.simultaneous_translation # noqa + import_successful = True +except BaseException: + import_successful = False + + +logger = logging.getLogger(__name__) + + +def check_import(flag): + if not flag: + raise ImportError( + "'examples.simultaneous_translation' is not correctly imported. " + "Please considering `pip install -e $FAIRSEQ_DIR`." + ) + + +@register_task("simul_speech_to_text") +class SimulSpeechToTextTask(SpeechToTextTask): + def __init__(self, args, tgt_dict): + check_import(import_successful) + super().__init__(args, tgt_dict) + + +@register_task("simul_text_to_text", dataclass=TranslationConfig) +class SimulTextToTextTask(TranslationTask): + def __init__(self, cfg, src_dict, tgt_dict): + check_import(import_successful) + super().__init__(cfg, src_dict, tgt_dict) From a20dc364647c94417f493ba0b0c8d1e1834e67eb Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Fri, 2 Apr 2021 14:44:20 -0700 Subject: [PATCH 542/707] Several updates on simul st (#1774) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1774 Reviewed By: jmp84 Differential Revision: D27529935 Pulled By: xutaima fbshipit-source-id: 35433cc1d862440ea110e084007ba187149bc193 --- .../docs/simulst_mustc_example.md | 24 ++++++--- examples/speech_to_text/seg_mustc_data.py | 54 +++++++++++++++++++ .../agents/fairseq_simul_st_agent.py | 2 +- 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 examples/speech_to_text/seg_mustc_data.py diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index d83ec086f9..52ca9ac062 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -114,6 +114,14 @@ Each line of target file `${TGT_FILE}` is the translation for each audio file in Translation_1 Translation_2 ``` +The evaluation runs on the original MUSTC segmentation. +The following command will generate the wav list and text file for a evaluation set `${SPLIT}` (chose from `dev`, `tst-COMMON` and `tst-HE`) in MUSTC to `${EVAL_DATA}`. +```bash +python ${FAIRSEQ}/examples/speech_to_text/seg_mustc_data.py \ + --data-root ${MUSTC_ROOT} --lang de \ + --split ${SPLIT} --task st \ + --output ${EVAL_DATA} +``` The `--data-bin` and `--config` should be the same in previous section if you prepare the data from the scratch. If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). It contains @@ -152,19 +160,19 @@ Notice that once a `--data-bin` is set, the `--config` is the base name of the c Set `--model-path` to the model checkpoint. A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. -The output should be similar as follow: +The result of this model on `tst-COMMON` is: ```bash { "Quality": { - "BLEU": 12.79214535384013 + "BLEU": 13.94974229366959 }, "Latency": { - "AL": 1669.5778120018108, - "AL_CA": 2077.9027656104813, - "AP": 0.7652936521983029, - "AP_CA": 0.8891561507382866, - "DAL": 2028.1566141735727, - "DAL_CA": 2497.336430059716 + "AL": 1751.8031870037803, + "AL_CA": 2338.5911762796536, + "AP": 0.7931395378788959, + "AP_CA": 0.9405103863210942, + "DAL": 1987.7811616943081, + "DAL_CA": 2425.2751560926167 } } ``` diff --git a/examples/speech_to_text/seg_mustc_data.py b/examples/speech_to_text/seg_mustc_data.py new file mode 100644 index 0000000000..1ee665d639 --- /dev/null +++ b/examples/speech_to_text/seg_mustc_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +import soundfile as sf +from examples.speech_to_text.prep_mustc_data import ( + MUSTC +) + +from tqdm import tqdm + +log = logging.getLogger(__name__) + + +def main(args): + root = Path(args.data_root).absolute() + lang = args.lang + split = args.split + + cur_root = root / f"en-{lang}" + assert cur_root.is_dir(), ( + f"{cur_root.as_posix()} does not exist. Skipped." + ) + + dataset = MUSTC(root.as_posix(), lang, split) + output = Path(args.output).absolute() + output.mkdir(exist_ok=True) + f_text = open(output / f"{split}.{lang}", "w") + f_wav_list = open(output / f"{split}.wav_list", "w") + for waveform, sample_rate, _, text, _, utt_id in tqdm(dataset): + sf.write( + output / f"{utt_id}.wav", + waveform.squeeze(0).numpy(), + samplerate=int(sample_rate) + ) + f_text.write(text + "\n") + f_wav_list.write(str(output / f"{utt_id}.wav") + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument("--task", required=True, type=str, choices=["asr", "st"]) + parser.add_argument("--lang", required=True, type=str) + parser.add_argument("--output", required=True, type=str) + parser.add_argument("--split", required=True, choices=MUSTC.SPLITS) + args = parser.parse_args() + + main(args) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 58e38963b5..61617a1739 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -334,7 +334,7 @@ def policy(self, states): torch.cuda.empty_cache() - if outputs["action"] == 0: + if outputs.action == 0: return READ_ACTION else: return WRITE_ACTION From aa5f0119a383e013e56ae5d88e4a7aff0e67f0f9 Mon Sep 17 00:00:00 2001 From: Jongsoo Park <jongsoo@fb.com> Date: Fri, 2 Apr 2021 19:51:12 -0700 Subject: [PATCH 543/707] remove import logging (#1779) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1779 Pull Request resolved: https://github.com/pytorch/fairseq/pull/3436 Remove residue of changes not meant to be landed in D26873232 (https://github.com/pytorch/fairseq/commit/edcef1306b48e7fa9bf84dcbec25171a1e57a5dc) Reviewed By: myleott Differential Revision: D27543742 fbshipit-source-id: ebe47baba27ec2446ef0b855d700b68373f738d8 --- fairseq/optim/cpu_adam.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index 21336ef59b..e36bccf123 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -12,7 +12,6 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from omegaconf import II, DictConfig -import logging try: From 8a42c243a7835f1f0da06efd323ddf6201eb1272 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 7 Apr 2021 16:03:55 -0700 Subject: [PATCH 544/707] Fix memory regression when using FSDP (#1788) Summary: FSDP overloads nn.Module.apply and inserts a step where it all-gathers params before calling apply (added in https://github.com/facebookresearch/fairscale/commit/fa1b85fbbe75f058b39f1bcf027de42e6ddbd487). This is important for the typical use case of nn.Module.apply -- weight initialization. But here we were using apply totally unnecessarily. It's easier to just loop over all the modules and call set_num_updates directly Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1788 Reviewed By: sshleifer Differential Revision: D27622168 Pulled By: myleott fbshipit-source-id: f5462107ad251cf7834b20a0eaccbe2f685da8f8 --- fairseq/models/fairseq_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 171a8a40f1..e55c7ba1ad 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -153,13 +153,10 @@ def do_upgrade(m, prefix): def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" - - def _apply(m): + for m in self.modules(): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) - self.apply(_apply) - def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} From ee0d5a0f65a25e5f5372776402aac5cb9c4adbf1 Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek <guw@fb.com> Date: Wed, 7 Apr 2021 19:57:12 -0700 Subject: [PATCH 545/707] fixup Obt 2 (#1614) (#1791) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: restore len() checking that was lost in a merge # Before submitting - [x] Was this discussed/approved via a Github issue? This is a fix for https://github.com/fairinternal/fairseq-py/issues/1614. The regression was identified in https://github.com/pytorch/fairseq/issues/3364 - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes regression in https://github.com/fairinternal/fairseq-py/pull/1614/files#diff-6e65327f729a8658d627b762ec14902e25927698f35e5495e6b8e3a1bfcfd7afR886-R943 This changes was meant to be a no-op but I forgot the len checking. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1791 Reviewed By: jmp84 Differential Revision: D27629177 Pulled By: gwenzek fbshipit-source-id: fe6fdd486b1a61f86547d1214180d7fd3042e51b --- fairseq/models/transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index eff5ba7b8f..29fdeb70bf 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -885,12 +885,13 @@ def extract_features_scriptable( enc: Optional[Tensor] = None padding_mask: Optional[Tensor] = None - if encoder_out is not None: + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: enc = encoder_out["encoder_out"][0] - padding_mask = encoder_out["encoder_padding_mask"][0] assert ( enc.size()[1] == bs ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] # embed positions positions = None From acf312418e4718996a103d67bd57516938137a7d Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Mon, 12 Apr 2021 09:26:40 -0700 Subject: [PATCH 546/707] 'Fix a bug when src len is smaller than wait k lagging (#1795) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1795 Reviewed By: jmp84 Differential Revision: D27701022 Pulled By: xutaima fbshipit-source-id: 2267402077fbc7c560e260f1ca730ba102d0c7cd --- .../modules/monotonic_multihead_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index b487f14a98..71c818e5cc 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -930,6 +930,8 @@ def p_choose( max_src_len, bsz, _ = key.size() if max_src_len < self.waitk_lagging: + if incremental_state is not None: + tgt_len = 1 return query.new_zeros( bsz * self.num_heads, tgt_len, max_src_len ) From 57560cb52c3a86fadbf315e24bc4ec174056a444 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 14 Apr 2021 01:49:14 -0700 Subject: [PATCH 547/707] enable manifold checkpoints with --keep-interval-updates Summary: Useful to enable --keep-interval-updates with Manifold checkpoints Reviewed By: myleott Differential Revision: D27577116 fbshipit-source-id: 6d4ae5aaccc07ecaed8ba6a333b6ab78b148187a --- fairseq/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 7e1b8479d1..4e02883e57 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -120,6 +120,8 @@ def is_better(a, b): for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order @@ -394,7 +396,7 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): descending order. """ pt_regexp = re.compile(pattern) - files = os.listdir(path) + files = PathManager.ls(path) entries = [] for i, f in enumerate(files): From dbfb7103bdc41cdaa9bfea4cf5dedd10c9750647 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 14 Apr 2021 01:49:14 -0700 Subject: [PATCH 548/707] add keep_interval_updates_pattern Summary: Motivation: I want to save checkpoints frequently, due to unreliable jobs in FB cluster that restart frequently. I want to do this without spamming Manifold storage, but still save some historical checkpoints (i.e. every 10k updates), so I can track how WER evolves over time. To save frequently, I can use a small --save-interval-updates. To delete old checkpoints to save storage, I can use --keep-interval-updates. However, this deletes all old checkpoints. This is where --keep-interval-updates-pattern comes in. If I now do: ``` --save-interval-updates 1000 --keep-interval-updates 1 --keep-interval-updates-pattern 10000 ``` This will: 1. checkpoint every 1000 updates so that job restarts don't impact us significantly 2. keep only the latest checkpoint to avoid saving a bunch of huge models in manifold 3. make an exception for #2 for every 10k updates so we can track WER over time Reviewed By: myleott Differential Revision: D27578403 fbshipit-source-id: 5aec2dc9a22778015f7a3daa017210190af81240 --- fairseq/checkpoint_utils.py | 20 +++++++++++++++----- fairseq/dataclass/configs.py | 8 ++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 4e02883e57..ac6c7339d4 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -114,9 +114,16 @@ def is_better(a, b): if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) - ) + if cfg.keep_interval_updates_pattern == -1: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + else: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), keep_match=True + ) + checkpoints = [x[0] for x in checkpoints if x[1] % cfg.keep_interval_updates_pattern != 0] + for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -388,7 +395,7 @@ def load_model_ensemble_and_task( return ensemble, cfg, task -def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If @@ -404,7 +411,10 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): if m is not None: idx = float(m.group(1)) if len(m.groups()) > 0 else i entries.append((idx, m.group(0))) - return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + if keep_match: + return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] + else: + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] def torch_persistent_save(obj, filename, async_write: bool = False): diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index be9f7c5af3..f1ca26514e 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -568,6 +568,14 @@ class CheckpointConfig(FairseqDataclass): "help": "keep the last N checkpoints saved with --save-interval-updates" }, ) + keep_interval_updates_pattern: int = field( + default=-1, + metadata={ + "help": "when used with --keep-interval-updates, skips deleting " + "any checkpoints with update X where " + "X % keep_interval_updates_pattern == 0" + }, + ) keep_last_epochs: int = field( default=-1, metadata={"help": "keep last N epoch checkpoints"} ) From 436166a00c2ecd1215df258f022608947cca2aa8 Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek <guw@fb.com> Date: Wed, 14 Apr 2021 04:59:04 -0700 Subject: [PATCH 549/707] fix MultiHeadAttention assert (#1798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/fairinternal/fairseq-py/issues/1538. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1798 Reviewed By: myleott Differential Revision: D27710902 Pulled By: gwenzek fbshipit-source-id: 2efdf645bb30e4cf6653c48371bfca8df6f94eaf --- fairseq/modules/multihead_attention.py | 7 ++- tests/test_transformer.py | 65 ++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 tests/test_transformer.py diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index d84c7e078d..b168a890ae 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -151,12 +151,11 @@ def forward( assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: - src_len, key_bsz, key_embed_dim = key.size() + src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): - assert (key_bsz, key_embed_dim) == (bsz, embed_dim) + assert key_bsz == bsz assert value is not None - assert (src_len, bsz, embed_dim) == value.shape - + assert src_len, bsz == value.shape[:2] if ( not self.onnx_trace diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..de5c5bdbd4 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,65 @@ +import argparse +import unittest +from typing import Any, Dict, Sequence + +import torch +from fairseq.models import transformer + +from tests.test_roberta import FakeTask + + +def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]: + if not tok: + tok = [10, 11, 12, 13, 14, 15, 2] + + batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size) + sample = { + "net_input": { + "src_tokens": batch, + "prev_output_tokens": batch, + "src_lengths": torch.tensor( + [len(tok)] * batch_size, dtype=torch.long, device=batch.device + ), + }, + "target": batch[:, 1:], + } + return sample + + +def mk_transformer(**extra_args: Any): + overrides = { + # Use characteristics dimensions + "encoder_embed_dim": 12, + "encoder_ffn_embed_dim": 14, + "decoder_embed_dim": 12, + "decoder_ffn_embed_dim": 14, + # Disable dropout so we have comparable tests. + "dropout": 0, + "attention_dropout": 0, + "activation_dropout": 0, + "encoder_layerdrop": 0, + } + overrides.update(extra_args) + # Overrides the defaults from the parser + args = argparse.Namespace(**overrides) + transformer.tiny_architecture(args) + + torch.manual_seed(0) + task = FakeTask(args) + return transformer.TransformerModel.build_model(args, task) + + +class TransformerTestCase(unittest.TestCase): + def test_forward_backward(self): + model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12) + sample = mk_sample() + o, _ = model.forward(**sample["net_input"]) + loss = o.sum() + loss.backward() + + def test_different_encoder_decoder_embed_dim(self): + model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16) + sample = mk_sample() + o, _ = model.forward(**sample["net_input"]) + loss = o.sum() + loss.backward() From fc90910314ce02fe168bdbe127e08a14e944a6fd Mon Sep 17 00:00:00 2001 From: Sujit Verma <sujitv@fb.com> Date: Wed, 14 Apr 2021 21:53:25 -0700 Subject: [PATCH 550/707] Migrating fairseq-py from fvcore to iopath. Summary: Migrating fairseq-py from fvcore to iopath. Reviewed By: myleott Differential Revision: D27109864 fbshipit-source-id: 041177c1bc9b5793b2ce0ecab87692097f3f353b --- fairseq/file_io.py | 68 +++++++++++++++++++++---------------------- tests/test_file_io.py | 10 +++---- 2 files changed, 38 insertions(+), 40 deletions(-) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 9a78ab505d..93c931093c 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -15,13 +15,13 @@ try: - from fvcore.common.file_io import PathManager as FVCorePathManager + from iopath.common.file_io import g_pathmgr as IOPathManager try: # [FB only - for now] AWS PathHandler for PathManager from .fb_pathhandlers import S3PathHandler - FVCorePathManager.register_handler(S3PathHandler()) + IOPathManager.register_handler(S3PathHandler()) except KeyError: logging.warning("S3PathHandler already registered.") except ImportError: @@ -30,15 +30,15 @@ ) except ImportError: - FVCorePathManager = None + IOPathManager = None -IOPathPathManager = None +IOPathManager = None class PathManager: """ Wrapper for insulating OSS I/O (using Python builtin operations) from - fvcore's PathManager abstraction (for transparently handling various + iopath's PathManager abstraction (for transparently handling various internal backends). """ @@ -51,8 +51,8 @@ def open( errors: Optional[str] = None, newline: Optional[str] = None, ): - if FVCorePathManager: - return FVCorePathManager.open( + if IOPathManager: + return IOPathManager.open( path=path, mode=mode, buffering=buffering, @@ -71,46 +71,46 @@ def open( @staticmethod def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: - if FVCorePathManager: - return FVCorePathManager.copy( + if IOPathManager: + return IOPathManager.copy( src_path=src_path, dst_path=dst_path, overwrite=overwrite ) return shutil.copyfile(src_path, dst_path) @staticmethod def get_local_path(path: str, **kwargs) -> str: - if FVCorePathManager: - return FVCorePathManager.get_local_path(path, **kwargs) + if IOPathManager: + return IOPathManager.get_local_path(path, **kwargs) return path @staticmethod def exists(path: str) -> bool: - if FVCorePathManager: - return FVCorePathManager.exists(path) + if IOPathManager: + return IOPathManager.exists(path) return os.path.exists(path) @staticmethod def isfile(path: str) -> bool: - if FVCorePathManager: - return FVCorePathManager.isfile(path) + if IOPathManager: + return IOPathManager.isfile(path) return os.path.isfile(path) @staticmethod def ls(path: str) -> List[str]: - if FVCorePathManager: - return FVCorePathManager.ls(path) + if IOPathManager: + return IOPathManager.ls(path) return os.listdir(path) @staticmethod def mkdirs(path: str) -> None: - if FVCorePathManager: - return FVCorePathManager.mkdirs(path) + if IOPathManager: + return IOPathManager.mkdirs(path) os.makedirs(path, exist_ok=True) @staticmethod def rm(path: str) -> None: - if FVCorePathManager: - return FVCorePathManager.rm(path) + if IOPathManager: + return IOPathManager.rm(path) os.remove(path) @staticmethod @@ -120,15 +120,15 @@ def chmod(path: str, mode: int) -> None: @staticmethod def register_handler(handler) -> None: - if FVCorePathManager: - return FVCorePathManager.register_handler(handler=handler) + if IOPathManager: + return IOPathManager.register_handler(handler=handler) @staticmethod def copy_from_local( local_path: str, dst_path: str, overwrite: bool = False, **kwargs ) -> None: - if FVCorePathManager: - return FVCorePathManager.copy_from_local( + if IOPathManager: + return IOPathManager.copy_from_local( local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs ) return shutil.copyfile(local_path, dst_path) @@ -136,8 +136,8 @@ def copy_from_local( @staticmethod def path_requires_pathmanager(path: str) -> bool: """Do we require PathManager to access given path?""" - if FVCorePathManager: - for p in FVCorePathManager._path_handlers.keys(): + if IOPathManager: + for p in IOPathManager._path_handlers.keys(): if path.startswith(p): return True return False @@ -166,15 +166,15 @@ def opena( """ Return file descriptor with asynchronous write operations. """ - global IOPathPathManager - if not IOPathPathManager: + global IOPathManager + if not IOPathManager: logging.info("ioPath is initializing PathManager.") try: from iopath.common.file_io import PathManager - IOPathPathManager = PathManager() + IOPathManager = PathManager() except Exception: logging.exception("Failed to initialize ioPath PathManager object.") - return IOPathPathManager.opena( + return IOPathManager.opena( path=path, mode=mode, buffering=buffering, @@ -190,7 +190,7 @@ def async_close() -> bool: NOTE: `PathManager.async_close()` must be called at the end of any script that uses `PathManager.opena(...)`. """ - global IOPathPathManager - if IOPathPathManager: - return IOPathPathManager.async_close() + global IOPathManager + if IOPathManager: + return IOPathManager.async_close() return False diff --git a/tests/test_file_io.py b/tests/test_file_io.py index 8ebbba4a2e..f39e2e1d58 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -38,8 +38,8 @@ def test_file_io(self): self.assertEqual(s, self._tmpfile_contents) def test_file_io_oss(self): - # Mock fvcore to simulate oss environment. - sys.modules["fvcore"] = MagicMock() + # Mock iopath to simulate oss environment. + sys.modules["iopath"] = MagicMock() from fairseq.file_io import PathManager with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: @@ -49,14 +49,12 @@ def test_file_io_oss(self): def test_file_io_async(self): # ioPath `PathManager` is initialized after the first `opena` call. try: - from fairseq.file_io import IOPathPathManager, PathManager + from fairseq.file_io import IOPathManager, PathManager - self.assertIsNone(IOPathPathManager) + self.assertIsNone(IOPathManager) _asyncfile = os.path.join(self._tmpdir, "async.txt") f = PathManager.opena(_asyncfile, "wb") f.close() - from fairseq.file_io import IOPathPathManager - self.assertIsNotNone(IOPathPathManager) finally: self.assertTrue(PathManager.async_close()) From 069763813097aa814c8c4e12d4cab4b321575b8d Mon Sep 17 00:00:00 2001 From: Sujit Verma <sujitv@fb.com> Date: Thu, 15 Apr 2021 16:59:08 -0700 Subject: [PATCH 551/707] Fix bug from iopath migration. Summary: Fix bug from iopath migration. Reviewed By: myleott, shuliuncsu Differential Revision: D27808039 fbshipit-source-id: 14b6c9ca3a461b00c2528d55b7269980324b3e10 --- fairseq/file_io.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 93c931093c..dba663d4aa 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -32,8 +32,6 @@ except ImportError: IOPathManager = None -IOPathManager = None - class PathManager: """ From 3a90a859d4dfdbf13f15399be12a1928aa2c54ff Mon Sep 17 00:00:00 2001 From: Shu Liu <shuliu@fb.com> Date: Fri, 16 Apr 2021 15:11:19 -0700 Subject: [PATCH 552/707] Add manifold support for fairseq file_io Summary: Recently fairseq file_io migrated to iopath in D27109864 (https://github.com/pytorch/fairseq/commit/fc90910314ce02fe168bdbe127e08a14e944a6fd). New IOPathManager doesn't have manifold support by default. Add manifold handler to fix the issue. Reviewed By: myleott Differential Revision: D27809504 fbshipit-source-id: 5cbf4440ed734132f865096c45cd3e47ccb6142d --- fairseq/file_io.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index dba663d4aa..6266e6a1d8 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -29,6 +29,14 @@ "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." ) + try: + # [FB only] Add extra FB only PathHandlers for PathManager + import fairseq.fb_file_io as fb_file_io + + fb_file_io.update_path_manager(IOPathManager) + except ImportError: + pass + except ImportError: IOPathManager = None From 89371294e54ef8c306f19733f2e8bab8233c401e Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Sat, 17 Apr 2021 02:02:52 -0700 Subject: [PATCH 553/707] TALNet: Consistency loss Summary: During training, for each batch, apply two different data augmentations, and require the model output to be similar. The dissimilarity between the output is used as an extra loss term. To apply two different augmentations to the same batch, we need to create batches containing two copies of the same indices, like [a, b, c, a, b, c]. I make this the responsibility of the **batch sampler**. TALNet uses a `DataBalancingEpochBatchIterator` to generate batches. This iterator creates a `_batch_sampler` within itself; this diff adds a `dups` argument to generate duplicate indices. If you're using the generic `EpochBatchIterator`, which accepts a batch sampler directly as input, you will need to modify the code of your own batch sampler to generate duplicate indices. This is actually quite easy: if `batch_sampler(n)` creates a generator that returns n indices at a time, then `(batch * 2 for batch in batch_sampler(n // 2))` creates a generator that returns two copies of n/2 indices at a time. The consistency loss is implemented by the `consistency_loss` function. To use it, call it in your model's `forward` method and put the result in the return dict. Then expose the loss in the `get_losses` or `get_extra_losses` method of your model, so criterion objects can read them. Two types of consistency losses are provided: 1. KL divergence; 2. Squared difference of the probabilities. I've observed significant gains in TALNet's MAP with both types; KL divergence is slightly better. Reviewed By: nayansinghal Differential Revision: D27179650 fbshipit-source-id: 76899335b029fab67bbd7941429c9bd1baf52d65 --- fairseq/criterions/wav2vec_criterion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index f682508cb1..521d0cf1ad 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -13,6 +13,7 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.logging.meters import safe_round +from fairseq.utils import is_xla_tensor @dataclass @@ -31,7 +32,6 @@ class Wav2VecCriterionConfig(FairseqDataclass): default_factory=lambda: [], metadata={"help": "output keys to log"}, ) -from fairseq.utils import index_put, is_xla_tensor @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): From 3dd3940dc6105c8bf50f3a1be810cc035d2f105d Mon Sep 17 00:00:00 2001 From: Ning Dong <dnn@fb.com> Date: Mon, 19 Apr 2021 03:50:13 -0700 Subject: [PATCH 554/707] Support more p_choose strategies in FixedStrideMonotonicAttention Summary: Earlier polymorphism worked in the class by inheritance (calling into super().p_choose) however super() is not TS friendly. This diff moves p_choose() methods to a separate util file and uses a variable to toggle between different strategies. Reviewed By: jmp84, sravyapopuri388 Differential Revision: D27827168 fbshipit-source-id: 3e85028796e3af0c02f1d77b93a2a3896825b9b1 --- .../modules/fixed_pre_decision.py | 33 ++++- .../modules/monotonic_multihead_attention.py | 109 +-------------- .../utils/p_choose_strategy.py | 124 ++++++++++++++++++ 3 files changed, 154 insertions(+), 112 deletions(-) create mode 100644 examples/simultaneous_translation/utils/p_choose_strategy.py diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index 0e9dfb6dfd..dd29c031b3 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -12,11 +12,16 @@ MonotonicMultiheadAttentionInfiniteLookback, ) from typing import Dict, Optional +from examples.simultaneous_translation.utils import p_choose_strategy def fixed_pooling_monotonic_attention(monotonic_attention): def create_model(monotonic_attention, klass): class FixedStrideMonotonicAttention(monotonic_attention): def __init__(self, args): + self.waitk_lagging = 0 + self.num_heads = 0 + self.noise_mean = 0.0 + self.noise_var = 0.0 super().__init__(args) self.pre_decision_type = args.fixed_pre_decision_type self.pre_decision_ratio = args.fixed_pre_decision_ratio @@ -24,6 +29,8 @@ def __init__(self, args): if self.pre_decision_ratio == 1: return + self.strategy = args.simul_type + if args.fixed_pre_decision_type == "average": self.pooling_layer = torch.nn.AvgPool1d( kernel_size=self.pre_decision_ratio, @@ -143,12 +150,26 @@ def p_choose( batch_size = query.size(1) if self.pre_decision_ratio == 1: - return self.p_choose_waitk( - query, - key, - key_padding_mask, - incremental_state=incremental_state, - ) + if self.strategy == "waitk": + return p_choose_strategy.waitk( + query, + key, + self.waitk_lagging, + self.num_heads, + key_padding_mask, + incremental_state=incremental_state, + ) + else: # hard_aligned or infinite_lookback + q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") + attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) + return p_choose_strategy.hard_aligned( + q_proj, + k_proj, + attn_energy, + self.noise_mean, + self.noise_var, + self.training + ) key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 71c818e5cc..f49b1daa2f 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -19,6 +19,7 @@ from . import register_monotonic_attention from typing import Dict, Optional +from examples.simultaneous_translation.utils import p_choose_strategy @with_incremental_state class MonotonicAttention(nn.Module): @@ -767,21 +768,7 @@ def p_choose( # attention energy attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) - noise = 0 - - if self.training: - # add noise here to encourage discretness - noise = ( - torch.normal(self.noise_mean, self.noise_var, attn_energy.size()) - .type_as(attn_energy) - .to(attn_energy.device) - ) - - p_choose = torch.sigmoid(attn_energy + noise) - _, _, tgt_len, src_len = p_choose.size() - - # p_choose: bsz * self.num_heads, tgt_len, src_len - return p_choose.view(-1, tgt_len, src_len) + return p_choose_strategy.hard_aligned(q_proj, k_proj, attn_energy, self.noise_mean, self.noise_var, self.training) def expected_attention(self, alpha, *args): """ @@ -920,94 +907,4 @@ def p_choose( key: bsz, src_len key_padding_mask: bsz, src_len """ - if incremental_state is not None: - # Retrieve target length from incremental states - # For inference the length of query is always 1 - tgt_len = int(incremental_state["steps"]["tgt"]) - else: - tgt_len, bsz, _ = query.size() - - max_src_len, bsz, _ = key.size() - - if max_src_len < self.waitk_lagging: - if incremental_state is not None: - tgt_len = 1 - return query.new_zeros( - bsz * self.num_heads, tgt_len, max_src_len - ) - - # Assuming the p_choose looks like this for wait k=3 - # src_len = 6, tgt_len = 5 - # [0, 0, 1, 0, 0, 0, 0] - # [0, 0, 0, 1, 0, 0, 0] - # [0, 0, 0, 0, 1, 0, 0] - # [0, 0, 0, 0, 0, 1, 0] - # [0, 0, 0, 0, 0, 0, 1] - # linearize the p_choose matrix: - # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...] - # The indices of linearized matrix that equals 1 is - # 2 + 6 * 0 - # 3 + 6 * 1 - # ... - # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 - # n from 0 to tgt_len - 1 - # - # First, generate the indices (activate_indices_offset: bsz, tgt_len) - # Second, scatter a zeros tensor (bsz, tgt_len * src_len) - # with activate_indices_offset - # Third, resize the tensor to (bsz, tgt_len, src_len) - - activate_indices_offset = ( - ( - torch.arange(tgt_len) * (max_src_len + 1) - + self.waitk_lagging - 1 - ) - .unsqueeze(0) - .expand(bsz, tgt_len) - .to(query) - .long() - ) - - if key_padding_mask is not None: - if key_padding_mask[:, 0].any(): - # Left padding - activate_indices_offset += ( - key_padding_mask.sum(dim=1, keepdim=True) - ) - - # Need to clamp the indices that are too large - activate_indices_offset = ( - activate_indices_offset - .clamp( - 0, - min( - [ - tgt_len, - max_src_len - self.waitk_lagging + 1 - ] - ) * max_src_len - 1 - ) - ) - - p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) - - p_choose = p_choose.scatter( - 1, - activate_indices_offset, - 1.0 - ).view(bsz, tgt_len, max_src_len) - - if incremental_state is not None: - p_choose = p_choose[:, -1:] - tgt_len = 1 - - # Extend to each head - p_choose = ( - p_choose.contiguous() - .unsqueeze(1) - .expand(-1, self.num_heads, -1, -1) - .contiguous() - .view(-1, tgt_len, max_src_len) - ) - - return p_choose + return p_choose_strategy.waitk(query, key, self.waitk_lagging, self.num_heads, key_padding_mask, incremental_state) diff --git a/examples/simultaneous_translation/utils/p_choose_strategy.py b/examples/simultaneous_translation/utils/p_choose_strategy.py new file mode 100644 index 0000000000..308227ed96 --- /dev/null +++ b/examples/simultaneous_translation/utils/p_choose_strategy.py @@ -0,0 +1,124 @@ +from typing import Optional, Dict +from torch import Tensor +import torch + + +def waitk( + query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None +): + if incremental_state is not None: + # Retrieve target length from incremental states + # For inference the length of query is always 1 + tgt_len = incremental_state["steps"]["tgt"] + assert tgt_len is not None + tgt_len = int(tgt_len) + else: + tgt_len, bsz, _ = query.size() + + max_src_len, bsz, _ = key.size() + + if max_src_len < waitk_lagging: + if incremental_state is not None: + tgt_len = 1 + return query.new_zeros( + bsz * num_heads, tgt_len, max_src_len + ) + + # Assuming the p_choose looks like this for wait k=3 + # src_len = 6, tgt_len = 5 + # [0, 0, 1, 0, 0, 0, 0] + # [0, 0, 0, 1, 0, 0, 0] + # [0, 0, 0, 0, 1, 0, 0] + # [0, 0, 0, 0, 0, 1, 0] + # [0, 0, 0, 0, 0, 0, 1] + # linearize the p_choose matrix: + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...] + # The indices of linearized matrix that equals 1 is + # 2 + 6 * 0 + # 3 + 6 * 1 + # ... + # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 + # n from 0 to tgt_len - 1 + # + # First, generate the indices (activate_indices_offset: bsz, tgt_len) + # Second, scatter a zeros tensor (bsz, tgt_len * src_len) + # with activate_indices_offset + # Third, resize the tensor to (bsz, tgt_len, src_len) + + activate_indices_offset = ( + ( + torch.arange(tgt_len) * (max_src_len + 1) + + waitk_lagging - 1 + ) + .unsqueeze(0) + .expand(bsz, tgt_len) + .to(query) + .long() + ) + + if key_padding_mask is not None: + if key_padding_mask[:, 0].any(): + # Left padding + activate_indices_offset += ( + key_padding_mask.sum(dim=1, keepdim=True) + ) + + # Need to clamp the indices that are too large + activate_indices_offset = ( + activate_indices_offset + .clamp( + 0, + min( + [ + tgt_len, + max_src_len - waitk_lagging + 1 + ] + ) * max_src_len - 1 + ) + ) + + p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) + + p_choose = p_choose.scatter( + 1, + activate_indices_offset, + 1.0 + ).view(bsz, tgt_len, max_src_len) + + if incremental_state is not None: + p_choose = p_choose[:, -1:] + tgt_len = 1 + + # Extend to each head + p_choose = ( + p_choose.contiguous() + .unsqueeze(1) + .expand(-1, num_heads, -1, -1) + .contiguous() + .view(-1, tgt_len, max_src_len) + ) + + return p_choose + + +def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True): + """ + Calculating step wise prob for reading and writing + 1 to read, 0 to write + """ + + noise = 0 + if training: + # add noise here to encourage discretness + noise = ( + torch.normal(noise_mean, noise_var, attn_energy.size()) + .type_as(attn_energy) + .to(attn_energy.device) + ) + + p_choose = torch.sigmoid(attn_energy + noise) + _, _, tgt_len, src_len = p_choose.size() + + # p_choose: bsz * self.num_heads, tgt_len, src_len + return p_choose.view(-1, tgt_len, src_len) From 4fc9f2be952d8f0cd476832911dc71f59b4079da Mon Sep 17 00:00:00 2001 From: Robin Jia <robinjia@fb.com> Date: Mon, 19 Apr 2021 15:51:59 -0700 Subject: [PATCH 555/707] Fix order issue in batched BART generation (#1785) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1785 Reviewed By: robinjia Differential Revision: D27831673 Pulled By: sshleifer fbshipit-source-id: 1acf142151853d24138889c956df4531dae794a2 --- fairseq/models/bart/hub_interface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 2ddeb763a3..9afe385b9d 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -111,7 +111,9 @@ def generate( skip_invalid_size_inputs=skip_invalid_size_inputs, **kwargs ) - res.extend(results) + for id, hypos in zip(batch['id'].tolist(), results): + res.append((id, hypos)) + res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])] return res def extract_features( From f6f220e917a0745bad5cc2dffdb35590f5feed8e Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Mon, 19 Apr 2021 16:30:32 -0700 Subject: [PATCH 556/707] Delete line that breaks gh ci (#1814) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1814 Reviewed By: myleott Differential Revision: D27867552 Pulled By: sshleifer fbshipit-source-id: ed30e02c962b31797e003cb810c085934a53202c --- tests/test_file_io.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_file_io.py b/tests/test_file_io.py index f39e2e1d58..425812bf16 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -50,8 +50,6 @@ def test_file_io_async(self): # ioPath `PathManager` is initialized after the first `opena` call. try: from fairseq.file_io import IOPathManager, PathManager - - self.assertIsNone(IOPathManager) _asyncfile = os.path.join(self._tmpdir, "async.txt") f = PathManager.opena(_asyncfile, "wb") f.close() From 801a64683164680562c77b688d9ca77fc3e0cea7 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Tue, 20 Apr 2021 12:17:11 -0700 Subject: [PATCH 557/707] allow subclasses of multi corpus dataset to not specify full_id Summary: This diff allow subclasses of multi corpus dataset to not specify full_id, which is only a useful feature for sampling batches between datasets. This is not possible for certain cases, such as minibatch CE training. Reviewed By: zdavid1995 Differential Revision: D27833104 fbshipit-source-id: e60f6ad200c0ca69915b9588405320ba2ecfbd0a --- fairseq/data/multi_corpus_dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 00e464ed31..acb91f3df6 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -165,9 +165,12 @@ def collater(self, samples): """ if len(samples) == 0: return None - _, key = self._map_index(samples[0]["full_id"]) - - return self.datasets[key].collater(samples) + if "full_id" in samples[0]: + _, key = self._map_index(samples[0]["full_id"]) + return self.datasets[key].collater(samples) + else: + # Subclasses may override __getitem__ to not specify full_id + return list(self.datasets.values())[0].collater(samples) def num_tokens(self, index: int): index, key = self._map_index(index) From da0432a3cd1ddf8f9797af83880570822265620b Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 21 Apr 2021 06:38:03 -0700 Subject: [PATCH 558/707] MultiGPU test and --log-file workaround (#1793) Summary: The initial problem I set out to solve was that it's not easy to add a multigpu test. I solved that problem but it ruined log capturing, both with `self.assertLogs` and `with contextlib.redirect_stdout(StringIO())`. After some brief digging, I gave up on trying to get those to work, and added support for `--log-file AGI_v0.log` which will write the `progress_bar.log()` statements to `log-file` as well as `stdout`. This functionality is used by the resumption test. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1793 Reviewed By: myleott Differential Revision: D27671192 Pulled By: sshleifer fbshipit-source-id: bcba5f9df7a965889a4cd6993f7eeb0f14b770c6 --- fairseq/dataclass/configs.py | 3 ++ fairseq/logging/progress_bar.py | 15 ++++-- fairseq_cli/train.py | 5 ++ tests/gpu/test_binaries_gpu.py | 90 ++++++++++++++++++++++++++++----- tests/test_binaries.py | 68 +------------------------ tests/test_plasma_utils.py | 5 +- tests/utils.py | 81 ++++++++++++++++++++++++++--- 7 files changed, 175 insertions(+), 92 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f1ca26514e..89d83b5b6b 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -95,6 +95,9 @@ class CommonConfig(FairseqDataclass): log_format: Optional[LOG_FORMAT_CHOICES] = field( default=None, metadata={"help": "log format to use"} ) + log_file: Optional[str] = field( + default=None, metadata={"help": "log file to copy metrics to."} + ) tensorboard_logdir: Optional[str] = field( default=None, metadata={ diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 0ae2bc006d..061082caef 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -29,6 +29,7 @@ def progress_bar( iterator, log_format: Optional[str] = None, log_interval: int = 100, + log_file: Optional[str] = None, epoch: Optional[int] = None, prefix: Optional[str] = None, tensorboard_logdir: Optional[str] = None, @@ -39,6 +40,10 @@ def progress_bar( ): if log_format is None: log_format = default_log_format + if log_file is not None: + handler = logging.FileHandler(filename=log_file) + logger.addHandler(handler) + if log_format == "tqdm" and not sys.stderr.isatty(): log_format = "simple" @@ -473,13 +478,13 @@ def _log_to_azureml(self, stats, tag=None, step=None): if Run is None: return if step is None: - step = stats['num_updates'] + step = stats["num_updates"] - prefix = '' if tag is None else tag + '/' + prefix = "" if tag is None else tag + "/" - for key in stats.keys() - {'num_updates'}: + for key in stats.keys() - {"num_updates"}: name = prefix + key if isinstance(stats[key], AverageMeter): - self.run.log_row(name=name, **{'step': step, key: stats[key].val}) + self.run.log_row(name=name, **{"step": step, key: stats[key].val}) elif isinstance(stats[key], Number): - self.run.log_row(name=name, **{'step': step, key: stats[key]}) + self.run.log_row(name=name, **{"step": step, key: stats[key]}) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index f736e67d0d..c1f2fbb4c7 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -59,6 +59,10 @@ def main(cfg: FairseqConfig) -> None: ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() + if cfg.common.log_file is not None: + handler = logging.FileHandler(filename=cfg.common.log_file) + logger.addHandler(handler) + np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) @@ -242,6 +246,7 @@ def train( progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, + log_file=cfg.common.log_file, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 5690e73752..5f879a7a27 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -5,6 +5,7 @@ import contextlib import logging +import json import os import tempfile import unittest @@ -22,6 +23,7 @@ ) +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") class TestTranslationGPU(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -29,16 +31,80 @@ def setUp(self): def tearDown(self): logging.disable(logging.NOTSET) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") - def test_fp16(self): + def test_fp16_multigpu(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_fp16") as data_dir: + log = os.path.join(data_dir, "train.log") create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, "fconv_iwslt_de_en", ["--fp16"]) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + ["--fp16", "--log-file", log], + world_size=min(torch.cuda.device_count(), 2), + ) generate_main(data_dir) + assert os.path.exists(log) + + @staticmethod + def parse_logs(logfile): + logs = [] + for ln in open(logfile, "r").readlines(): + try: + logs.append(json.loads(ln)) + except json.JSONDecodeError: + continue + return logs + + def test_resume_training(self): + flags = [ + "--fp16", + "--log-format", + "json", + "--max-update", + "10", + "--save-interval-updates", + "2", + "--log-interval", + "1", + "--log-file", + ] + world_size = min(torch.cuda.device_count(), 2) + arch = "fconv_iwslt_de_en" + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fp16") as data_dir: + log = os.path.join(data_dir, "train.log") + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, arch, flags + [log], world_size=world_size, + ) + log2 = os.path.join(data_dir, "resume.log") + restore_file = os.path.join(data_dir, "checkpoint_1_2.pt") + assert os.path.exists( + restore_file + ), f"{restore_file} not written. Choices: {os.listdir(data_dir)}" + train_translation_model( + data_dir, + arch, + flags + [log2, "--restore-file", restore_file], + world_size=world_size, + ) + + l1 = self.parse_logs(log) + l2 = self.parse_logs(log2) + assert int(l2[0]["num_updates"]) == 3, f"{l1}\n\n {l2}" + for k in [ + "train_loss", + "train_num_updates", + "train_ppl", + "train_gnorm", + ]: + from_scratch, resumed = l1[-1][k], l2[-1][k] + assert ( + from_scratch == resumed + ), f"difference at {k} {from_scratch} != {resumed}" - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_memory_efficient_fp16(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_memory_efficient_fp16") as data_dir: @@ -49,7 +115,6 @@ def test_memory_efficient_fp16(self): ) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_transformer_fp16(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_transformer") as data_dir: @@ -73,7 +138,6 @@ def test_transformer_fp16(self): ) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_levenshtein_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( @@ -108,10 +172,12 @@ def test_levenshtein_transformer(self): generate_main( data_dir, gen_config, - path=os.pathsep.join([ - os.path.join(data_dir, "checkpoint_last.pt"), - os.path.join(data_dir, "checkpoint_last.pt"), - ]), + path=os.pathsep.join( + [ + os.path.join(data_dir, "checkpoint_last.pt"), + os.path.join(data_dir, "checkpoint_last.pt"), + ] + ), ) @@ -237,6 +303,7 @@ def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=Fa train.main(quantize_args) +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") class TestQuantization(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -244,7 +311,6 @@ def setUp(self): def tearDown(self): logging.disable(logging.NOTSET) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_quantization(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_quantization") as data_dir: @@ -254,6 +320,7 @@ def test_quantization(self): _quantize_language_model(data_dir, "transformer_lm") +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") class TestOptimizersGPU(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -261,7 +328,6 @@ def setUp(self): def tearDown(self): logging.disable(logging.NOTSET) - @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_flat_grads(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_flat_grads") as data_dir: diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 49e6dcd9f8..4e20774262 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -15,7 +15,7 @@ from typing import List, Dict import torch from fairseq import options -from fairseq_cli import eval_lm, train, validate +from fairseq_cli import eval_lm, train from tests.utils import ( create_dummy_data, generate_main, @@ -24,6 +24,7 @@ preprocess_translation_data, create_laser_data_and_config_json, train_translation_model, + train_language_model, ) @@ -1852,71 +1853,6 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): train.main(train_args) -def train_language_model( - data_dir, - arch, - extra_flags=None, - run_validation=False, - extra_valid_flags=None, - task="language_modeling", -): - train_parser = options.get_training_parser() - train_args = options.parse_args_and_arch( - train_parser, - [ - "--task", - task, - data_dir, - "--arch", - arch, - "--optimizer", - "adam", - "--lr", - "0.0001", - "--max-tokens", - "500", - "--tokens-per-sample", - "500", - "--save-dir", - data_dir, - "--max-epoch", - "1", - "--no-progress-bar", - "--distributed-world-size", - "1", - "--ddp-backend", - "no_c10d", - "--num-workers", - "0", - ] - + (extra_flags or []), - ) - train.main(train_args) - - if run_validation: - # test validation - validate_parser = options.get_validation_parser() - validate_args = options.parse_args_and_arch( - validate_parser, - [ - "--task", - task, - data_dir, - "--path", - os.path.join(data_dir, "checkpoint_last.pt"), - "--valid-subset", - "valid", - "--max-tokens", - "500", - "--no-progress-bar", - "--num-workers", - "0", - ] - + (extra_valid_flags or []), - ) - validate.main(validate_args) - - def eval_lm_main(data_dir, extra_flags=None): eval_lm_parser = options.get_eval_lm_parser() eval_lm_args = options.parse_args_and_arch( diff --git a/tests/test_plasma_utils.py b/tests/test_plasma_utils.py index 5737530e3d..a5cf386b86 100644 --- a/tests/test_plasma_utils.py +++ b/tests/test_plasma_utils.py @@ -5,8 +5,7 @@ import numpy as np -from tests.test_binaries import train_language_model -from tests.utils import create_dummy_data, preprocess_lm_data +from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model try: from pyarrow import plasma @@ -16,7 +15,7 @@ except ImportError: PYARROW_AVAILABLE = False -dummy_path = 'dummy' +dummy_path = "dummy" @unittest.skipUnless(PYARROW_AVAILABLE, "") diff --git a/tests/utils.py b/tests/utils.py index 1bf6f8d7f3..6e0c709517 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,6 +23,8 @@ from fairseq.models.fairseq_encoder import EncoderOut from fairseq.tasks import LegacyFairseqTask from fairseq_cli import generate, interactive, preprocess, train, validate +import fairseq.distributed.utils as distributed_utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf def dummy_dictionary(vocab_size, prefix="token_"): @@ -35,10 +37,7 @@ def dummy_dictionary(vocab_size, prefix="token_"): def dummy_dataloader( - samples, - padding_idx=1, - eos_idx=2, - batch_size=None, + samples, padding_idx=1, eos_idx=2, batch_size=None, ): if batch_size is None: batch_size = len(samples) @@ -320,6 +319,7 @@ def train_translation_model( run_validation=False, lang_flags=None, extra_valid_flags=None, + world_size=1, ): if lang_flags is None: lang_flags = [ @@ -349,14 +349,16 @@ def train_translation_model( "1", "--no-progress-bar", "--distributed-world-size", - "1", + str(world_size), "--num-workers", "0", ] + lang_flags + (extra_flags or []), ) - train.main(train_args) + + cfg = convert_namespace_to_omegaconf(train_args) + distributed_utils.call_main(cfg, train.main) if run_validation: # test validation @@ -646,3 +648,70 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): prev_output_tokens, encoder_out=encoder_out, **kwargs ) return decoder_out + + +def train_language_model( + data_dir, + arch, + extra_flags=None, + run_validation=False, + extra_valid_flags=None, + task="language_modeling", + world_size=1, +): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + task, + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + str(world_size), + "--ddp-backend", + "no_c10d", + "--num-workers", + "0", + ] + + (extra_flags or []), + ) + cfg = convert_namespace_to_omegaconf(train_args) + distributed_utils.call_main(cfg, train.main) + + if run_validation: + # test validation + validate_parser = options.get_validation_parser() + validate_args = options.parse_args_and_arch( + validate_parser, + [ + "--task", + task, + data_dir, + "--path", + os.path.join(data_dir, "checkpoint_last.pt"), + "--valid-subset", + "valid", + "--max-tokens", + "500", + "--no-progress-bar", + "--num-workers", + "0", + ] + + (extra_valid_flags or []), + ) + validate.main(validate_args) From 2af4ffe77b230a0a228af9be09e1c0e9d4731906 Mon Sep 17 00:00:00 2001 From: Sujit Verma <sujitv@fb.com> Date: Wed, 21 Apr 2021 09:04:56 -0700 Subject: [PATCH 559/707] Supporting PathManager in cached_path util. Summary: Supporting PathManager in cached_path util. Reviewed By: myleott Differential Revision: D27895813 fbshipit-source-id: 68d7345ebb50737c53a72dcd467536f86b61e1dd --- fairseq/file_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py index ec6de37f77..d1d5ea6574 100644 --- a/fairseq/file_utils.py +++ b/fairseq/file_utils.py @@ -139,6 +139,19 @@ def filename_to_url(filename, cache_dir=None): return url, etag +def cached_path_from_pm(url_or_filename): + """ + Tries to cache the specified URL using PathManager class. + Returns the cached path if success otherwise failure. + """ + try: + from fairseq.file_io import PathManager + local_path = PathManager.get_local_path(url_or_filename) + return local_path + except Exception: + return None + + def cached_path(url_or_filename, cache_dir=None): """ Given something that might be a URL (or might be a local path), @@ -165,6 +178,9 @@ def cached_path(url_or_filename, cache_dir=None): # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) else: + cached_path = cached_path_from_pm(url_or_filename) + if cached_path: + return cached_path # Something unknown raise ValueError( "unable to parse {} as a URL or as a local path".format(url_or_filename) From 207254bf56374831d08c20064ca7a2740a871ce5 Mon Sep 17 00:00:00 2001 From: Michael Anderson <anderso2@fb.com> Date: Wed, 21 Apr 2021 09:08:06 -0700 Subject: [PATCH 560/707] Adding check for filler size (#3495) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3495 Avoid creating size-0 tensor "filler" in case src_len is the same as key_padding_mask_size or prev_key_padding_mask_size Reviewed By: jackm321 Differential Revision: D27897778 fbshipit-source-id: 26fd95852da2cd932717c7abcac3e1fb43deaf77 --- fairseq/modules/multihead_attention.py | 34 +++++++++++++++----------- tests/test_multihead_attention.py | 12 +++++++++ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index b168a890ae..9bdca0f6af 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -403,21 +403,27 @@ def _append_prev_key_padding_mask( # leaves the frame, there will be a time when prev or current # is None elif prev_key_padding_mask is not None: - filler = torch.zeros( - (batch_size, src_len - prev_key_padding_mask.size(1)), - device=prev_key_padding_mask.device, - ) - new_key_padding_mask = torch.cat( - [prev_key_padding_mask.float(), filler.float()], dim=1 - ) + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: - filler = torch.zeros( - (batch_size, src_len - key_padding_mask.size(1)), - device=key_padding_mask.device, - ) - new_key_padding_mask = torch.cat( - [filler.float(), key_padding_mask.float()], dim=1 - ) + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index 9aa9cb2f87..620a2d6791 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -35,6 +35,18 @@ def test_append_prev_key_padding_mask(self): torch.tensor([[0, 1, 0]]).bool(), torch.tensor([[0, 1, 0, 1]]).bool(), ), + # prev_key_padding_mask already full + ( + torch.tensor([[0, 1, 0, 1]]).bool(), + None, + torch.tensor([[0, 1, 0, 1]]).bool(), + ), + # key_padding_mask already full + ( + None, + torch.tensor([[0, 1, 0, 1]]).bool(), + torch.tensor([[0, 1, 0, 1]]).bool(), + ), ] for c in cases: key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( From 05b86005bcca0155319fa9b81abfd69f63c06906 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 21 Apr 2021 15:49:03 -0700 Subject: [PATCH 561/707] Fix FSDP optim state loading (#1819) Summary: ### Problem: - if we consolidate optim state dict on rank 0, rank 1+ save `optimizer.state_dict()`. When they try to load, they call get_shard(last_optim_state), which is wrong since the optim state is already shared. They should find the global consolidated optimizer state dict and load that. ### Possible Solutions: - if world size is the same, you could just reuse the local OSD. - [this PR] rank 1+ load optim state from the rank0 file and call get_shard - separate file for optim_state that every rank loads. (like 'shared.pt' on `gshard-azure`). This will save some CPU Ram. ### Note: - I don't think it's possible to pass `--use-sharded-state` from the command line. It should be I think. ### Implementation here + if FSDP saves -1 as state['last_optimizer_key'], it means that, on load, rank 0's optim state must be loaded. + regression test Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1819 Reviewed By: zhengwy888 Differential Revision: D27910281 Pulled By: sshleifer fbshipit-source-id: d34987008f77ce7e0cb28b7224dd2aabed38a70c --- fairseq/trainer.py | 16 +++++++++++----- tests/gpu/test_binaries_gpu.py | 19 ++++++++++--------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 6195afb4a6..5e87b573f1 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -26,7 +26,7 @@ from fairseq.optim import lr_scheduler from omegaconf import OmegaConf - +import re logger = logging.getLogger(__name__) @@ -331,14 +331,17 @@ def _build_optimizer(self): def consolidate_optimizer(self): """For OSS, we need to consolidate the state dict.""" + if self.cfg.checkpoint.no_save_optimizer_state: + return self._gathered_optim_state = None if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() - elif self.cfg.distributed_training.ddp_backend == 'fully_sharded': - self._gathered_optim_state = self.model.gather_full_optim_state_dict(self.optimizer, - recipient_rank=0) - + elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state: + st = self.model.gather_full_optim_state_dict(self.optimizer) # only returns on rank 0 + if st is None: + st = -1 # sentinel so that workers do not save optimizer.state_dict() + self._gathered_optim_state = st def state_dict(self): state_dict = { @@ -423,6 +426,9 @@ def load_checkpoint( filename, load_on_all_ranks=load_on_all_ranks ) last_optim_state = state.get("last_optimizer_state", None) + if last_optim_state == -1: + master_path = re.sub("shard[0-9]+", "shard0", filename) + last_optim_state = torch.load(master_path, map_location='cpu')['last_optimizer_state'] # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 5f879a7a27..45417c7eb7 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -56,7 +56,13 @@ def parse_logs(logfile): continue return logs - def test_resume_training(self): + def test_resume_training_fsdp(self): + self._test_resume_training(["--ddp-backend", "fully_sharded"]) + + def test_resume_training_noc10d(self): + self._test_resume_training([]) + + def _test_resume_training(self, extra_clargs, arch="fconv_iwslt_de_en"): flags = [ "--fp16", "--log-format", @@ -67,27 +73,22 @@ def test_resume_training(self): "2", "--log-interval", "1", - "--log-file", - ] + ] + extra_clargs world_size = min(torch.cuda.device_count(), 2) - arch = "fconv_iwslt_de_en" with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_fp16") as data_dir: log = os.path.join(data_dir, "train.log") create_dummy_data(data_dir) preprocess_translation_data(data_dir) train_translation_model( - data_dir, arch, flags + [log], world_size=world_size, + data_dir, arch, flags + ["--log-file", log], world_size=world_size, ) log2 = os.path.join(data_dir, "resume.log") restore_file = os.path.join(data_dir, "checkpoint_1_2.pt") - assert os.path.exists( - restore_file - ), f"{restore_file} not written. Choices: {os.listdir(data_dir)}" train_translation_model( data_dir, arch, - flags + [log2, "--restore-file", restore_file], + flags + ["--log-file", log2, "--restore-file", restore_file], world_size=world_size, ) From 40f6c758b361adc7b87fa553f6f439daa4ee9501 Mon Sep 17 00:00:00 2001 From: Jeffrey Karres <jkarres@fb.com> Date: Thu, 22 Apr 2021 21:32:09 -0700 Subject: [PATCH 562/707] sorting os.listdir outputs to ensure consistent import ordering Summary: There's a common idiom of doing imports based on the output of `os.listdir(os.path.dirname(__file__))`. In this idiom, imports are performed based on the order of directories in that output. This is dangerous, because the behavior of your program depends on the order of those files, but the order of those files is not guaranteed. Reviewed By: zhengwy888 Differential Revision: D27951383 fbshipit-source-id: 97dc9a0b7d853886e19a9643c33508f146c71617 --- examples/simultaneous_translation/models/__init__.py | 2 +- examples/simultaneous_translation/modules/__init__.py | 2 +- examples/simultaneous_translation/utils/__init__.py | 2 +- examples/speech_recognition/criterions/__init__.py | 2 +- examples/speech_recognition/models/__init__.py | 2 +- examples/speech_recognition/tasks/__init__.py | 2 +- fairseq/criterions/__init__.py | 2 +- fairseq/data/encoders/__init__.py | 2 +- fairseq/model_parallel/criterions/__init__.py | 2 +- fairseq/optim/__init__.py | 2 +- fairseq/optim/lr_scheduler/__init__.py | 2 +- fairseq/scoring/__init__.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/simultaneous_translation/models/__init__.py b/examples/simultaneous_translation/models/__init__.py index 083da43732..257a96593f 100644 --- a/examples/simultaneous_translation/models/__init__.py +++ b/examples/simultaneous_translation/models/__init__.py @@ -7,7 +7,7 @@ import os -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): model_name = file[: file.find(".py")] importlib.import_module( diff --git a/examples/simultaneous_translation/modules/__init__.py b/examples/simultaneous_translation/modules/__init__.py index ad64774de4..c695850c04 100644 --- a/examples/simultaneous_translation/modules/__init__.py +++ b/examples/simultaneous_translation/modules/__init__.py @@ -16,7 +16,7 @@ _, ) = registry.setup_registry("--simul-type") -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): model_name = file[: file.find(".py")] importlib.import_module( diff --git a/examples/simultaneous_translation/utils/__init__.py b/examples/simultaneous_translation/utils/__init__.py index be0ba4d99a..1e9ce844f5 100644 --- a/examples/simultaneous_translation/utils/__init__.py +++ b/examples/simultaneous_translation/utils/__init__.py @@ -8,7 +8,7 @@ # automatically import any Python files in the criterions/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): module = file[: file.find(".py")] importlib.import_module("examples.simultaneous_translation.utils." + module) diff --git a/examples/speech_recognition/criterions/__init__.py b/examples/speech_recognition/criterions/__init__.py index a667b1c918..579abd2ace 100644 --- a/examples/speech_recognition/criterions/__init__.py +++ b/examples/speech_recognition/criterions/__init__.py @@ -9,7 +9,7 @@ except ImportError: files_to_skip.add("ASG_loss.py") -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip: criterion_name = file[: file.find(".py")] importlib.import_module( diff --git a/examples/speech_recognition/models/__init__.py b/examples/speech_recognition/models/__init__.py index 0ad9663f11..54b5a1c312 100644 --- a/examples/speech_recognition/models/__init__.py +++ b/examples/speech_recognition/models/__init__.py @@ -2,7 +2,7 @@ import os -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): model_name = file[: file.find(".py")] importlib.import_module("examples.speech_recognition.models." + model_name) diff --git a/examples/speech_recognition/tasks/__init__.py b/examples/speech_recognition/tasks/__init__.py index ffa5f3bd8c..7ac3b8dc69 100644 --- a/examples/speech_recognition/tasks/__init__.py +++ b/examples/speech_recognition/tasks/__init__.py @@ -2,7 +2,7 @@ import os -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): task_name = file[: file.find(".py")] importlib.import_module("examples.speech_recognition.tasks." + task_name) diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 8cc6c0f043..4dbf46a1cb 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -30,7 +30,7 @@ def build_criterion(cfg: DictConfig, task): # automatically import any Python files in the criterions/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): file_name = file[: file.find(".py")] importlib.import_module("fairseq.criterions." + file_name) diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py index 2e807d8ae7..7cbe00a105 100644 --- a/fairseq/data/encoders/__init__.py +++ b/fairseq/data/encoders/__init__.py @@ -23,7 +23,7 @@ # automatically import any Python files in the encoders/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): module = file[: file.find(".py")] importlib.import_module("fairseq.data.encoders." + module) diff --git a/fairseq/model_parallel/criterions/__init__.py b/fairseq/model_parallel/criterions/__init__.py index 6239b50362..5fae7bd4c2 100644 --- a/fairseq/model_parallel/criterions/__init__.py +++ b/fairseq/model_parallel/criterions/__init__.py @@ -8,7 +8,7 @@ # automatically import any Python files in the criterions/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): module = file[: file.find(".py")] importlib.import_module("fairseq.model_parallel.criterions." + module) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 112c8ad10f..01c08c98d2 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -40,7 +40,7 @@ def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs): # automatically import any Python files in the optim/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): file_name = file[: file.find(".py")] importlib.import_module("fairseq.optim." + file_name) diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index f07d43c7c3..5b3dbc023a 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -30,7 +30,7 @@ def build_lr_scheduler(cfg: DictConfig, optimizer): # automatically import any Python files in the optim/lr_scheduler/ directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): file_name = file[: file.find(".py")] importlib.import_module("fairseq.optim.lr_scheduler." + file_name) diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 2372727883..58f2f563e4 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -49,7 +49,7 @@ def build_scorer(choice, tgt_dict): # automatically import any Python files in the current directory -for file in os.listdir(os.path.dirname(__file__)): +for file in sorted(os.listdir(os.path.dirname(__file__))): if file.endswith(".py") and not file.startswith("_"): module = file[: file.find(".py")] importlib.import_module("fairseq.scoring." + module) From b0ae834d528a4a466202107a22356aed71bb6161 Mon Sep 17 00:00:00 2001 From: ngoyal2707 <ngoyal2707@users.noreply.github.com> Date: Mon, 26 Apr 2021 13:14:41 -0700 Subject: [PATCH 563/707] Flores pretrained model release (#1825) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1825 Reviewed By: huihuifan Differential Revision: D28002947 fbshipit-source-id: 2ba69a72431db5aa7aef890309eefb0fb31c7ad6 --- examples/flores101/README.md | 223 +++++++++++++++++++++++++++++ examples/flores101/flores_logo.png | Bin 0 -> 33184 bytes fairseq/models/transformer.py | 2 + 3 files changed, 225 insertions(+) create mode 100644 examples/flores101/README.md create mode 100644 examples/flores101/flores_logo.png diff --git a/examples/flores101/README.md b/examples/flores101/README.md new file mode 100644 index 0000000000..58d9c05aff --- /dev/null +++ b/examples/flores101/README.md @@ -0,0 +1,223 @@ +<p align="center"> +<img src="flores_logo.png" width="500"> +</p> + +# Flores101: Large-Scale Multilingual Machine Translation + +## Introduction + +Baseline pretrained models for small and large tracks of WMT 21 Large-Scale Multilingual Machine Translation competition. + +Flores Task at WMT 21: http://www.statmt.org/wmt21/large-scale-multilingual-translation-task.html + +Flores announement blog post: https://ai.facebook.com/blog/flores-researchers-kick-off-multilingual-translation-challenge-at-wmt-and-call-for-compute-grants/ + + + +## Pretrained models + +Model | Num layers | Embed dimension | FFN dimension| Vocab Size | #params | Download +---|---|---|---|---|---|--- +`flores101_mm100_615M` | 12 | 1024 | 4096 | 256,000 | 615M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz +`flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.tgz + + +These models are trained similar to [M2M-100](https://arxiv.org/abs/2010.11125) with additional support for the languages that are part of the WMT Large-Scale Multilingual Machine Translation track. Full list of languages can be found at the bottom. + + +## Example Generation code + +### Download model, sentencepiece vocab + +```bash +fairseq=/path/to/fairseq +cd $fairseq + +# Download 615M param model. +wget https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz + +# Extract +tar -xvzf flores101_mm100_615M.tar.gz +``` + +### Encode using our SentencePiece Model +Note: Install SentencePiece from [here](https://github.com/google/sentencepiece) + + +```bash +fairseq=/path/to/fairseq +cd $fairseq + +# Download example dataset From German to French +sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de +sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr + +for lang in de fr ; do + python scripts/spm_encode.py \ + --model flores101_mm100_615M/sentencepiece.bpe.model \ + --output_format=piece \ + --inputs=raw_input.de-fr.${lang} \ + --outputs=spm.de-fr.${lang} +done +``` + +### Binarization + +```bash +fairseq-preprocess \ + --source-lang de --target-lang fr \ + --testpref spm.de-fr \ + --thresholdsrc 0 --thresholdtgt 0 \ + --destdir data_bin \ + --srcdict flores101_mm100_615M/dict.txt --tgtdict flores101_mm100_615M/dict.txt +``` + +### Generation + + +```bash +fairseq-generate \ + data_bin \ + --batch-size 1 \ + --path flores101_mm100_615M/model.pt \ + --fixed-dictionary flores101_mm100_615M/dict.txt \ + -s de -t fr \ + --remove-bpe 'sentencepiece' \ + --beam 5 \ + --task translation_multi_simple_epoch \ + --lang-pairs flores101_mm100_615M/language_pairs.txt \ + --decoder-langtok --encoder-langtok src \ + --gen-subset test \ + --fp16 \ + --dataset-impl mmap \ + --distributed-world-size 1 --distributed-no-spawn +``` + +### Supported Languages and lang code + +Language | lang code +---|--- +Akrikaans | af +Amharic | am +Arabic | ar +Assamese | as +Asturian | ast +Aymara | ay +Azerbaijani | az +Bashkir | ba +Belarusian | be +Bulgarian | bg +Bengali | bn +Breton | br +Bosnian | bs +Catalan | ca +Cebuano | ceb +Chokwe | cjk +Czech | cs +Welsh | cy +Danish | da +German | de +Dyula| dyu +Greek | el +English | en +Spanish | es +Estonian | et +Persian | fa +Fulah | ff +Finnish | fi +French | fr +Western Frisian | fy +Irish | ga +Scottish Gaelic | gd +Galician | gl +Gujarati | gu +Hausa | ha +Hebrew | he +Hindi | hi +Croatian | hr +Haitian Creole | ht +Hungarian | hu +Armenian | hy +Indonesian | id +Igbo | ig +Iloko | ilo +Icelandic | is +Italian | it +Japanese | ja +Javanese | jv +Georgian | ka +Kachin | kac +Kamba | kam +Kabuverdianu | kea +Kongo | kg +Kazakh | kk +Central Khmer | km +Kimbundu | kmb +Northern Kurdish | kmr +Kannada | kn +Korean | ko +Kurdish | ku +Kyrgyz | ky +Luxembourgish | lb +Ganda | lg +Lingala | ln +Lao | lo +Lithuanian | lt +Luo | luo +Latvian | lv +Malagasy | mg +Maori | mi +Macedonian | mk +Malayalam | ml +Mongolian | mn +Marathi | mr +Malay | ms +Maltese | mt +Burmese | my +Nepali | ne +Dutch | nl +Norwegian | no +Northern Sotho | ns +Nyanja | ny +Occitan | oc +Oromo | om +Oriya | or +Punjabi | pa +Polish | pl +Pashto | ps +Portuguese | pt +Quechua | qu +Romanian | ro +Russian | ru +Sindhi | sd +Shan | shn +Sinhala | si +Slovak | sk +Slovenian | sl +Shona | sn +Somali | so +Albanian | sq +Serbian | sr +Swati | ss +Sundanese | su +Swedish | sv +Swahili | sw +Tamil | ta +Telugu | te +Tajik | tg +Thai | th +Tigrinya | ti +Tagalog | tl +Tswana | tn +Turkish | tr +Ukrainian | uk +Umbundu | umb +Urdu | ur +Uzbek | uz +Vietnamese | vi +Wolof | wo +Xhosa | xh +Yiddish | yi +Yoruba | yo +Chinese| zh +Zulu | zu \ No newline at end of file diff --git a/examples/flores101/flores_logo.png b/examples/flores101/flores_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..d4d1455c6eab608ff5317ce885183cd213564273 GIT binary patch literal 33184 zcmeGE_dnME`#+A~jY1hwlD$f#DA`Ucq?D0{c{(*Dd!F_<%Q{O&Q7Y?1QkmIKGYS#T zh%=|1ah}$B(rM4n?d0)#{_y<=zQ4TB%caZZ?7ZJ^>o~5*K5m|x80oVe;yVOE5Zl!& zx~34sgo7Y@MdpLx7iPhu+7KiHUDegT<!3WDbitM1p-*$hOY?jZ<<P|Lq4$sF4kSw0 z`p`Mu5m6LDRlewzE($o%7tVO7sOHJ*`v=ZlT$9V3<Swmw^2LjO^`L6X!)lkLlhL=` zxYL==2Ry7BcY_`Oai))=PTWH5cI|F@fi-ELh7X(RSo;0%&>u#6@Bh9Y!7%K9YY?PU z{qHws9Nqr6(O5|Ezuy>1|9x@SN3s9+^?#F8ifb|b_uKy_8RY!GnNk0ng#5pmwf;AW z@Bd~N`M*g}YX6)0|84L8Q4jhbLjE82{{Jr;B4P)T2Y|2_mA)}-=DbjlYgN+o)ji;* z86)^{!GtEcD2n+{i74q^A0@H0{H;y<_4MHBhv!q$LMAmSVMI<36SLBXxtcV+gF3yZ zITY_ij^~Q5$(<dELFZGp?$C5ez!jQ`SR&k0^iPBLM6Gub={;$`*InNG{c_}F_4=%d z2~8I@=+N{Es<W+O%Y72>Gen`OFO(ZlKuabH!s^lJ=!DwbEUYC&dimY?-`TK#;*V8m zQyCkg*m1@=oQs+7;@FK>3JPYQ(P19&z7w7-<Uuae-dHfF_eQ0Cn#v+_E}q$)*VBa+ zuAhTU-hX$NJN}M#0NhDBRju8W{H@t}M-9I8IXFBhM+kcMLan1qPI=^WE=?_>a+Wj! z$-MOygat2uTjbeY^e5{wq3)e~n6P+Hu__j=`Oi`UMtc+Ep^Bs1uq)dPa*+1T44vOS z2<5k3n78Nr0yjGho<p-A7z8n=Kd(O$;A9$E)2weBB0(nMbISEGXr;zIRd0|1-eB_7 z-Sju?*V7f~`wcZ%PAl(C{(v_5*p^WBJ&xjQf)HQD8$))Uk{>a#`vWDkm`Ih58miCH z<5;YB+RzMyLj!U-Ur4sT{LPK?%S|&5+7?{6Fbv-2kz+haq;$EEIR_M1{>IQq{I8;z z&F@~Cfze_C@;bEgMzC1+^E&Bc0W%~dUENtI{&!jZ&U4xfS{J|zg0Hd$5PiMDO%XT9 zP+rqLMU}P!Y8c=V_(;Y;f|xL&3)t;zm#Y{SCXSc<@Ymh?JSza6p$cg0_pFOwt_Yph zQxJ}seFkAV+gevzj4to7T?}ABRbn+s7~eb8;}6(hC{2%Yv-l}d&FOLfff{Z8U|?xT z@vPP9o=Y2dbglci=^F44{6e-FZ{DCi#xMh>yH;5@+4QFZdeFw-j_!XTz*%27hNYn# z$_LQ$){PugLn7mrx;4BZvUC$w35ldV!zqB#?5*ybcJLt%)xxQ)A*>jrn7pvJd5xnO zhGQ~1^AeZy-1NSI?W^ZXGumEc>7s3&qz16`Rqm%v%P`fg|0z0}Hn8iSVn27<9@Kon z#Kmy^Z$9u}M+0AM_R)#_?ONEv7VMd0tPzUc<fEeowl#J^vF2y(EaY7I-MM>7{upgj zk7yLt>J9Qf&e?)QI<0&FfCq#l)hd3-pgq@;rU|vYv|2r)w^w1JcwZh|XAaSk(!V=< z4*!);G_C0h9oGGpvxUzcpVMTEyotXG2BOk{g+6vtt#0IZufii~v4Y%n()y3#-W<sG z{@I5KS!h2CGyG9uvA>Es|B_k1&j-<v_(uzz<kBtbFIV!La~*QJn#VF3oFf6&U}7>e zGpaLQKJg>E-^wt1B)bjXr3c(ey^DP66HbEoCo@N#y6@L?={Vw>^J#=xPzSHPMvA}? zTwop)*e@R^6A~%7AfrQV|LQ)2rKziiT7V17-x&)*SJCXa`zqM&yn+H-Ve!9U&G0Gk z0tAbM1Qo9<2&d|Ot6C@t?hRVX%M$nWl}kB+=PnNMXTHDsfbzMp`q|@$slE%?V=9^3 zLFZF!`nYMBZ&=+jhSTBHqBuRQI`pHE5qZXuX%)PeeSdi<2*Zh@kfmvd{y2I~T`UdX zOaB|7F$!p%f`aoD@xSR|V@oD}dwT5V5lquV!ZY*;5S-G2_#FJ#FY_aw@#BLLB(~hC zEPKE}PENNOM>-+QwS==8m%r7^Dc2WB)BH-*CKnhR2=CD6p~p=vNx~2W&P=7;cX4S% zCsJj>@zFXUMap93vonVluj$Us?5PA>K}J%8$3e_xtwIZqCh95C03FZeM#CumU4}{3 zTxrRMDD+y3t+F~lYfk#iGa5!mf$VBMIN^J<Lth^m0W;beTDbs2SVoenY5yy*UdwMO z6vdnpoYn~C=Zpd;`%rKM41+?#AwhF+GDPnB$HRh;QuoP$WO!mMy1_S5M{P&@0Ts7! zB^$9B%Ey{>A5Nr=lGjGD2Tk7wB;;$^mIp+Ei4FhmU2eS-NgBsqn#q5@9eM+dIuYTU z9@_b0o(CC08?|r?oC;sl`hu+bR9LDyQdaD*9o&V(Mekd(fJ&!07*#<LFwZg7%+=)Y zr_o)D9}K8)o;?|oGB;cw3-li$Eig5T*$<OqUBq&K_wjL{s)sIdOf&DvDNUHx((Fzz z`ol$ii=B4Maax&dEr2)oPc>{$Ogx!oOSOjN3%`-Q#`eriKG0aJi$JCdA}}LK;_ab$ zg)JpDP7kM34XcwQU52^>Kyhw;q?snM*qIr=|7-!0-g~r;Jsg9mD5h*7;qwOv05InD zcCps4+*5OpLhsyfjXo~Cxxa=BP%Io@SheT%jQueUb<|B75|TdfZb@+<YNXD$(r@IK zlBuB_kv%l60zT$SbgBy{+~k%99`zdryNs-d)uPeKK;7j@&y}QDyN@p%z?BM&>;RBz z3m<Iy#Bz!@>KI@MtPfy%2uLp#+gSHwlq#oRKFb%9GLK%%e;3y|b7)iYG|!-F5MAl} z_hf2U#I`?e;2(Ez8vQwE%hO=6pxuInDbXgWwk_-gz)kRrgu@}eYxIw9T*c3q0AIpt z&zFD<ftJ;DlvaBhs2+gKp4~G0N=z7diXPb<j!7%bh)%JFa34%n>qfw#sQZU{^b!*E zJ5vaB5TJHiFSk!j0kwIloL}8c(0ag^z0~DG^4xLnBYAkH9NMNx5`YGYHoD)C2zs2O zMu1I+J|5f2Up9AHlhS_oISl1-e?-LuZQ5K>FUXpMN9xfA;pD-f-#&7>vUae!u*r|} zCFnJAG4`1mi|C|1=^w{7zMPzGrDMgM%w{BE_aGw)HfA1`k6tT_R?ELwRQnWdhAhh) zx}#eIe2ZB}jWeefs*+&g9@9w1V;{r2dMbeAlmp8NK7{ihrJ*&tLb)}VYJc^{EwjN7 zTV83?%Cg<sx?6WT^xee&&H=?l!kscZVSS>XJ@Mg%WSHy&vCBmEKQxIgQt`$x%=#SR zR1*9Xy<_cXQ<H_343G^lI$wD_jUnPp#^zQn`F|e3yQ0*<U(rMM!NtBvPR?C4x(iuW z`d3F6rrBSO{Uq&t4d_SxZ^@UiYIJoU!ke6ld%hrq1C+!iPHN2F6Hc*L4t#U>+@v25 zXH<3pU28BA&$(D!a2#Xj!bD1EYMQo#bJEULdKG0PJ7+dNwC!lU^vHf^)@~0~H(qEK zzB;lU3_lm+7YH_S%iXVzpHr!M*S9G|A*2owlaBAzjw=(yAbxYxrqLQ;BrScvp*u)6 z8^5R>aC%Pno6D!bU%;EwNS%WNe`^41**B8|9Kd@mJ4AygBVKxR2iixs#Kd;c%yg>! z4m4`neh=R#!L!p^ImYE;e{Ub0{vPSNg}<Y_1wfI5jBxp}utLPDrVSF&Yxd-9SyP~s zChEKnm4gk<n~j&GaJgHXaj4^0#$I#*-eLdaorG3S9GI>nM4@h|uvs7Zy_AduBagwl zw_#!-Z6-;}n{PHdqggG0ZUepw%Y8MGFzNuzyw?YRE^)M|z_#PrmNm^}qGVZc(<MO0 zko9~BirIL%FdsFzh0_B_xr=X0@NJ(<!iRq0$vV1oQ$ibmfpFgSk)!<X?M~Zvgbf!q zKk-_tE-(DDA6BbsqdFy&(ddzt+{MWd1=r273nS%6IhA-#kDt~V8e;YL4{$4?Zrz06 z<@fAuaHH?{1Sf3Tb_`0$TAt62s(0EH`q<jE=PE6nqEk(t(9_-ef{2k_$@O?K-uM)t zd|wd^xC8zrr;;eMj{fnXz&DP$#@YpB(}JuP3OBo^29_rS!gMmQe_&{+nD67xOW591 z``!^ys-d?4XT@gk+VyywJTstME&kdp0VW!s4m6AEzt9o4IsJe#Q(Ewmtv9GNSv+c9 zO%}J051{Q=Q{?G`5Xnd7ue;GC0CSO%<or2*FfOnf5fjZyLGx7c)b_?!xvBJ5^XY%< zLL!5D_Sl*vYeZ8{eg#=o=fTqbAqpDNXiWiE)Fv?4E{JPxv$UKJggHj<tw@tyKv=Lv zKWnq=<3Jy1ewKD_0<I5mRTJ-?UtPVhn#!Znsbt2yMewYUK$x?1x>2EH?x8Nm-d}e2 zT$@f{?o;Q{ekF*Q!qwcxyYLG!`zDf9Q&7a#uN^F@TVKJ!`31<qZR<BLH#=n`N+y3W z7%8oncofkI70k0Q{neoEnPn?R(t;Ti=c3iY>8{yl4!_%hCn)1ByI~l4)Tbt7WMB@q zd&}%ksWrEILN0^kH*m5+uzAi!2@X}dxSo}=lFF`N+S%5Q)2Rjx0nI{VDC4xA3NQ*> ztC@g#Lk~}HRAUKj%M`!{kBqT(<<-=~j}E{1x{lP{nj+F0XZFrc*(#s&w4>n_3Iay5 zj-o8)3b+gi=~^8atv0~97#NwR2gttR4WN;j+P5`Llgu7$rs9MnN~!rt(%pJK&5^=p z9^Mq;-@Od0Q$jS~1|nZYu`ic<QpodQ<GIowF1P`C$pp*?XaHW-rvg`opjdS1MA}*= z@FkLIPG|cD0IfaXB1~+jHSAhLkK;wHieZ}fjxyrJG;ITKAQwQIhKSjaVH2A9zDcrK zibk)QMYblBbjWJ;K;2muK$S7*RG>nn!>A=&{>ne0*`6sP+#`1{kQfaFWOoP0dRWgl z0(4mu_FkXsX)#ye<W#1A2UxHp>y4BishwB|0?`A1|8N})?VWH`V+)3<-%o0IKZ68? z+)j&y3=}pKSF$QxrNu&kr}uZo*&e+W6n6PfoqB&Y0sK32=KK%f;Z#vlj!&jhXSXbz z@x2V1G+keqYf<u-{U!|t4Xez!v6kBB=;Uq38Y{C_LT$lRR&@7eto@sy6P%P8M>TTN z;I2&~uysP5cps+<H(*vV1;_p5fNrJ{8Ye~X?Z*cWv>BoxQ@{|wAgKQh8TsM6@_#Hl z#_2ZesJ5q0OR$+ACLir4zh2#Str<ktr9igUdqjc0^8wY&@h_mBkFrDgq2~dw<bMc= z0;1NGz}9GwVqXQYt68BKu3hVeI%*Hi^J@v1d=~8#bpoT{Vr(Zx-hhf+qx&ZoqIXH{ zp$K9RPkMSf@Ikt6j@s1y<6Sx{MtJkiPyK{Z;p{U~oUBb?od3|!RAF}VS9i<b&$LL% z)|=(dx`C~mD-sk7aWPk5bY=p10)X#gOz{H%i$iDas#m!u;L~Up`CVLyRQ|;cpiV#< zKm$Z&E!7^Hy5A~G7hw8p2;FbdH13;OSu=NK++N-OS1m)1K-tjZ$#r)RUaqsdRn?t= zo8O&_w~~Y}TCbJ<8~xWDc8yeD7_@VCce~_Cy%!5EFcUR9vbjZ^DWfFl@R_)sI;r_H zR3mg|)wRw{vp^ucs>2BtcHa9nB-l4DB31x&LB^3<WFerKQrPglkmCU6Nh&s5a%PZs zN7Nilc8VfTZuENWA}R-~4{1yGfKVSM!kUM+%7XrwHwJI6w55;BGF_&Ao!ktv^)K0) z%sAzJ&1aEZz$lO%v{OV^$4q!#SkIOAMIdOg3tb=OgO`qKd*--;l9)tJxKC=7f3o*# z-~2tfxpaGQ+h1cUkgA%K;1}A>Lho&!E+`ss!Dhl{DHEFwjI68N!aTumwgT+%d%o1i zsjMs^@8#7eAllu_wai6FJvS!2=qU=@Le^|?{L7e>yS0GhI#kT{T6R}TGUo<@(FGpN z7Q##Teot||hVPGx$>sFPuSr=!vPsG>yw-(+IYa9-WvTeGRf%j6-~`O;e)HEo#-r$p zNoD<I3^pJ$^DNo3*mqe<?N^tXuqSt3<%=~Bp0xCH#tPj9gZ^X@hQzFSK@#-cg5_sY zJ=Xip@`ZA0b+gCTI?j_<lKFRrw)EY}v82>eTsT_*n`^$n87rT!A?8&p@V;$L)XV(v zsd4t<cGqAJw+$IF8Kx$O3te-fj3gX-eA{trI@DOkX!ejrbw1h;L`2mzDr~_#SvQnd zGwbtm-F^_!k~0h1TUC^;Vj^M#`@!ibH-<tOe#~wXJ~$NchJ#OOd)Px-4fPtX)y|HD zYb|{H*U<T^jTjf!EPqC)yy+EwOrrjv?>~II!1j)-|27E3Dv<TO21RcSUuXC~O`VEc z78A1pfi|^9SlzcFfr}ZJ!7B^!)3KXoeT-YfwJ1qNb^GIc!C|nq9H8Tl(!xjn6%}Bc zkmD~kn+nMBl<aLs=UvzW&z|3M&=xKyFKqv4G;ML9Zq6jMLLgV_dQ{7_LxZ-ycg@AJ ztL;q>3aRJ&?R2e~$&!Xb&8i)%-su#F-GR4hN5wml**(A&lqUuOH)WUdBjRX`9q=@0 zam5*_6B@vW1P~dhagZImwGah+Wb%8q{9RnL%1T{$y*Rm%niZQT3NnA!AMKoc{kSn1 zL_`{RDs#_zvcQ3DEBMg6ZMhRh(dHLl0~n1iz+YWv=?MMfV-yio#q}Y5_Ii$?^w;(2 zgEnS9|7?Jc=wzHO1tQ<k-<B+!b*uf0AZNW9^_TIf7zvs$+gVF;@inWL&A+Az7Cx-) zVgU>lHK^zigsWs_Yv#wj>WQ4(`J>eTHsr3`*{$0`CzX3$jC$_eVN5yHk^bB9&5mfg zueI`(%QqEa%Ddg=BB|xH6u`oJAK(UYp3Yj2R^n|tE?)iDw}hdF%?%)x{o0{D?<cmB zCoQQ(9{3R+AgfgK#!yE5Zyeh16bw<#!?`#M<N#z=JE^1e&Uqk}QTfKOpyV*Z&*=WY zGGE}4|4=|RQkOw?2zb@B)XWQ}hqf*tw0F=QoT|;YcL}vTI9*XCaiGtLNmZWW4cKsE z(k5m6RV_(x%h?Ec#dfDR{F^csy8s}FNIk>0^o}zAmVXl@AC72@wZsTin>kR!Y$cuQ zPbrZu{K)X5$cPt{I_Z7zk*?4C*b?#H<wC*sp`XH3i_`+|wCgrjor363WvNW6Bl_Q_ zf%$)^AiXWF6*c8D$&>E^7mXRP=#LJ8n+=A~LBfFiIYC#-jMBybmcI8`p^52~1ETHL z@4JT|=|rkz`;3(ZEIevblh;_6<1jw$o1Cr~RFoLN#OV=#UqRLrk$3*FO<Q1{7ZcjS zw2s{Ac`Gn?b!7|hqmiT9epT7tC!1>SuRN9am{=F=b%e8C_F)2IFtk{P-uqFDhff9% zXZJrIE8rq22y^#F&;p!y<Lw}ArO{19EC2kps0W~Ff6jY^3T^iO`Qcl6v2)RJYsd|G zKz&6+OmQ8egR!KhrG81pH{QHUADkGaJC~V#!GBu~X`A)PN*X>*k^4r?OA2gXU5~JL z(ww-h`y3=!|7^zPJY-TWZIM|j-ZRFm9?|isW4k|#Hl>2m=%L0&Dp92v?q{Qzejj9N zCqk&nFtj*~5^MKnL7)8W9P-lP!FG@W?C27!e&nD_R;uruO<$-{13{}Bu#f>wIqn|D z;0tVo*hm$wNTq0YN%@<C6L)<LQCK;j;S6Gg)Xj?M_^e5ei%TJgC?(XIH{|%i&5-*< z<9-17KMW_K(NA{nA;}3zgB(7|BB@60OpEoKXERw4j70WWB>u5TDg~IkE~@S87xTuH zUXZf(82NFN6Qfg|^6borQ6vduf((BcE#@%)8c72&*HqfF7>GbDLM)E`eBDLlffgRr zF?S;WE=7aMX*0>SmS#qk9Zgw%Es|RQ1OMbh#xmh48i>2;&ypbQH5B{2Dxt9X`*UVu zG6<gVtiI78l2`@~zoU}aa9?jo6AirM!2bZNlmJI}36;qYj*`JN05#k+XJKa~ZlAuA z;|HWmNzC=Y+q~q_p!>|-Pg|QZ-$nA3aC6qnJWS}pURo2PQ!RCp7E6hoiQNk^{@5&e z=x&*sRs8kR-I@gs0dT{mRK#`(u7yTYQws+moS{mcse94e*piDRkQYBj9Xcd8xCXEM zsyew$>0WjlKO44mEvl84w5<ZG$<HEphMKZ~P7bDDa`orrREn6_(~L%^&&<fdG+{ss z<w-jm#5UI}Vn2EXxza&#FSA~BiT~yJ`t|Za+w{!}Udv0l#n%igzY>BxwtKXSs%$3Y zUweTCbPiU6JZ@fxhY@NaWc(|zM|C{SmA;)mWIfp_?JnSXf|yKP=U?Ia5za2J&;ZW8 z+*t?`TW<vj^&*`YS~$ULGIaS{JJwRM8H6%`ElQUMx?i&@+2DJ>-mltO<dzi_o2`J! zd!2*HHfrlK|MEXir<(XR><PL58Cr>#o=3E4v5B?athT+AYCN)*AJYaR{W@iz*Im-k zs5a0@t{PfF^CChcMG09hvarGTS5+A-at6IZ#y47&&T_nf1M$2k0slAt&t{5Hnw>cZ zR*<9%Xqk+8qZGW#@*Yb|2*8tydA6lRHwHkIkoBaII)twDVCZ+yC{&g4C9iw3HxTTI z9TwzcKD#Z!0`~gD4sg*e$uX7Lm-zbVK`&Q3K_t7_U*G#fn|DvaR)=Ndj-)biVthv| zLaJ&&i3&tYQm?f>4}f>d6KhETROvuN`hb)y`1<|%leQPDcoq6|_BQ0vI>`H$5Vrbf zYT(KrhN3$vv6~XB)=fIjdf#T}TB#{vWvxx>jb3Skv^s{zybS{OyB%HB{eDvhrY8OF zNNV6w%xcq};l%A1i1X1bKP#1ZddUN4)OBw}{a-JDyX*!MZg>;$>8#lB35z~|Jprz< zonq>k_l4xahpSD|*bxRuQ>;*(=s)D){!-W@Ui*owciJrtC9le;)35OL?dY}M5{{4t zp91+`F~iy_{*7Kq|K=yOnbu_J0(J!^$cZ#P@zY_R{HL(lh{tEdQ#Rm&XaEaXv9amt z@DYr!F8LkE0q#!JsnZ>VoRXw3xfX{VO50l{#HjXDZdjE0v?@=^;w|qi-|Ipm3jUdD zHWOYK*CglyLS5`^ms<epy}!&e<{a7_H&hUwez|rL=1ct}EM$4y^+GaD8yb?|Ee|tb zE`WNFAmN_i>md@mH1&}`bJm-yF5c+{UVU@*LnLHGsWin{Kx97cenRY8nl_xSc0tVe zchGY*6<v9PmWd8I8Z!`sZpVtLj)QE<{Oi)NnyVmPe^&UqxiL;w6?Zt1gWg#D-dqM7 zWZ1r#nZ)bfsOmELt+olRylfq&U+G={^X8AP%)?AAj~d#8YGXirU<kj^A=m$&Y6DOS z!iX7^U}zRYaEUYV1b!E|SJ7MUOnpRUwIZb#3j%#W^aVghcoro7j8y|qFl%M~yGzbr zNQo4*SHjx!4O;oQZ^X7Gw0@GVXX$dl8R=Rle=~gzyK=^FzOH@Cq7lK2Q^EsHQ}@R3 zxt4U4?C!RAjf2`RDh|bTbbG?<vz?0A0$5{k>W$Xp15my>3x~jHCfmUJ_4`?MPK(M_ z>4Zn{5c3e>>dvyaxYEg!uMgivT)3_{9d$0@@j9RQVE>Q96#rO3RpC8v@oAa+cn<<Z z%#r}y0V6T4M1RV<DemQDr6NTZB0PMc7U5q4N`XvGpF~m%+q&|GLwVxRaVne>H+9eI z!$v@0#h=R|B%5Pny=iluE>`@WFF3z{>_{5y`b<crCOeSzeo!RHm44vGjI$uL^;F6N zUT?B?r(z%HP(i7QmWf+l3BKK-Z{Miyr@CH5kreBEevlEv$SY)xl?m=8vgY4u%H54! zs-f6?-%_&?DhNrdPn$f6v6oN}?)BKI6+FbW{wzA}K$<Z1;=}dXKBI}{b_X6Gf!(2& zS7Cezp;2^{00BsJyVG#@-5u(Bo9wD^DB%ZELVvZkWpw&Mf{w7njk6R)q1j0NNT6Z< zV6MzaQj~u>A<Ah!7kE+A*tVBYYN1N&Q6*tBY&Qj2!VrW9<vBrM&z8wzDfeeK-N%{p z+nJc`dDB+M4l_gQ@#68D$92{FqGajw0+|?*WNL$qu(d{4)T?$xvd|e?9?~ZZzav=8 z`_F4u4-Ow2wlDJ7PE1!A9J!}>!J+VR7=QoOT7uzrFC-h*KKZp+<A8767hs$KvOusq zEpeS7d5J?Jxu%QZ3qH<YICOUPEXT#ELK#G$Y+liWtEkRZXM2$q#|z0dG3djm38PU! z83C0!Puud;H8=3>@;86D;6iYJ?|DUfR#Z2@@m{cvzrV{?WUywG!Xe{!$5&yBuBMkM zDf!Q*I&MxpP!YG{wGIlvx6(F5HQi*ct=05D;IAuw;6GIT!wGSrtLV-!Pr>8R;?DOL zIfAgtwF`w$d1eDN3kTEFh(uy#R;ZcU<*JgCTEx0oud{>}v@+H1n2?K_aj_ss$7?GH zpQ|@<opC^{iLPh=l@Ik7rZY)lo3EL*MwN9}`SOOs4|rAHz3Eaecg0Wfgi-a0w13{% z(u31OZT7Lj%?DbI+M?MG-P0`ae(&SDjB&En)u;$P&N+@e@P2ExB;(pq<q`YWJ)~{t zx`SO8uWSBv@=RhEueJOs!Bg;q8R!3TeoHRkIA;#L=F=1U8sP5r0JEY2!&Q-RIxLcH z|17bhh$%YKn^%j8aG|ejM}+@{x&iJ6x%+<+xbTicHtru@Vz9p62E_i@Ma7#wS#^22 zPW}Hb+Pe&)U`7WY;Q1HS^Ud(Jw-Hqb$q$vBM+zUydUCaAbiFE@ZYvqQ^MqRXL$m4b zaCDk4$+lZhsPS$X3(h}}memDZ^eSc(nxtiief@<4K$_hZ$8g(Ej{i>wq}2lIpS5)O z&y&jS!x2n~wx}dLrw5kn<nutGBZ5F<9m#ZiZ<-?!(y2M%KceY{1)va^?D|-k9wThV z?&|#0>Ae26^a5)J<G{hUbqS&yYZo^2*UPP|%>sCqiOvCM$*ONwGpnokX7U&)^<l8* z-5a8e1aU2*KCLDM#A70pm7}jem{heoD09LOv(<<kek^+ka{f3sSTpD0x1+MVn(8*= z`$zUkcou8-X2AIsWkioY`;J@WoV)b7(X}@t>9V!#?XO09LHQ2(dw{h=NrSn~X;gV< z>#eWovEYxXuZ)&rYVZ~b-RWUInZG<hvj{j=du-51B^nLFlqLM`Z|(+lXQG~eMZuA^ z=y~?Rpx|ri#t)h-^I-9d!*rM{>#5`q3>+mw<2t%q$2hxyHY>=JKVc*bl=}_ct#tI- zsj#{_l~()XT{DNKBVJ<3!}{Me!9(g-cgI_SUjz8kLVtEG@2V9B@(ZVhO3~B<TTpV; z+{V3j>PAiiyysNh?ymp&6rZW0QCsC&P!Ou-T%=$LuU&}CkJ(OX5kFN*HjUT(bR~DT z{+bUC`gq8xr-w-nTDlVFTUXZ(IhmoBUDidY&gWA!y21`pW<!iyfX}ljOR2iXto-_G z)vrQ9Toc6+l%#WDl^#=AGugt&<8$)!J$L20X4AEX)e;HNvWnQw1t1s2$=k9(M>q`o zzRvrduFiG@Lrny(awCtaU(rjW)=)FhS>QVkaU>w_sIA)_M0MbL7Pedf8RnWfZPgnL z^oKR6&?N*|zHpxMS@-~PsJmogOZF(o%(M0CukIVu(M=D6nY$xMn&gkTzzL9p9$5Zl zu<pNH_~f(}>C%X;4~F{N=ilLW;BYM<)Iee(pxC!J0@Xo&=%*~ru|X*xHL5LNU4ON) z+8u7B*=p;xbRGsO!uET0t$0d1w4myXPS^Y9F+`b0_0|?)<pSSNf^y+{RZj$E>Ag?{ zkwRT2;+l`PnNY+Kr|r*$)WIV1Upb{yBsoBlC6DIV)>b9)vgRZKl~puKIhg}@M2$}a z0~AI3r8|Zn)@lKSDP<1M9XJk>GkXPh_O%}1p_Tl@?7rqOo0%biJjzZz_8qc~{W8Wx z{czU-qSMRJobj5%X`Of(T}n<leXK9%g%!)0*J<HI(8chZieH|&0Um{hzvi$HzE;>s z%?1TaU_N8T{#IZK$LA&jTZr?K7aCKM4ztmRy?gI4NhtXtniwYhaNQ$#ATr=+?KkG{ zkF5vf&!M;VIV7IXB?hd;L^(-~n#Ve6&RC&;`KmK_C8>4#MyPcB5D8ZDeV?T2h>i6i zbC?lG<hWqguB3k#^M9w3vTKh&LP4oWmDVr(4cn=As5f40t^t4UBkDo}P9aB~`4mLT z02>EQ2EtFMQ*4nv$*%;b;q2Xrvk#Ao2CS;Pjp1gh2i2r&=HAS?&N%`>77O_@UZdVF zIK5s)Eib4uvEuZ2%e1o|TZ539wjdGFr0tCSt)NX`IQ(#IjhkW|Slw~sFst{_9r6j; z8=pdDB;U@8K;M4^_rck_G(m#`NZWG0KL_$wc;5HBzsFo_?u!Nx1(ho0roDW2Rx}z& zS*%UijgZjfD8=gs&uSf<hNra_r>2q`8z|EG<9;<#t91Y~^7_3=52QhL{27UdPIY#v zG2(>K+{aRxr8y0^00ALyOCLpVg*Ol;=ig?(T~y(X`Ufa?3$rU;#WzX*J3&~Ivc785 z>e-el9|hFVBf+$qVry<&g2CfE62)7CU&9`#t*4C+7+{sw>=-DaafgBsEI6g-;Aj_{ z1%*C*gOjqi{V|8FcOglGNFtui$5A7DwTifP9dG~kPBX!$jGtmCgbFY8Pml6d%%m2& z{oC;U^>ragr2yV)Q3>FY#smVg-s2Bjj=e6F&XfAvBw>`^s-u&=O<BovYXXq17{bzz zx;^WU0L8|ri22RO1@7An7<r-~YnLV{P=Seva3=2>eu`N2t-HMgTdbabIn`Jnzd4}m z?jL0+k!cTmpwGEj7z^Kx<tcgE{I*_t+5~<zbs?s$N$OOL=ZNR1X9-wGDaUhk%z*Hz z_=}RmLYM%I4s!y?3T&0p_eb$KwWc8c-$a*7#r+`?V7mjipmfL&^%n12K6oD~*sk6A zly^S3yDAT>orXw-5l=#IwvDvNYH&PDetw2R5CJVXa%CZ2@$dzBAnp*<@N95?tGJ** z1lt*S?|jN|0f-sCm!}rJPxpp=t$ccYCf_!D;;Zo19C_f017Eb%Sx|`;I4)PW?PU~C zcvRJ#`JUk|{Sup)mi>_T=-2VNThto^N(DxQCh>`+bI>EW2J?J(7^seUe4MR8{3EW% zv7)iIr*jmGszN{m$qK!%dC2KNxS1K+SH<J|SZdmj*^#?<#Xr{DN$YsuG;f>M&@9uK zs5elxE*%dpRsq$RNk_E?*Ufri2&Lb~q?Q!njT>wkVC_;0zHTQw*@aXJ?SwDL<tP>T zhZ6-WJhx|`(2ChF;`8bOyZ^kH9Wl~f+mg;|TFY|!m40aMl3Vq3OeE+zYU?FDZJ7(9 zbf=w%pq&}HmK(lbjl;&g&uWn>1RT`JxGSq(<IYN*UG&j5mE?x8FI_6dL#5GdxG%FK z(bl0wFaAMBEL}SA=$(T1^$A?;7B~pBiB*+b35*vE&_kFk;C@@^_c;WE0vL0iVtd<_ z)$Y|sKKp8m;X@azW%06!M?!+Tn8XK^ZiX~(l`!Dz@%%xkyhPjI%cR5W!5N9GogeBn z6+x#r#@XdSfA-*cdT0TvT3x;khtp~3e)9q&d-f+KA1n7wSoTXF%J4R~*|q8GEpx6a zm?$rZ4`~ZK3iRujf^xeg{s2^7xb0Hvr|ToP+A;>XY);FAW^a@7=9Ws3|D&>}{5c4X z9u_2QZpbh&axf{GAG{K^9w6D@4!YIWR^|Ap88MV0SbnVa?y#lVC2!)9ejn|-b_oWL zBnjW@S2xd0Qa(sK6|ZRnz!kgLV}?$0fVM>JbS6U}VuU@iN1Tp+k!bTLYp(6=+Z*KV zB^5EociSc#%aqeM+09?2t3^dbkz^GYAPbh%?|o`PIzk#!_X2c`)H8@Q2smv#>}3N@ z1ugr&@(nY5aRohNodkksw;<Q)*Bx!RF)a0_(@AV0pkunHuk4ndkl4Vvtt0)Qe&K-q zlEdw9QZOH;Oc9*<e7mjmMT{T9SrxEr{T+L0@(zUo8fBUY>*!vUpRF6OrWtVEty^aJ zOdj6fzc?u2?}qS?L3S_2(m(W;Bf9P?Q>%BnuZ~oxK9B>@vM7nTujCHO-Bt|_*C0sf z!{IC^rO7dAO%cTDQ&A-Q<nE_<FJbY<$7vyw%Nwh}Q<x;Q1ak{z3Ol#b6*s5Vhv-X< zxqkr)ncbTYwB1S~5(078DlHhH(X@Y$r?lbl=x`$J%9tqViA=$B-+%>n1!5L})M$%x z2;T5R*Re3Qf}R@r2Fn+l{s=Q1s;*ubH2-;DzWLSX=B_#fZHZ(#P3Jc1Eb#!P^05GY zkLDz)Zh=${cT&Liro;J^Ltp(8S$~c6_LkWmWsfi0zP?n_1Yf<=h_pu4(E<7o;eH!h z@6T0fs*k^{yJh9q8P|&M0yX+`Rjp8*>$(7E^7%-=qlp3AUhdCx&c~uTkb37xn{>tH zg|qC#g9w&TvjOA*?{olX{&$DUuL*`81j;A~WieAbF!0s=%=^Dlzz$L`we-7k4{o<r zq)wA(>wt@Uf>U;8#!>`EygEVreK$jN!#|EAzHI2m+=h4PDbPPa6FQ~5ZHa=f>a*ae zs33QX8J!JZm5^g;x?k1qB0x&<1G_49Tx04nE4J|4Z2Bi@*eB4)av$!}R2~TXku@KC zFJnO+tUZu&rsp#Zrx(Y9qr=|KDHvs;buVMb@^HV(ggxgZVV_ukMbCZ9mByTqIJndR z{1s`DiH(<^r1SJZs`V6O1j4WEjUi0uy$icQkq;1qSP;n|<F0s84?`xELAL8a?H&Lr zcAVdXzK`FGYQnxd|HWNd*O*}^*VXL2I({oua-mB%jum(VtiP^}Uwl>0g0rBtp8Dnm z!Z6^j1|}$5qv%*#$(;kfe^ZzJgq;y24N6q^heBoL%i;jEpxzHaTNI`dQ%LUFea&Xt zXkTbQt$vzzeFJp9Mg@ZA%X4c|YVU3H{~_;Kw)m#6eqH%$KSUe?9mgQMR#{MRx6z9) zA^pDF><r&j$t%TI3qv}nVXrW7LFL&Oa&&ReYY4KFUeunn8arkPv^mYA8;Qu5Q-=SP zAU2kY_1_%f$lewR^>Kgx0Z`OKNTUB!u1o96Jw@JXj`ycPUqfBNqYl^ELnOt$bE$!? zVN9y*{j0&!KK0$tfKh%YdP(XiZ@Jv4biP`e9EZ>%bifGVuXMub+To>s0sVe|t2WiJ z5?XEm;x%m7x&H2}xwQBosIp^%PLUS`18ohWn<M6RG(j}TuAVeGWUJgO*cW^XzN?`o zsgn#GwD)+7f`oC4I^(guj&<mObm`{cnx*A!TlA5L1Hpi%HGmzzgF_}pPq+XFDTx+N zU=OAOB-tbNesf`T?gEj3Ad`<lA{K;?e^#6ErcYQoK>}Ri3QvHF25~mb6`1+9dSHf- za8CKVAZVI8lgi~heJPZy8hJpgWPf-;&yyAF$i=c_BzZB7K5m7N?;G3_Sh~~-il(K7 zYb&HouCCt4;z9>s*Xu||?_JetyWQRN3yr?079n*t_NL5qaOUA!R{ABf^rtMOv$tCI zkK3yCpxyWVEbxqK>X8{b+WRYE8nKi{l8?cw`>(F?O%~iNgEv?B(o&+&K??;ovgRT- z)fBL+<jLhK-}1xi9OASVC+N}sC8M@mmdyCcPO%M&ycuhv7EeHRc<PW@+j19y3yVqI zEICI>685=e2FdT=KZ=kl+T2Qdg7q48^Q`fd61{Zn800E8L<C7iBETP;iV`-pdonyc z6c0A~e%765UKzzSxp$pN%ikaXTE8In+%?A!_?BW~!{D8X%G@mHKrPg;OSX4W$j`+; zxC8bP*YEas>R{lv+P1@Fyd^;kD`3YV(?4W9AdaA)_m!(r#s(WD&d)eB3bE2^vT*>; zVz8Wx|GWwl@y=4Au>~_TK&|axtV+y&kDJ%=S-y07;`8GQutTZo+8U`jO-i}x@I?xI zlk|?}Dc%etl&3~%SAR&nt*0iqy%{#Vca2I*w0M&Yl1O0&EEh*b4S4OUd&h2+!%t~Z zK3e;mv~B7!orB+Q3;Y<p8h6T2c9ECy!2U^93<5^iLV)7UBBg-o`8kPu3#-?%1U*FJ zZ|FMraw{~H#dWsfU0JGY5EB_V8L8>uPz<oo`{L}G*sXF5iy&1}FJW;p_grGPi9>hV z?jFT!U<(rQ?pIe!S3RuEa7Cseu|;3@QGgOyw!X9qCpBcv%3V2^eYn<1^pe@D;`-N2 zAQ4kdyGMXW#T5j5!?gi?1fEqTrP$nT{92t2^pxi<Gse8QeGHnIwtkq6(2&eNysf;d zR$}v{i+pBsgGkUvBraGE&)7oN1ns=Ts~k8`#Q|(7J-eza1VE{~s#0rW&J%#Y6fSe- z+}h}l0W(e$_ywTtmGq=0ZqppB&NV#1af7f;xud&g{p2-G`$CxDQ|h*7Gn`n`45Z+s zR6e@_y|A%csA28?49mt1E>Z5_n~mFp(Xf*Q?(74k7okyPQOtq8cDhFhDY^}&A>z_i zpA2xJLBTfXzUU=<Vi6^&MQU=xtgZBrAD7CFjBthrsc#!h7t|MKL9YntaD|=Q-?yYN z+6CZG7o{cq3KzkH#o=el3sC<iTt~nS3pT&xC>M;I!J9mjHqp160cK{nmrkOp@}Ncr z&OCo+@@wXH;ZirqN1U5@>}QW%P;FgXaXmAsJ_hLY%Kdz?(@}GS`AZ{QnL7L004GC7 zpNC*6q_}X@;V+|x4NqC`Y8;1^^M+i}W@^RCKh6#>lN}M$jXDUd6Clf}h!{6nQQPn` zka!0Q>PtZZUH0`95<e)^9FwoB4QGd&=Wx-Faj~}#UK3JXGcm)9f$iz;Z%^vf^}r+! z){HI#<&)iSk7Qh%D<r=+9X@<2Ji17+Dvjs`O8Q6_`fJlnPOn7+q|&o+*Y-(jEX7f0 zz<m7fkbz&S?>Sx4lU3w_W79<i7&`X2FH#1|lBSJ;J9P$Yo1g!y=9OSCje6|C5BCIg z4b;Rh&I^0hC@fiOSKD5D3;V@`?W+C+SWpn`{EUk|9yvuwx@ibBDfeGmIkU6uM8i?g zqzxQDH3MdjOQ?oOsj=nZY?&IFqTQPNz`?%jV_cpe-o;=Jk|P0voWip+ea@#j;(YX` z?+=;@>VC_AL0%}`Uq9CLQYjbqB}VEL<3spiP6K;J-_=C-?Z71XhaxG1%AHpW!}UpE zmYy%%qU|rpc=Cl&hkXbN%FZ0(Teu*6r15fVLtgj=aw-d_D$q#kr@XJeJ|b}`!9e-E zMPuNuD7F%@B_g&1^4<2$*rvFB4bkKo<if(p?3~GEvr1d{Lv*E%91hq;CoPtA2HYFi z`Sq#-*A9cK-P$C9;E`;ZrQ3y`dDS9<LZG=0fT=@zlI|YM&=p)M*OL-c4F=V&V=<jl zV`aO;Yd(s%2JrRyd$_&zM#d*Y)`g{eY;pg*HC*(J;{*LycvFHug1cZck<1Inz%k1N z#V%PH_{}FR@VdK!z_xKFNm%qsl>WY6qOyady2H&NbqQTkR5{0`1@}`mWcco%Tc>JP zuC1~9ee9uc9?^`ji{GI&Ied9zn1<gz_~*esJmR;BCgEH*yV*g<z29CK6f|w!@#-=A zW3KWzdp>nIwR(l-WucV|b7$Xx?!f!t0ux9kdcZF@i|vV=d@lM+N7z_tIK*lBg3fYa z=t1gX&QYdU?%UEC3_OJ2E8B?EW^aIhRFY<-4jLepKp|OW7;dHt?s@@2$>)aS^5)6n zlG+2>mZx>X8wdNiO%k~ZLS5=987`-4m<-m=`aRe1b#x>8`NF5;_sNvUMPI@tly+&r zz)v`F!5?P_D#A37`K+ER{T$D7_EgYunEB4owy7DOp}k<qxJtS<K~r|^#{P+~kVA(+ z0HVe6RkSoeT)kDupyVncq&Bow5LYx!sN-<nEE)afnYj7kx+kFxK3z{luX>3O_n+J! zZN$mJ3ObhtIE*ro5j7mo)~=(F(i0#AqWFVDP$E*XPZ|Zeao<NuK{M)G@nJWB_)q6k z>%g4?AQD~%%<YQVP6x?Z^Q%pu8n>j&g0JWLV5G%D02gU4Bzq${0wzvIckZmGHyG*G z*s%1p_+L2`ptp}Isf?RCC@9o%d)u}`l0P%N=+iRph(UaH%|PF@6Jkej=ayLoosZaG zL)|TVz$oDScnrCr3ie^*W#DIS-l%g^Qjdia#l_y&fGyir-G!Yna+`f-8&;|g_}vvQ z3<?FH>30N)<OiL|v`Th5*uP#<4y)F_Tmc$NUSoM|$&!A!$kipv4%;<WNC+`r`bz-b zp0W5}7x+*lsA&VJfcpW1U<VgByQlwyiij_RYj1D9mFRu73}%3tc}R4FOH3;!4EsdE zYd);raZ~+QveSsIvf$xg^$u!cHj{NrPGC6KK2);8nIc@`X2ZFH^Q6r$a=e@_yVUp+ zC`+xYfOZv!)}8&hE-eMBv<6RCsQj%Br(9kzq&dOhp3T+7h4`Xyk6x2^19S0?*9KCn zid;{HpVVJjyyW4t>0JrZF(B1*bz;r4&PC<#w%>*1AT+H&DzH}|6_=+8b;It+@b*MG z(s`5KhWI#`7}KCCUPQnA8}sUN01vv8`8oq5df=JdlU8wb_el|v-^>MrcgC+c7x4%< zFyh}m?mfL=qw_u7!`W?HRLRUGMA$ZE6vKKlu`<_Z`*R;XMv#}Dkxq4MX!(QA+`9Kt zOw4`NvpSRk;bd@YvE9AWJ3E{dxAJ+=K=ZA!egax~dbwrT?20yvFF4*K{Yb;c&cT0= zx!aCd+bi2L%)S)WV){2dF_@FGnNTCilzP0jn<IILJ}-IC8z8*6n{Zyogt`BLviZv~ z6kA^Oh})=_>ZY@#=!UW3_*x$5hSH=?lbV85`F`ePb?&6a?++~3E!3=Y<zsNxjbu<P zk_oAQ+%>&al7Saa*Znqtu5{vSx34`rIiv{&-Pcy7-J+ox0Y$?7{wcwr7h$KMhPLI6 z@vFK5pq$rZUPl6L71mFd64t%npJUG718IZL5w1r2rHygcd^XkK!*Tf)(+m%^jI0)& zCAl#m+6e~39IK2sbmKd*dp6a!BT;~n&+js%r4Oc#tKC^MK`gyuxPbsY;ro9V!2jlx zrK1#!ynS!HPqt~kyZr-Q(RbH(pc@*dKOMc3|B@_C%@Uga?-i6w#k3r`Ex!9IpaI5# zwWzjyk@w53?rr>sFR<23{f%Wn7kReGs68VGH9V|!=7S(_9jWP%W1J&ZF?DjX;(sgF zy}dtas^;!p!$&Q*o0yu=cy%QcFN|2(7*9om8y&ue<NopZY|Rai9<4RLJ&fleB85l| zNpK~QH!+X02@SF)RHZ1yOnwbXGuFpbr&|M07>m(DL&$c(1<R!B-P^aUbWrH~#Y{=N zA<(gIti91z^9;QUIFR+>@q*Oeq~L95^FqxXkt>KQ#8l6(^&r(Y+?{CP%K@hZvTXuC z8TFiiY}JA!uHePD{X@{qr-o<^Dzt%79EJ?27}`-ozZ`o{(fwA9u5{T@3;NY=+8vCO zKeDHgiw*>Af1zRP%&ma;!mE4d1NTe}L!)FjTr>m?-yQriwD%&I(d4mBC=9?C6ns)O zrj1uFBfR%-A%kZSH+YbgU9-6xgHBKM4J`##0#kG+@xPr?3xG60L7x+SHKitph73Ye zKu$~JzdxbC-!~hlsTe~Ct-b3SG{8&OzCDs!&xz^bwvhwJth=WnC@3HE_}tVc9pv5I zoxe1MUMn9osCU{3w7QP&1YHEvhwM6`NAmneGRe)fy=a*M9L|Exl9pS*4a#~kfWDD_ z(@Wkv|M#~QP_eViEkqjTp03JrWgZpP%F#)5XTm&{yG@(Gxz^aAX7{4lupWT>$N*QN z*d^}vbEmcJ=K;lCd)^S=v7I)kWWc@8uevu+6@qSil6X7^miRj7Wufw4Q6K2?5P)mA z*|@mC?M0*Mq@CRJ>NXuYSyy{RKwp;;e<HT^dnO~cB{E26N*&V3$q%zec;V!s@;}iZ z=$sxJ8k_Y;(bqL!6non(yZteIdPMJA|8X;9(8cuF*mnni8Pa=WCq6qsoX1b$uJS5Q zIANB!z^SX}bm(q<-iN8SE~vZ@oa?X+0a}M$%uSRoVB66=oGi~~&Y4q2*NTbcCrw>P z&EIMgIFNs>Sjw_#haD=4FsNU;TBxb<r#obt2rygy`}MQ}n@BC6O%dS`u2B4M)CcXS z{t$!1d^$UJ;ko1N!Ho}__8Q@z4?y$n3v>B*W1qO(1~c8t#S8CdIj}+<0o#jFXk~Dh z5J=StPLQ0(PE&4^^>kC5o|eG}wQ{aGhV}>b%G>G@4#lHZO23GXM6j(Ah%C?<0E*kY zs2XfhUqqhi&(N!93?Xj=ur~zu)$p>-bm6kj2xw^};QHuqpQm@u9vjws#`bjXVbfJB zYB-xYRQI5V-vD0EOK<!tgYdG6>}^`8={nw^1o8g{Xgj-Xu=}ZLawrr~XJxxQ9I6em zDSFoL{<p;MgscqioRT2}8ja5VKqtMF>Co_DdE>s9M8iY#XdNTPOBtI4#|tb)Wd?;y z2{xfMpp{sS?|r&F8-|~eRGQJj^Xs#4>t;?^*+`%7O>Whke996b-PozGZQd)#gs#<y z#cC}tW<<-k;YDnF5dwslk8TudrX`AOSHG4m&r3EXTU|og)bqKXFZDlK7}+HrCSzz` zH1Axhe%*@N)U1RM7HOK|ZYXa%?{@geD5X>hwd^>l<rT&^+zYAZ0Or5en1&zfjwANG z@XsD+B)tMacNu~Oy*}=yQpz2lcg;$P=VDDIl>4i1wJM+3gDq9iEL9j$*SD&_2yMY5 zSHXqIvw$kRbw}a-*CT&AhUOgjNHC2bj_4-KCVEmd=mL~}FBCb<M#wkC#1bFP{x-GJ z<|+ztiO6tKPw%^@nDYNKA<~a6sup9&dr4vTBDNrZ^-JNWjnv~9>DLC;yT@pn^k`th zgouQbm=-0H&!&_R5dYv$d_Ziu&7jk63hG{^@2|egx>j8|$VvykqDw9vzFjw(uymMQ zIarv}n_Mt)Ulnf}8xC0rtO70#C|=WH!w?s~<Ubxg&<w~7MDjJ_?zpP_j7*3-$8;*U zukGQT=*Rp;Hu;GGgPMXtHw}*Ve(tchVcWjwWUd6!VA6TyW$o`mjvX-_lxy&1eA&9u ziXL@ZvT0Qi1&l!G+;R)Ru}gMUtQAR60xZCqeBO`{g#Nwtw|~mv&z&bqw)upB;oiC* zxHjUgUr$tYcHt<O*LMYNHV4^r6UR1jZ2{bhlI?0}h6fM}uzf&N)mXJ*@y+L!V!1HB zI&89*F2kNzJU%afS}|sFsja~;1TqA{t-7kq)2_|_^|I`qjT_j2L2EUvOc<+AM-KMC zVu+VbZqB!QIdiN?v35tfT_|y#L*$ETUv2pS$u+|_5vt`cK6@#-wCqa0`iV!<C>1w8 z=smuVT98ldRvqk&Oej}8>9m&O7(hoy`pimkbE!?%_TlYw<Na)^#qXMYuvDbUR|r*T zj*YRo>9&;hf$q0HYM2h_&-=QA3NhC(uk;*Q3>?{+E!R_)w4-L=Y^ySaGhz5PF9$;c z?CdR_6TkZI(nED;bx<}`8WHYpzaD!FxwJ0G{EZaRI%VLtRPpKoB5Z$NL4_Lkvh{nG zUdn@05BoMh&}mt^9BcyoJ=r%Ew19QcqF2b{fmF|3m5$+|PJirJWpDl55&TT{i^hGc z2XA&ALn;pZQLWwaGwpNg2U&RP^cD2*1(3IEEn;@*eprNiSXAn$ae7qLa(#!K&MxPC zMnhkkIv%9@dD6>qk_usNL_V`9`w1^P4!nrulXu@O^A3nLu`RtD9)cuK4Yg-Cu~JMB zUVr3KSY}*x*0ZDSE<>Jg*iI*Lrx<Kx9%snCL$7Q*=^7cx<Km)Gx(u_7S}hsVeY;uz z`CctlbU+Rk=dw^VFPG;<3E1|qnhp3<%IbwUxlU$SV79CEj)&>)-qjFv;O-U$TmCrt z%`-Xr_q(RCI~qCNnijmiAG_W=G#o1X2eV92BsTMD<RsgK4)#3;#4YGhnYl*D3p&7S zr>y&O%mT`nRARpPadP1zbgeDz&Wb0Qw=XFV4b_UbHI0`$Skt>`XgP*(x?DtmTF1~i z=PK4$D7Y0=tgDo*4~Bd`%YO$t>2C%!flFPwwj#iSRN4VtCaV(D#h_I>zbA(-d3#-t z0tCp8<SA4Pd!rbZa67OzRCVox1MUp#s*J?b(*mQE;fmXJwcmZXX5BCH8v$?_aMbb` zj&psyn+g?i8?ZD203RSZXLs9>?RqMYo2P$qk=yF5f0X<EB$R7*<{j(1W|vkDz~fie zQOid$zSko49%YrgswCD`C?uVByIAR1A|Ss2a|uCM`P&pe?V6^1Ha`tHzhLc}R2)$e zNno9=sMbA|pK;D2gxC=Zanp=G=b}?jCEJA-Z=m$8<eVE|y2%{ziVTUjov0Bz?;W7{ z#(8&2X2%Sli$>I){En_q9scv5T-#mrp?|FlJ1gDiVCA<Wk*>+Jw_7iv*c%LhtC4jw zUt7a2cAA+kASy9t@j1tH<Fb`ZReqyRX&K{T0`XRxHa^HoxLnUE3JxS%?P{B)#`n$R z?G?Z&%l+WN@&|LE^|TI3=xMufqvc7<dOQDvT22PN5v-fun(}NA<`D0Q9mt=GYEjbo zY#Io4lO|4&7VhnO<EDteEzdK-e!$yrD_E+1D>>+*vFfy*;%F`mIdcc&?%prn>a<G# zl5e_2uIRnTg<6(kDk*;D8Jf)zpafW#)8Rg;LV)q!@;X0AAj%I`S3U<e&!eF^1u#%U zk<BurFHKY0M~UBvka9S8Eyci168iUJAt*#8J$CS=ZL)v@+oP{BE2_?4j?o2VHpxS< zE{~Tp674kzBnqm%&-Ln(V>dwzshDDIW=h!cBCtP<1#Td}2Y%`<#&lau$D%#w(K+;$ zTRSCBj|)Qnb}XEP5DF?rvUfqQRLaq%;}DD8+fB=XWJcRy;77EaWICBtW?}J00rW24 zWRtz;lH}ob&d>4-53sg<$Cq3)bSz!dHLa7G9$N-#Xx?O-rAOu8Iurgt84CXRQ^ZQF zed_ZvuT`$~&JSGUgkF1K0+a`wowSp0t&R_BQ6OO0djE;_J=99DyFVjEV1GnlO-DV< zAJ&j1gpl+8zxKZTAIkRqdqj(bsJOG1+cMVajx1weD#pI=Tbb;GF^rv1l)DH+vdfx% z->FnG%35|NG?px57h`zNx$o!md7j^%=MVULy?V*0>pHLVI+pkGK9&>z-9y&3_VR=D z?J?hPH4>09v#VLdFp;d2*}rZcWf4*!EDe?^)=QR5-ys_zY3OH}g)~2fR7g1*jeoqV z+nWQzpd>i{GU%xcb+M?ZWa}ko-T8<SuKGZ7n@Rg;gOM#^tCdhY<AXgT{P~Es=yL;+ z2N)@6%x&rNk#>5E0x@5Ecj}653vE^l&f8P9!JHB11X~<7OlSD6Bb2~Xx@P^n`-vdu z*aMla_P7zX*}S#IuWzWvU0S)+O6@Vn2rXn%qPO-CXViePNdY44ZuVKEcWOV6*scW` z$+x-IQ`fhJGL@qV!q(fa@0VBwKNLrXy0=?SyBSDMwf;63BwfT<mJ%qG;dgo(M*~t4 zB@N74_0r*rUAD_Kr&aDDUoVCZH>K3I>_s#rcU`Yc|5UGX;Q!lqz159lG?8Ow4HEED zhwOlFGOrkwihxsHAAp6{s#j-tsgVo-Xr<0;J{oPV9LWQ{;3l(R%h#gUy-9PwclJrl zqQZt3Y<JvUO9l$9WWHLrhh`NYIVX)QT!3V)DfC=#kOB5ulgtc<_QFRGFC1gCemZDR z1}6<XMF)u8re6kja_^YU`pwUKGFmqoX@ZD!OVj?WES*w}2#0}@ukXgaw;WF}-e#_` zVRa8v*Ff`ck?X}Obfc#tJ3G)?f=S3$$M^m>dHq@fED=<-?+Jm!wa_XnF0{U)WnV+! z6Iyj=9YW89!}aUD#wTX@>vDJOI&^9p+OF^q)%eoDOLL-w8|(c!Nof;Id~RzbN)Wi+ zfRqakxC#)sSGaOjp`!qfos{W(p4eugIHKFz*pE$|N_0WLR)DJbO(=<x1@0n;ue9W? zDaOE>eY&n3n{+VC=+0Xu79dWS;T`i7d1AU<tW)+y{&_rZGyMAZfo(EUIA-$c5JUUO zlh)ngDj*#58ZfCUMl0B~9V&aY#c^P-6T8wsEA+-Mi+1cmIrlmDjXaMQwOzbpLc?gB z{XOtZsSnoOeZ5$%5?o1q+An)jpzMg7HaQq*Ga7iQQsU|IHTH=lBc(#=a)KI+DM<%D z;766hhU^%_&9!ad0Q2j+&LC+U9;?h&jPg2<$DaD}5aIP%j08=Sww;x$0p5`A+=gXZ z2m68&9QZ##CifQlhzC|4p648z2%9x`g8tbrd2T?yBBP<%jU)AK>CNUw05(IWR9rem z%3(_}!R)-Oxb~=umPrj+YjDq2#-m{8@d!_&UN9QuSfZDv*sL`A9h|>4G|cV|-6?HE z)n13IjkaB1(d!XHk0YZMLH_uun7eg@{$ofco_?TFU7T~dTyMR0_NT%iALTD;G(%_; z)~A9OUiLkQ%gohUJa?l=525|sjDM1Pmp10N1muL1$1df<$hBdr1a^oGi|<&?9Q>%o zpROMTZwWoQ5-%CrIW94(OPliq63USLL_}NoZDru}g80r{@zBsu&halIlmT}I3}l>Q z!c<cCW^Qs(25jEiqC+9$Sf!kAD}9a~z1Q6LCcOxc7+x=gpUMwO1i3^k*f)JFJ=zp+ zq6!^;XV9Skv_4g?0M%YdgZTt1>_?=a59ecD)^cwGGqjq`^!(4c)WYhXwi#t^t`zGX zPtIgkXi&ggw+w*Wr2YORHYFPTHZJeDC<vZxXz$%Ah#85WFU*DGTOx?#q<m!K&LiPR zO(pe8-+qEoG*XOiJ^XzBug<7SZ!fzD4emp0zJ`Op$?$c!ay06-A6r;%rT~432l^>7 z7?g0n`C*3u-c_%0K>5COMuy3d(6O-)anGd`=mBr<^_cX+YL(afKzhu&emnwZ0Htv> zCMA)JrS5}qewwsH=R(<(4PDMX<bCnp*2%c1OfOVFWYfIwr*5(N-u;Nm-ph6E8TxlQ z1NCE^@jHrm3vRIf<MvtQy3#d@I{Z#Vhpk&PuHi{|f0qU^kLMS$P`VZot5m!p8Kr}o z=+L4((SeyUe-B#rZ|tWTOVH`xdCjV+CVn$C+|9;UeFc#orITuXQ}J$&VRJ%CL8Kv+ zpvOAS<YRD^=8(;Yfztb+tEZdstfl_$_VKdApXrw1Sys@OY-5??v6ooR5^o?@fI*3h z4VadaNNYupm3X4EX{+cLpNx#d=;*)YtqMTx8T{~<o2U=5pyZagyuXVm+3FWX{&bD! zTvCXGgp!?J>2wzEhW8KYAB_SC;kkNRJ(fwoW_jZ%9U-%I2JF@Mh_N+ah;%`6+ZQSO zPFHe*q=I1~N%zxMeS+EhIAxJ4gGV)YdNEq}?gxLr;?mA7x`8>hBW}$0It{@W+Owj& z!wFc9yI1!10KpC)vv~WbM$FxDmjF8*k!Hj8)lu~45>Fvqh#6^0>FiRzj;aIy(f-5E zm-=@>aP-F2CXUNHy1VysTl<ZD!zVg~^vOhJh>3ngN2Xig6M2colL@o?$~*JAICQ+S z`cCF~W?q%_KJflq(=YVz>z|G33A(5FL04|-+Z!V@QR9PtSaWR&?SL+E;g=-<itQpa z0{O;!OgGp3Tq<exeG@(O*Uw(D_*_=U5il-QaXW2cP9F&s?X0Gd#gXElK+RvH!Ln+i z>z@nScBniIN|>eXiBfC-=Dpz}Mr~+-fz~5hm&tJe`*z=O^dgq`o+$}s{iskdm(EQw zIJ9szh|ue8pbsp|5Z6MM)+DdBGMGY1(&IDMLYyr>Xzfi$oXSa=;U8@^Jo17Zbjw~( zf--1Jx~0h#+3_cRu!F!$oKnDw+~k{)ex|*!3q;~IY#E3|JopoZU~*Bho<YGCV0}rR zBGS!qu~FmRLqW1OR<TaMc=rTk+K-!3a(E+zpfVPc<>fMA-t0=ZvY*$OTB6BZSwfTY zecCbiG)1f@n=hR`;*|TF_6kdmSIO4RR8Q4I0nYK9G)AE$(B1E9b5og=&E%hcrresI z!CgKAQtHpnLSO>=zIDO6AsjCIhX10SGzeoxb)(G6r^JepVP7yqToJoH4X!6v!q^(9 zd*idpM&%c4@(iY00Rdusb9W1-diZB&|5_vRJeVi%z3l=<6R?#^&toTnM7Bx9MULP) zTF2Sebo;bR7M_XZ&Nkvq;1Q;+PhvBHn_HpTD8lEyvVZkSVB23h!{g2Bs87wMDSNb; zRSjgoM=OdBT%XUHs)LTVZt`dt;xjKN^6-?-?gc&$SME>nm#caU-fWM7+3)Mp4VsyZ zLQ(j`pw2HtUWL&y@&Y>1o<7o3tz+ElwjLplc1*+H*lLcgL?Y@crl7y8LHL6OO6QU3 zq53{fjKt+pIpi-iyw<Au+xC%&fTF_QO-Gj=f4qM{S<2)hXXj5=77%|{8FBU`Eius? zrieC<x<QRQmhFOLaF5<r0*Mu7lPG{4q69jfH*72qGhsG@!6j~JGm-E*iDmkA*i{qn z8qSC-BC2W)yE~-^nB=+4s5N7eYffZ}FPY11Mb^=@i?{AC$0XJiR6GKeU6lKDc%;YM zJ4%Q(=MqxYN?mxLSK0zkxv#-x!j^Z^6b{$C^b=U?=Lq{9514Y*`9|tZfoT`SWqK%z zUQw1NHcdf4uZLDD`_D5w$RjVXhE08w=ddOzjH?=6iFEjgf0*oX7Rkp)B?$4ShFe(M z3dgY`sH8A`Wg+d)FcktHQ|dR}h`u@Pww{5xzoaek?Vm1E8i!6~AF)dCn4leKd47rT z`AhD|+m8V^0;Y-;`aJQFQm8D$kUuF2L8MFD8j({u8BJR1=?lY#pQ7?mgm*z1{0a(e zEIu726s-oUhOYZ8EZy~A!l{%+hwZEOLvetIsG}|QG(roeWc;@)hAvWWTE>RivYo%R zi;ICB_yHH3kV`mFD^b|g9&XZ*=FSqGW>UEYLg`Px9D(($E7(!&_cy33oWvU7wvK(L zq)@bWD0Fz@n(@Sfle~gI@WJlm`aS$#l@#60jURt-J~wcn4y5+YPi;z266JA<+K$z! zAM64e`n9>Hhk3jvXct<g#flwy!dUH+FWxbY00eexMX+o8o{`q_o@Bu!4b-Sy`KkFS zjgj(Xoxb0}W?$Fjivg*pL{$^~asEa^z33YHrrFoYov=P*r?GzaoJQy2?D}0t06|0C z66&OeilisJTmYKMU=#vv-Xl`dPcYYmmg4l<QGd=|Gw$35ErgvO9I2myk*XPl<ZyjD z%morGqt(R<k<Q^N##bHemjL(+oflVy=a|Sg=~%Sv;n(C{Zja=w>z4dY<8HXuDcL7! zUXUjqpIV-%B801-FC2Zw`|1ox5M_2k+g04g^9ccIZ_&OrDc0XA6CW=no$sBiZJ%*o z!SaMN_RF;22#IZF&GOngUDXynV({x&Z`mKB^yiBFWPc=`nYd2?kCwK!`eJg<ItT$6 z557)y!;=qS_X*pYA;kq+PY1%4xR<}ShjF%sG;2iI$<e^WjrH_Ly=xDVk1quf&k%vh z&~joNtkjOp+C`oN#YjN7nIw1R<KbP~QFOBhGFNs><was9AwQUrsLw$O`;A6!m(7rg zI}Ib1D1vI~b5`>k3gb=%IxWu28y25Z{$W?j-#MfxD$*PtA_D{``%5z*nXr-q-b3cA zc)FAuOOqcM4v?1WJ5)Q|6&o-r#byO;;t2;Z7)O@dCHVJXgw7bSpI6!97q#rwl>G|% zp%8nYj5;9?jSm_LZ$*-i-(R&}pK1<eS`&F_J3^X{i*1tHuT~S%Vve6q32(GQjwL<i zpGA!#W%Cn=Z_nt1ez00Z$ygWS)NHools7bYKu_Q-v(e)vAKrZ4yB|754yIp|QHaiM zOe0Q|p}!yCh;<Ly@!#)#l(;iBKCA$<It74~m3AJ@K;%|mieJS={tXWSh45i1lzQFR zLq?GB5PGsd{Ux+aW9xuaymbK6j>y{db{O?gP%$ql4m*YED9%`&3b;X05k|>UInml# zCmA2PkH|{CId%G3BVosajwO!?GHaHzrOgS5hi(+}E7<SI>sL419}ZCf4tv;ofgu#K zr9E@WEa7buNNT+PF#FE5QdtlCmaRea>3%Ql#{Q;Q!I0a?0V#BC+n)Z!*k%IneZ8s_ zFBvc#T!?<g1ln{=k?Mtyk7wdeyFKOn4h$#OQt)SGfFUZc$a)O6dt2F^gTXzah7fpk zHR~kB;riH5K|X((X;KI^G)V3jZ*7g0Ek4qafvj=xhmpW0xDW^CVY>6uKP3VK&Ss9T z#8)OYo=EO*@*|JQ&8}0AF|NnGXhQ(}l&(q?I2|a?U!_};5}}y8HyLJ%YjlbUOxMnw zvkU-}C4_TrM8JZs@rAi@Fo>dolSHm!dG==y@)IxqwiO;f&UJO*#DBUiUX*_psSc*P zI+ak&%p&P{%A(Jvjkb0dSn*b=0%lb9U`@%FH{R&o(P+VlfF<6|*{*X`LcPz?tCE&( zNkQ`;lOwoMs&icn<R2}x^hpE&3Wao|q#c2cGwOL3_RIegODm8~F@|MjVA(symAP#M zCFbkBFR>6nR4;Y*B0l#G2Y&jlMq7D17Y8KWP$UMav?HdOGf0h0d3bDa$Gs{{X`)rp z&G9jBt|?S@V8}u>tK@crPocZih`T3NS{LxRr!YKaqrA+C+3Pj|irLSNk}uc+Ah1Z3 z4%$;crwlk!395rbZNfxz7l4P0`nFgGBm$eBUuV(uJa$y~zRVJ8c_i@F>3hi8gUW(A z28e>Gc;n{HsdLxUwwO-t1B1}wYkUDMjB6d@nmIn@%WH4?ZnjzihqBkK_4@v<Upne? zwyWF^9bgX?MpU<S(;mst1HA;YX@D{!NIxUwfJVQ}I@saa&_T<nY=9G^K2Z8VvKBc! z#@Hg#-e>G-Blc+Vs$M3j62{p<e)J$LzjC1FDT0~~mVwxQJJ%BJ0{2`=glQ~?mC!5l z@AUIb@n9^lzvr>&iK6A!GhX~NS2Llm+C7j}^8?}uMMkpZDR+6uHUo<+Ur>o1`C6WO zm{)t6jnZeF;%M^=o3*>Q<Eaq>xG|=h9k_k*4F&!EveP+wL}@$5n3}NO-xYUS!4^|F ziI&KXqq6(p{_ZurJk~e%SvD7s%)4f*vD_WE)0aC-AJ+$8!K)TrNKGi?lvj)Bdrd!R zP7REbYt~7uu$^$x>vnK~og?W@=esBwmg=M?Lg2zPkj)K{=hX8D39puZ{Q#!74Wi<z zZpC9e50toVQ<~Goje|~&Ry}B~u4Zh;h2pKO#9opcriQX^k0@q8DzyvGcr*0lDe}DR z*(`d|BqCONI!ss}v^r`<FO?eNTX<7vW-MizXMosMB<<bql!&?L_FIu3ypfrYsG-C) zkOIEGd}q#9Zle|uyGBaJj}{@`m0$k4RrbOjjLM=n9%ls)UX^%jGIMCnAO@}7kb-zr zscnz(C8P)|Z?7O0I<1%r6tn^#I5j8lzISY2)W(pH8ca^Vrs_u8#ZW(B`VODhmHej` zAZCb{k}aUwfAb*>WXYMe$a&#h>$Xlx5h@a1+6)mGL~ry7zu0(K>*WmI{>MqloUvdI z^zZbjXwsGJ?HW4zg-_x)K=Sm@xp3BsK;^~HXQ1C+Cx$ICRDKDaEI2BICDDcGGBy0% zuqd=%x{JzHPM=&X1Mb3^C9iVCC?Tu5jvP0{u1I{+#QL<ulcRg2B`($rP!7V`G}@Ft z$!1%T?^)8a%z$YK85wG|Vzd(r>Y7kkEgT~4-sB5Iaio8+bd7<0D@Zvn=D0q|1NJ6= z_28R#!0H+sH&5^7oX*(Tq(aq@8KAcajI6*g=}JHF72_$=4=P@L65IgqG}GDsr3B8O zAe_4aQqr!2Fu)^f4JEE$3LQN7`!?_xFbMc0Gn&!|EV;datbV7fh2hDU*jRh;YU)Zn zB1SNucb+A#PlFsX6CyniXl+tc^@AulP7$w)B&&HPV@QPum^$$5ia+PN3#+w`T6AN& z0q+`E<NT4c|6zB!{$2SesB8j`APNUkN@c*PRU0{pHz$;nAGx+bej5vDW&@vcJ0XjB zc`ZG*3FOySr5HUXjqG<gm_Rvkr{Fp^ETL@wJ300@N;-oJ{96&WVBDicuV9J~{yYx^ zV?BN@bW}>>1*oyiHu{zR6lc9C1nR*NsOj6|nw14;+KsEk4GT|7x;?JBU;h4GP%Y8r z6s)8><F%Z1yMa<!F88)dP3Xx;+IU!Y0GPleYO3dYG69^Hd41-(nCc-)BP5ldD-3ja z3()Bl?A%r^{)v6D+awCKTH;t+G+K=Fi8IS)|F}0=mQ(M|^02Yp9avJxV|!oo_r5Sy zh=&fa+7&;mwrjFMesyvQX9Zuf=q+!xh+ecF+P?xX2l-bsvxmspRKROmqD+i((E+Pq z@?KFS<>{kzqh?1=q-gpv{`ydiXv}lKr;PFihjag_;!MQ5)^o6x#9dUhg&f@iu_7Am zPO=d(eTFU=xFo$P_R!7NmnB@xD_BmJvscNPAe8G6ru2ooa<o3c5(Oueox6G*0Dg2i zWJpUsjb?<sj?*2(bjCqEve9F5U8+;v7E_<)D{>?M2ygZ~&VGuT7U%LbB$zL?ZSgcL zht?Ja*iNM7Pr1+c;#E(Wdrq|rT89RdeIZ+lbVwYnLR@RNt|H+c-quWBgV#@CM)Jgr z+?Q@s!0=Gp`8*8Z_kqd*zO=UOqQ_S4F%r<4SxvA*)f%O(vVK#I)dy=>JJhU44dK9( zz7GpD+vm}{11enYqKS@adJ@Z>UIpRzed(RFaXn1JQOz_ayYZ`0^TaDuCam}mI&KrD z`0sz@bwV2lP4AVc6g2kTp9Dji$5FX=H8SKC4w1xKIsLuR-rvC=BtZ*DIxEMjS&iy; zZwUO=H`%|wK2Bm5&?E@I8!L?5;RVqWhJnmlxNa6O|G4|#XrhjAP#A0vhQLt4^|`*` z&^x*a`c_?Dm4ISNST4lLWT7w$F*B&NTf-G*HX)t%uaZb;dvDaZhNwBe_g$)cjaY2A zc=G{mG@nJ3hNjlIwfgaB&mN5M9>Uxsbk~Z)d$cKa%9RQXuuBzk<pT1W5q(yG{bmM9 z%ErXuEHAf}xDJnh;eWS)W_W8Z=V&D&Hf&6)%2PXZkZ|i+0UQhZ8iCJ?mkBt3Ai{TS z8oBMpj@E8WllFtm{xt&%lOAqxMU;WyMYj!B_8!dSuy=gFfaAn=vut(}7mZ%mqSdb~ z72}!X#gG1A0d)I|MN_JYb^zWY^?YFZ7%5P~Q~OwCer<NN;QXKAx3rp+10L{JrGuJ= zytagqfw-VS_V$U>uJl7fz%8@nH&P-u!22(Xe!L8lMZmDyUpbQ4x*=0Q2y8J<D<B*M z3d>D20ADxlAu9*a#IjO)H&jtdhPqr3)L1vBLW)(#f5?@E=(e4^v~!mbQNW2!w^*sb z2hYlR#5|8)f`L|IP*10-o%B`M)Lamlf#~Zq>aNyXSRt*^k6zJB;Ud|$Ke;hqWp=1K zbqZr;EV#gdXo0cio2Fpw%_6B~LMU2dZ#d8%N*gUqZz|@)SuLPCDQGoM*sngmxon_l z-avSz8T-oGl1GNrx5fDUGiT1^2qI_Tnk|;U4PYsUuNl;=f#)~__DcPu%sLy2?>*^) z#|zdxW|8duxug`q$eonrH=;rXzYvV&s<EzAjgj;mehZoyJPHGzb+G=>mY{ByC zO@`VZ-<<Hm-_hBP8w3_c4nPOKxyh8?WoJ@cn1PzbuWb}SelME{2go>~7bLvoE?A+{ z4M5pSQ68=0pA~W-8bBpbCy3i~%(&?XUq{o&SqG{O-m*B~7p^_#yH$-s5X@I(#%^-I z+Rv-&yO>`&Zq%B{jN{rk;g~%b&E!Y<)YV@<H0s|jTZ=R@n?>FN9bKR@P)VXOuT~Gi z5N8ZJjw=DAsQ~5#R90I??nvAH!4^6=?CFH-Sc<`ME&Z6uBLw0<_8c$gd_~VU60ulk zV5+eEvIJovgnQE<#uSxlZp+6DfwKvUMef8ST9xhT8K1*1XT)r9EseFYNAB!G$(Opd z*AO)%Ln)Rl8K`?yGWprZ^esigtKBt_MG+^wD`WSl01^H7D=G(|ZIIKLN1}e}&Plas z<KVoK=7~Y%PUhA@CAnZ0lp&hE|A^VSUHM5@W0mS%#)n*gPX2?8*xm&SNuQ*kGMyC- zTfEKf6qvykcVpNf*IKR$iXmFVp)tAGSI4Kq5Ba5j7yS7y@Ld1HM@)Uen_3B3>#W-1 z)Jm`BWpXDseV=;3u{Y5dKee^(*8?3g4YsYXVWa48ksKCi*~>uZ`LkcU-Mqb?J6=ip zdes=t_QRPa1|VDWK=LJ24=33~P8|Wlrb!2w@y19(m#3fEVP+A7n08#GbR?k^j9R7d zli!{*9xvbCXQ`h5cqu}{>ingaI-W2?CfQ>Wzy#U)cVl$J9n(Ry6$L!YO!a_+Y5>#G z?Yrtb(BMmP4T5)V$GdDAhu-|Gq4U2=7E{J;9BqP1F^{p{*x&l}FXf$Gvo{tCt2r0; z4%mKohg0_Tb?@igBA%J^ISu=TW%2p;@wBFJeOTSX5VVoS2lzfvuh*~!l5U;Cie5J* zQV8X>n1f{!!Ff$+d)S-Ahun?a3YD|mmRt$d$N+9!PaKceJ~BMbFbC}~d-&+DTd(x* zo>aPiq2~C)rJm&~A!{ED_mfBpsiXHx`hUO)XQh{=`TG8x)3}vJ*2TyEk*sXn?0Q^M z8Y_D_tJ1DyXrEN*<Gxp6^Cdz^SVj&udSut>>df+2wf|x1R;~S)hqgxt4E@<I02kGn zAh()Ph&zT9Y->&E`wgWQFwH;Tm?y=O_KUtP7_Gq0n*xwg6#Av3?ZaQT8eGVBt>8oQ z{$1i755Y;c9H3FN;E9O)rFHxki}w!2_~3(+N5%yiQ$DBD>##^)Fcn^&IwpSQnhL2y z{M;ZritipP%i`qa!qjY%BmIS&<tL-FAbt`>wh4m5ynAsA?UOf2x(%)~)@Dr4xl?i^ z$XOqZr*TpDNAUV4S@gUja0?%^n@>eLN0Hgy)9lG+J%?!Ra|2qXVlINXwgGcRtx#QI z{m{Bkw2aOLyut9I&hT&pRFP9!TcfZ_r|@iG3(3S)wW@-BrM1Q(10yACzyB`76fRm6 zFm00O2Hdl1L<cwi(`=x?_MQj^u48}x0x&c&g`3=tzP?}SmM}eRkCqr@kXgzw;;Ov+ za^b-{wxbXdqBl|RL}${#>f}g&xXI~F4>SdpF=}>txuOnRXWw)h?5err52KI#yB1WW zx#cTsqn;B2@)@<W{L>$N!IsJsAMtga3bsDAQyVJ_Xq0?%DD*61#BjpX=ll`>{Qi$- zeD5!<L_(!tX6W<UOuD7AC`f24z;Hbr%)<}rq1$0sSyJm!3Aks3C~^7=N(COC3}Crm z*?4%Yo5X!*0>5t0sL4q+ygV7Hj!*lp+Z>2ezJFUpga`!3$dzx}s<=#_Sv0KkcgZMk zk;g6kFZzL!JGWI2=gur^tugfV3paYj<FidAR7i-}to3N8ypE;Gi#eTuA<nm+`j`#y zlsm`1W@CX)F#cp=Q3O`rP2VBUd|kFE>EET2P1*vRL}cP_DDqWER<Pq++#91$x6+U7 zrlZZ@obyA6W<G;tR!j_a-S1`Fp7BfDA6ok^F4~_*l&dP7EGtDDqq49lSWpt&AZ-p+ zGV~#E_=lh6`LE0(Ct))HO-~o9Vqzefzf9Qjv$3}lgjEk`IS&8m{K_Gib$m>LpZeCl zgpLSfXE_Tvfn*pGyadxu1$*4&0;hFbg>j%Buc;A_iML)606}MI4*f;i>1CSP`hHeK zpHD4NUWj46Q5HJf;_-9+&Ii4N+H(wMflbY3_>l&k2)$GROOg}XbfJSe?!=R<*rMRe zw*BLtMA@w|8v$|ckzx~5sz@MgLeWqR(S5W)Ju^;SJbl95F2<`|vM~;l>4v*K?C>#n zqy+=d5ti|dH*0bM=MBC7yP!-z{3Og)pp7@Ll;%*EzJHj~ro9qrnz9jYAGxiE`Yr#1 z?tnV1gr5@+y1Svil!M13@9UdHPJ%&uk@9vwbM?8iN|eP__cC4dHVOoXn9Q#G%xG)1 z-FRe&bpRgj+4v>68e<^#CEUIs1H%A65^J)C|4}=}8@z<Y9>@n795?n7ObPYk{3=C3 zM3Z4d`r%rypEkI+3%QpJO#$OldZ8c1HwSP)K0_lyYhO2+w0tij3}g;80}%~JN|kO= zzsr){^!Lea=ZUxPeg6&}nQf@M|8mR*X$=-m-DpvgeW`-OxqjlpG?;WSeulWWTqt4m zHgyUYwm8xyJ~-i6Y76XOl*7tVb|FoI?f)(f2BDP<%clSGhwwX~vX%e<Mh%@_CdklA z0Pm}+5n_F+YcYTtseI5uHeok?ny}Taktd%#{mp^*IO3tEaq)KseC-lQd~>!MN50Y- z7h2e|`iowB0oMf>phE+_*#}Jh{fL2@;gxv(=tal|rXZ+SwWN$T9mCGz^=R=e9EWa~ zC+3}X0+dqg`)z;R?&HO`d(EPr9^-1KFz|v=UVXf4PI$_XUk%G7kIw*Af|~HqC6|OS zMNYjl#6@p2!B<68(oDkf0bJlHl>GNeho_%4qM>NTuajB}wj+O1;JRu)m9p61=OMd1 zRld!d!J9gy*U8V+=Y`+jKA{3vc5R47rd{|^^(Zgzo#;e=4}6|ea0~W)PB8xZDlNOP z=8sWOAz8{j{fUdekJEx#JhH)}(iQTp7@QLV+NhJQIFggp#v`**1vHg~k&op4Bx~qm zqCnzUrr^u*;3cn};~J2_{5r-C^7ntqQ7GgUqfa*{N#=LVS$xI@!xpIuFC?>xXdapW zNqqCKb+;3waiD-V1Zq55>evdA&cH_&V4bgZZhg6$97y{4glxTG0%qa>!gW!X)-oF` zA8@qXZYx!{W>Q4>J%nWjmfJuvsx6|8lsq&?Vok9QG(6M{1Nl)Qf#P3(gJKgO0Gv`2 z&AqY%j7Yi~HYwK>ae!U%131QuH{o3+H943+d*G;Hqx_db=hN755XodpJO*XiOF27% zhXn9DGMc=WO}xDK1k9P9FqX#o<1j_1seyu#Ep%G`bQL#(7#9^C9ETY-6&{yR(9bCg zWj5hoYm>x1_omG(Hxg`GK}%B7(V~=`qaN@3a{x~L!J-TbB7pf`H@awoir4462lnod zWX|PydNF(4#ZZ7ojl?;$p&{PD-U|_(KWDG=%|Y+?>@_j>$M@2tbO-UWfB+#sso>q> z@Stk5&vA>ZM=uJmMbL)DLlV^KC7HqJhL*}rFS_0U@0|OC>vIl>np3e0N%?2LE<eWt zN}<X1;+j>^nWiT*y3PfQiq?7&eIp{UnWQae!>r#qQ(3YBPF1Z;Uf@6p@ODI0gNoEz z(8gG>&;Sq3eGrs5pI|mBrrh*_+W43tT0|H9NDavz`sV9}b!;y+J}~RvSN1VO)MFRC zduymgR)n4lF;DY6QtVQ?^i!pEX|7co4N;ELJbfI-Z3B6gTsh7S6c!||FCx9(@mws$ z7z6p9)Su6~*j|f@UK9g!O`$UFCr#I_QN<*QX@hW=x2<kejRtH8rW~W$+#X!D<Z~RG z)T*4?DLKHdSiaBlm!?%sXk!S0x1~DR{8KjOFIiZB>uu9m-Q&I)aa1ax&_<b<Xx>Ug zNv5sfg_oFyIn0JG`^_}($_#LW9cr_|%MpIeT%VpB2^5H4%JelxUyB9DD5T<&^5d=v z%_%uqF@0mU%K8JaaVs?}epBzRQQ&-M3NrS5gewU<2gg+=LyF&@d;tQt3t$~dY4fo! zdmATkgf^?GVT;U)++pT#=rf$;i>uEI$SRA3^K1o{O{@+1?hP$K!4YmCW&mw1aZl|} zo#vUZm}(iWI9seXuC6eEk9J0{w5h4>^ntjZ76-EwREfE3@hdlsHC&wB0i_%dOqTU; z0!|FDNkVZBmD<dE=vJ%&-d&l0%o^}9BFNccs7xg@)orAs2vdLn5|(KNW}11;Fx(nn zS9F+Zh1*(%hiQC_F>SU=zGwcN8`4tA_52ca<B!f1+WC<uDt;+eX&D~N$n#luS`+a= zXw5layHRCrS120{PnQ(Kh1%~Jo-8CO{aR^r-OP)wxl4S|t(A)MN}25F7T8emSn2@> zql8TY`U>#uuVH_WzXTiQBfNY(VBKHeLPcOnHAIz9>i80H+5E0|2dcu~XYP1+K)XAZ zY%93rAs!{HP#r|hz$bY0p$!hbwmc&M;d#CKl^ZBmw@K&y;7r_szPEWZFdMn6t%^=3 zFBJNA6c)AVma5c{)xLJe8b?<x#8boUci$h}ZZ;%?>SVa65L+u-{=XV&vEy<)q}9O- z&f@6uAT|Ida*>J>H$Hbv-s+hlwUzKvOq+dXdJ-BHVxEt-Mc7Mphk0&af>@Z%;Ld2P z2C<KRP?+(gXEHLYIu-DyU}mYCxh%Tdzgc+OUT8nJ;k%K1OpMuEr`oW!-R@<H`4=%@ zB!8J$v_z-D9wq1cx8O41xu`@l?pnKXToqb004U#Yt3Q6j!PI_Gy#eT1nA(m6ZtKvs z3Hzevi-{b@XA2i4>MS5j`jVT31jRuE2JN$QD5Yilto02#NP>Rg-ErG(_@kqB#1)X5 zd<dx*A@$r9&`B>lFp&3a5)W?3EDQPE4lJrv_Hzl-T+NjVE{9FUpQT{tZ%t)#sQr|Q z{l|NUU?xW^-EvZe?ipgrhHA5A$WGeL1dvbi0>v3zpxD8P=i6_PUTfW=$S>dXw#yQZ zb7L%>Na!HvV*~Noau9=f&^06+dM_T&n;thgTfyb6=mWVl#)gD#<RtqkTh6w^cW<0{ zrWrj_v)>6pX?Jd$u#-M7^PcPyly5zwu`<Jp!&1Y-*B=-a>C1q`37bH*(y0C>;CAv? z?;D|BXDiPwl0lST-N4B~ePIW&v|KU*3i|9SK<iSoEi;ZcYZA*poQNd<2r=axh(oW& z9TVAb1|@(9IeGP|%ot#YyTgC{0xa!KF1n+>hZXA<oAAeN@KLnKB(6O!T;`c|Emq{J z0jH@&!_yxbVkG@!w)~Dbg=Eg9gaM=R{aCV=|2`1>cw2|gzheoAp5i|73B35+yW`~_ zOOl*kUU=QzD39D_P;<Ej!asL~Cjo3Mcb||m?Zav|gMb2Jw;(-6`t|+X*8+=)=+#l_ zTQm{UIBe4VRtKLZiRYc%TUK!E$R8QX0MX8GWoDZ+oSnLY1xCb+<6Q5QO3eB4oK-Tf zKv`?k6GPOsa4=S&B6?{VHAKa@GDySe%MlNn9%i)@4!pEaMol$z;9MEJ$%O4fLVAC* z(R1#GD9)ua^+e9gs2uD|&gIOib4-G^A&Wn=IF~j;7TWq)GPX^%W(pd5$U~&5pya!k za@6lq6+ZeGvn>XDD6K+on%aA>{VPW?^@d<1;uh|D`h|hoefey$(>><IPl5RG(M1u$ z?)+IUe~Ttcet2MuI{(|le?V%w5~z@>x-4=b%ivMm*CtzdnYiqdEDgL&FKwF%=U<s} z5Pw$mJHJr;*yQ{sO51yVL{7gjS|0O(9p3@lH#`}dBkO9kfr1Nj0Zr1kJ5KbyLIB0z zUixi)u=qM!U|L@<{N(uXMd;up0Ay#iv>Rai^2(xs$Y73A$h$j{#|6?JfkMj6kQ_2w z!L)YZ{_RJ?NSJSnaIPx83lAB`$$7D|B;8zOjCMD@akA-xtXu~KA9u>r)4IW-3Vrf1 zmKLXJ;24u;<YBk&FuKOGIsUMSQ~y8C%5&Bn^(Sb{4!OvZ3f9%-;pqOpW62pH5!K-w zU8r)e)<M}2Ma(=mL*=B7`(4p7wA=15)tc)RB-40CUyzuyF?eZmtSZQvNU1hHTTLVF z>e?5qw<0~VQbvb)bF}FMJXZ*?8gBDdH!~~hw8_BS3E)3$=0-nbv6mnp!};t|hcddk z8HZrNipaD0)PcGry?{#;aGn$OZ8>$pwEoGugdc<DDT|0Xkq^q>n+N4gi$5abt7#$D zMP=cDsdF)$2u=9;V_CY{aw-uMaLFN*?Ym&=LR%$wg);?(_nh_J<}Ix$ZUUx6504EQ zF0XOpmvt5gtNGYrv=_=!d}1a@N~*vq3OC~7K?}l=H|d_iVdBY%l>f(dM9*H`?5c<C zf%c<E6;K@C^He6isgR-0DYc7`AN^lHffiq3(qC~mroU$HHi^en3o-ees#}rtoiSWa zMfF)LZ}XO0oF5X(US-@uM*5B*BAAXGOIFERM^_ib-|uV#ddD^rlVqd~lAK9H?Cs@= zu>hvNtR@7eisO)z!pbY*xVKaNGyxS@U0R5n`ON+2d~?M+i3*p+%0AfAGxmXw!$3%t zPaHzbr!&=|&zd<`j?;K=n*Do9vj0izq!E3IaVq8in5@{}wPiccKi(%(MGrv15Km5X zV40LZTh+9$fI#A5#|y}}Jg%B4P|EpWR9eTS=CT+UG?45NH}J#AMK7u5Zq=j|>Z_9M zPv(WTD^0@e9MKE;Mhg>NtodHPlX?3@gQ9kKtA7v1hI^LSes|PpLveH>D^k>ab7@qF zUjmuY=BKuO#kY$;8cpVPiOWHKjZx{@<mWZuoE6ZE=}?-S9u((qI>{zsg9G=x!XL3N z>N(FA-r(71Z9%t?%Jlc(X>O+B$`q?}^k2uJ|1`f=eNTD+XnLzil=B7PvusaAxj&SH zAp0h54~IY2445O2Tm8E-Dn!K_m{OVW-@Zvw7LvNeConI-*_YH?_856BPREW}9xj$~ zCStW!-^F}Xo{XjgBWH>`jE}{Rn;&;MnX_#A1TtY-Gd#^vMX-||qrf@;1BbYJk@0`p zn;|^K(H9I)J%(~75t760M$Zs+>CfeP$R2NYqHwfT62u29TiG9JYF(~7by^)Rezd{1 zpZjv!Yr|WcK>XO^SN+^}cR!%|0i_C4y#0Dt<kr6}%VlFjcZkLzm?)-*SIO3ugb%~? zA04BzTKum(A2$7GyQuA}V0!s5I9TJ~uZD;JJQMYY|1|PKy!HQ?N<jVWzyGP}pGh~= zkHCztf4)H=Q~xxAQ-Aosj^F<G@xW*Q=kW6V|6ch&4YrVo;QwCvpZ;*_Mcyj@_sakF z1%uCiWc|<b{%808pR)IODd+_dOc^Z=RTHWxfy2=LM-KmgwU~Uez_Q$ivbk7F-8ik= LdKwjK4w3&0wn$FJ literal 0 HcmV?d00001 diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 29fdeb70bf..f5d1af4477 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -101,6 +101,8 @@ def spm(path): 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), + 'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'), + 'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'), } # fmt: on From 8b861beae282ec5dd5051686440948a3f893c3ec Mon Sep 17 00:00:00 2001 From: lematt1991 <lematt1991@gmail.com> Date: Wed, 28 Apr 2021 05:56:31 -0700 Subject: [PATCH 564/707] Escape % in `keep_interval_updates_pattern` help description (#3514) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This leads to the following error: ``` ValueError: unsupported format character 'k' (0x6b) at index 95 ``` Resolves https://github.com/pytorch/fairseq/issues/3491 Pull Request resolved: https://github.com/pytorch/fairseq/pull/3514 Test Plan: `fairseq-train --help` Produces: ``` ... --keep-interval-updates-pattern KEEP_INTERVAL_UPDATES_PATTERN when used with --keep-interval-updates, skips deleting any checkpoints with update X where X % keep_interval_updates_pattern == 0 ... ``` # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). https://github.com/pytorch/fairseq/issues/3491 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D28037028 Pulled By: lematt1991 fbshipit-source-id: b237a151b82e851954ad3ea51a0c4a14c572ffab --- fairseq/dataclass/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 89d83b5b6b..902756bfff 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -576,7 +576,7 @@ class CheckpointConfig(FairseqDataclass): metadata={ "help": "when used with --keep-interval-updates, skips deleting " "any checkpoints with update X where " - "X % keep_interval_updates_pattern == 0" + "X %% keep_interval_updates_pattern == 0" }, ) keep_last_epochs: int = field( From 1305008e97872335d4ae8de4d015ddf8c43e87df Mon Sep 17 00:00:00 2001 From: Naman Goyal <namangoyal@learnfair0732.h2.fair> Date: Wed, 28 Apr 2021 10:44:50 -0700 Subject: [PATCH 565/707] fixed typo in flores model download link (#1827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1827 Reviewed By: ngoyal2707 Differential Revision: D28060291 Pulled By: lematt1991 fbshipit-source-id: 2540eb2a7d6a1fe37af9a3e9b4ed3df9e05a0823 --- examples/flores101/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flores101/README.md b/examples/flores101/README.md index 58d9c05aff..635c13f40b 100644 --- a/examples/flores101/README.md +++ b/examples/flores101/README.md @@ -19,7 +19,7 @@ Flores announement blog post: https://ai.facebook.com/blog/flores-researchers-ki Model | Num layers | Embed dimension | FFN dimension| Vocab Size | #params | Download ---|---|---|---|---|---|--- `flores101_mm100_615M` | 12 | 1024 | 4096 | 256,000 | 615M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz -`flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.tgz +`flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz These models are trained similar to [M2M-100](https://arxiv.org/abs/2010.11125) with additional support for the languages that are part of the WMT Large-Scale Multilingual Machine Translation track. Full list of languages can be found at the bottom. @@ -220,4 +220,4 @@ Xhosa | xh Yiddish | yi Yoruba | yo Chinese| zh -Zulu | zu \ No newline at end of file +Zulu | zu From c6409c029dc9b3af5308a269ec5c68dbdbafdc78 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Wed, 28 Apr 2021 23:22:25 -0700 Subject: [PATCH 566/707] Fix issue with encoder padding mask Summary: - Fix issue with encoder padding mask - Also add lengths as a field in encoder_out of encode_src method - Add a conditional clause in transformer_monotonic_attention.py to handle the case where encoder_padding_mask is None Reviewed By: jmp84 Differential Revision: D28080936 fbshipit-source-id: 99f78c5e3fe5644960ade44210ea78280ef53b8c --- .../models/transformer_monotonic_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index d7aeca5ea5..1062e9b955 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -152,7 +152,8 @@ def pre_attention( encoder_out = encoder_out_dict["encoder_out"][0] encoder_padding_mask = ( encoder_out_dict["encoder_padding_mask"][0] - if len(encoder_out_dict["encoder_padding_mask"]) > 0 + if encoder_out_dict["encoder_padding_mask"] + and len(encoder_out_dict["encoder_padding_mask"]) > 0 else None ) From 9cb6fe4c93c17f1d4e327ea345cd1c653432c76c Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov <mavlyutov@fb.com> Date: Thu, 29 Apr 2021 10:24:58 -0700 Subject: [PATCH 567/707] Copy to local before loading checkpoint Summary: Follow-up for "Fix FSDP optim state loading (#1819)". Update for remote file systems. Reviewed By: sshleifer Differential Revision: D28088088 fbshipit-source-id: 5d2f3ea5084fbbb21564d053317d2c07565cf2bc --- fairseq/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5e87b573f1..0b70a97356 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -428,7 +428,8 @@ def load_checkpoint( last_optim_state = state.get("last_optimizer_state", None) if last_optim_state == -1: master_path = re.sub("shard[0-9]+", "shard0", filename) - last_optim_state = torch.load(master_path, map_location='cpu')['last_optimizer_state'] + local_master_path = PathManager.get_local_path(master_path) + last_optim_state = torch.load(local_master_path, map_location='cpu')['last_optimizer_state'] # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank From d6855baec88f99ac776962027b91d404fe917eea Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Thu, 29 Apr 2021 16:16:09 -0700 Subject: [PATCH 568/707] Simplify CountingIterator Summary: Simplify the implementation of `CountingIterator`, and added test cases. The old implementation could fail on such a test case: ``` ref = list(range(10)) itr = CountingIterator(ref) first_item = next(itr) # consume one item remaining_items = list(itr) # raises exception because of "length mismatch" ``` This happens because `list(itr)` invokes `itr.__iter__` and reiterate the underlying list from the start, but `itr.n` has been already incremented by `next(itr)`. The new implementation is simpler and avoids such an error. Reviewed By: myleott Differential Revision: D27802505 fbshipit-source-id: c97fd0a27d865c0ff3b24016fa6aa0afabbf0a73 --- fairseq/data/iterators.py | 72 ++++++++++++++------------------------- tests/test_iterators.py | 58 +++++++++++++++++++------------ 2 files changed, 61 insertions(+), 69 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 66eaf875cb..293d853822 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -31,74 +31,52 @@ class CountingIterator(object): iterable (iterable): iterable to wrap start (int): starting iteration count. Note that this doesn't actually advance the iterator. - total (int): override the iterator length returned by - ``__len__``. This can be used to truncate *iterator*. + total (int): override the iterator length returned by ``__len``. + This can be used to truncate *iterator*. Attributes: n (int): number of elements consumed from this iterator """ def __init__(self, iterable, start=None, total=None): - self.iterable = iterable - self.itr = iter(self) - - if start is None: - self.n = getattr(iterable, "n", 0) - else: - self.n = start - - if total is None: - self.total = self.n + len(iterable) - else: - self.total = total + self._itr = iter(iterable) + self.n = start or getattr(iterable, "n", 0) + self.total = total or self.n + len(iterable) def __len__(self): return self.total def __iter__(self): - for x in self.iterable: - if self.n >= self.total: - raise RuntimeError( - "Mismatch between actual and expected iterable length. " - "This may be caused by resuming training from a checkpoint using " - "a different number of GPUs, in which case you can try the " - "--reset-dataloader option. Alternatively you may have a train or " - "validation set that is smaller than the number of GPUs. If none " - "of these apply, please report this to the fairseq developers." - ) - self.n += 1 - yield x + return self def __next__(self): - return next(self.itr) + if not self.has_next(): + raise StopIteration + try: + x = next(self._itr) + except StopIteration: + raise IndexError(f"Iterator expected to have length {self.total}, " + "but exhausted at position {self.n}.") + self.n += 1 + return x def has_next(self): """Whether the iterator has been exhausted.""" - return self.n < len(self) + return self.n < self.total - def skip(self, num_to_skip): - """Fast-forward the iterator by skipping *num_to_skip* elements.""" - next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) + def skip(self, n): + """Fast-forward the iterator by skipping n elements.""" + for _ in range(n): + next(self) return self def take(self, n): - """ - Truncates the iterator to n elements at most. - """ + """Truncate the iterator to n elements at most.""" self.total = min(self.total, n) - # Propagate this change to the underlying iterator - # Only take after what we have already consumed (i.e. after restarting - # from checkpoint mid epoch, we have to subtract self.n which is the - # starting point) - # - # This to maintain the invariant self.total = self.n + len(iterable), - # before calling __next__ or __iter__ - propagated_take = max(n - self.n, 0) - if hasattr(self.iterable, "take"): - self.iterable.take(propagated_take) - else: - self.iterable = itertools.islice(self.iterable, propagated_take) + if hasattr(self._itr, "take"): + self._itr.take(max(n - self.n, 0)) + return self class EpochBatchIterating(object): @@ -620,10 +598,10 @@ def __len__(self): def take(self, n): self.total = min(self.total, n) - # Propagate this change to the underlying iterator if hasattr(self._iterable, "take"): self._iterable.take(n) + return self def __next__(self): # Create consumer if not created yet diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 3d2c4d6251..7b3dd48485 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -9,7 +9,8 @@ class TestIterators(unittest.TestCase): - def test_counting_iterator(self, ref=None, itr=None): + def test_counting_iterator_index(self, ref=None, itr=None): + # Test the indexing functionality of CountingIterator if ref is None: assert itr is None ref = list(range(10)) @@ -17,6 +18,7 @@ def test_counting_iterator(self, ref=None, itr=None): else: assert len(ref) == 10 assert itr is not None + self.assertTrue(itr.has_next()) self.assertEqual(itr.n, 0) self.assertEqual(next(itr), ref[0]) @@ -26,9 +28,36 @@ def test_counting_iterator(self, ref=None, itr=None): itr.skip(3) self.assertEqual(itr.n, 5) self.assertEqual(next(itr), ref[5]) - itr.skip(3) - self.assertEqual(itr.n, 9) - self.assertEqual(next(itr), ref[9]) + itr.skip(2) + self.assertEqual(itr.n, 8) + self.assertEqual(list(itr), [ref[8], ref[9]]) + self.assertFalse(itr.has_next()) + + def test_counting_iterator_length_mismatch(self): + ref = list(range(10)) + # When the underlying iterable is longer than the CountingIterator, + # the remaining items in the iterable should be ignored + itr = iterators.CountingIterator(ref, total=8) + self.assertEqual(list(itr), ref[:8]) + # When the underlying iterable is shorter than the CountingIterator, + # raise an IndexError when the underlying iterable is exhausted + itr = iterators.CountingIterator(ref, total=12) + self.assertRaises(IndexError, list, itr) + + def test_counting_iterator_take(self): + # Test the "take" method of CountingIterator + ref = list(range(10)) + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(len(itr), len(list(iter(itr)))) + self.assertEqual(len(itr), 5) + + itr = iterators.CountingIterator(ref) + itr.take(5) + self.assertEqual(next(itr), ref[0]) + self.assertEqual(next(itr), ref[1]) + itr.skip(2) + self.assertEqual(next(itr), ref[4]) self.assertFalse(itr.has_next()) def test_grouped_iterator(self): @@ -41,11 +70,11 @@ def test_grouped_iterator(self): itr = iterators.GroupedIterator(x, 5) self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - # test CountingIterator functionality + # test the GroupIterator also works correctly as a CountingIterator x = list(range(30)) ref = list(iterators.GroupedIterator(x, 3)) itr = iterators.GroupedIterator(x, 3) - self.test_counting_iterator(ref, itr) + self.test_counting_iterator_index(ref, itr) def test_sharded_iterator(self): # test correctness @@ -67,22 +96,7 @@ def test_sharded_iterator(self): x = list(range(30)) ref = list(iterators.ShardedIterator(x, num_shards=3, shard_id=0)) itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0) - self.test_counting_iterator(ref, itr) - - def test_counting_iterator_take(self): - ref = list(range(10)) - itr = iterators.CountingIterator(ref) - itr.take(5) - self.assertEqual(len(itr), len(list(iter(itr)))) - self.assertEqual(len(itr), 5) - - itr = iterators.CountingIterator(ref) - itr.take(5) - self.assertEqual(next(itr), ref[0]) - self.assertEqual(next(itr), ref[1]) - itr.skip(2) - self.assertEqual(next(itr), ref[4]) - self.assertFalse(itr.has_next()) + self.test_counting_iterator_index(ref, itr) def test_counting_iterator_buffered_iterator_take(self): ref = list(range(10)) From a4e1d4a3daf4f6f5557505026fd94b8716fba7b3 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sun, 2 May 2021 19:24:59 -0700 Subject: [PATCH 569/707] add binarized audio dataset for large datasets (#1840) Summary: adds a binarized audio dataset to prevent oom when training with very large datasets Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1840 Reviewed By: arbabu123 Differential Revision: D28137756 Pulled By: alexeib fbshipit-source-id: bfe58e9d2bc9909b38876d78cc2e8aea783b3fed --- .../wav2vec2_large_librivox_tpu-pod.yaml | 9 +- .../wav2vec2_large_librivox_tpu.yaml | 9 +- examples/wav2vec/scripts/binarize_manifest.sh | 33 +++ fairseq/data/__init__.py | 3 +- fairseq/data/audio/raw_audio_dataset.py | 191 ++++++++++++------ fairseq/data/dictionary.py | 3 +- fairseq/tasks/audio_pretraining.py | 113 ++++++----- 7 files changed, 246 insertions(+), 115 deletions(-) create mode 100644 examples/wav2vec/scripts/binarize_manifest.sh diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml index 676c9fe339..ff35a95b65 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml @@ -55,7 +55,7 @@ model: quantize_targets: true extractor_mode: layer_norm layer_norm_first: true - final_dim: 256 + final_dim: 768 latent_temp: [2.0,0.1,0.999995] encoder_layerdrop: 0.00 dropout_input: 0.0 @@ -64,8 +64,9 @@ model: attention_dropout: 0.0 conv_bias: true - mask_channel_prob: 0.1 - mask_prob: 0.1 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 feature_grad_mult: 1.0 - diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml index c45c4d9117..2036e23c6b 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -55,7 +55,7 @@ model: quantize_targets: true extractor_mode: layer_norm layer_norm_first: true - final_dim: 256 + final_dim: 768 latent_temp: [2.0,0.1,0.999995] encoder_layerdrop: 0.00 dropout_input: 0.0 @@ -64,8 +64,9 @@ model: attention_dropout: 0.0 conv_bias: true - mask_channel_prob: 0.1 - mask_prob: 0.1 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 feature_grad_mult: 1.0 - diff --git a/examples/wav2vec/scripts/binarize_manifest.sh b/examples/wav2vec/scripts/binarize_manifest.sh new file mode 100644 index 0000000000..6f201bdb52 --- /dev/null +++ b/examples/wav2vec/scripts/binarize_manifest.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# usage: bash binarize_manifest <dest_dir> <train_split> <valid_split> + +DEST_DIR=$1 +TRAIN_SPLIT=$2 +VALID_SPLIT=$3 +FAIRSEQ_ROOT=$4 + +mkdir -p $DEST_DIR + +# split file path and lengths into separate files +cut -f1 $TRAIN_SPLIT.tsv > $DEST_DIR/train_fnames.txt +cut -f1 $VALID_SPLIT.tsv > $DEST_DIR/valid_fnames.txt +cut -f2 $TRAIN_SPLIT.tsv > $DEST_DIR/train.lengths +cut -f2 $VALID_SPLIT.tsv > $DEST_DIR/valid.lengths + +# copy root directory +head -1 $TRAIN_SPLIT.tsv > $DEST_DIR/train.root +head -1 $VALID_SPLIT.tsv > $DEST_DIR/valid.root + +# remove root directory +sed -i '1d' $DEST_DIR/train_fnames.txt +sed -i '1d' $DEST_DIR/valid_fnames.txt +sed -i '1d' $DEST_DIR/train.lengths +sed -i '1d' $DEST_DIR/valid.lengths + +# insert spaces between characters +sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/train_fnames.txt +sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/valid_fnames.txt + +# run preprocessor +PYTHONPATH=$FAIRSEQ_ROOT python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $DEST_DIR/train_fnames.txt --validpref $DEST_DIR/valid_fnames.txt --workers 60 --only-source --destdir $DEST_DIR diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 9b30813955..30af792185 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -12,7 +12,7 @@ from .add_target_dataset import AddTargetDataset from .append_token_dataset import AppendTokenDataset -from .audio.raw_audio_dataset import FileAudioDataset +from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset from .backtranslation_dataset import BacktranslationDataset from .bucket_pad_length_dataset import BucketPadLengthDataset from .colorize_dataset import ColorizeDataset @@ -69,6 +69,7 @@ "AppendTokenDataset", "BacktranslationDataset", "BaseWrapperDataset", + "BinarizedAudioDataset", "BucketPadLengthDataset", "ColorizeDataset", "ConcatDataset", diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index d0ff604e2b..2d3dd238a6 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -46,10 +46,8 @@ def __init__( if self.compute_mask_indices: self.mask_compute_kwargs = mask_compute_kwargs self._features_size_map = {} - self._C = mask_compute_kwargs['encoder_embed_dim'] - self._conv_feature_layers = eval( - mask_compute_kwargs['conv_feature_layers'] - ) + self._C = mask_compute_kwargs["encoder_embed_dim"] + self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"]) def __getitem__(self, index): raise NotImplementedError() @@ -84,34 +82,32 @@ def crop_to_max_size(self, wav, target_size): def _compute_mask_indices(self, dims, padding_mask): B, T, C = dims mask_indices, mask_channel_indices = None, None - if self.mask_compute_kwargs['mask_prob'] > 0: + if self.mask_compute_kwargs["mask_prob"] > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, - self.mask_compute_kwargs['mask_prob'], - self.mask_compute_kwargs['mask_length'], - self.mask_compute_kwargs['mask_selection'], - self.mask_compute_kwargs['mask_other'], + self.mask_compute_kwargs["mask_prob"], + self.mask_compute_kwargs["mask_length"], + self.mask_compute_kwargs["mask_selection"], + self.mask_compute_kwargs["mask_other"], min_masks=2, - no_overlap=self.mask_compute_kwargs['no_mask_overlap'], - min_space=self.mask_compute_kwargs['mask_min_space'], + no_overlap=self.mask_compute_kwargs["no_mask_overlap"], + min_space=self.mask_compute_kwargs["mask_min_space"], ) mask_indices = torch.from_numpy(mask_indices) - if self.mask_compute_kwargs['mask_channel_prob'] > 0: + if self.mask_compute_kwargs["mask_channel_prob"] > 0: mask_channel_indices = compute_mask_indices( (B, C), None, - self.mask_compute_kwargs['mask_channel_prob'], - self.mask_compute_kwargs['mask_channel_length'], - self.mask_compute_kwargs['mask_channel_selection'], - self.mask_compute_kwargs['mask_channel_other'], - no_overlap=self.mask_compute_kwargs['no_mask_channel_overlap'], - min_space=self.mask_compute_kwargs['mask_channel_min_space'], + self.mask_compute_kwargs["mask_channel_prob"], + self.mask_compute_kwargs["mask_channel_length"], + self.mask_compute_kwargs["mask_channel_selection"], + self.mask_compute_kwargs["mask_channel_other"], + no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"], + min_space=self.mask_compute_kwargs["mask_channel_min_space"], ) mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .unsqueeze(1) - .expand(-1, T, -1) + torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1) ) return mask_indices, mask_channel_indices @@ -155,22 +151,18 @@ def collater(self, samples): if self.pad: input["padding_mask"] = padding_mask - if hasattr(self, 'num_buckets') and self.num_buckets > 0: + if hasattr(self, "num_buckets") and self.num_buckets > 0: assert self.pad, "Cannot bucket without padding first." - bucket = max(self._bucketed_sizes[s['id']] for s in samples) + bucket = max(self._bucketed_sizes[s["id"]] for s in samples) num_pad = bucket - collated_sources.size(-1) if num_pad: - input['source'] = self._bucket_tensor( - collated_sources, num_pad, 0 - ) - input['padding_mask'] = self._bucket_tensor( - padding_mask, num_pad, True - ) + input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) + input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) if self.compute_mask_indices: - B = input['source'].size(0) - T = self._get_mask_indices_dims(input['source'].size(-1)) - padding_mask_reshaped = input['padding_mask'].clone() + B = input["source"].size(0) + T = self._get_mask_indices_dims(input["source"].size(-1)) + padding_mask_reshaped = input["padding_mask"].clone() extra = padding_mask_reshaped.size(1) % T if extra > 0: padding_mask_reshaped = padding_mask_reshaped[:, :-extra] @@ -178,15 +170,14 @@ def collater(self, samples): padding_mask_reshaped.size(0), T, -1 ) padding_mask_reshaped = padding_mask_reshaped.all(-1) - input['padding_count'] = ( - padding_mask_reshaped.sum(-1).max().item() - ) + input["padding_count"] = padding_mask_reshaped.sum(-1).max().item() mask_indices, mask_channel_indices = self._compute_mask_indices( - (B, T, self._C), padding_mask_reshaped, + (B, T, self._C), + padding_mask_reshaped, ) input["mask_indices"] = mask_indices input["mask_channel_indices"] = mask_channel_indices - out['sample_size'] = mask_indices.sum().item() + out["sample_size"] = mask_indices.sum().item() out["net_input"] = input return out @@ -195,7 +186,7 @@ def _get_mask_indices_dims(self, size, padding=0, dilation=1): if size not in self._features_size_map: L_in = size for (_, kernel_size, stride) in self._conv_feature_layers: - L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 L_out = 1 + L_out // stride L_in = L_out self._features_size_map[size] = L_out @@ -223,6 +214,25 @@ def ordered_indices(self): order.append(self.sizes) return np.lexsort(order)[::-1] + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, + self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + class FileAudioDataset(RawAudioDataset): def __init__( @@ -249,10 +259,9 @@ def __init__( **mask_compute_kwargs, ) - self.fnames = [] - self.line_inds = set() - skipped = 0 + self.fnames = [] + sizes = [] with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for i, line in enumerate(f): @@ -263,32 +272,94 @@ def __init__( skipped += 1 continue self.fnames.append(items[0]) - self.line_inds.add(i) - self.sizes.append(sz) - self.set_bucket_info(num_buckets) + sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") - def set_bucket_info(self, num_buckets): - self.num_buckets = num_buckets - if self.num_buckets > 0: - self._collated_sizes = np.minimum( - np.array(self.sizes), self.max_sample_size, - ) - self.buckets = get_buckets( - self._collated_sizes, self.num_buckets, - ) - self._bucketed_sizes = get_bucketed_sizes( - self._collated_sizes, self.buckets - ) - logger.info( - f"{len(self.buckets)} bucket(s) for the audio dataset: " - f"{self.buckets}" + self.sizes = np.array(sizes, dtype=np.int64) + + try: + import pyarrow + + self.fnames = pyarrow.array(self.fnames) + except: + logger.debug( + "Could not create a pyarrow array. Please install pyarrow for better performance" ) + pass + + self.set_bucket_info(num_buckets) def __getitem__(self, index): import soundfile as sf - fname = os.path.join(self.root_dir, self.fnames[index]) + fname = os.path.join(self.root_dir, str(self.fnames[index])) + wav, curr_sample_rate = sf.read(fname) + feats = torch.from_numpy(wav).float() + feats = self.postprocess(feats, curr_sample_rate) + return {"id": index, "source": feats} + + +class BinarizedAudioDataset(RawAudioDataset): + def __init__( + self, + data_dir, + split, + sample_rate, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + num_buckets=0, + compute_mask_indices=False, + **mask_compute_kwargs, + ): + super().__init__( + sample_rate=sample_rate, + max_sample_size=max_sample_size, + min_sample_size=min_sample_size, + shuffle=shuffle, + pad=pad, + normalize=normalize, + compute_mask_indices=compute_mask_indices, + **mask_compute_kwargs, + ) + + from fairseq.data import data_utils, Dictionary + + self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt")) + + root_path = os.path.join(data_dir, f"{split}.root") + if os.path.exists(root_path): + with open(root_path, "r") as f: + self.root_dir = next(f).strip() + else: + self.root_dir = None + + fnames_path = os.path.join(data_dir, split) + self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict) + lengths_path = os.path.join(data_dir, f"{split}.lengths") + + with open(lengths_path, "r") as f: + for line in f: + sz = int(line.rstrip()) + assert ( + sz >= min_sample_size + ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}" + self.sizes.append(sz) + + self.sizes = np.array(self.sizes, dtype=np.int64) + + self.set_bucket_info(num_buckets) + logger.info(f"loaded {len(self.fnames)} samples") + + def __getitem__(self, index): + import soundfile as sf + + fname = self.fnames_dict.string(self.fnames[index], separator="") + if self.root_dir: + fname = os.path.join(self.root_dir, fname) + wav, curr_sample_rate = sf.read(fname) feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 8d219e20ef..0d8308a811 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -70,6 +70,7 @@ def string( extra_symbols_to_ignore=None, unk_string=None, include_eos=False, + separator=" ", ): """Helper for converting a tensor of token indices to a string. @@ -96,7 +97,7 @@ def token_string(i): if hasattr(self, "bos_index"): extra_symbols_to_ignore.add(self.bos()) - sent = " ".join( + sent = separator.join( token_string(i) for i in tensor if utils.item(i) not in extra_symbols_to_ignore diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index df073a1814..a84798d076 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -12,10 +12,17 @@ from argparse import Namespace from dataclasses import dataclass, field +import numpy as np from typing import Optional, Any from omegaconf import MISSING, II -from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders +from fairseq.data import ( + AddTargetDataset, + BinarizedAudioDataset, + Dictionary, + FileAudioDataset, + encoders, +) from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import GenerationConfig @@ -44,6 +51,13 @@ class AudioPretrainingConfig(FairseqDataclass): default=None, metadata={"help": "extension of the label file to load, used for fine-tuning"}, ) + binarized_dataset: bool = field( + default=False, + metadata={ + "help": "if true, loads binarized dataset (useful for very large datasets). " + "See examples/wav2vec/scripts/binarize_manifest.sh" + }, + ) sample_rate: int = field( default=16_000, metadata={ @@ -92,9 +106,7 @@ class AudioPretrainingConfig(FairseqDataclass): ) num_batch_buckets: int = field( default=0, - metadata={ - "help": "number of buckets" - }, + metadata={"help": "number of buckets"}, ) precompute_mask_indices: bool = field( default=False, @@ -152,61 +164,70 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): def load_target_dictionary(self): if self.cfg.labels: - dict_path = os.path.join( - self.cfg.data, f"dict.{self.cfg.labels}.txt" - ) + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") return Dictionary.load(dict_path) return None def _get_mask_precompute_kwargs(self, cfg): if self.cfg.precompute_mask_indices or self.cfg.tpu: args = [ - 'mask_length', - 'mask_prob', - 'mask_selection', - 'mask_other', - 'no_mask_overlap', - 'mask_min_space', - 'mask_channel_length', - 'mask_channel_prob', - 'mask_channel_selection', - 'mask_channel_other', - 'no_mask_channel_overlap', - 'mask_channel_min_space', - 'encoder_embed_dim', - 'conv_feature_layers', + "mask_length", + "mask_prob", + "mask_selection", + "mask_other", + "no_mask_overlap", + "mask_min_space", + "mask_channel_length", + "mask_channel_prob", + "mask_channel_selection", + "mask_channel_other", + "no_mask_channel_overlap", + "mask_channel_min_space", + "encoder_embed_dim", + "conv_feature_layers", ] return {arg: cfg[arg] for arg in args} else: return {} - def load_dataset( - self, split: str, task_cfg: FairseqDataclass = None, **kwargs - ): + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): - task_cfg.autoregressive = not task_cfg.criterion == 'ctc' - - manifest = os.path.join(data_path, "{}.tsv".format(split)) - self.datasets[split] = FileAudioDataset( - manifest, - sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), - max_sample_size=self.cfg.max_sample_size, - min_sample_size=self.cfg.min_sample_size, - pad=task_cfg.labels is not None or task_cfg.enable_padding, - normalize=task_cfg.normalize, - num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), - compute_mask_indices=( - self.cfg.precompute_mask_indices or self.cfg.tpu - ), - **self._get_mask_precompute_kwargs(task_cfg), - ) + task_cfg.autoregressive = not task_cfg.criterion == "ctc" + + if task_cfg.binarized_dataset: + self.datasets[split] = BinarizedAudioDataset( + data_path, + split=split, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), + **self._get_mask_precompute_kwargs(task_cfg), + ) + else: + manifest_path = os.path.join(data_path, "{}.tsv".format(split)) + + self.datasets[split] = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), + **self._get_mask_precompute_kwargs(task_cfg), + ) - if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0: + if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0: logger.info( "Pretraining on TPUs may suffer convergence " "issues when training with `mask_channel_prob` value of " @@ -217,13 +238,15 @@ def load_dataset( label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") with open(label_path, "r") as f: labels = [ - line for i, line in enumerate(f) + line + for i, line in enumerate(f) if i in self.datasets[split].line_inds ] assert len(labels) == len(self.datasets[split]), ( - f"labels length ({len(labels)}) and dataset length " - f"({len(self.datasets[split])}) do not match") + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) process_label = LabelEncoder(self.target_dictionary) @@ -234,7 +257,7 @@ def load_dataset( eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=task_cfg.get('autoregressive', False), + add_to_input=task_cfg.get("autoregressive", False), ) @property From a8c8f0be177649f8178fdcbae519c17894efd4d7 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Tue, 4 May 2021 18:27:48 -0700 Subject: [PATCH 570/707] fix wav2vec finetuning (#1848) Summary: fixes a regression from previous PR that removed line_inds from the raw audio dataset this time we dont store line indices (which can be very large for big datasets) but instead store indices of skipped examples where applicable Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1848 Reviewed By: arbabu123 Differential Revision: D28198764 Pulled By: alexeib fbshipit-source-id: 49580e09e6c1145b45c18802f4481d6df3de8cd2 --- fairseq/data/audio/raw_audio_dataset.py | 3 +++ fairseq/tasks/audio_pretraining.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 2d3dd238a6..1f945993a8 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -262,6 +262,8 @@ def __init__( skipped = 0 self.fnames = [] sizes = [] + self.skipped_indices = set() + with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for i, line in enumerate(f): @@ -270,6 +272,7 @@ def __init__( sz = int(items[1]) if min_sample_size is not None and sz < min_sample_size: skipped += 1 + self.skipped_indices.add(i) continue self.fnames.append(items[0]) sizes.append(sz) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index a84798d076..c27668439b 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -236,11 +236,12 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], 'skipped_indices', set()) with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) - if i in self.datasets[split].line_inds + if i not in skipped_indices ] assert len(labels) == len(self.datasets[split]), ( From 374fdc5cd94d361bb9b1089fe2c1d30a2eb15fdd Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Wed, 5 May 2021 17:35:54 -0700 Subject: [PATCH 571/707] fix eval of older checkpoints (fixes #3528) (#1851) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1851 Reviewed By: michaelauli, arbabu123 Differential Revision: D28226892 Pulled By: alexeib fbshipit-source-id: e07641dda46be2708e1f9d0c0cbc5b8dedaa92e7 --- fairseq/tasks/audio_pretraining.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index c27668439b..071331a10a 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -199,7 +199,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == "ctc" - if task_cfg.binarized_dataset: + if getattr(task_cfg, 'binarized_dataset', False): self.datasets[split] = BinarizedAudioDataset( data_path, split=split, From eb228ee74c6bc9803eb7dbd398d8cda16c55ccd2 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Wed, 5 May 2021 23:41:04 -0700 Subject: [PATCH 572/707] do per gpu seeding for shuffling batches with batch level sampling, fix bug with not disabling iterator cache Summary: we want to avoid the case where each gpu has a batch from the same dataset. not sure if this is happening right now, but to avoid this we can seed a shuffle by the GPU rank. also fixed a bug where we need to reset iterator for the new batch sampling Differential Revision: D28085750 fbshipit-source-id: 1643738b397d850f737fcd27c6398c216342464a --- fairseq/data/multi_corpus_dataset.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index acb91f3df6..1bd61c32eb 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -45,6 +45,7 @@ def __init__( seed: int, sort_indices: bool = False, batch_sample: bool = False, + distributed_rank=None, ): super().__init__() assert isinstance(datasets, OrderedDict) @@ -55,6 +56,7 @@ def __init__( self.seed = seed self.sort_indices = sort_indices self.batch_sample = batch_sample + self.distributed_rank = distributed_rank # Avoid repeated conversions to list later self.dataset_list = list(datasets.values()) @@ -72,6 +74,7 @@ def __init__( def ordered_indices(self): start = time.time() with data_utils.numpy_seed(self.seed, self.epoch): + logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}") sampled_indices = [] num_selected_instances = 0 @@ -186,6 +189,7 @@ def can_reuse_epoch_itr_across_epochs(self): def set_epoch(self, epoch, **unused): super().set_epoch(epoch) + logger.info(f"setting epoch of multi_corpus_dataset to {epoch}") self.epoch = epoch @property @@ -227,5 +231,10 @@ def batch_by_size( logger.info(f"Created {len(cur_batches)} batches for dataset {key}") batches += cur_batches - # Assume shuffling is handled in fairseq/data/iterators.py + # If this dataset is used in a distributed training setup, + # then shuffle such that the order is seeded by the distributed rank + # as well + if self.distributed_rank is not None: + with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank): + np.random.shuffle(batches) return batches From 14439b12ad37bd8b9e2b1383209603df996bd3f5 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Fri, 7 May 2021 12:31:59 -0700 Subject: [PATCH 573/707] add fp16 comm hook Summary: speed up fp32 distributed training runs. Reviewed By: myleott Differential Revision: D28128720 fbshipit-source-id: 7855e4ecd43e194fd79e95bf9f35d4377a98779a --- fairseq/dataclass/configs.py | 4 ++++ fairseq/dataclass/constants.py | 1 + fairseq/models/distributed_fairseq_model.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 902756bfff..f5a405ec76 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -12,6 +12,7 @@ from fairseq.dataclass.constants import ( DATASET_IMPL_CHOICES, DDP_BACKEND_CHOICES, + DDP_COMM_HOOK_CHOICES, GENERATION_CONSTRAINTS_CHOICES, GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, @@ -244,6 +245,9 @@ class DistributedTrainingConfig(FairseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = field( default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) + ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( + default="none", metadata={"help": "communication hook"} + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index faba0862fa..442c25982b 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -43,6 +43,7 @@ def ChoiceEnum(choices: List[str]): "pytorch_ddp", "slow_mo", ]) +DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 3422faea74..6af288b10e 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -18,6 +18,7 @@ ModuleProxyWrapper, TPUDistributedDataParallel, ) +from torch.distributed.algorithms.ddp_comm_hooks import register_ddp_comm_hook, DDPCommHookType logger = logging.getLogger(__name__) @@ -64,6 +65,9 @@ def DistributedFairseqModel(args, model, process_group, device): process_group=process_group, find_unused_parameters=args.find_unused_parameters, ) + if args.ddp_comm_hook == 'fp16': + logger.info("enable fp16 communication hook in DDP") + register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: From a2314b4e8a2a769eca073e87cef9e52f0e01ec08 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Mon, 10 May 2021 02:24:17 -0700 Subject: [PATCH 574/707] offload state_dict to cpu Summary: state_dict is summoned during checkpoint to GPU0. unfortunately with large models this will exceed single GPU memory limit. Moving it to cpu. Question for Myle: should we set this option as default? when saving checkpoint the state_dict would eventually be moved to CPU any way. Reviewed By: myleott Differential Revision: D28203587 fbshipit-source-id: e738f48c83e35873c46bcec3471d105f2b4f4d8e --- fairseq/distributed/fully_sharded_data_parallel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 9c290b3fda..ff66c0b1d1 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -7,13 +7,13 @@ from typing import Optional import torch - from fairseq.dataclass.configs import DistributedTrainingConfig from fairseq.distributed import utils as dist_utils try: from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + has_FSDP = True except ImportError: FSDP = torch.nn.Module @@ -50,7 +50,7 @@ def unwrapped_module(self) -> torch.nn.Module: else: return self.module - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): if self.use_sharded_state: return super().local_state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars @@ -90,6 +90,7 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F group = dist_utils.get_data_parallel_group() if group is None and cfg.distributed_world_size == 1: from fairscale.utils.testing import DummyProcessGroup + group = DummyProcessGroup(rank=0, size=1) fsdp_config = { "process_group": group, @@ -100,6 +101,7 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F "cpu_offload": cfg.cpu_offload, "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, + "state_dict_device": torch.device("cpu"), # reduce GPU mem usage } with enable_wrap( wrapper_cls=FullyShardedDataParallel, @@ -120,6 +122,7 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): """ try: from fairscale.nn import wrap + if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) if num_params >= min_num_params: From 97969ac5f52090fd172f508975a3e9069c57e1af Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Mon, 10 May 2021 23:42:41 -0700 Subject: [PATCH 575/707] --combine-valid-sets (#1843) Summary: - `--combine-valid-sets` causes valid.bin, valid1.bin, ... to be concatenated. All metrics will be reported together. - `--valid-subsets` works the same. If you pass `--valid-subsets valid1,valid2` you get valid1_loss and valid2_loss logged separately. - if user passes `--valid-subset valid` (the default) and we see files named valid1, valid2 we raise an error. User must pass `--ignore-unused-valid-sets` to override. This previously led to valid1, valid2 being silently ignored. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1843 Reviewed By: myleott Differential Revision: D28323815 Pulled By: sshleifer fbshipit-source-id: dfd46076d3f684e36f8dacfadd38fd0038ce6755 --- fairseq/data/data_utils.py | 36 +++++++- fairseq/dataclass/configs.py | 13 +++ fairseq/tasks/language_modeling.py | 2 +- fairseq_cli/train.py | 10 ++- tests/test_valid_subset_checks.py | 128 +++++++++++++++++++++++++++++ 5 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 tests/test_valid_subset_checks.py diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 63c7fcd118..8c5e5a490d 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -10,7 +10,7 @@ import contextlib import itertools import logging -import os +import re import warnings from typing import Optional, Tuple @@ -18,7 +18,8 @@ import torch from fairseq.file_io import PathManager - +from fairseq import utils +import os logger = logging.getLogger(__name__) @@ -68,7 +69,6 @@ def copy_tensor(src, dst): copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) return res - def load_indexed_dataset( path, dictionary=None, dataset_impl=None, combine=False, default="cached" ): @@ -558,3 +558,33 @@ def get_bucketed_sizes(orig_sizes, buckets): sizes[mask] = end_val start_val = end_val return sizes + + + +def _find_extra_valid_paths(dataset_path: str) -> set: + paths = utils.split_paths(dataset_path) + all_valid_paths = set() + for sub_dir in paths: + contents = PathManager.ls(sub_dir) + valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] + all_valid_paths |= {os.path.basename(p) for p in valid_paths} + # Remove .bin, .idx etc + roots = {os.path.splitext(p)[0] for p in all_valid_paths} + return roots + + +def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: + """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored.""" + if ( + train_cfg.dataset.ignore_unused_valid_subsets + or train_cfg.dataset.combine_valid_subsets + or train_cfg.dataset.disable_validation + ): + return + other_paths = _find_extra_valid_paths(train_cfg.task.data) + specified_subsets = train_cfg.dataset.valid_subset.split(",") + ignored_paths = [p for p in other_paths if p not in specified_subsets] + if ignored_paths: + advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." + msg = f"Valid paths {ignored_paths} will be ignored. {advice}" + raise ValueError(msg) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f5a405ec76..8bd246c8c4 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -426,6 +426,19 @@ class DatasetConfig(FairseqDataclass): " (e.g. train, valid, test)" }, ) + combine_valid_subsets: Optional[bool] = field( + default=None, + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)", + "argparse_alias": "--combine-val", + }, + ) + ignore_unused_valid_subsets: Optional[bool] = field( + default=False, + metadata={"help": "do not raise error if valid subsets are ignored"}, + ) + validate_interval: int = field( default=1, metadata={"help": "validate every N epochs"} ) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 3069490fdc..4b76a51c61 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -201,7 +201,7 @@ def load_dataset( """Load a given dataset split. Args: - split (str): name of the split (e.g., train, valid, test) + split (str): name of the split (e.g., train, valid, valid1, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index c1f2fbb4c7..cb49915827 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -23,7 +23,7 @@ tasks, utils, ) -from fairseq.data import iterators +from fairseq.data import iterators, data_utils from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf @@ -114,8 +114,12 @@ def main(cfg: FairseqConfig) -> None: # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model - for valid_sub_split in cfg.dataset.valid_subset.split(","): - task.load_dataset(valid_sub_split, combine=False, epoch=1) + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + if cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: diff --git a/tests/test_valid_subset_checks.py b/tests/test_valid_subset_checks.py new file mode 100644 index 0000000000..ab778fb3fa --- /dev/null +++ b/tests/test_valid_subset_checks.py @@ -0,0 +1,128 @@ +import os +import shutil +import tempfile +import unittest + +from fairseq import options +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored +from .utils import create_dummy_data, preprocess_lm_data, train_language_model + + +def make_lm_config( + data_dir, + extra_flags=None, + task="language_modeling", + arch="transformer_lm_gpt2_tiny", +): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + task, + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + ] + + (extra_flags or []), + ) + cfg = convert_namespace_to_omegaconf(train_args) + return cfg + + +def write_empty_file(path): + with open(path, "w"): + pass + assert os.path.exists(path) + + +class TestValidSubsetsErrors(unittest.TestCase): + """Test various filesystem, clarg combinations and ensure that error raising happens as expected""" + + def _test_case(self, paths, extra_flags): + with tempfile.TemporaryDirectory() as data_dir: + [ + write_empty_file(os.path.join(data_dir, f"{p}.bin")) + for p in paths + ["train"] + ] + cfg = make_lm_config(data_dir, extra_flags=extra_flags) + raise_if_valid_subsets_unintentionally_ignored(cfg) + + def test_default_raises(self): + with self.assertRaises(ValueError): + self._test_case(["valid", "valid1"], []) + with self.assertRaises(ValueError): + self._test_case( + ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] + ) + + def partially_specified_valid_subsets(self): + with self.assertRaises(ValueError): + self._test_case( + ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] + ) + # Fix with ignore unused + self._test_case( + ["valid", "valid1", "valid2"], + ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"], + ) + + def test_legal_configs(self): + self._test_case(["valid"], []) + self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"]) + self._test_case(["valid", "valid1"], ["--combine-val"]) + self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"]) + self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"]) + self._test_case( + ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"] + ) + self._test_case( + ["valid1"], ["--valid-subset", "valid1"] + ) # valid.bin doesn't need to be ignored. + + def test_disable_validation(self): + self._test_case([], ["--disable-validation"]) + self._test_case(["valid", "valid1"], ["--disable-validation"]) + + +class TestCombineValidSubsets(unittest.TestCase): + def _train(self, extra_flags): + with self.assertLogs() as logs: + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir, num_examples=20) + preprocess_lm_data(data_dir) + + shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin") + shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx") + train_language_model( + data_dir, + "transformer_lm", + ["--max-update", "0", "--log-format", "json"] + extra_flags, + run_validation=False, + ) + return [x.message for x in logs.records] + + def test_combined(self): + flags = ["--combine-valid-subsets"] + logs = self._train(flags) + assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1 + assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined + + def test_subsets(self): + flags = ["--valid-subset", "valid,valid1"] + logs = self._train(flags) + assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1 + assert any(["valid1_ppl" in x for x in logs]) # metrics are combined From 8f1a34af7cb6e92fa3c443e61803e0ff347a784e Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Wed, 12 May 2021 09:01:06 -0700 Subject: [PATCH 576/707] add zip file support to raw_audio_dataset Summary: Add zip file support to raw_audio_dataset - Allow reading WAV/FLAC/OGG audios from stored zip file with given byte offset and length - Path format in the manifest TSV: `[zip_path]:[byte_offset]:[byte_length]` (e.g. `en/flac.zip:33255867035:212288`) - Packing audios in small number of zip files facilitates file management and improves loading speed (avoiding random access of many small audio files) Reviewed By: jmp84 Differential Revision: D28343999 fbshipit-source-id: a9cd2fbeb6e318cf9787065beb3bbddac25d0aba --- fairseq/data/audio/audio_utils.py | 46 +++++++++++++++++++- fairseq/data/audio/raw_audio_dataset.py | 15 ++++++- fairseq/data/audio/speech_to_text_dataset.py | 44 ++++++------------- 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index ddd5642c7e..f51cb0cddc 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,11 +1,12 @@ from pathlib import Path -from typing import BinaryIO, Optional, Tuple, Union +from typing import BinaryIO, Optional, Tuple, Union, List import numpy as np import torch SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} +FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} def _convert_to_mono( @@ -128,3 +129,46 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: ) return features + + +def is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +def is_sf_audio_data(data: bytes) -> bool: + is_wav = (data[0] == 82 and data[1] == 73 and data[2] == 70) + is_flac = (data[0] == 102 and data[1] == 76 and data[2] == 97) + is_ogg = (data[0] == 79 and data[1] == 103 and data[2] == 103) + return is_wav or is_flac or is_ogg + + +def read_from_stored_zip(zip_path: str, offset: int, file_size: int) -> bytes: + with open(zip_path, "rb") as f: + f.seek(offset) + data = f.read(file_size) + return data + + +def parse_path(path: str) -> Tuple[str, List[int]]: + """Parse data path which is either a path to + 1. a .npy/.wav/.flac/.ogg file + 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" + + Args: + path (str): the data path to parse + + Returns: + file_path (str): the file path + slice_ptr (list of int): empty in case 1; + byte offset and length for the slice in case 2 + """ + + if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: + _path, slice_ptr = path, [] + else: + _path, *slice_ptr = path.split(":") + if not Path(_path).is_file(): + raise FileNotFoundError(f"File not found: {_path}") + assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}" + slice_ptr = [int(i) for i in slice_ptr] + return _path, slice_ptr diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 1f945993a8..1ceef8ce06 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -7,6 +7,7 @@ import logging import os import sys +import io import numpy as np import torch @@ -14,6 +15,9 @@ from .. import FairseqDataset, BaseWrapperDataset from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes +from fairseq.data.audio.audio_utils import ( + parse_path, read_from_stored_zip, is_sf_audio_data +) logger = logging.getLogger(__name__) @@ -295,8 +299,15 @@ def __init__( def __getitem__(self, index): import soundfile as sf - fname = os.path.join(self.root_dir, str(self.fnames[index])) - wav, curr_sample_rate = sf.read(fname) + path_or_fp = os.path.join(self.root_dir, self.fnames[index]) + _path, slice_ptr = parse_path(path_or_fp) + if len(slice_ptr) == 2: + byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + assert is_sf_audio_data(byte_data) + path_or_fp = io.BytesIO(byte_data) + + wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") + feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats} diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index c6c64db084..b889ff5356 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -19,7 +19,10 @@ ResamplingDataset, data_utils as fairseq_data_utils, ) -from fairseq.data.audio.audio_utils import get_fbank, get_waveform +from fairseq.data.audio.audio_utils import ( + get_fbank, get_waveform, read_from_stored_zip, is_npy_data, + is_sf_audio_data, parse_path, FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS +) from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform @@ -120,39 +123,22 @@ def get_feature_transforms(self, split, is_train): return cfg -def is_npy_data(data: bytes) -> bool: - return data[0] == 147 and data[1] == 78 - - -def is_flac_or_wav_data(data: bytes) -> bool: - is_flac = data[0] == 102 and data[1] == 76 - is_wav = data[0] == 82 and data[1] == 73 - return is_flac or is_wav - - -def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes: - with open(file_path, "rb") as f: - f.seek(offset) - data = f.read(file_size) - return data - - def get_features_from_npy_or_audio(path): ext = op.splitext(op.basename(path))[1] - if ext not in {".npy", ".flac", ".wav"}: + if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: raise ValueError(f'Unsupported file format for "{path}"') return np.load(path) if ext == ".npy" else get_fbank(path) -def get_features_or_waveform_from_uncompressed_zip( +def get_features_or_waveform_from_stored_zip( path, byte_offset, byte_size, need_waveform=False ): assert path.endswith(".zip") - data = read_from_uncompressed_zip(path, byte_offset, byte_size) + data = read_from_stored_zip(path, byte_offset, byte_size) f = io.BytesIO(data) if is_npy_data(data): features_or_waveform = np.load(f) - elif is_flac_or_wav_data(data): + elif is_sf_audio_data(data): features_or_waveform = \ get_waveform(f, always_2d=False)[0] if need_waveform else get_fbank(f) else: @@ -173,18 +159,14 @@ def get_features_or_waveform(path: str, need_waveform=False): Returns: features_or_waveform (numpy.ndarray): speech features or waveform. """ - _path, *extra = path.split(":") - if not op.exists(_path): - raise FileNotFoundError(f"File not found: {_path}") - - if len(extra) == 0: + _path, slice_ptr = parse_path(path) + if len(slice_ptr) == 0: if need_waveform: return get_waveform(_path, always_2d=False) return get_features_from_npy_or_audio(_path) - elif len(extra) == 2: - extra = [int(i) for i in extra] - features_or_waveform = get_features_or_waveform_from_uncompressed_zip( - _path, extra[0], extra[1], need_waveform=need_waveform + elif len(slice_ptr) == 2: + features_or_waveform = get_features_or_waveform_from_stored_zip( + _path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform ) else: raise ValueError(f"Invalid path: {path}") From d151f2787240cca4e3c7e47640e647f8ae028c37 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Wed, 12 May 2021 18:03:01 -0700 Subject: [PATCH 577/707] S2T example bug fixes Summary: S2T example bug fixes Reviewed By: jmp84 Differential Revision: D27930414 fbshipit-source-id: 14edc85a34094a4fff53646390d366b39ffd8206 --- examples/speech_to_text/docs/librispeech_example.md | 4 ++-- examples/speech_to_text/docs/mustc_example.md | 8 ++++---- fairseq/data/audio/speech_to_text_dataset.py | 4 ++-- fairseq/models/speech_to_text/s2t_transformer.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/speech_to_text/docs/librispeech_example.md b/examples/speech_to_text/docs/librispeech_example.md index 4749e6cecc..4040fda942 100644 --- a/examples/speech_to_text/docs/librispeech_example.md +++ b/examples/speech_to_text/docs/librispeech_example.md @@ -23,9 +23,9 @@ if you want to use our pre-trained models. ## Training ```bash fairseq-train ${LS_ROOT} --save-dir ${SAVE_DIR} \ - --config-yaml config.yaml --train-subset train --valid-subset dev \ + --config-yaml config.yaml --train-subset train-clean-100,train-clean-360,train-other-500 --valid-subset dev-clean,dev-other \ --num-workers 4 --max-tokens 40000 --max-update 300000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ --arch s2t_transformer_s --share-decoder-input-output-embed \ --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \ --clip-norm 10.0 --seed 1 --update-freq 8 diff --git a/examples/speech_to_text/docs/mustc_example.md b/examples/speech_to_text/docs/mustc_example.md index 79df0aafdc..c95ef3e156 100644 --- a/examples/speech_to_text/docs/mustc_example.md +++ b/examples/speech_to_text/docs/mustc_example.md @@ -45,7 +45,7 @@ En-De as example: fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 ``` @@ -56,7 +56,7 @@ fairseq-train ${MUSTC_ROOT} \ --train-subset train_de_asr,train_nl_asr,train_es_asr,train_fr_asr,train_it_asr,train_pt_asr,train_ro_asr,train_ru_asr \ --valid-subset dev_de_asr,dev_nl_asr,dev_es_asr,dev_fr_asr,dev_it_asr,dev_pt_asr,dev_ro_asr,dev_ru_asr \ --save-dir ${JOINT_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 ``` @@ -98,7 +98,7 @@ En-De as example: fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} @@ -110,7 +110,7 @@ fairseq-train ${MUSTC_ROOT} \ --train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \ --valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \ --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ - --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \ --arch s2t_transformer_s --ignore-prefix-size 1 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \ --load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index b889ff5356..d4b5668d8f 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -43,9 +43,9 @@ def __init__(self, yaml_path): with open(yaml_path) as f: self.config = yaml.load(f, Loader=yaml.FullLoader) except Exception as e: - logger.info(f"Failed to load config from {yaml_path}: {e}") + raise Exception(f"Failed to load config from {yaml_path}: {e}") else: - logger.info(f"Cannot find {yaml_path}") + raise FileNotFoundError(f"{yaml_path} not found") @property def vocab_filename(self): diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 7480dc7967..ff3d2100c7 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -279,7 +279,7 @@ class S2TTransformerEncoder(FairseqEncoder): def __init__(self, args): super().__init__(None) - self.encoder_freezing_updates = args.encoder_freezing_updates + self.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) self.num_updates = 0 self.dropout_module = FairseqDropout( From 425c36eafff535fe7337f8bdd5ace22ebacc78cb Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Fri, 14 May 2021 18:52:06 -0700 Subject: [PATCH 578/707] support use_sharded_state on command line Summary: we wanted to use sharded_state because 1. to save memory 2. support sharded state loading, which allows MoE models's weight to live on their respective shard I just added the use_sharded_state as a config option, and added unit test to make sure it runs fine. old revision's comment: fairseq.FSDP has a flag use_sharded_state, but I had to address a couple problems before being able to use it. 1. fairscale FSDP (FSDP for short) calls self.state_dict/load_state_dict, which has been overwritten by fairseq.FSDP, this is not a desired behavior 2. the optimizer states shouldn't be sharded again when use_sharded_state is True 3. expose this option on the command line. Reviewed By: sshleifer Differential Revision: D28375035 fbshipit-source-id: c2f59a9c62163405033f34ed595ba78528aea850 --- fairseq/dataclass/configs.py | 3 +++ fairseq/distributed/fully_sharded_data_parallel.py | 4 ++-- fairseq/trainer.py | 3 ++- tests/gpu/test_binaries_gpu.py | 3 +++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 8bd246c8c4..f41cfcd94f 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -379,6 +379,9 @@ class DistributedTrainingConfig(FairseqDataclass): cpu_offload: bool = field( default=False, metadata={"help": "offload FP32 params to CPU"} ) + use_sharded_state: bool = field( + default=False, metadata={"help": "use sharded checkpoint files"}, + ) @dataclass diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index ff66c0b1d1..8a96bfc765 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -77,7 +77,7 @@ def load_state_dict(self, state_dict, strict=True, model_cfg=None): @contextlib.contextmanager -def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): +def fsdp_enable_wrap(cfg: DistributedTrainingConfig): try: from fairscale.nn import enable_wrap except ImportError: @@ -105,7 +105,7 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F } with enable_wrap( wrapper_cls=FullyShardedDataParallel, - use_sharded_state=use_sharded_state, + use_sharded_state=cfg.use_sharded_state, **fsdp_config, ): yield diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0b70a97356..dc06928dfc 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -496,7 +496,8 @@ def load_checkpoint( last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) - elif self.cfg.distributed_training.ddp_backend == 'fully_sharded': + elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state: + # if use_sharded_state, the last_optim_state is already sharded, skip this last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 45417c7eb7..a0824c23ad 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -59,6 +59,9 @@ def parse_logs(logfile): def test_resume_training_fsdp(self): self._test_resume_training(["--ddp-backend", "fully_sharded"]) + def test_resume_training_fsdp_sharded_state(self): + self._test_resume_training(["--ddp-backend", "fully_sharded", "--use-sharded-state"]) + def test_resume_training_noc10d(self): self._test_resume_training([]) From a1fea2eb0e5a68c9f91b18a344056675332181a3 Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines <mandeep.baines@gmail.com> Date: Thu, 20 May 2021 14:04:37 -0700 Subject: [PATCH 579/707] enable pin_memory for DataLoaders (#3560) Summary: To avoid the creation of a cuda:0 context, I needed to make sure that the `BackgroundConsumer` thread had its cuda device context set to the correct GPU. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3560 Reviewed By: myleott Differential Revision: D28573071 Pulled By: msbaines fbshipit-source-id: c2bedf67d8f356a29fa82eb4d8f15983efce3ffc --- fairseq/data/iterators.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 293d853822..8321a49b54 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -214,6 +214,7 @@ def _get_iterator_for_epoch(self, epoch, shuffle, offset=0): num_workers=self.num_workers, timeout=self.timeout, worker_init_fn=worker_init_fn, + pin_memory=True, ) # Wrap with a BufferedIterator if needed @@ -469,6 +470,7 @@ def shuffle_batches(batches, seed): batch_sampler=batches[offset:], num_workers=self.num_workers, timeout=self.timeout, + pin_memory=True, ) # Wrap with a BufferedIterator if needed @@ -546,15 +548,20 @@ def __init__(self, iterable, num_shards, shard_id, fill_value=None): class BackgroundConsumer(Thread): - def __init__(self, queue, source, max_len): + def __init__(self, queue, source, max_len, cuda_device): Thread.__init__(self) self._queue = queue self._source = source self._max_len = max_len self.count = 0 + self.cuda_device = cuda_device def run(self): + # set_device to avoid creation of GPU0 context when using pin_memory + if self.cuda_device is not None: + torch.cuda.set_device(self.cuda_device) + try: for item in self._source: self._queue.put(item) @@ -586,6 +593,7 @@ def _create_consumer(self): self._queue, self._iterable, self.total, + torch.cuda.current_device() if torch.cuda.is_available() else None ) self._consumer.daemon = True self._consumer.start() From f68de08a7326d1915461b84ad8e6ccb979d39578 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Thu, 20 May 2021 21:39:31 -0700 Subject: [PATCH 580/707] Auto formatting changes Summary: TSIA Reviewed By: jmp84 Differential Revision: D28523558 fbshipit-source-id: 97a90050e426be071f59127596a84bf71ede476d --- .../models/speech_to_text/convtransformer.py | 10 ++++----- fairseq/models/transformer.py | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 40e6dd3f4e..eba000d7b0 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -7,8 +7,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from fairseq.data.data_utils import lengths_to_padding_mask from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -28,6 +28,7 @@ class ConvTransformerModel(FairseqEncoderDecoderModel): Transformer-based Speech translation model from ESPNet-ST https://arxiv.org/abs/2004.10234 """ + def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @@ -304,10 +305,10 @@ def forward(self, src_tokens, src_lengths): subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() - input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to(input_len_0.device) - input_lengths = torch.min( - input_len_0, input_len_1 + input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( + input_len_0.device ) + input_lengths = torch.min(input_len_0, input_len_1) encoder_padding_mask = lengths_to_padding_mask(input_lengths) @@ -323,7 +324,6 @@ def forward(self, src_tokens, src_lengths): else: maybe_encoder_padding_mask = encoder_padding_mask - return { "encoder_out": [x], "encoder_padding_mask": [maybe_encoder_padding_mask] diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f5d1af4477..b7b8783fa2 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -417,7 +417,8 @@ def build_encoder_layer(self, args): # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint else 0 + if not checkpoint + else 0 ) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer @@ -468,10 +469,9 @@ def forward( hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ - return self.forward_scriptable(src_tokens, - src_lengths, - return_all_hiddens, - token_embeddings) + return self.forward_scriptable( + src_tokens, src_lengths, return_all_hiddens, token_embeddings + ) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. @@ -509,7 +509,7 @@ def forward_scriptable( """ # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) - has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) @@ -767,8 +767,10 @@ def build_output_projection(self, args, dictionary, embed_tokens): ) num_base_layers = getattr(args, "base_layers", 0) for i in range(num_base_layers): - self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) - + self.layers.insert( + ((i + 1) * args.decoder_layers) // (num_base_layers + 1), + BaseLayer(args), + ) def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) @@ -780,7 +782,8 @@ def build_decoder_layer(self, args, no_encoder_attn=False): # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint else 0 + if not checkpoint + else 0 ) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer From a5df5de926838c2d3b890c7b97fd68d7883cec2a Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 20 May 2021 23:20:14 -0700 Subject: [PATCH 581/707] wav2vec_u_readme (#1888) Summary: initial wav2vec-U readme to be updated Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1888 Reviewed By: michaelauli Differential Revision: D28595427 Pulled By: alexeib fbshipit-source-id: 8e1baca8a367f9b38a66e58489ad127341214f58 --- examples/wav2vec/unsupervised/README.md | 175 ++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 examples/wav2vec/unsupervised/README.md diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md new file mode 100644 index 0000000000..fdfcc04d26 --- /dev/null +++ b/examples/wav2vec/unsupervised/README.md @@ -0,0 +1,175 @@ + +# wav2vec Unsupervised (wav2vec-U) + +Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec)) as well as unlabeled speech and text data. + + The wav2vec-U training procedure consists of three consecutive main steps: +* Preparation of speech representations and text data +* Generative adversarial training (GAN) +* Iterative self-training + Kaldi LM-decoding + + +## Preparation of speech and text data +Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}. + + +In **/path/to/data/with_silence** you need a *train.tsv* file as well as *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. + +Here is how you can create new audio files without silences from a list of input audio files: +``` +python scripts/unsupervised/remove_silences.py /path/to/data/with_silence/train.tsv \ + --save-dir /path/to/data/without_silence/audio \ + --output /path/to/data/without_silence/train.tsv & +``` + + +In this first part, we use mostly phonemized text. Here is how you can transform a text file into its phonemized .phn version: +``` +# Will phonemize word dictionary and then phonemize text using dict lookup (for language $lg) + +python scripts/unsupervised/phonemize.py $path_to_wrd_dict $lg < text.wrd > text.phn & + +``` +Next, you can reproduce Figure 2/3 of the wav2vec-U paper by training linear models on top of each layer's frozen wav2vec 2.0 representations, using supervised data. You can observe that certain layers provide lower PER, which shows the closeness of their representations to phoneme outputs. Note that this step requires supervision and is thus not necessary. + +``` +# Learn linear model on top of layer N(=15) using supervised data + +fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data/without_silence \ + model.w2v_path=/path/to/model.pt \ + model.layer=15 \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name vox_10h_phn +``` + + + +We can extract features of layer *N* using the following: +``` +# Extract features from layer N +split=train # valid test +python scripts/unsupervised/fb_wav2vec_ctc_filter.py \ + /path/to/data/without_silence \ + --split $split \ + --layer=15 \ + --checkpoint /path/to/model.pt \ + --save-dir /path/to/features & +``` + + + +Next we perform clustering of wav2vec representations (step 2 in the paper): +``` +# Identify clusters in the representations with k-means + +python scripts/unsupervised/fb_wav2vec_cluster_faiss.py \ + /path/to/data/train.tsv \ + -f "CLUS128" \ + --sample-pct 0.5 \ + --layer 15 \ + --checkpoint /path/to/model.pt \ + --save-dir /path/to/features/clustering/segmented & +``` + +And use those clusters to segment the audio data (step 3 in the paper): +``` +# Transcribe cluster ids of audio data +python scripts/unsupervised/fb_wav2vec_apply_cluster_faiss.py \ + /path/to/data \ + --split $split \ + --checkpoint /path/to/model.pt \ + --path /path/to/features/clustering/segmented/CLUS128 & +``` + + Learn and apply PCA to the representations to retain important features +``` +# Compute PCA +python scripts/pca.py \ + /path/to/features/unfiltered/train.npy \ + --dim 512 \ + --output /path/to/features/unfiltered/unfiltered_pca + +# Apply PCA +python scripts/apply_pca.py \ + $outdir \ + --split $split \ + --pca-path /path/to/features/unfiltered/unfiltered_pca/512_pca \ + --batch-size 1048000 \ + --save-dir /path/to/features/unfiltered/precompute_unfiltered_pca${dim} +``` + +Then we build segment representations by mean-pooling representations according to clusters: + + + +``` +# Build segment representations + +python scripts/unsupervised/merge_clusters.py \ + /path/to/features/unfiltered/precompute_unfiltered_pca512 \ + --split $split \ + --cluster-dir /path/to/features/clustering/segmented/CLUS128 \ + --pooling mean \ + --save-dir /path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean & + +``` +Finally, we found that segment boundaries are noisy due to the lack of supervision and we therefore found it useful to also mean-pool pairs of adjacent segment representations to increase robustness: +``` +# Mean-pool adjacent time steps + +python scripts/unsupervised/mean_pool.py \ + /path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean \ + --split $split & + --save-dir $savedir & +``` + +For adversarial training, we preprocess the text data by adding silence tokens. +``` +# Add <SIL> tokens on text in preparation for GAN training +python scripts/unsupervised/fb_wrd_to_phonemizer.py \ + -s 0.25 --surround < /path/to/data/gan.txt > /path/to/data/gan.txt_s0.25.phns & + +# Binarize with fairseq-preprocess +fairseq-preprocess --dataset-impl mmap \ + --trainpref /path/to/data/gan.txt_s0.25.phns \ + --workers 6 --thresholdsrc 0 --only-source \ + --destdir /path/to/data --srcdict /path/to/data/dict.phn.txt & +``` + + +## Generative adversarial training (GAN) + +We then use a GAN model to build a first unsupervised ASR model. The data preparation above of both speech features and text data is a necessary procedure that enables the generator to match speech to text in an unsupervised way. + +Launching GAN training on top of preprocessed features, with default hyperparameters can be done with: + +``` +PREFIX=w2v_unsup_gan_xp +TASK_DATA=/path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean_pooled +TEXT_DATA=/path/to/data # path to fairseq-preprocessed GAN data +KENLM_PATH=/path/to/data/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here) + +PREFIX=$PREFIX fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + -m --config-dir configs/unsup \ + --config-name gan_feats_by_label \ + dataset.valid_subset=valid \ + task.data=${TASK_DATA} \ + task.text_data=${TEXT_DATA} \ + task.kenlm_path=${KENLM_PATH} \ + 'common.seed=range(0,5)' & +``` +However, this step requires an hyperparameter search, which can be launched with: +``` + +``` + +Note that hyperparameter search and model/epoch selection are done using a fully unsupervised metric (see Section 4.3). + +## Iterative self-training + Kaldi LM-decoding +After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words. + + +Please see [this README](http://github.com/pytorch/fairseq/tree/master/examples/wav2vec/unsupervised/kaldi_st) for more instructions on how to do iterative self-training + Kaldi LM-decoding. From f9edd9f9b919b5fe77255296c69e052e4a930b2b Mon Sep 17 00:00:00 2001 From: freewym <freewym@gmail.com> Date: Fri, 21 May 2021 00:39:54 -0700 Subject: [PATCH 582/707] fix the error when using pyarrow for raw_audio_dataset (#3561) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [X] Did you write any new necessary tests? ## What does this PR do? fix the error when using pyarrow for raw_audio_dataset, Currently there is an error `TypeError: join() argument must be str, bytes, or os.PathLike object, not 'StringScalar'`. In this PR I just convert the type to string in `__getitem__()` for `os.path.join()` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3561 Reviewed By: jmp84 Differential Revision: D28594620 Pulled By: kahne fbshipit-source-id: fd2daa992df85ac0919ba30fa9afa67a0c89d956 --- fairseq/data/audio/raw_audio_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 1ceef8ce06..4cb5193bde 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -299,7 +299,7 @@ def __init__( def __getitem__(self, index): import soundfile as sf - path_or_fp = os.path.join(self.root_dir, self.fnames[index]) + path_or_fp = os.path.join(self.root_dir, str(self.fnames[index])) _path, slice_ptr = parse_path(path_or_fp) if len(slice_ptr) == 2: byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) From 649af635f40dfdd4ab47e26c5183e69a62f8c49c Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Fri, 21 May 2021 07:33:12 -0700 Subject: [PATCH 583/707] Wav2vec u (#1889) Summary: Wav2vec-U implementation Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1889 Reviewed By: michaelauli Differential Revision: D28596815 Pulled By: alexeib fbshipit-source-id: bb09d081d167d5d10968acc6e056044bf96679ac --- examples/speech_recognition/kaldi/__init__.py | 0 .../kaldi/add-self-loop-simple.cc | 94 +++ .../speech_recognition/kaldi/kaldi_decoder.py | 244 ++++++ .../kaldi/kaldi_initializer.py | 698 +++++++++++++++++ examples/wav2vec/__init__.py | 0 .../wav2vec/config/finetuning/base_960h.yaml | 2 +- examples/wav2vec/unsupervised/README.md | 148 +--- examples/wav2vec/unsupervised/__init__.py | 0 .../config/finetuning/w2v_finetune.yaml | 62 ++ .../wav2vec/unsupervised/config/gan/w2vu.yaml | 108 +++ .../unsupervised/config/generate/viterbi.yaml | 22 + .../wav2vec/unsupervised/data/__init__.py | 13 + .../data/extracted_features_dataset.py | 144 ++++ .../unsupervised/data/random_input_dataset.py | 62 ++ .../unsupervised/kaldi_self_train/README.md | 56 ++ .../unsupervised/kaldi_self_train/st/cmd.sh | 15 + .../kaldi_self_train/st/decode_phone.sh | 33 + .../kaldi_self_train/st/decode_word_step1.sh | 46 ++ .../kaldi_self_train/st/decode_word_step2.sh | 30 + .../st/local/copy_aligned_text.py | 4 + .../kaldi_self_train/st/local/decode.sh | 38 + .../st/local/prepare_data_from_w2v.py | 56 ++ .../kaldi_self_train/st/local/prepare_lang.sh | 37 + .../st/local/prepare_lang_word.sh | 35 + .../kaldi_self_train/st/local/prepare_lm.sh | 35 + .../kaldi_self_train/st/local/score.sh | 63 ++ .../kaldi_self_train/st/local/show_wer.sh | 52 ++ .../st/local/train_subset_lgbeam.sh | 129 ++++ .../kaldi_self_train/st/local/unsup_select.py | 135 ++++ .../st/local/unsup_select_decode.sh | 37 + .../st/local/unsup_select_decode_word.sh | 35 + .../unsupervised/kaldi_self_train/st/path.sh | 5 + .../unsupervised/kaldi_self_train/st/steps | 1 + .../st/steps_gan/train_deltas.sh | 175 +++++ .../st/steps_gan/train_lda_mllt.sh | 239 ++++++ .../st/steps_gan/train_sat.sh | 281 +++++++ .../unsupervised/kaldi_self_train/st/train.sh | 43 ++ .../unsupervised/kaldi_self_train/st/utils | 1 + .../wav2vec/unsupervised/models/__init__.py | 11 + .../wav2vec/unsupervised/models/wav2vec_u.py | 658 ++++++++++++++++ .../wav2vec/unsupervised/scripts/apply_pca.py | 72 ++ .../unsupervised/scripts/copy_labels.py | 10 + .../unsupervised/scripts/filter_lexicon.py | 40 + .../unsupervised/scripts/filter_tsv.py | 37 + .../unsupervised/scripts/g2p_wrd_to_phn.py | 41 + .../unsupervised/scripts/ltr_to_wrd.py | 16 + .../wav2vec/unsupervised/scripts/mean_pool.py | 90 +++ .../unsupervised/scripts/merge_clusters.py | 112 +++ .../scripts/normalize_and_filter_text.py | 57 ++ .../unsupervised/scripts/normalize_text.py | 22 + examples/wav2vec/unsupervised/scripts/pca.py | 53 ++ .../scripts/phonemize_with_sil.py | 83 ++ .../unsupervised/scripts/prepare_audio.sh | 57 ++ .../unsupervised/scripts/prepare_text.sh | 56 ++ .../unsupervised/scripts/remove_silence.py | 64 ++ examples/wav2vec/unsupervised/scripts/vads.py | 81 ++ .../scripts/wav2vec_apply_cluster_faiss.py | 111 +++ .../scripts/wav2vec_cluster_faiss.py | 210 ++++++ .../scripts/wav2vec_extract_features.py | 117 +++ examples/wav2vec/unsupervised/scripts/wer.py | 82 ++ .../unsupervised/scripts/wrd_to_ltr.py | 16 + .../wav2vec/unsupervised/tasks/__init__.py | 11 + .../unsupervised/tasks/unpaired_audio_text.py | 437 +++++++++++ .../wav2vec/unsupervised/w2vu_generate.py | 706 ++++++++++++++++++ examples/wav2vec/wav2vec_manifest.py | 16 +- fairseq/data/audio/raw_audio_dataset.py | 18 +- fairseq/data/data_utils.py | 4 + fairseq/logging/meters.py | 32 + fairseq/logging/metrics.py | 24 +- fairseq/models/distributed_fairseq_model.py | 14 +- fairseq/models/wav2vec/wav2vec2.py | 133 +++- fairseq/models/wav2vec/wav2vec2_asr.py | 88 ++- fairseq/optim/adam.py | 10 +- fairseq/options.py | 2 + fairseq/sequence_generator.py | 18 +- fairseq/tasks/audio_pretraining.py | 15 +- fairseq/utils.py | 48 +- fairseq_cli/eval_lm.py | 2 +- fairseq_cli/hydra_train.py | 17 +- fairseq_cli/preprocess.py | 3 + fairseq_cli/train.py | 7 +- fairseq_cli/validate.py | 1 + 82 files changed, 6625 insertions(+), 255 deletions(-) create mode 100644 examples/speech_recognition/kaldi/__init__.py create mode 100644 examples/speech_recognition/kaldi/add-self-loop-simple.cc create mode 100644 examples/speech_recognition/kaldi/kaldi_decoder.py create mode 100644 examples/speech_recognition/kaldi/kaldi_initializer.py create mode 100644 examples/wav2vec/__init__.py create mode 100644 examples/wav2vec/unsupervised/__init__.py create mode 100644 examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml create mode 100644 examples/wav2vec/unsupervised/config/gan/w2vu.yaml create mode 100644 examples/wav2vec/unsupervised/config/generate/viterbi.yaml create mode 100644 examples/wav2vec/unsupervised/data/__init__.py create mode 100644 examples/wav2vec/unsupervised/data/extracted_features_dataset.py create mode 100644 examples/wav2vec/unsupervised/data/random_input_dataset.py create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/README.md create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh create mode 120000 examples/wav2vec/unsupervised/kaldi_self_train/st/steps create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh create mode 100755 examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh create mode 100644 examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh create mode 120000 examples/wav2vec/unsupervised/kaldi_self_train/st/utils create mode 100644 examples/wav2vec/unsupervised/models/__init__.py create mode 100644 examples/wav2vec/unsupervised/models/wav2vec_u.py create mode 100644 examples/wav2vec/unsupervised/scripts/apply_pca.py create mode 100644 examples/wav2vec/unsupervised/scripts/copy_labels.py create mode 100644 examples/wav2vec/unsupervised/scripts/filter_lexicon.py create mode 100644 examples/wav2vec/unsupervised/scripts/filter_tsv.py create mode 100644 examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py create mode 100644 examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py create mode 100644 examples/wav2vec/unsupervised/scripts/mean_pool.py create mode 100644 examples/wav2vec/unsupervised/scripts/merge_clusters.py create mode 100644 examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py create mode 100644 examples/wav2vec/unsupervised/scripts/normalize_text.py create mode 100644 examples/wav2vec/unsupervised/scripts/pca.py create mode 100644 examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py create mode 100644 examples/wav2vec/unsupervised/scripts/prepare_audio.sh create mode 100644 examples/wav2vec/unsupervised/scripts/prepare_text.sh create mode 100644 examples/wav2vec/unsupervised/scripts/remove_silence.py create mode 100644 examples/wav2vec/unsupervised/scripts/vads.py create mode 100644 examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py create mode 100644 examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py create mode 100644 examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py create mode 100644 examples/wav2vec/unsupervised/scripts/wer.py create mode 100644 examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py create mode 100644 examples/wav2vec/unsupervised/tasks/__init__.py create mode 100644 examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py create mode 100644 examples/wav2vec/unsupervised/w2vu_generate.py diff --git a/examples/speech_recognition/kaldi/__init__.py b/examples/speech_recognition/kaldi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/speech_recognition/kaldi/add-self-loop-simple.cc b/examples/speech_recognition/kaldi/add-self-loop-simple.cc new file mode 100644 index 0000000000..89754b925e --- /dev/null +++ b/examples/speech_recognition/kaldi/add-self-loop-simple.cc @@ -0,0 +1,94 @@ +/* +* Copyright (c) Facebook, Inc. and its affiliates. +* +* This source code is licensed under the MIT license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include <iostream> +#include "fstext/fstext-lib.h" // @manual +#include "util/common-utils.h" // @manual + +/* + * This program is to modify a FST without self-loop by: + * for each incoming arc with non-eps input symbol, add a self-loop arc + * with that non-eps symbol as input and eps as output. + * + * This is to make sure the resultant FST can do deduplication for repeated + * symbols, which is very common in acoustic model + * + */ +namespace { +int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) { + typedef fst::MutableArcIterator<fst::StdVectorFst> IterType; + + int32 num_states_before = fst->NumStates(); + fst::MakePrecedingInputSymbolsSame(false, fst); + int32 num_states_after = fst->NumStates(); + KALDI_LOG << "There are " << num_states_before + << " states in the original FST; " + << " after MakePrecedingInputSymbolsSame, there are " + << num_states_after << " states " << std::endl; + + auto weight_one = fst::StdArc::Weight::One(); + + int32 num_arc_added = 0; + + fst::StdArc self_loop_arc; + self_loop_arc.weight = weight_one; + + int32 num_states = fst->NumStates(); + std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states); + + for (int32 state = 0; state < num_states; state++) { + for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) { + fst::StdArc arc(aiter.Value()); + if (arc.ilabel != 0) { + incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel); + } + } + } + + for (int32 state = 0; state < num_states; state++) { + if (!incoming_non_eps_label_per_state[state].empty()) { + auto& ilabel_set = incoming_non_eps_label_per_state[state]; + for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) { + self_loop_arc.ilabel = *it; + self_loop_arc.olabel = 0; + self_loop_arc.nextstate = state; + fst->AddArc(state, self_loop_arc); + num_arc_added++; + } + } + } + return num_arc_added; +} + +void print_usage() { + std::cout << "add-self-loop-simple usage:\n" + "\tadd-self-loop-simple <in-fst> <out-fst> \n"; +} +} // namespace + +int main(int argc, char** argv) { + if (argc != 3) { + print_usage(); + exit(1); + } + + auto input = argv[1]; + auto output = argv[2]; + + auto fst = fst::ReadFstKaldi(input); + auto num_states = fst->NumStates(); + KALDI_LOG << "Loading FST from " << input << " with " << num_states + << " states." << std::endl; + + int32 num_arc_added = AddSelfLoopsSimple(fst); + KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl; + + fst::WriteFstKaldi(*fst, std::string(output)); + KALDI_LOG << "Writing FST to " << output << std::endl; + + delete fst; +} \ No newline at end of file diff --git a/examples/speech_recognition/kaldi/kaldi_decoder.py b/examples/speech_recognition/kaldi/kaldi_decoder.py new file mode 100644 index 0000000000..5f62cc58ae --- /dev/null +++ b/examples/speech_recognition/kaldi/kaldi_decoder.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +import logging +from omegaconf import MISSING +import os +import torch +from typing import Optional +import warnings + + +from dataclasses import dataclass +from fairseq.dataclass import FairseqDataclass +from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi + + +logger = logging.getLogger(__name__) + + +@dataclass +class KaldiDecoderConfig(FairseqDataclass): + hlg_graph_path: Optional[str] = None + output_dict: str = MISSING + + kaldi_initializer_config: Optional[KaldiInitializerConfig] = None + + acoustic_scale: float = 0.5 + max_active: int = 10000 + beam_delta: float = 0.5 + hash_ratio: float = 2.0 + + is_lattice: bool = False + lattice_beam: float = 10.0 + prune_interval: int = 25 + determinize_lattice: bool = True + prune_scale: float = 0.1 + max_mem: int = 0 + phone_determinize: bool = True + word_determinize: bool = True + minimize: bool = True + + num_threads: int = 1 + + +class KaldiDecoder(object): + def __init__( + self, + cfg: KaldiDecoderConfig, + beam: int, + nbest: int = 1, + ): + try: + from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer + from kaldi.base import set_verbose_level + from kaldi.decoder import ( + FasterDecoder, + FasterDecoderOptions, + LatticeFasterDecoder, + LatticeFasterDecoderOptions, + ) + from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions + from kaldi.fstext import read_fst_kaldi, SymbolTable + except: + warnings.warn( + "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi" + ) + + # set_verbose_level(2) + + self.acoustic_scale = cfg.acoustic_scale + self.nbest = nbest + + if cfg.hlg_graph_path is None: + assert ( + cfg.kaldi_initializer_config is not None + ), "Must provide hlg graph path or kaldi initializer config" + cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config) + + assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path + + if cfg.is_lattice: + self.dec_cls = LatticeFasterDecoder + opt_cls = LatticeFasterDecoderOptions + self.rec_cls = LatticeFasterRecognizer + else: + assert self.nbest == 1, "nbest > 1 requires lattice decoder" + self.dec_cls = FasterDecoder + opt_cls = FasterDecoderOptions + self.rec_cls = FasterRecognizer + + self.decoder_options = opt_cls() + self.decoder_options.beam = beam + self.decoder_options.max_active = cfg.max_active + self.decoder_options.beam_delta = cfg.beam_delta + self.decoder_options.hash_ratio = cfg.hash_ratio + + if cfg.is_lattice: + self.decoder_options.lattice_beam = cfg.lattice_beam + self.decoder_options.prune_interval = cfg.prune_interval + self.decoder_options.determinize_lattice = cfg.determinize_lattice + self.decoder_options.prune_scale = cfg.prune_scale + det_opts = DeterminizeLatticePhonePrunedOptions() + det_opts.max_mem = cfg.max_mem + det_opts.phone_determinize = cfg.phone_determinize + det_opts.word_determinize = cfg.word_determinize + det_opts.minimize = cfg.minimize + self.decoder_options.det_opts = det_opts + + self.output_symbols = {} + with open(cfg.output_dict, "r") as f: + for line in f: + items = line.rstrip().split() + assert len(items) == 2 + self.output_symbols[int(items[1])] = items[0] + + logger.info(f"Loading FST from {cfg.hlg_graph_path}") + self.fst = read_fst_kaldi(cfg.hlg_graph_path) + self.symbol_table = SymbolTable.read_text(cfg.output_dict) + + self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads) + + def generate(self, models, sample, **unused): + """Generate a batch of inferences.""" + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions, padding = self.get_emissions(models, encoder_input) + return self.decode(emissions, padding) + + def get_emissions(self, models, encoder_input): + """Run encoder and normalize emissions""" + model = models[0] + + all_encoder_out = [m(**encoder_input) for m in models] + + if len(all_encoder_out) > 1: + + if "encoder_out" in all_encoder_out[0]: + encoder_out = { + "encoder_out": sum(e["encoder_out"] for e in all_encoder_out) + / len(all_encoder_out), + "encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"], + } + padding = encoder_out["encoder_padding_mask"] + else: + encoder_out = { + "logits": sum(e["logits"] for e in all_encoder_out) + / len(all_encoder_out), + "padding_mask": all_encoder_out[0]["padding_mask"], + } + padding = encoder_out["padding_mask"] + else: + encoder_out = all_encoder_out[0] + padding = ( + encoder_out["padding_mask"] + if "padding_mask" in encoder_out + else encoder_out["encoder_padding_mask"] + ) + + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out, normalize=True) + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + + return ( + emissions.cpu().float().transpose(0, 1), + padding.cpu() if padding is not None and padding.any() else None, + ) + + def decode_one(self, logits, padding): + from kaldi.matrix import Matrix + + decoder = self.dec_cls(self.fst, self.decoder_options) + asr = self.rec_cls( + decoder, self.symbol_table, acoustic_scale=self.acoustic_scale + ) + + if padding is not None: + logits = logits[~padding] + + mat = Matrix(logits.numpy()) + + out = asr.decode(mat) + + if self.nbest > 1: + from kaldi.fstext import shortestpath + from kaldi.fstext.utils import ( + convert_compact_lattice_to_lattice, + convert_lattice_to_std, + convert_nbest_to_list, + get_linear_symbol_sequence, + ) + + lat = out["lattice"] + + sp = shortestpath(lat, nshortest=self.nbest) + + sp = convert_compact_lattice_to_lattice(sp) + sp = convert_lattice_to_std(sp) + seq = convert_nbest_to_list(sp) + + results = [] + for s in seq: + _, o, w = get_linear_symbol_sequence(s) + words = list(self.output_symbols[z] for z in o) + results.append( + { + "tokens": words, + "words": words, + "score": w.value, + "emissions": logits, + } + ) + return results + else: + words = out["text"].split() + return [ + { + "tokens": words, + "words": words, + "score": out["likelihood"], + "emissions": logits, + } + ] + + def decode(self, emissions, padding): + if padding is None: + padding = [None] * len(emissions) + + ret = list( + map( + lambda e, p: self.executor.submit(self.decode_one, e, p), + emissions, + padding, + ) + ) + return ret diff --git a/examples/speech_recognition/kaldi/kaldi_initializer.py b/examples/speech_recognition/kaldi/kaldi_initializer.py new file mode 100644 index 0000000000..6d2a2a4b6b --- /dev/null +++ b/examples/speech_recognition/kaldi/kaldi_initializer.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +import hydra +from hydra.core.config_store import ConfigStore +import logging +from omegaconf import MISSING, OmegaConf +import os +import os.path as osp +from pathlib import Path +import subprocess +from typing import Optional + +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import FairseqDataclass + +script_dir = Path(__file__).resolve().parent +config_path = script_dir / "config" + + +logger = logging.getLogger(__name__) + + +@dataclass +class KaldiInitializerConfig(FairseqDataclass): + data_dir: str = MISSING + fst_dir: Optional[str] = None + in_labels: str = MISSING + out_labels: Optional[str] = None + wav2letter_lexicon: Optional[str] = None + lm_arpa: str = MISSING + kaldi_root: str = MISSING + blank_symbol: str = "<s>" + silence_symbol: Optional[str] = None + + +def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path: + in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt" + if not in_units_file.exists(): + + logger.info(f"Creating {in_units_file}") + + with open(in_units_file, "w") as f: + print("<eps> 0", file=f) + i = 1 + for symb in vocab.symbols[vocab.nspecial :]: + if not symb.startswith("madeupword"): + print(f"{symb} {i}", file=f) + i += 1 + return in_units_file + + +def create_lexicon( + cfg: KaldiInitializerConfig, + fst_dir: Path, + unique_label: str, + in_units_file: Path, + out_words_file: Path, +) -> (Path, Path): + + disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt" + lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt" + disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt" + if ( + not lexicon_file.exists() + or not disambig_lexicon_file.exists() + or not disambig_in_units_file.exists() + ): + logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})") + + assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels + + if cfg.wav2letter_lexicon is not None: + lm_words = set() + with open(out_words_file, "r") as lm_dict_f: + for line in lm_dict_f: + lm_words.add(line.split()[0]) + + num_skipped = 0 + total = 0 + with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open( + lexicon_file, "w" + ) as out_f: + for line in w2l_lex_f: + items = line.rstrip().split("\t") + assert len(items) == 2, items + if items[0] in lm_words: + print(items[0], items[1], file=out_f) + else: + num_skipped += 1 + logger.debug( + f"Skipping word {items[0]} as it was not found in LM" + ) + total += 1 + if num_skipped > 0: + logger.warning( + f"Skipped {num_skipped} out of {total} words as they were not found in LM" + ) + else: + with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f: + for line in in_f: + symb = line.split()[0] + if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>": + print(symb, symb, file=out_f) + + lex_disambig_path = ( + Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl" + ) + res = subprocess.run( + [lex_disambig_path, lexicon_file, disambig_lexicon_file], + check=True, + capture_output=True, + ) + ndisambig = int(res.stdout) + disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl" + res = subprocess.run( + [disamib_path, "--include-zero", in_units_file, str(ndisambig)], + check=True, + capture_output=True, + ) + with open(disambig_in_units_file, "wb") as f: + f.write(res.stdout) + + return disambig_lexicon_file, disambig_in_units_file + + +def create_G( + kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str +) -> (Path, Path): + + out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt" + grammar_graph = fst_dir / f"G_{arpa_base}.fst" + if not grammar_graph.exists() or not out_words_file.exists(): + logger.info(f"Creating {grammar_graph}") + arpa2fst = kaldi_root / "src/lmbin/arpa2fst" + subprocess.run( + [ + arpa2fst, + "--disambig-symbol=#0", + f"--write-symbol-table={out_words_file}", + lm_arpa, + grammar_graph, + ], + check=True, + ) + return grammar_graph, out_words_file + + +def create_L( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + lexicon_file: Path, + in_units_file: Path, + out_words_file: Path, +) -> Path: + lexicon_graph = fst_dir / f"L.{unique_label}.fst" + + if not lexicon_graph.exists(): + logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})") + make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl" + fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" + fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + def write_disambig_symbol(file): + with open(file, "r") as f: + for line in f: + items = line.rstrip().split() + if items[0] == "#0": + out_path = str(file) + "_disamig" + with open(out_path, "w") as out_f: + print(items[1], file=out_f) + return out_path + + return None + + in_disambig_sym = write_disambig_symbol(in_units_file) + assert in_disambig_sym is not None + out_disambig_sym = write_disambig_symbol(out_words_file) + assert out_disambig_sym is not None + + try: + with open(lexicon_graph, "wb") as out_f: + res = subprocess.run( + [make_lex, lexicon_file], capture_output=True, check=True + ) + assert len(res.stderr) == 0, res.stderr.decode("utf-8") + res = subprocess.run( + [ + fstcompile, + f"--isymbols={in_units_file}", + f"--osymbols={out_words_file}", + "--keep_isymbols=false", + "--keep_osymbols=false", + ], + input=res.stdout, + capture_output=True, + ) + assert len(res.stderr) == 0, res.stderr.decode("utf-8") + res = subprocess.run( + [fstaddselfloops, in_disambig_sym, out_disambig_sym], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=olabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(lexicon_graph) + raise + except AssertionError: + os.remove(lexicon_graph) + raise + + return lexicon_graph + + +def create_LG( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + lexicon_graph: Path, + grammar_graph: Path, +) -> Path: + lg_graph = fst_dir / f"LG.{unique_label}.fst" + + if not lg_graph.exists(): + logger.info(f"Creating {lg_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + try: + with open(lg_graph, "wb") as out_f: + res = subprocess.run( + [fsttablecompose, lexicon_graph, grammar_graph], + capture_output=True, + check=True, + ) + res = subprocess.run( + [ + fstdeterminizestar, + "--use-log=true", + ], + input=res.stdout, + capture_output=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstpushspecial], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=ilabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(lg_graph) + raise + + return lg_graph + + +def create_H( + kaldi_root: Path, + fst_dir: Path, + disambig_out_units_file: Path, + in_labels: str, + vocab: Dictionary, + blk_sym: str, + silence_symbol: Optional[str], +) -> (Path, Path, Path): + h_graph = ( + fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst" + ) + h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt" + disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int") + disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int") + if ( + not h_graph.exists() + or not h_out_units_file.exists() + or not disambig_in_units_file_int.exists() + ): + logger.info(f"Creating {h_graph}") + eps_sym = "<eps>" + + num_disambig = 0 + osymbols = [] + + with open(disambig_out_units_file, "r") as f, open( + disambig_out_units_file_int, "w" + ) as out_f: + for line in f: + symb, id = line.rstrip().split() + if line.startswith("#"): + num_disambig += 1 + print(id, file=out_f) + else: + if len(osymbols) == 0: + assert symb == eps_sym, symb + osymbols.append((symb, id)) + + i_idx = 0 + isymbols = [(eps_sym, 0)] + + imap = {} + + for i, s in enumerate(vocab.symbols): + i_idx += 1 + isymbols.append((s, i_idx)) + imap[s] = i_idx + + fst_str = [] + + node_idx = 0 + root_node = node_idx + + special_symbols = [blk_sym] + if silence_symbol is not None: + special_symbols.append(silence_symbol) + + for ss in special_symbols: + fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym)) + + for symbol, _ in osymbols: + if symbol == eps_sym or symbol.startswith("#"): + continue + + node_idx += 1 + # 1. from root to emitting state + fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol)) + # 2. from emitting state back to root + fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) + # 3. from emitting state to optional blank state + pre_node = node_idx + node_idx += 1 + for ss in special_symbols: + fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym)) + # 4. from blank state back to root + fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) + + fst_str.append("{}".format(root_node)) + + fst_str = "\n".join(fst_str) + h_str = str(h_graph) + isym_file = h_str + ".isym" + + with open(isym_file, "w") as f: + for sym, id in isymbols: + f.write("{} {}\n".format(sym, id)) + + with open(h_out_units_file, "w") as f: + for sym, id in osymbols: + f.write("{} {}\n".format(sym, id)) + + with open(disambig_in_units_file_int, "w") as f: + disam_sym_id = len(isymbols) + for _ in range(num_disambig): + f.write("{}\n".format(disam_sym_id)) + disam_sym_id += 1 + + fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" + fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" + fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" + + try: + with open(h_graph, "wb") as out_f: + res = subprocess.run( + [ + fstcompile, + f"--isymbols={isym_file}", + f"--osymbols={h_out_units_file}", + "--keep_isymbols=false", + "--keep_osymbols=false", + ], + input=str.encode(fst_str), + capture_output=True, + check=True, + ) + res = subprocess.run( + [ + fstaddselfloops, + disambig_in_units_file_int, + disambig_out_units_file_int, + ], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstarcsort, "--sort_type=olabel"], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(h_graph) + raise + return h_graph, h_out_units_file, disambig_in_units_file_int + + +def create_HLGa( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + h_graph: Path, + lg_graph: Path, + disambig_in_words_file_int: Path, +) -> Path: + hlga_graph = fst_dir / f"HLGa.{unique_label}.fst" + + if not hlga_graph.exists(): + logger.info(f"Creating {hlga_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" + fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + + try: + with open(hlga_graph, "wb") as out_f: + res = subprocess.run( + [ + fsttablecompose, + h_graph, + lg_graph, + ], + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstdeterminizestar, "--use-log=true"], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmsymbols, disambig_in_words_file_int], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmepslocal], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(hlga_graph) + raise + + return hlga_graph + + +def create_HLa( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + h_graph: Path, + l_graph: Path, + disambig_in_words_file_int: Path, +) -> Path: + hla_graph = fst_dir / f"HLa.{unique_label}.fst" + + if not hla_graph.exists(): + logger.info(f"Creating {hla_graph}") + + fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" + fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" + fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" + fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" + fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" + + try: + with open(hla_graph, "wb") as out_f: + res = subprocess.run( + [ + fsttablecompose, + h_graph, + l_graph, + ], + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstdeterminizestar, "--use-log=true"], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmsymbols, disambig_in_words_file_int], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstrmepslocal], + input=res.stdout, + capture_output=True, + check=True, + ) + res = subprocess.run( + [fstminimizeencoded], + input=res.stdout, + capture_output=True, + check=True, + ) + out_f.write(res.stdout) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + os.remove(hla_graph) + raise + + return hla_graph + + +def create_HLG( + kaldi_root: Path, + fst_dir: Path, + unique_label: str, + hlga_graph: Path, + prefix: str = "HLG", +) -> Path: + hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst" + + if not hlg_graph.exists(): + logger.info(f"Creating {hlg_graph}") + + add_self_loop = script_dir / "add-self-loop-simple" + kaldi_src = kaldi_root / "src" + kaldi_lib = kaldi_src / "lib" + + try: + if not add_self_loop.exists(): + fst_include = kaldi_root / "tools/openfst-1.6.7/include" + add_self_loop_src = script_dir / "add-self-loop-simple.cc" + + subprocess.run( + [ + "c++", + f"-I{kaldi_src}", + f"-I{fst_include}", + f"-L{kaldi_lib}", + add_self_loop_src, + "-lkaldi-base", + "-lkaldi-fstext", + "-o", + add_self_loop, + ], + check=True, + ) + + my_env = os.environ.copy() + my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}" + + subprocess.run( + [ + add_self_loop, + hlga_graph, + hlg_graph, + ], + check=True, + capture_output=True, + env=my_env, + ) + except subprocess.CalledProcessError as e: + logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") + raise + + return hlg_graph + + +def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path: + if cfg.fst_dir is None: + cfg.fst_dir = osp.join(cfg.data_dir, "kaldi") + if cfg.out_labels is None: + cfg.out_labels = cfg.in_labels + + kaldi_root = Path(cfg.kaldi_root) + data_dir = Path(cfg.data_dir) + fst_dir = Path(cfg.fst_dir) + fst_dir.mkdir(parents=True, exist_ok=True) + + arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0] + unique_label = f"{cfg.in_labels}.{arpa_base}" + + with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f: + vocab = Dictionary.load(f) + + in_units_file = create_units(fst_dir, cfg.in_labels, vocab) + + grammar_graph, out_words_file = create_G( + kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base + ) + + disambig_lexicon_file, disambig_L_in_units_file = create_lexicon( + cfg, fst_dir, unique_label, in_units_file, out_words_file + ) + + h_graph, h_out_units_file, disambig_in_units_file_int = create_H( + kaldi_root, + fst_dir, + disambig_L_in_units_file, + cfg.in_labels, + vocab, + cfg.blank_symbol, + cfg.silence_symbol, + ) + lexicon_graph = create_L( + kaldi_root, + fst_dir, + unique_label, + disambig_lexicon_file, + disambig_L_in_units_file, + out_words_file, + ) + lg_graph = create_LG( + kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph + ) + hlga_graph = create_HLGa( + kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int + ) + hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph) + + # for debugging + # hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int) + # hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped") + # create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped") + + return hlg_graph + + +@hydra.main(config_path=config_path, config_name="kaldi_initializer") +def cli_main(cfg: KaldiInitializerConfig) -> None: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + initalize_kaldi(cfg) + + +if __name__ == "__main__": + + logging.root.setLevel(logging.INFO) + logging.basicConfig(level=logging.INFO) + + try: + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + + cfg_name = get_args().config_name or "kaldi_initializer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "kaldi_initializer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=KaldiInitializerConfig) + + cli_main() diff --git a/examples/wav2vec/__init__.py b/examples/wav2vec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml index e393805ad8..2d38211e91 100644 --- a/examples/wav2vec/config/finetuning/base_960h.yaml +++ b/examples/wav2vec/config/finetuning/base_960h.yaml @@ -31,7 +31,7 @@ criterion: optimization: max_update: 320000 - lr: [0.00001] + lr: [0.0001] sentence_avg: true optimizer: diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md index fdfcc04d26..2277e65ffb 100644 --- a/examples/wav2vec/unsupervised/README.md +++ b/examples/wav2vec/unsupervised/README.md @@ -1,4 +1,3 @@ - # wav2vec Unsupervised (wav2vec-U) Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec)) as well as unlabeled speech and text data. @@ -8,135 +7,35 @@ Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition * Generative adversarial training (GAN) * Iterative self-training + Kaldi LM-decoding - ## Preparation of speech and text data Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}. - -In **/path/to/data/with_silence** you need a *train.tsv* file as well as *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. +In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. Here is how you can create new audio files without silences from a list of input audio files: -``` -python scripts/unsupervised/remove_silences.py /path/to/data/with_silence/train.tsv \ - --save-dir /path/to/data/without_silence/audio \ - --output /path/to/data/without_silence/train.tsv & -``` +```shell +python scripts/vads.py < /path/to/train.tsv > train.vads +python scripts/remove_silence.py --tsv /path/to/train.tsv --vads train.vads --out /dir/to/save/audio/files -In this first part, we use mostly phonemized text. Here is how you can transform a text file into its phonemized .phn version: +python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0 ``` -# Will phonemize word dictionary and then phonemize text using dict lookup (for language $lg) -python scripts/unsupervised/phonemize.py $path_to_wrd_dict $lg < text.wrd > text.phn & - -``` -Next, you can reproduce Figure 2/3 of the wav2vec-U paper by training linear models on top of each layer's frozen wav2vec 2.0 representations, using supervised data. You can observe that certain layers provide lower PER, which shows the closeness of their representations to phoneme outputs. Note that this step requires supervision and is thus not necessary. - -``` -# Learn linear model on top of layer N(=15) using supervised data - -fairseq-hydra-train \ - distributed_training.distributed_port=$PORT \ - task.data=/path/to/data/without_silence \ - model.w2v_path=/path/to/model.pt \ - model.layer=15 \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ - --config-name vox_10h_phn -``` - - - -We can extract features of layer *N* using the following: -``` -# Extract features from layer N -split=train # valid test -python scripts/unsupervised/fb_wav2vec_ctc_filter.py \ - /path/to/data/without_silence \ - --split $split \ - --layer=15 \ - --checkpoint /path/to/model.pt \ - --save-dir /path/to/features & -``` - - - -Next we perform clustering of wav2vec representations (step 2 in the paper): -``` -# Identify clusters in the representations with k-means - -python scripts/unsupervised/fb_wav2vec_cluster_faiss.py \ - /path/to/data/train.tsv \ - -f "CLUS128" \ - --sample-pct 0.5 \ - --layer 15 \ - --checkpoint /path/to/model.pt \ - --save-dir /path/to/features/clustering/segmented & -``` - -And use those clusters to segment the audio data (step 3 in the paper): -``` -# Transcribe cluster ids of audio data -python scripts/unsupervised/fb_wav2vec_apply_cluster_faiss.py \ - /path/to/data \ - --split $split \ - --checkpoint /path/to/model.pt \ - --path /path/to/features/clustering/segmented/CLUS128 & -``` - - Learn and apply PCA to the representations to retain important features -``` -# Compute PCA -python scripts/pca.py \ - /path/to/features/unfiltered/train.npy \ - --dim 512 \ - --output /path/to/features/unfiltered/unfiltered_pca - -# Apply PCA -python scripts/apply_pca.py \ - $outdir \ - --split $split \ - --pca-path /path/to/features/unfiltered/unfiltered_pca/512_pca \ - --batch-size 1048000 \ - --save-dir /path/to/features/unfiltered/precompute_unfiltered_pca${dim} -``` +You will need to add the path to rVAD directory to vads.py. -Then we build segment representations by mean-pooling representations according to clusters: - - +Next, we need to preprocess the audio data to better match phonemized text data: +```shell +zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt ``` -# Build segment representations - -python scripts/unsupervised/merge_clusters.py \ - /path/to/features/unfiltered/precompute_unfiltered_pca512 \ - --split $split \ - --cluster-dir /path/to/features/clustering/segmented/CLUS128 \ - --pooling mean \ - --save-dir /path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean & +Note that if you have splits different than train/valid/test, you will need to modify this script. +Now we need to prepare text data: +```shell +zsh scripts/prepare_text.sh language /path/to/text/file /output/dir ``` -Finally, we found that segment boundaries are noisy due to the lack of supervision and we therefore found it useful to also mean-pool pairs of adjacent segment representations to increase robustness: -``` -# Mean-pool adjacent time steps -python scripts/unsupervised/mean_pool.py \ - /path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean \ - --split $split & - --save-dir $savedir & -``` - -For adversarial training, we preprocess the text data by adding silence tokens. -``` -# Add <SIL> tokens on text in preparation for GAN training -python scripts/unsupervised/fb_wrd_to_phonemizer.py \ - -s 0.25 --surround < /path/to/data/gan.txt > /path/to/data/gan.txt_s0.25.phns & - -# Binarize with fairseq-preprocess -fairseq-preprocess --dataset-impl mmap \ - --trainpref /path/to/data/gan.txt_s0.25.phns \ - --workers 6 --thresholdsrc 0 --only-source \ - --destdir /path/to/data --srcdict /path/to/data/dict.phn.txt & -``` +Note that if you want to use a different phonemizer, such as G2P, you will need to modify this script. ## Generative adversarial training (GAN) @@ -152,24 +51,25 @@ TEXT_DATA=/path/to/data # path to fairseq-preprocessed GAN data KENLM_PATH=/path/to/data/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here) PREFIX=$PREFIX fairseq-hydra-train \ - distributed_training.distributed_port=$PORT \ - -m --config-dir configs/unsup \ - --config-name gan_feats_by_label \ - dataset.valid_subset=valid \ + -m --config-dir configs/gan \ + --config-name w2vu \ task.data=${TASK_DATA} \ task.text_data=${TEXT_DATA} \ task.kenlm_path=${KENLM_PATH} \ 'common.seed=range(0,5)' & ``` -However, this step requires an hyperparameter search, which can be launched with: -``` -``` - -Note that hyperparameter search and model/epoch selection are done using a fully unsupervised metric (see Section 4.3). +Once we find the best checkpoint (chosen using unsupervised metric that combined language model perplexity and vocabulary usage), we can use it to generate phone labels (or word labels with an appropriate kaldi WFST): +```shell +python w2vu_generate.py --config-dir config/generate --config-name viterbi \ +fairseq.task.data=/path/to/dir/with/tsvs fairseq.common_eval.path=/path/to/gan/checkpoint \ +fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions +``` ## Iterative self-training + Kaldi LM-decoding After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words. Please see [this README](http://github.com/pytorch/fairseq/tree/master/examples/wav2vec/unsupervised/kaldi_st) for more instructions on how to do iterative self-training + Kaldi LM-decoding. + +*** Note: these instructions are a work in progress and will be updated over the next few days diff --git a/examples/wav2vec/unsupervised/__init__.py b/examples/wav2vec/unsupervised/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml b/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml new file mode 100644 index 0000000000..e94da2ba4e --- /dev/null +++ b/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml @@ -0,0 +1,62 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + +checkpoint: + no_epoch_checkpoints: true + save_interval_updates: 20000 + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 800000 + skip_invalid_size_inputs_valid_test: true + train_subset: train + valid_subset: valid + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + find_unused_parameters: True + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.25 + mask_channel_prob: 0.1 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 diff --git a/examples/wav2vec/unsupervised/config/gan/w2vu.yaml b/examples/wav2vec/unsupervised/config/gan/w2vu.yaml new file mode 100644 index 0000000000..d168a11e19 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/gan/w2vu.yaml @@ -0,0 +1,108 @@ +# @package _group_ + +common: + fp16: false + fp16_no_flatten_grads: true + log_format: json + log_interval: 100 + tensorboard_logdir: tb + reset_logging: false + suppress_crashes: false + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: weighted_lm_ppl + save_dir: . + +task: + _name: unpaired_audio_text + data: ??? + text_data: ??? + labels: phn + sort_by_length: false + unfiltered: false + max_length: null + append_eos: false + kenlm_path: ??? + +dataset: + num_workers: 6 + batch_size: 160 + skip_invalid_size_inputs_valid_test: true + valid_subset: valid + +criterion: + _name: model + log_keys: + - accuracy_dense + - accuracy_token + - temp + - code_ppl + +optimization: + max_update: 150000 + clip_norm: 5.0 + lr: [0] + +optimizer: + _name: composite + groups: + generator: + lr: [0.0004] + lr_float: null + optimizer: + _name: adam + adam_betas: [0.5,0.98] + adam_eps: 1e-06 + weight_decay: 0 + amsgrad: false + lr_scheduler: + _name: fixed + warmup_updates: 0 + discriminator: + lr: [ 0.0005 ] + lr_float: null + optimizer: + _name: adam + adam_betas: [0.5,0.98] + adam_eps: 1e-06 + weight_decay: 0.0001 + amsgrad: false + lr_scheduler: + _name: fixed + warmup_updates: 0 + +lr_scheduler: pass_through + +model: + _name: wav2vec_u + + discriminator_dim: 384 + discriminator_depth: 2 + discriminator_kernel: 6 + discriminator_linear_emb: false + discriminator_causal: true + discriminator_max_pool: false + discriminator_act_after_linear: false + discriminator_dropout: 0.0 + discriminator_weight_norm: false + + generator_stride: 1 + generator_kernel: 4 + generator_bias: false + generator_dropout: 0.1 + + smoothness_weight: 0.5 + smoothing: 0 + smoothing_one_sided: false + gumbel: false + hard_gumbel: false + gradient_penalty: 1.5 + code_penalty: 4.0 + temp: [ 2,0.1,0.99995 ] + input_dim: 512 + + segmentation: + type: JOIN + mean_pool_join: false + remove_zeros: false diff --git a/examples/wav2vec/unsupervised/config/generate/viterbi.yaml b/examples/wav2vec/unsupervised/config/generate/viterbi.yaml new file mode 100644 index 0000000000..0f850bb3e7 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/generate/viterbi.yaml @@ -0,0 +1,22 @@ +# @package _group_ + +fairseq: + task: + _name: unpaired_audio_text + labels: phn + data: ??? + sort_by_length: false + shuffle: false + text_data: '' + + common_eval: + path: ??? + quiet: true + + dataset: + gen_subset: valid + batch_size: 1 + +w2l_decoder: VITERBI +lm_model: ??? +post_process: silence diff --git a/examples/wav2vec/unsupervised/data/__init__.py b/examples/wav2vec/unsupervised/data/__init__.py new file mode 100644 index 0000000000..d0545627ef --- /dev/null +++ b/examples/wav2vec/unsupervised/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .extracted_features_dataset import ExtractedFeaturesDataset +from .random_input_dataset import RandomInputDataset + + +__all__ = [ + "ExtractedFeaturesDataset", + "RandomInputDataset", +] diff --git a/examples/wav2vec/unsupervised/data/extracted_features_dataset.py b/examples/wav2vec/unsupervised/data/extracted_features_dataset.py new file mode 100644 index 0000000000..d6ee9c4a36 --- /dev/null +++ b/examples/wav2vec/unsupervised/data/extracted_features_dataset.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import contextlib + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, data_utils + + +logger = logging.getLogger(__name__) + + +class ExtractedFeaturesDataset(FairseqDataset): + def __init__( + self, + path, + split, + min_length=3, + max_length=None, + labels=None, + label_dict=None, + shuffle=True, + sort_by_length=True, + ): + super().__init__() + + self.min_length = min_length + self.max_length = max_length + self.shuffle = shuffle + self.sort_by_length = sort_by_length + self.label_dict = label_dict + + if labels is not None: + assert label_dict is not None + + self.sizes = [] + self.offsets = [] + self.labels = [] + + path = os.path.join(path, split) + data_path = path + self.data = np.load(data_path + ".npy", mmap_mode="r") + + offset = 0 + skipped = 0 + + if not os.path.exists(path + f".{labels}"): + labels = None + + with open(data_path + ".lengths", "r") as len_f, open( + path + f".{labels}", "r" + ) if labels is not None else contextlib.ExitStack() as lbl_f: + for line in len_f: + length = int(line.rstrip()) + lbl = None if labels is None else next(lbl_f).rstrip().split() + if length >= min_length and ( + max_length is None or length <= max_length + ): + self.sizes.append(length) + self.offsets.append(offset) + if lbl is not None: + self.labels.append(lbl) + offset += length + + self.sizes = np.asarray(self.sizes) + self.offsets = np.asarray(self.offsets) + + logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") + + def __getitem__(self, index): + offset = self.offsets[index] + end = self.sizes[index] + offset + feats = torch.from_numpy(self.data[offset:end].copy()).float() + + res = {"id": index, "features": feats} + if len(self.labels) > 0: + res["target"] = self.label_dict.encode_line( + self.labels[index], + line_tokenizer=lambda x: x, + append_eos=False, + ) + + return res + + def __len__(self): + return len(self.sizes) + + def collater(self, samples): + if len(samples) == 0: + return {} + + features = [s["features"] for s in samples] + sizes = [len(s) for s in features] + + target_size = max(sizes) + + collated_features = features[0].new_zeros( + len(features), target_size, features[0].size(-1) + ) + padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) + for i, (f, size) in enumerate(zip(features, sizes)): + collated_features[i, :size] = f + padding_mask[i, size:] = True + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": {"features": collated_features, "padding_mask": padding_mask}, + } + + if len(self.labels) > 0: + target = data_utils.collate_tokens( + [s["target"] for s in samples], + pad_idx=self.label_dict.pad(), + left_pad=False, + ) + res["target"] = target + return res + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + if self.sort_by_length: + order.append(self.sizes) + return np.lexsort(order)[::-1] + else: + return order[0] diff --git a/examples/wav2vec/unsupervised/data/random_input_dataset.py b/examples/wav2vec/unsupervised/data/random_input_dataset.py new file mode 100644 index 0000000000..886505616c --- /dev/null +++ b/examples/wav2vec/unsupervised/data/random_input_dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import List + +from fairseq.data import BaseWrapperDataset, data_utils + + +class RandomInputDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + random_input_dataset, + input_key_path: List[str], + add_to_input, + pad_idx, + ): + super().__init__(dataset) + self.random_input_dataset = random_input_dataset + if isinstance(input_key_path, str): + input_key_path = [input_key_path] + assert len(input_key_path) > 0 + self.input_key_path = input_key_path + self.add_to_input = add_to_input + self.pad_idx = pad_idx + + def get_target(self, item): + target_loc = item + for p in self.input_key_path[:-1]: + target_loc = target_loc[p] + return self.input_key_path[-1], target_loc + + def get_target_value(self, item): + k, target_loc = self.get_target(item) + return target_loc[k] + + def __getitem__(self, index): + item = self.dataset[index] + k, target_loc = self.get_target(item) + target_loc[k] = random.choice(self.random_input_dataset) + return item + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + + random_inputs = data_utils.collate_tokens( + [self.get_target_value(s) for s in samples if s["id"] in indices], + pad_idx=self.pad_idx, + left_pad=False, + ) + k, target_loc = self.get_target( + collated if not self.add_to_input else collated["net_input"] + ) + target_loc[k] = random_inputs + + return collated diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/README.md b/examples/wav2vec/unsupervised/kaldi_self_train/README.md new file mode 100644 index 0000000000..314984fcbb --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/README.md @@ -0,0 +1,56 @@ +# Self-Training with Kaldi HMM Models +This folder contains recipes for self-training on pseudo phone transcripts and +decoding into phones or words with [kaldi](https://github.com/kaldi-asr/kaldi). + +To start, download and install kaldi follow its instruction, and place this +folder in `path/to/kaldi/egs`. + +## Training +Assuming the following has been prepared: +- `w2v_dir`: contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt` +- `lab_dir`: contains pseudo labels `{train,valid}.txt` +- `arpa_lm`: Arpa-format n-gram phone LM for decoding +- `arpa_lm_bin`: Arpa-format n-gram phone LM for unsupervised model selection to be used with KenLM + +Set these variables in `train.sh`, as well as `out_dir`, the output directory, +and then run it. + +The output will be: +``` +==== WER w.r.t. real transcript (select based on unsupervised metric) +INFO:root:./out/exp/mono/decode_valid/scoring/14.0.0.tra.txt: score 0.9178 wer 28.71% lm_ppl 24.4500 gt_wer 25.57% +INFO:root:./out/exp/tri1/decode_valid/scoring/17.1.0.tra.txt: score 0.9257 wer 26.99% lm_ppl 30.8494 gt_wer 21.90% +INFO:root:./out/exp/tri2b/decode_valid/scoring/8.0.0.tra.txt: score 0.7506 wer 23.15% lm_ppl 25.5944 gt_wer 15.78% +``` +where `wer` is the word eror rate with respect to the pseudo label, `gt_wer` to +the ground truth label, `lm_ppl` the language model perplexity of HMM prediced +transcripts, and `score` is the unsupervised metric for model selection. We +choose the model and the LM parameter of the one with the lowest score. In the +example above, it is `tri2b`, `8.0.0`. + + +## Decoding into Phones +In `decode_phone.sh`, set `out_dir` the same as used in `train.sh`, set +`dec_exp` and `dec_lmparam` to the selected model and LM parameter (e.g. +`tri2b` and `8.0.0` in the above example). `dec_script` needs to be set +according to `dec_exp`: for mono/tri1/tri2b, use `decode.sh`; for tri3b, use +`decode_fmllr.sh`. + +The output will be saved at `out_dir/dec_data` + + +## Decoding into Words +`decode_word_step1.sh` prepares WFSTs for word decoding. Besides the variables +mentioned above, set +- `wrd_arpa_lm`: Arpa-format n-gram word LM for decoding +- `wrd_arpa_lm_bin`: Arpa-format n-gram word LM for unsupervised model selection + +`decode_word_step1.sh` decodes the `train` and `valid` split into word and runs +unsupervised model selection using the `valid` split. The output is like: +``` +INFO:root:./out/exp/tri2b/decodeword_valid/scoring/17.0.0.tra.txt: score 1.8693 wer 24.97% lm_ppl 1785.5333 gt_wer 31.45% +``` + +After determining the LM parameter (`17.0.0` in the example above), set it in +`decode_word_step2.sh` and run it. The output will be saved at +`out_dir/dec_data_word`. diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh new file mode 100644 index 0000000000..e74953194d --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh @@ -0,0 +1,15 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="run.pl --mem 2G" +export decode_cmd="run.pl --mem 4G" +export mkgraph_cmd="run.pl --mem 8G" diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh new file mode 100644 index 0000000000..947342a0b7 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# decode into phones (and prepare a new data directory for HMM outputs) + +. ./path.sh + +set -eu + +out_dir= # same as in train.sh +dec_lmparam= # LM hyperparameters (e.g., 7.0.0) +dec_exp= +dec_script= +dec_splits="train valid" +dec_data_dir=$out_dir/dec_data # where to write HMM output + +data_dir=${out_dir}/data + +local/decode.sh --nj 40 --graph_name graph \ + --val_sets "$dec_splits" --decode_script $dec_script \ + $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test + +if [ ! -z $dec_lmparam ]; then + for x in $dec_splits; do + mkdir -p $dec_data_dir/$x + cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/ + + tra=$out_dir/exp/$dec_exp/decode_${x}/scoring/${dec_lmparam}.tra + cat $tra | utils/int2sym.pl -f 2- $data_dir/lang/words.txt | \ + sed 's:<UNK>::g' | sed 's:<SIL>::g' > $dec_data_dir/${x}/text + utils/fix_data_dir.sh $dec_data_dir/${x} + echo "WER on ${x} is" $(compute-wer ark:$data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-) + done +fi diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh new file mode 100644 index 0000000000..c1276bbe4d --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# prepare word WFSTs, reference data, and decode + +set -eu + +w2v_dir= # same as in train.sh +out_dir= # same as in train.sh +lexicon= # word to phone mapping +wrd_arpa_lm= # word LM +wrd_arpa_lm_bin= # word LM for KenLM, used in unsupervised selection + +dec_exp= # what HMM stage to decode (e.g., tri3b) +dec_script= # what decoding script to use (e.g., steps/decode_fmllr.sh) +phn_label=phnc +wrd_label=wrd +dec_suffix=word +dec_splits="train valid" +valid_split="valid" + +data_dir=$out_dir/data +wrd_data_dir=$out_dir/data_word + +lexicon_clean=$(mktemp) +cat $lexicon | sort | uniq > $lexicon_clean +local/prepare_lang_word.sh $w2v_dir/dict.${phn_label}.txt $data_dir $lexicon_clean && rm $lexicon_clean +local/prepare_lm.sh --langdir $data_dir/lang_word --lmdir $data_dir/lang_test_word $wrd_arpa_lm $data_dir + +for x in $dec_splits; do + x_gt=${x}_gt + mkdir -p $wrd_data_dir/$x_gt + cp $data_dir/$x_gt/{feats.scp,cmvn.scp,utt2spk,spk2utt} $wrd_data_dir/$x_gt/ + python local/copy_aligned_text.py < $w2v_dir/$x.$wrd_label > $wrd_data_dir/$x_gt/text +done + +local/decode.sh --nj 40 --graph_name graph${dec_suffix} --decode_suffix $dec_suffix \ + --val_sets "$dec_splits" --decode_script $dec_script \ + $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test_word + +local/unsup_select_decode_word.sh \ + --split $valid_split --kenlm_path $wrd_arpa_lm_bin \ + --ref_txt $wrd_data_dir/${valid_split}_gt/text \ + --psd_txt $data_dir/${valid_split}/text \ + --dec_name decode${dec_suffix} --graph_name graph${dec_suffix} \ + --phonemize_lexicon $data_dir/local/dict_word/lexicon.txt \ + $out_dir/exp diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh new file mode 100644 index 0000000000..59a6cbb125 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# prepare a new data directory of HMM word output + +. ./path.sh + +set -eu + +out_dir= # same as in train.sh +dec_lmparam= # LM hyperparameters (e.g., 7.0.0) + +dec_exp=tri3b # what HMM stage to decode (e.g., tri3b) +dec_suffix=word +dec_splits="train valid" +dec_data_dir=$out_dir/dec_data_word # where to write HMM output + +data_dir=$out_dir/data +wrd_data_dir=$out_dir/data_word + +for x in $dec_splits; do + mkdir -p $dec_data_dir/$x + cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/ + + tra=$out_dir/exp/$dec_exp/decode${dec_suffix}_${x}/scoring/${dec_lmparam}.tra + cat $tra | utils/int2sym.pl -f 2- $data_dir/lang_word/words.txt | \ + sed 's:<UNK>::g' | sed 's:<SIL>::g' > $dec_data_dir/$x/text + utils/fix_data_dir.sh $dec_data_dir/$x + echo "WER on $x is" $(compute-wer ark:$wrd_data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-) +done + diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py new file mode 100644 index 0000000000..5f4faa9921 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py @@ -0,0 +1,4 @@ +import sys + +for idx, line in enumerate(sys.stdin): + print(f"utt{idx:010d} {line}", end='') \ No newline at end of file diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh new file mode 100755 index 0000000000..811cb63c88 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -u + +val_sets="dev_other" +graph_name=graph +decode_suffix="" +decode_script="steps/decode_fmllr.sh" +decode_args="" +nj=60 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +set -x +exp_dir=$1 +data_root=$2 +lang_test=$3 + +graph=$exp_dir/$graph_name + +if [ ! -d $graph ]; then + utils/mkgraph.sh $lang_test $exp_dir $graph +fi + +for part in $val_sets; do + dec_dir=$exp_dir/decode${decode_suffix}_${part} + if [ ! -d $dec_dir ]; then + echo "decoding $part for $exp_dir" + $decode_script --nj $nj --cmd "$decode_cmd" $decode_args \ + $graph $data_root/$part $dec_dir & + else + echo "$dec_dir exists. skip" + fi +done + +wait diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py new file mode 100644 index 0000000000..66954ea5c9 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py @@ -0,0 +1,56 @@ +import kaldi_io +import numpy as np +import os + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("w2v_dir", help="wav2vec feature and text directory") + parser.add_argument("tar_root", help="output data directory in kaldi's format") + parser.add_argument("split", help="name of the subset") + parser.add_argument("--label", default="", help="if specified, copy labels too") + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + + tar_dir = os.path.join(args.tar_root, args.split) + os.makedirs(tar_dir, exist_ok=True) + + lengths_path = os.path.join(args.w2v_dir, f"{args.split}.lengths") + with open(lengths_path) as f: + lengths = [int(line.rstrip()) for line in f] + offsets = [0] + np.cumsum(lengths[:-1]).tolist() + feats = np.load( + os.path.join(args.w2v_dir, f"{args.split}.npy"), + mmap_mode="r" + ) + assert feats.shape[0] == sum(lengths), \ + f"lengths mismatch {feats.shape[0]} != {sum(lengths)}" + + ark_path = os.path.join(tar_dir, "feats.ark") + scp_path = os.path.join(tar_dir, "feats.scp") + wspec = f"ark:| copy-feats --compress=true ark:- ark,scp:{ark_path},{scp_path}" + with kaldi_io.open_or_fd(wspec, "wb") as f: + for idx, (offset, length) in enumerate(zip(offsets, lengths)): + feat = feats[offset:offset+length] + kaldi_io.write_mat(f, feat, key=f"utt{idx:010d}") + + u2s_path = os.path.join(tar_dir, "utt2spk") + s2u_path = os.path.join(tar_dir, "spk2utt") + with open(u2s_path, "w") as f_u2s, open(s2u_path, "w") as f_s2u: + for idx in range(len(lengths)): + f_u2s.write(f"utt{idx:010d} utt{idx:010d}\n") + f_s2u.write(f"utt{idx:010d} utt{idx:010d}\n") + + if bool(args.label): + lab_path = os.path.join(args.w2v_dir, f"{args.split}.{args.label}") + txt_path = os.path.join(tar_dir, "text") + with open(lab_path) as f_lab, open(txt_path, "w") as f_txt: + for idx, line in enumerate(f_lab): + f_txt.write(f"utt{idx:010d} {line}") + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh new file mode 100755 index 0000000000..e9a80001eb --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +sil_prob=0.5 +num_sil_states=3 +num_nonsil_states=1 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +set -eux + +dict=$1 +data_dir=$2 + +dict_dir=$data_dir/local/dict +tmplm_dir=$data_dir/local/lang_tmp +lm_dir=$data_dir/lang + +mkdir -p $dict_dir $tmplm_dir $lm_dir + +# prepare dict +echo "SIL" > $dict_dir/silence_phones.txt +echo "SIL" > $dict_dir/optional_silence.txt +awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt + +echo "SIL SIL" > $dict_dir/lexicon.txt +echo "<UNK> SIL" >> $dict_dir/lexicon.txt +awk '{print $1" "$1}' $dict >> $dict_dir/lexicon.txt + +echo "SIL" > $dict_dir/extra_questions.txt +awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt + +# prepare lang +utils/prepare_lang.sh --sil-prob $sil_prob --position-dependent-phones false \ + --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \ + $dict_dir "<UNK>" $tmplm_dir $lm_dir diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh new file mode 100755 index 0000000000..a7ea3877be --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +num_sil_states=3 +num_nonsil_states=1 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +set -eux + +dict=$1 +data_dir=$2 +lexicon=$3 + +dict_dir=$data_dir/local/dict_word +tmplm_dir=$data_dir/local/lang_tmp_word +lm_dir=$data_dir/lang_word + +mkdir -p $dict_dir $tmplm_dir $lm_dir + +# prepare dict +echo "SIL" > $dict_dir/silence_phones.txt +echo "SIL" > $dict_dir/optional_silence.txt +awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt + +(echo "!SIL SIL"; echo "<UNK> SIL";) | cat - $lexicon > $dict_dir/lexicon.txt + +echo "SIL" > $dict_dir/extra_questions.txt +awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt + +# prepare lang +utils/prepare_lang.sh --position-dependent-phones false \ + --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \ + $dict_dir "<UNK>" $tmplm_dir $lm_dir diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh new file mode 100755 index 0000000000..c2edcefede --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +langdir="" +lmdir="" + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +arpa_lm=$1 +data=$2 + +if [ -z $langdir ]; then + langdir=$data/lang +fi +if [ -z $lmdir ]; then + lmdir=$data/lang_test +fi + +if [ ! -d $langdir ]; then + echo "$langdir not found. run local/prepare_lang.sh first" && exit 1 +fi + +mkdir -p $lmdir +cp -r $langdir/* $lmdir + +if [[ "$arpa_lm" == *.gz ]]; then + gunzip -c $arpa_lm | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt - $lmdir/G.fst +else + arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt $arpa_lm $lmdir/G.fst +fi +fstisstochastic $lmdir/G.fst +utils/validate_lang.pl $lmdir || exit 1 + +echo "done preparing lm ($lmdir)" diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh new file mode 100755 index 0000000000..cb5bbb7277 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2014 Guoguo Chen +# Apache 2.0 + +[ -f ./path.sh ] && . ./path.sh + +# begin configuration section. +cmd=run.pl +stage=0 +decode_mbr=true +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] <data-dir> <lang-dir|graph-dir> <decode-dir>" + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." + echo " --min_lmwt <int> # minumum LM-weight for lattice rescoring " + echo " --max_lmwt <int> # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text | sed 's:<NOISE>::g' | sed 's:<SPOKEN_NOISE>::g' > $dir/scoring/test_filt.txt + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-best-path --word-symbol-table=$symtab \ + ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1; +done + +# Note: the double level of quoting for the sed command +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \ + cat $dir/scoring/LMWT.$wip.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\<UNK\>::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; +done + +exit 0; diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh new file mode 100755 index 0000000000..9ecf1690c6 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +split="dev_other" +ref_data="" +get_best_wer=true +dec_name="decode" +graph_name="graph" + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +exp_root=$1 + +set -eu + +echo "==== WER w.r.t. pseudo transcript" +for x in $exp_root/*/${dec_name}_${split}*; do grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh; done + + +if [ ! -z $ref_data ]; then + echo "==== WER w.r.t. real transcript (select based on pseudo WER)" + ref_txt=$ref_data/$split/text + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + lmwt=$( + grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh | + sed 's/.*wer_\(.*\)$/\1/g' | sed 's/_/./g' + ) + tra=$x/scoring/$lmwt.tra + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:<UNK>::g' | sed 's:<SIL>::g' | \ + compute-wer --text --mode=present \ + ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra + done +fi + +if [ ! -z $ref_data ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on true WER)" + ref_txt=$ref_data/$split/text + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:<UNK>::g' | sed 's:<SIL>::g' | \ + compute-wer --text --mode=present \ + ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra + done | sort -k2n | head -n1 + done +fi + +exit 0; diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh new file mode 100755 index 0000000000..913c1d8e43 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash + +out_root=/tmp +out_name=train_${RANDOM} +num_nonsil_states=1 + +valid="dev_other" +train="train" +mono_size="-1" # 2000 +tri1_size="-1" # 5000 +tri2b_size="-1" # 10000 +tri3b_size="-1" # 10000 + +# Acoustic model parameters +numLeavesTri1=2000 +numGaussTri1=10000 +numLeavesMLLT=2500 +numGaussMLLT=15000 +numLeavesSAT=2500 +numGaussSAT=15000 + +stage=1 +max_stage=1 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +data=$1 +lang=$2 +lang_test=$3 + +exp_root=$out_root/$out_name + +# you might not want to do this for interactive shells. +set -e + + +if [ $stage -le 1 ] && [ $max_stage -ge 1 ]; then + # train a monophone system + if [ ! $mono_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $mono_size $data/${train}_${mono_size} + mono_train=${train}_${mono_size} + else + mono_train=${train} + fi + + steps/train_mono.sh --boost-silence 1.25 --nj 20 --cmd "$train_cmd" \ + --initial-beam 40 --regular-beam 60 --retry-beam 120 \ + $data/$mono_train $lang $exp_root/mono + + utils/mkgraph.sh $lang_test $exp_root/mono $exp_root/mono/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/mono/graph $data/$valid $exp_root/mono/decode_$valid & +fi + + +if [ $stage -le 2 ] && [ $max_stage -ge 2 ]; then + # train a first delta + delta-delta triphone system on a subset of 5000 utterances + if [ ! $tri1_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri1_size $data/${train}_${tri1_size} + tri1_train=${train}_${tri1_size} + else + tri1_train=${train} + fi + + steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + $data/$tri1_train $lang \ + $exp_root/mono $exp_root/mono_ali_${tri1_train} + + steps_gan/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states $numLeavesTri1 $numGaussTri1 \ + $data/$tri1_train $lang \ + $exp_root/mono_ali_${tri1_train} $exp_root/tri1 + + utils/mkgraph.sh $lang_test $exp_root/tri1 $exp_root/tri1/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri1/graph $data/$valid $exp_root/tri1/decode_$valid & +fi + +if [ $stage -le 3 ] && [ $max_stage -ge 3 ]; then + # train an LDA+MLLT system. + if [ ! $tri2b_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri2b_size $data/${train}_${tri2b_size} + tri2b_train=${train}_${tri2b_size} + else + tri2b_train=${train} + fi + + steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + $data/$tri2b_train $lang \ + $exp_root/tri1 $exp_root/tri1_ali_${tri2b_train} + + steps_gan/train_lda_mllt.sh --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states \ + --splice-opts "--left-context=3 --right-context=3" $numLeavesMLLT $numGaussMLLT \ + $data/$tri2b_train $lang \ + $exp_root/tri1_ali_${tri2b_train} $exp_root/tri2b + + utils/mkgraph.sh $lang_test $exp_root/tri2b $exp_root/tri2b/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri2b/graph $data/$valid $exp_root/tri2b/decode_$valid & +fi + + +if [ $stage -le 4 ] && [ $max_stage -ge 4 ]; then + # Train tri3b, which is LDA+MLLT+SAT on 10k utts + if [ ! $tri3b_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri3b_size $data/${train}_${tri3b_size} + tri3b_train=${train}_${tri3b_size} + else + tri3b_train=${train} + fi + + steps/align_si.sh --nj 10 --cmd "$train_cmd" --use-graphs true \ + $data/$tri3b_train $lang \ + $exp_root/tri2b $exp_root/tri2b_ali_${tri2b_train} + + steps_gan/train_sat.sh --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states $numLeavesSAT $numGaussSAT \ + $data/$tri3b_train $lang \ + $exp_root/tri2b_ali_${tri2b_train} $exp_root/tri3b + + utils/mkgraph.sh $lang_test $exp_root/tri3b $exp_root/tri3b/graph + steps/decode_fmllr.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri3b/graph $data/$valid $exp_root/tri3b/decode_$valid & +fi + +wait diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py new file mode 100644 index 0000000000..1122c88c19 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py @@ -0,0 +1,135 @@ +""" +Implement unsupervised metric for decoding hyperparameter selection: + $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ +""" +import argparse +import logging +import math +import sys + +import kenlm +import editdistance +from g2p_en import G2p + +logging.root.setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("ref_tra", help="reference pseudo labels") + parser.add_argument("hyp_tra", help="decoded pseudo labels to be assess") + parser.add_argument("--kenlm_path", default="/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o5.bin", help="") + parser.add_argument("--uppercase", action="store_true", help="") + parser.add_argument("--skipwords", default="", help="") + parser.add_argument("--gt_tra", default="", help="ground truth pseudo labels for computing oracle WER") + parser.add_argument("--min_vt_uer", default=0.0, type=float) + parser.add_argument("--phonemize", action="store_true", help="phonemize word hypotheses, used when reference is phone transcript") + parser.add_argument("--phonemize_lexicon", default="", type=str, help="use a lexicon for phonemizing") + return parser + +def load_tra(tra_path): + with open(tra_path, "r") as f: + uid_to_tra = {} + for line in f: + toks = line.rstrip().split() + uid, tra = toks[0], " ".join(toks[1:]) + uid_to_tra[uid] = tra + logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") + return uid_to_tra + +def load_lex(lex_path): + with open(lex_path, "r") as f: + w2p = {} + for line in f: + w, p = line.rstrip().split(None, 1) + w2p[w] = p.split() + return w2p + +def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict): + d_cnt = 0 + w_cnt = 0 + w_cnt_h = 0 + for uid in hyp_uid_to_tra: + ref = ref_uid_to_tra[uid].split() + if g2p_dict is not None: + hyp = [] + for word in hyp_uid_to_tra[uid].split(): + if word in g2p_dict: + hyp = hyp + g2p_dict[word] + else: + logger.warning(f"{word} not in g2p_dict") + elif g2p is not None: + hyp = g2p(hyp_uid_to_tra[uid]) + hyp = [p for p in hyp if p != "'" and p != " "] + hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] + else: + hyp = hyp_uid_to_tra[uid].split() + logger.debug(( + f"======================\n" + f"HYP: {' '.join(hyp)}\n" + f"REF: {' '.join(ref)}" + )) + d_cnt += editdistance.eval(ref, hyp) + w_cnt += len(ref) + w_cnt_h += len(hyp) + wer = float(d_cnt) / w_cnt + logger.debug(( + f"wer = {wer*100:.2f}%; num. of ref words = {w_cnt}; " + f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" + )) + return wer + +def compute_lm_ppl(hyp_uid_to_tra, score_fn): + lm_score = 0. + w_cnt = 0 + for hyp in hyp_uid_to_tra.values(): + cur_score = score_fn(hyp) + cur_cnt = len(hyp.split()) + 1 # plus one for </s> + lm_score += cur_score + w_cnt += cur_cnt + logger.debug(( + f"======================\n" + f"score sum/avg = {cur_score:.2f}/{cur_score/cur_cnt:.2f}\n" + f"hyp = {hyp}" + )) + lm_ppl = math.pow(10, -lm_score / w_cnt) + logger.debug(f"lm ppl = {lm_ppl:.2f}; num. of words = {w_cnt}") + return lm_ppl + +def main(): + args = get_parser().parse_args() + logger.debug(f"Args: {args}") + + ref_uid_to_tra = load_tra(args.ref_tra) + hyp_uid_to_tra = load_tra(args.hyp_tra) + assert not bool(set(hyp_uid_to_tra.keys()) - set(ref_uid_to_tra.keys())) + + lm = kenlm.Model(args.kenlm_path) + skipwords = set(args.skipwords.split(",")) + def compute_lm_score(s): + s = " ".join(w for w in s.split() if w not in skipwords) + s = s.upper() if args.uppercase else s + return lm.score(s) + + g2p, g2p_dict = None, None + if args.phonemize: + if args.phonemize_lexicon: + g2p_dict = load_lex(args.phonemize_lexicon) + else: + g2p = G2p() + + wer = compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict) + lm_ppl = compute_lm_ppl(hyp_uid_to_tra, compute_lm_score) + + gt_wer = -math.inf + if args.gt_tra: + gt_uid_to_tra = load_tra(args.gt_tra) + gt_wer = compute_wer(gt_uid_to_tra, hyp_uid_to_tra, None, None) + + score = math.log(lm_ppl) * max(wer, args.min_vt_uer) + logging.info(f"{args.hyp_tra}: score={score:.4f}; wer={wer*100:.2f}%; lm_ppl={lm_ppl:.4f}; gt_wer={gt_wer*100:.2f}%") + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh new file mode 100755 index 0000000000..b34c5b6e06 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +split="dev_other" +ref_txt="" # ground truth transcript path +psd_txt="" # pseudo transcript path +get_best_wer=true +dec_name="decode" +graph_name="graph" +kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +exp_root=$1 +unsup_args="" +if [ $# -ge 2 ]; then + unsup_args=$2 +fi + +set -eu + +if [ ! -z $ref_txt ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + ( + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:<UNK>::g' | sed 's:<SIL>::g' > $tra.txt + python local/unsup_select.py $psd_txt $tra.txt --kenlm_path $kenlm_path --gt_tra $ref_txt $unsup_args + done 2>/dev/null | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 + ) & + done +fi +wait + diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh new file mode 100755 index 0000000000..c10a6b8809 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +split="dev_other" +ref_txt="" # ground truth transcript path +psd_txt="" # pseudo transcript path +get_best_wer=true +dec_name="decode" +graph_name="graph" +kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin +phonemize_lexicon="" + +. ./cmd.sh +. ./path.sh +. parse_options.sh +. /private/home/wnhsu/unsup_asr/fairseq-py-unsup/env.sh + +exp_root=$1 + +set -eu + +if [ ! -z $ref_txt ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:\<UNK\>::g' > $tra.txt + python local/unsup_select.py $psd_txt $tra.txt \ + --kenlm_path $kenlm_path --gt_tra $ref_txt --phonemize \ + --phonemize_lexicon "$phonemize_lexicon" + done | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 + done +fi + + diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh new file mode 100755 index 0000000000..1a6fb5f891 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/steps b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps new file mode 120000 index 0000000000..6e99bf5b5a --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh new file mode 100755 index 0000000000..af68715ab0 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# Begin configuration. +stage=-4 # This allows restarting after partway, when something when wrong. +config= +cmd=run.pl +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +realign_iters="10 20 30"; +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +beam=10 +careful=false +retry_beam=40 +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +power=0.25 # Exponent for number of gaussians according to occurrence counts +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=true" + # use the option --cmvn-opts "--norm-means=false" +cmvn_opts= +delta_opts= +context_opts= # use"--context-width=5 --central-position=2" for quinphone +num_nonsil_states=3 +# End configuration. + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh; +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_deltas.sh <num-leaves> <tot-gauss> <data-dir> <lang-dir> <alignment-dir> <exp-dir>" + echo "e.g.: steps/train_deltas.sh 2000 10000 data/train_si84_half data/lang exp/mono_ali exp/tri1" + echo "main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." + echo " --config <config-file> # config containing options" + echo " --stage <stage> # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do + [ ! -f $f ] && echo "train_deltas.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter increment for #Gauss +oov=`cat $lang/oov.int` || exit 1; +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +nj=`cat $alidir/num_jobs` || exit 1; +mkdir -p $dir/log +echo $nj > $dir/num_jobs + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +sdata=$data/split$nj; +split_data.sh $data $nj || exit 1; + + +[ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ + echo "$0: warning: ignoring CMVN options from source directory $alidir" +$norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" +echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. +[ ! -z $delta_opts ] && echo $delta_opts > $dir/delta_opts + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |" + +rm $dir/.error 2>/dev/null + +if [ $stage -le -3 ]; then + echo "$0: accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts \ + --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + sum-tree-stats $dir/treeacc $dir/*.treeacc 2>$dir/log/sum_tree_acc.log || exit 1; + rm $dir/*.treeacc +fi + +if [ $stage -le -2 ]; then + echo "$0: getting questions for tree-building, via clustering" + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts \ + $dir/treeacc $lang/phones/sets.int \ + $dir/questions.int 2> $dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $lang/topo $dir/questions.int \ + $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; + + $cmd $dir/log/init_model.log \ + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl || exit 1; + if grep 'no stats' $dir/log/init_model.log; then + echo "** The warnings above about 'no stats' generally mean you have phones **" + echo "** (or groups of phones) in your phone set that had no corresponding data. **" + echo "** You should probably figure out whether something went wrong, **" + echo "** or whether your data just doesn't happen to have examples of those **" + echo "** phones. **" + fi + + gmm-mixup --mix-up=$numgauss $dir/1.mdl $dir/1.occs $dir/1.mdl 2>$dir/log/mixup.log || exit 1; + rm $dir/treeacc +fi + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +if [ $stage -le 0 ]; then + echo "$0: compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + +x=1 +while [ $x -lt $num_iters ]; do + echo "$0: training pass $x" + if [ $stage -le $x ]; then + if echo $realign_iters | grep -w $x >/dev/null; then + echo "$0: aligning data" + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --mix-up=$numgauss --power=$power \ + --write-occs=$dir/$[$x+1].occs $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc + rm $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + +rm $dir/final.mdl $dir/final.occs 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + +steps/info/gmm_dir_info.pl $dir + +echo "$0: Done training system with delta+delta-delta features in $dir" + +exit 0 diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh new file mode 100755 index 0000000000..9d8c319ce8 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh @@ -0,0 +1,239 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# +# LDA+MLLT refers to the way we transform the features after computing +# the MFCCs: we splice across several frames, reduce the dimension (to 40 +# by default) using Linear Discriminant Analysis), and then later estimate, +# over multiple iterations, a diagonalizing transform known as MLLT or STC. +# See http://kaldi-asr.org/doc/transform.html for more explanation. +# +# Apache 2.0. + +# Begin configuration. +cmd=run.pl +config= +stage=-5 +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +realign_iters="10 20 30"; +mllt_iters="2 4 6 12"; +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +dim=40 +beam=10 +retry_beam=40 +careful=false +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +power=0.25 # Exponent for number of gaussians according to occurrence counts +randprune=4.0 # This is approximately the ratio by which we will speed up the + # LDA and MLLT calculations via randomized pruning. +splice_opts= +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=false" +cmvn_opts= +context_opts= # use "--context-width=5 --central-position=2" for quinphone. +# End configuration. +train_tree=true # if false, don't actually train the tree. +use_lda_mat= # If supplied, use this LDA[+MLLT] matrix. +num_nonsil_states=3 + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_lda_mllt.sh [options] <#leaves> <#gauss> <data> <lang> <alignments> <dir>" + echo " e.g.: steps/train_lda_mllt.sh 2500 15000 data/train_si84 data/lang exp/tri1_ali_si84 exp/tri2b" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." + echo " --config <config-file> # config containing options" + echo " --stage <stage> # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do + [ ! -f $f ] && echo "train_lda_mllt.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment +oov=`cat $lang/oov.int` || exit 1; +nj=`cat $alidir/num_jobs` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` || exit 1; +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; + +mkdir -p $dir/log + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +echo $nj >$dir/num_jobs +echo "$splice_opts" >$dir/splice_opts # keep track of frame-splicing options + # so that later stages of system building can know what they were. + + +[ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ + echo "$0: warning: ignoring CMVN options from source directory $alidir" +$norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" +echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. + +sdata=$data/split$nj; +split_data.sh $data $nj || exit 1; + +splicedfeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- |" +# Note: $feats gets overwritten later in the script. +feats="$splicedfeats transform-feats $dir/0.mat ark:- ark:- |" + + + +if [ $stage -le -5 ]; then + if [ -z "$use_lda_mat" ]; then + echo "$0: Accumulating LDA statistics." + rm $dir/lda.*.acc 2>/dev/null + $cmd JOB=1:$nj $dir/log/lda_acc.JOB.log \ + ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $alidir/final.mdl ark:- ark:- \| \ + acc-lda --rand-prune=$randprune $alidir/final.mdl "$splicedfeats" ark,s,cs:- \ + $dir/lda.JOB.acc || exit 1; + est-lda --write-full-matrix=$dir/full.mat --dim=$dim $dir/0.mat $dir/lda.*.acc \ + 2>$dir/log/lda_est.log || exit 1; + rm $dir/lda.*.acc + else + echo "$0: Using supplied LDA matrix $use_lda_mat" + cp $use_lda_mat $dir/0.mat || exit 1; + [ ! -z "$mllt_iters" ] && \ + echo "$0: Warning: using supplied LDA matrix $use_lda_mat but we will do MLLT," && \ + echo " which you might not want; to disable MLLT, specify --mllt-iters ''" && \ + sleep 5 + fi +fi + +cur_lda_iter=0 + +if [ $stage -le -4 ] && $train_tree; then + echo "$0: Accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts \ + --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + [ `ls $dir/*.treeacc | wc -w` -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1; + $cmd $dir/log/sum_tree_acc.log \ + sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; + rm $dir/*.treeacc +fi + + +if [ $stage -le -3 ] && $train_tree; then + echo "$0: Getting questions for tree clustering." + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts $dir/treeacc $lang/phones/sets.int \ + $dir/questions.int 2> $dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $lang/topo $dir/questions.int \ + $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: Building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; +fi + +if [ $stage -le -2 ]; then + echo "$0: Initializing the model" + if $train_tree; then + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; + grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; + rm $dir/treeacc + else + cp $alidir/tree $dir/ || exit 1; + $cmd JOB=1 $dir/log/init_model.log \ + gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \ + "$feats subset-feats ark:- ark:-|" || exit 1; + fi +fi + + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: Converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then + echo "$0: Compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $data/split$nj/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + + +x=1 +while [ $x -lt $num_iters ]; do + echo Training pass $x + if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then + echo Aligning data + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + if echo $mllt_iters | grep -w $x >/dev/null; then + if [ $stage -le $x ]; then + echo "$0: Estimating MLLT" + $cmd JOB=1:$nj $dir/log/macc.$x.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $dir/$x.mdl ark:- ark:- \| \ + gmm-acc-mllt --rand-prune=$randprune $dir/$x.mdl "$feats" ark:- $dir/$x.JOB.macc \ + || exit 1; + est-mllt $dir/$x.mat.new $dir/$x.*.macc 2> $dir/log/mupdate.$x.log || exit 1; + gmm-transform-means $dir/$x.mat.new $dir/$x.mdl $dir/$x.mdl \ + 2> $dir/log/transform_means.$x.log || exit 1; + compose-transforms --print-args=false $dir/$x.mat.new $dir/$cur_lda_iter.mat $dir/$x.mat || exit 1; + rm $dir/$x.*.macc + fi + feats="$splicedfeats transform-feats $dir/$x.mat ark:- ark:- |" + cur_lda_iter=$x + fi + + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss --power=$power \ + $dir/$x.mdl "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + +rm $dir/final.{mdl,mat,occs} 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs +ln -s $cur_lda_iter.mat $dir/final.mat + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + +steps/info/gmm_dir_info.pl $dir + +echo "$0: Done training system with LDA+MLLT features in $dir" + +exit 0 diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh new file mode 100755 index 0000000000..f75afafb1c --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh @@ -0,0 +1,281 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. + + +# This does Speaker Adapted Training (SAT), i.e. train on +# fMLLR-adapted features. It can be done on top of either LDA+MLLT, or +# delta and delta-delta features. If there are no transforms supplied +# in the alignment directory, it will estimate transforms itself before +# building the tree (and in any case, it estimates transforms a number +# of times during training). + + +# Begin configuration section. +stage=-5 +exit_stage=-100 # you can use this to require it to exit at the + # beginning of a specific stage. Not all values are + # supported. +fmllr_update_type=full +cmd=run.pl +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +beam=10 +retry_beam=40 +careful=false +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +context_opts= # e.g. set this to "--context-width 5 --central-position 2" for quinphone. +realign_iters="10 20 30"; +fmllr_iters="2 4 6 12"; +silence_weight=0.0 # Weight on silence in fMLLR estimation. +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +power=0.2 # Exponent for number of gaussians according to occurrence counts +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +phone_map= +train_tree=true +tree_stats_opts= +cluster_phones_opts= +compile_questions_opts= +# End configuration section. +num_nonsil_states=3 + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_sat.sh <#leaves> <#gauss> <data> <lang> <ali-dir> <exp-dir>" + echo " e.g.: steps/train_sat.sh 2500 15000 data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri3b" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." + echo " --config <config-file> # config containing options" + echo " --stage <stage> # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $data/feats.scp $lang/phones.txt $alidir/final.mdl $alidir/ali.1.gz; do + [ ! -f $f ] && echo "train_sat.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment +oov=`cat $lang/oov.int` +nj=`cat $alidir/num_jobs` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +sdata=$data/split$nj; +splice_opts=`cat $alidir/splice_opts 2>/dev/null` # frame-splicing options. +cmvn_opts=`cat $alidir/cmvn_opts 2>/dev/null` +delta_opts=`cat $alidir/delta_opts 2>/dev/null` +phone_map_opt= +[ ! -z "$phone_map" ] && phone_map_opt="--phone-map='$phone_map'" + +mkdir -p $dir/log +cp $alidir/splice_opts $dir 2>/dev/null # frame-splicing options. +cp $alidir/cmvn_opts $dir 2>/dev/null # cmn/cmvn option. +cp $alidir/delta_opts $dir 2>/dev/null # delta option. + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +echo $nj >$dir/num_jobs +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; + +# Set up features. + +if [ -f $alidir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "$0: feature type is $feat_type" + +## Set up speaker-independent features. +case $feat_type in + delta) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; + lda) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" + cp $alidir/final.mat $dir + cp $alidir/full.mat $dir 2>/dev/null + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; +esac + +## Get initial fMLLR transforms (possibly from alignment dir) +if [ -f $alidir/trans.1 ]; then + echo "$0: Using transforms from $alidir" + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/trans.JOB ark:- ark:- |" + cur_trans_dir=$alidir +else + if [ $stage -le -5 ]; then + echo "$0: obtaining initial fMLLR transforms since not present in $alidir" + # The next line is necessary because of $silphonelist otherwise being incorrect; would require + # old $lang dir which would require another option. Not needed anyway. + [ ! -z "$phone_map" ] && \ + echo "$0: error: you must provide transforms if you use the --phone-map option." && exit 1; + $cmd JOB=1:$nj $dir/log/fmllr.0.JOB.log \ + ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ + weight-silence-post $silence_weight $silphonelist $alidir/final.mdl ark:- ark:- \| \ + gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $alidir/final.mdl "$sifeats" \ + ark:- ark:$dir/trans.JOB || exit 1; + fi + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$dir/trans.JOB ark:- ark:- |" + cur_trans_dir=$dir +fi + +if [ $stage -le -4 ] && $train_tree; then + # Get tree stats. + echo "$0: Accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts $tree_stats_opts $phone_map_opt --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + [ "`ls $dir/*.treeacc | wc -w`" -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1; + $cmd $dir/log/sum_tree_acc.log \ + sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; + rm $dir/*.treeacc +fi + +if [ $stage -le -3 ] && $train_tree; then + echo "$0: Getting questions for tree clustering." + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) \ + $cluster_phones_opts $context_opts \ + $dir/treeacc $lang/phones/sets.int $dir/questions.int 2>$dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $compile_questions_opts $lang/topo $dir/questions.int $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: Building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; +fi + +if [ $stage -le -2 ]; then + echo "$0: Initializing the model" + if $train_tree; then + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; + grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; + rm $dir/treeacc + else + cp $alidir/tree $dir/ || exit 1; + $cmd JOB=1 $dir/log/init_model.log \ + gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \ + "$feats subset-feats ark:- ark:-|" || exit 1; + fi +fi + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: Converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $phone_map_opt $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +[ "$exit_stage" -eq 0 ] && echo "$0: Exiting early: --exit-stage $exit_stage" && exit 0; + +if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then + echo "$0: Compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + +x=1 +while [ $x -lt $num_iters ]; do + echo Pass $x + if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then + echo Aligning data + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + + if echo $fmllr_iters | grep -w $x >/dev/null; then + if [ $stage -le $x ]; then + echo Estimating fMLLR transforms + # We estimate a transform that's additional to the previous transform; + # we'll compose them. + $cmd JOB=1:$nj $dir/log/fmllr.$x.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + weight-silence-post $silence_weight $silphonelist $dir/$x.mdl ark:- ark:- \| \ + gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $dir/$x.mdl \ + "$feats" ark:- ark:$dir/tmp_trans.JOB || exit 1; + for n in `seq $nj`; do + ! ( compose-transforms --b-is-affine=true \ + ark:$dir/tmp_trans.$n ark:$cur_trans_dir/trans.$n ark:$dir/composed_trans.$n \ + && mv $dir/composed_trans.$n $dir/trans.$n && \ + rm $dir/tmp_trans.$n ) 2>$dir/log/compose_transforms.$x.log \ + && echo "$0: Error composing transforms" && exit 1; + done + fi + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" + cur_trans_dir=$dir + fi + + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + [ `ls $dir/$x.*.acc | wc -w` -ne "$nj" ] && echo "$0: Wrong #accs" && exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --power=$power --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc + rm $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + + +if [ $stage -le $x ]; then + # Accumulate stats for "alignment model"-- this model is + # computed with the speaker-independent features, but matches Gaussian-for-Gaussian + # with the final speaker-adapted model. + $cmd JOB=1:$nj $dir/log/acc_alimdl.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + gmm-acc-stats-twofeats $dir/$x.mdl "$feats" "$sifeats" \ + ark,s,cs:- $dir/$x.JOB.acc || exit 1; + [ `ls $dir/$x.*.acc | wc -w` -ne "$nj" ] && echo "$0: Wrong #accs" && exit 1; + # Update model. + $cmd $dir/log/est_alimdl.log \ + gmm-est --power=$power --remove-low-count-gaussians=false $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc|" $dir/$x.alimdl || exit 1; + rm $dir/$x.*.acc +fi + +rm $dir/final.{mdl,alimdl,occs} 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs +ln -s $x.alimdl $dir/final.alimdl + + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +utils/summarize_warnings.pl $dir/log +( + echo "$0: Likelihood evolution:" + for x in `seq $[$num_iters-1]`; do + tail -n 30 $dir/log/acc.$x.*.log | awk '/Overall avg like/{l += $(NF-3)*$(NF-1); t += $(NF-1); } + /Overall average logdet/{d += $(NF-3)*$(NF-1); t2 += $(NF-1);} + END{ d /= t2; l /= t; printf("%s ", d+l); } ' + done + echo +) | tee $dir/log/summary.log + + +steps/info/gmm_dir_info.pl $dir + +echo "$0: done training SAT system in $dir" + +exit 0 diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh new file mode 100644 index 0000000000..f3a3d3fc7c --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +set -eu + +w2v_dir= # contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt` +lab_dir= # contains pseudo labels `{train,valid}.txt` +out_dir= # output root +arpa_lm= # phone LM +arpa_lm_bin= # (binary) phone LM for KenLM, used in unsupervised selection + +label=phnc +train_name="train" +valid_name="valid" +data_dir=${out_dir}/data + +mkdir -p ${out_dir}/exp +local/prepare_lang.sh $w2v_dir/dict.${label}.txt $data_dir +local/prepare_lm.sh $arpa_lm $data_dir + +for x in $train_name $valid_name; do + x_gt=${x}_gt + + # prepare pseudo data + python local/prepare_data_from_w2v.py $w2v_dir $data_dir $x + steps/compute_cmvn_stats.sh $data_dir/$x $out_dir/exp/make_feat/$x $out_dir/feats/$x + python local/copy_aligned_text.py < $lab_dir/$x.txt > $data_dir/$x/text + + # prepare ground truth data + mkdir $data_dir/$x_gt + cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $data_dir/$x_gt/ + python local/copy_aligned_text.py < $w2v_dir/$x.$label > $data_dir/$x_gt/text +done + +local/train_subset_lgbeam.sh \ + --out_root ${out_dir} --out_name exp --train $train_name --valid $valid_name \ + --mono_size 2000 --tri1_size 5000 --tri2b_size -1 --tri3b_size -1 \ + --stage 1 --max_stage 3 $data_dir $data_dir/lang $data_dir/lang_test + +local/unsup_select_decode.sh \ + --split $valid_name --kenlm_path $arpa_lm_bin \ + --ref_txt $data_dir/${valid_name}_gt/text \ + --psd_txt $data_dir/${valid_name}/text \ + $out_dir/exp diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/utils b/examples/wav2vec/unsupervised/kaldi_self_train/st/utils new file mode 120000 index 0000000000..b240885218 --- /dev/null +++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/examples/wav2vec/unsupervised/models/__init__.py b/examples/wav2vec/unsupervised/models/__init__.py new file mode 100644 index 0000000000..3e3039b708 --- /dev/null +++ b/examples/wav2vec/unsupervised/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .wav2vec_u import Wav2vec_U + + +__all__ = [ + "Wav2vec_U", +] diff --git a/examples/wav2vec/unsupervised/models/wav2vec_u.py b/examples/wav2vec/unsupervised/models/wav2vec_u.py new file mode 100644 index 0000000000..d3f195b94e --- /dev/null +++ b/examples/wav2vec/unsupervised/models/wav2vec_u.py @@ -0,0 +1,658 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum, auto +import math +import numpy as np +from typing import Tuple, List, Optional, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autograd + +from fairseq import checkpoint_utils, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.modules import ( + SamePad, + TransposeLast, +) + + +class SegmentationType(Enum): + NONE = auto() + RANDOM = auto() + UNIFORM_RANDOM = auto() + UNIFORM_RANDOM_JOIN = auto() + JOIN = auto() + + +@dataclass +class SegmentationConfig(FairseqDataclass): + type: SegmentationType = SegmentationType.NONE + subsample_rate: float = 0.25 + mean_pool: bool = True + mean_pool_join: bool = False + remove_zeros: bool = False + + +@dataclass +class Wav2vec_UConfig(FairseqDataclass): + + discriminator_kernel: int = 3 + discriminator_dilation: int = 1 + discriminator_dim: int = 256 + discriminator_causal: bool = True + discriminator_linear_emb: bool = False + discriminator_depth: int = 1 + discriminator_max_pool: bool = False + discriminator_act_after_linear: bool = False + discriminator_dropout: float = 0.0 + discriminator_spectral_norm: bool = False + discriminator_weight_norm: bool = False + + generator_kernel: int = 4 + generator_dilation: int = 1 + generator_stride: int = 1 + generator_bias: bool = False + generator_dropout: float = 0.0 + + blank_weight: float = 0 + blank_mode: str = "add" + blank_is_sil: bool = False + no_softmax: bool = False + + smoothness_weight: float = 0.0 + smoothing: float = 0.0 + smoothing_one_sided: bool = False + gradient_penalty: float = 0.0 + probabilistic_grad_penalty_slicing: bool = False + code_penalty: float = 0.0 + gumbel: bool = False + hard_gumbel: bool = True + temp: Tuple[float, float, float] = (2, 0.1, 0.99995) + input_dim: int = 128 + wgan_loss: bool = False + + segmentation: SegmentationConfig = SegmentationConfig() + + +class Segmenter(nn.Module): + cfg: SegmentationConfig + + def __init__(self, cfg: SegmentationConfig): + super().__init__() + self.cfg = cfg + self.subsample_rate = cfg.subsample_rate + + def pre_segment(self, dense_x, dense_padding_mask): + return dense_x, dense_padding_mask + + def logit_segment(self, logits, padding_mask): + return logits, padding_mask + + +class RandomSegmenter(Segmenter): + def pre_segment(self, dense_x, dense_padding_mask): + target_num = math.ceil(dense_x.size(1) * self.subsample_rate) + ones = torch.ones(dense_x.shape[:-1], device=dense_x.device) + indices, _ = ones.multinomial(target_num).sort(dim=-1) + indices_ld = indices.unsqueeze(-1).expand(-1, -1, dense_x.size(-1)) + dense_x = dense_x.gather(1, indices_ld) + dense_padding_mask = dense_padding_mask.gather(1, index=indices) + return dense_x, dense_padding_mask + + +class UniformRandomSegmenter(Segmenter): + def pre_segment(self, dense_x, dense_padding_mask): + bsz, tsz, fsz = dense_x.shape + + target_num = math.ceil(tsz * self.subsample_rate) + + rem = tsz % target_num + + if rem > 0: + dense_x = F.pad(dense_x, [0, 0, 0, target_num - rem]) + dense_padding_mask = F.pad( + dense_padding_mask, [0, target_num - rem], value=True + ) + + dense_x = dense_x.view(bsz, target_num, -1, fsz) + dense_padding_mask = dense_padding_mask.view(bsz, target_num, -1) + + if self.cfg.mean_pool: + dense_x = dense_x.mean(dim=-2) + dense_padding_mask = dense_padding_mask.all(dim=-1) + else: + ones = torch.ones((bsz, dense_x.size(2)), device=dense_x.device) + indices = ones.multinomial(1) + indices = indices.unsqueeze(-1).expand(-1, target_num, -1) + indices_ld = indices.unsqueeze(-1).expand(-1, -1, -1, fsz) + dense_x = dense_x.gather(2, indices_ld).reshape(bsz, -1, fsz) + dense_padding_mask = dense_padding_mask.gather(2, index=indices).reshape( + bsz, -1 + ) + return dense_x, dense_padding_mask + + +class JoinSegmenter(Segmenter): + def logit_segment(self, logits, padding_mask): + preds = logits.argmax(dim=-1) + + if padding_mask.any(): + preds[padding_mask] = -1 # mark pad + uniques = [] + + bsz, tsz, csz = logits.shape + + for p in preds: + uniques.append( + p.cpu().unique_consecutive(return_inverse=True, return_counts=True) + ) + + new_tsz = max(u[0].numel() for u in uniques) + new_logits = logits.new_zeros(bsz, new_tsz, csz) + new_pad = padding_mask.new_zeros(bsz, new_tsz) + + for b in range(bsz): + u, idx, c = uniques[b] + keep = u != -1 + + if self.cfg.remove_zeros: + keep.logical_and_(u != 0) + + if self.training and not self.cfg.mean_pool_join: + u[0] = 0 + u[1:] = c.cumsum(0)[:-1] + m = c > 1 + r = torch.rand(m.sum()) + o = (c[m] * r).long() + u[m] += o + new_logits[b, : u.numel()] = logits[b, u] + else: + new_logits[b].index_add_( + dim=0, index=idx.to(new_logits.device), source=logits[b] + ) + new_logits[b, : c.numel()] /= c.unsqueeze(-1).to(new_logits.device) + + new_sz = keep.sum() + if not keep.all(): + kept_logits = new_logits[b, : c.numel()][keep] + new_logits[b, :new_sz] = kept_logits + + if new_sz < new_tsz: + pad = new_tsz - new_sz + new_logits[b, -pad:] = 0 + new_pad[b, -pad:] = True + + return new_logits, new_pad + + +class UniformRandomJoinSegmenter(UniformRandomSegmenter, JoinSegmenter): + pass + + +SEGMENT_FACTORY = { + SegmentationType.NONE: Segmenter, + SegmentationType.RANDOM: RandomSegmenter, + SegmentationType.UNIFORM_RANDOM: UniformRandomSegmenter, + SegmentationType.UNIFORM_RANDOM_JOIN: UniformRandomJoinSegmenter, + SegmentationType.JOIN: JoinSegmenter, +} + + +class Discriminator(nn.Module): + def __init__(self, dim, cfg: Wav2vec_UConfig): + super().__init__() + + inner_dim = cfg.discriminator_dim + kernel = cfg.discriminator_kernel + dilation = cfg.discriminator_dilation + self.max_pool = cfg.discriminator_max_pool + + if cfg.discriminator_causal: + padding = kernel - 1 + else: + padding = kernel // 2 + + def make_conv(in_d, out_d, k, p=0, has_dilation=True): + conv = nn.Conv1d( + in_d, + out_d, + kernel_size=k, + padding=p, + dilation=dilation if has_dilation else 1, + ) + if cfg.discriminator_spectral_norm: + conv = nn.utils.spectral_norm(conv) + elif cfg.discriminator_weight_norm: + conv = nn.utils.weight_norm(conv) + return conv + + inner_net = [ + nn.Sequential( + make_conv(inner_dim, inner_dim, kernel, padding), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + nn.Dropout(cfg.discriminator_dropout), + nn.GELU(), + ) + for _ in range(cfg.discriminator_depth - 1) + ] + [ + make_conv(inner_dim, 1, kernel, padding, has_dilation=False), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + ] + + if cfg.discriminator_linear_emb: + emb_net = [make_conv(dim, inner_dim, 1)] + else: + emb_net = [ + make_conv(dim, inner_dim, kernel, padding), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + ] + + if cfg.discriminator_act_after_linear: + emb_net.append(nn.GELU()) + + self.net = nn.Sequential( + *emb_net, + nn.Dropout(cfg.discriminator_dropout), + *inner_net, + ) + + def forward(self, x, padding_mask): + x = x.transpose(1, 2) # BTC -> BCT + x = self.net(x) + x = x.transpose(1, 2) + x_sz = x.size(1) + if padding_mask is not None and padding_mask.any() and padding_mask.dim() > 1: + padding_mask = padding_mask[:, : x.size(1)] + x[padding_mask] = float("-inf") if self.max_pool else 0 + x_sz = x_sz - padding_mask.sum(dim=-1) + x = x.squeeze(-1) + if self.max_pool: + x, _ = x.max(dim=-1) + else: + x = x.sum(dim=-1) + x = x / x_sz + return x + + +class Generator(nn.Module): + def __init__(self, input_dim, output_dim, cfg: Wav2vec_UConfig): + super().__init__() + + self.cfg = cfg + self.output_dim = output_dim + self.stride = cfg.generator_stride + self.dropout = nn.Dropout(cfg.generator_dropout) + + padding = cfg.generator_kernel // 2 + self.proj = nn.Sequential( + TransposeLast(), + nn.Conv1d( + input_dim, + output_dim, + kernel_size=cfg.generator_kernel, + stride=cfg.generator_stride, + dilation=cfg.generator_dilation, + padding=padding, + bias=cfg.generator_bias, + ), + TransposeLast(), + ) + + def forward(self, dense_x, tokens, dense_padding_mask): + dense_x = self.dropout(dense_x) + + dense_x = self.proj(dense_x) + if self.stride > 1: + dense_padding_mask = dense_padding_mask[:, :: self.stride] + + if dense_padding_mask.size(1) != dense_x.size(1): + new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) + diff = new_padding.size(1) - dense_padding_mask.size(1) + assert ( + diff > 0 + ), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}" + if diff > 0: + new_padding[:, diff:] = dense_padding_mask + else: + assert diff < 0 + new_padding = dense_padding_mask[:, :diff] + + dense_padding_mask = new_padding + + result = {} + + token_x = None + if tokens is not None: + token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) + token_x.scatter_(1, tokens.view(-1, 1).long(), 1) + token_x = token_x.view(tokens.shape + (self.output_dim,)) + + result["dense_x"] = dense_x + result["token_x"] = token_x + result["dense_padding_mask"] = dense_padding_mask + + return result + + +@register_model("wav2vec_u", dataclass=Wav2vec_UConfig) +class Wav2vec_U(BaseFairseqModel): + def calc_gradient_penalty(self, real_data, fake_data): + + b_size = min(real_data.size(0), fake_data.size(0)) + t_size = min(real_data.size(1), fake_data.size(1)) + + if self.cfg.probabilistic_grad_penalty_slicing: + + def get_slice(data, dim, target_size): + + size = data.size(dim) + diff = size - target_size + if diff <= 0: + return data + + start = np.random.randint(0, diff + 1) + return data.narrow(dim=dim, start=start, length=target_size) + + real_data = get_slice(real_data, 0, b_size) + real_data = get_slice(real_data, 1, t_size) + fake_data = get_slice(fake_data, 0, b_size) + fake_data = get_slice(fake_data, 1, t_size) + + else: + real_data = real_data[:b_size, :t_size] + fake_data = fake_data[:b_size, :t_size] + + alpha = torch.rand(real_data.size(0), 1, 1) + alpha = alpha.expand(real_data.size()) + alpha = alpha.to(real_data.device) + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + + disc_interpolates = self.discriminator(interpolates, None) + + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size(), device=real_data.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradient_penalty = (gradients.norm(2, dim=1) - 1) ** 2 + return gradient_penalty + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.update_num = num_updates + self.curr_temp = max( + self.max_temp * self.temp_decay ** num_updates, self.min_temp + ) + + def discrim_step(self, num_updates): + if num_updates < self.zero_pretrain_updates: + return False + if self.dynamic_step_thresh <= 0 or self.last_acc is None: + return num_updates % 2 == 1 + else: + return self.last_acc < self.dynamic_step_thresh + + def get_groups_for_update(self, num_updates): + return "discriminator" if self.discrim_step(num_updates) else "generator" + + def __init__(self, cfg: Wav2vec_UConfig, target_dict): + super().__init__() + + self.cfg = cfg + self.zero_index = target_dict.index("<SIL>") if "<SIL>" in target_dict else 0 + self.smoothness_weight = cfg.smoothness_weight + self.wgan_loss = cfg.wgan_loss + + output_size = len(target_dict) + self.pad = target_dict.pad() + self.eos = target_dict.eos() + self.smoothing = cfg.smoothing + self.smoothing_one_sided = cfg.smoothing_one_sided + self.no_softmax = cfg.no_softmax + self.gumbel = cfg.gumbel + self.hard_gumbel = cfg.hard_gumbel + self.last_acc = None + + self.gradient_penalty = cfg.gradient_penalty + self.code_penalty = cfg.code_penalty + self.blank_weight = cfg.blank_weight + self.blank_mode = cfg.blank_mode + self.blank_index = target_dict.index("<SIL>") if cfg.blank_is_sil else 0 + assert self.blank_index != target_dict.unk() + + self.discriminator = self.Discriminator(output_size, cfg) + for p in self.discriminator.parameters(): + p.param_group = "discriminator" + + self.pca_A = self.pca_b = None + d = cfg.input_dim + + self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation) + + self.generator = self.Generator( + d, output_size, cfg, lambda x: self.normalize(x)[0] + ) + + for p in self.generator.parameters(): + p.param_group = "generator" + + for p in self.segmenter.parameters(): + p.param_group = "generator" + + self.max_temp, self.min_temp, self.temp_decay = cfg.temp + self.curr_temp = self.max_temp + self.update_num = 0 + + @classmethod + def build_model(cls, cfg, task): + return cls(cfg, task.target_dictionary) + + def get_logits( + self, + net_output: Optional[Dict[str, List[Optional[torch.Tensor]]]], + normalize: bool = False, + ): + logits = net_output["logits"] + + if self.blank_weight != 0: + if self.blank_mode == "add": + logits[..., self.blank_index] += self.blank_weight + elif self.blank_mode == "set": + logits[..., self.blank_index] = self.blank_weight + else: + raise Exception(f"invalid blank mode {self.blank_mode}") + + padding = net_output["padding_mask"] + if padding.any(): + logits[padding] = float("-inf") + logits[padding][..., self.blank_index] = float("inf") + + if normalize: + logits = utils.log_softmax(logits.float(), dim=-1) + + return logits.transpose(0, 1) + + def get_normalized_probs( + self, + net_output: Tuple[ + torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]] + ], + log_probs: bool, + sample: Optional[Dict[str, torch.Tensor]] = None, + ): + logits = self.get_logits(net_output) + + probs = super().get_normalized_probs(logits, log_probs, sample) + # BTC -> TBC for ctc + probs = probs.transpose(0, 1) + return probs + + def normalize(self, dense_x): + + bsz, tsz, csz = dense_x.shape + + if dense_x.numel() == 0: + raise Exception(dense_x.shape) + _, k = dense_x.max(-1) + hard_x = ( + dense_x.new_zeros(bsz * tsz, csz) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(-1, csz) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + code_perplexity = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ) + + avg_probs = torch.softmax(dense_x.reshape(-1, csz).float(), dim=-1).mean(dim=0) + prob_perplexity = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ) + + if not self.no_softmax: + if self.training and self.gumbel: + dense_x = F.gumbel_softmax( + dense_x.float(), tau=self.curr_temp, hard=self.hard_gumbel + ).type_as(dense_x) + else: + dense_x = dense_x.softmax(-1) + + return dense_x, code_perplexity, prob_perplexity + + def forward( + self, + features, + padding_mask, + random_label=None, + dense_x_only=False, + segment=True, + ): + if segment: + features, padding_mask = self.segmenter.pre_segment(features, padding_mask) + + orig_size = features.size(0) * features.size(1) - padding_mask.sum() + + gen_result = self.generator(features, random_label, padding_mask) + + orig_dense_x, token_x = gen_result["dense_x"], gen_result["token_x"] + orig_dense_padding_mask = gen_result["dense_padding_mask"] + + if segment: + dense_x, dense_padding_mask = self.segmenter.logit_segment( + orig_dense_x, orig_dense_padding_mask + ) + else: + dense_x = orig_dense_x + dense_padding_mask = orig_dense_padding_mask + + dense_logits = dense_x + prob_perplexity = None + code_perplexity = None + + if not (self.no_softmax and dense_x_only): + dense_x, code_perplexity, prob_perplexity = self.normalize(dense_logits) + + if dense_x_only or self.discriminator is None: + return { + "logits": dense_x, + "padding_mask": dense_padding_mask, + } + + token_padding_mask = random_label == self.pad + + dense_y = self.discriminator(dense_x, dense_padding_mask) + token_y = self.discriminator(token_x, token_padding_mask) + + sample_size = features.size(0) + + d_step = self.discrim_step(self.update_num) + + fake_smooth = self.smoothing + real_smooth = self.smoothing + if self.smoothing_one_sided: + fake_smooth = 0 + + zero_loss = None + smoothness_loss = None + code_pen = None + + if d_step: + if self.wgan_loss: + loss_dense = dense_y.sum() + loss_token = -1 * token_y.sum() + else: + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_ones(dense_y.shape) - fake_smooth, + reduction="sum", + ) + loss_token = F.binary_cross_entropy_with_logits( + token_y, + token_y.new_zeros(token_y.shape) + real_smooth, + reduction="sum", + ) + if self.training and self.gradient_penalty > 0: + grad_pen = self.calc_gradient_penalty(token_x, dense_x) + grad_pen = grad_pen.sum() * self.gradient_penalty + else: + grad_pen = None + else: + grad_pen = None + loss_token = None + if self.update_num >= self.zero_pretrain_updates: + if self.wgan_loss: + loss_dense = -1 * dense_y.sum() + else: + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_zeros(dense_y.shape) + fake_smooth, + reduction="sum", + ) + num_vars = dense_x.size(-1) + if prob_perplexity is not None: + code_pen = (num_vars - prob_perplexity) / num_vars + if self.exponential_code_pen: + code_pen = (1 - 1 / code_pen ** 2).exp() + code_pen = code_pen * sample_size * self.code_penalty + else: + loss_dense = None + + if self.smoothness_weight > 0: + smoothness_loss = F.mse_loss( + dense_logits[:, :-1], dense_logits[:, 1:], reduction="none" + ) + smoothness_loss[dense_padding_mask[:, 1:]] = 0 + smoothness_loss = ( + smoothness_loss.mean() * sample_size * self.smoothness_weight + ) + + result = { + "losses": { + "grad_pen": grad_pen, + "code_pen": code_pen, + "smoothness": smoothness_loss, + }, + "temp": self.curr_temp, + "code_ppl": code_perplexity, + "prob_ppl": prob_perplexity, + "d_steps": int(d_step), + "sample_size": sample_size, + } + + suff = "_d" if d_step else "_g" + result["losses"]["dense" + suff] = loss_dense + result["losses"]["token" + suff] = loss_token + + return result diff --git a/examples/wav2vec/unsupervised/scripts/apply_pca.py b/examples/wav2vec/unsupervised/scripts/apply_pca.py new file mode 100644 index 0000000000..0cddd20001 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/apply_pca.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import math +import numpy as np +import tqdm +import torch +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="transforms features via a given pca and stored them in target dir" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--pca-path', type=str, help='pca location. will append _A.npy and _b.npy', required=True) + parser.add_argument('--batch-size', type=int, default=2048000, help='batch size') + parser.add_argument('--unfiltered', action='store_true', help='process the unfiltered version') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + data_poth = source_path + "_unfiltered" if args.unfiltered else source_path + + print(f"data path: {data_poth}") + + features = np.load(data_poth + ".npy", mmap_mode="r") + pca_A = torch.from_numpy(np.load(args.pca_path + "_A.npy")).cuda() + pca_b = torch.from_numpy(np.load(args.pca_path + "_b.npy")).cuda() + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + copyfile(data_poth + ".lengths", save_path + ".lengths") + copyfile(source_path + ".phn", save_path + ".phn") + copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + batches = math.ceil(features.shape[0] / args.batch_size) + + with torch.no_grad(): + for b in tqdm.trange(batches): + start = b * args.batch_size + end = start + args.batch_size + x = torch.from_numpy(features[start:end]).cuda() + x = torch.matmul(x, pca_A) + pca_b + npaa.append(x.cpu().numpy()) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/copy_labels.py b/examples/wav2vec/unsupervised/scripts/copy_labels.py new file mode 100644 index 0000000000..989868388e --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/copy_labels.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +for idx, line in enumerate(sys.stdin): + print(f"utt{idx:010d} {line}", end="") diff --git a/examples/wav2vec/unsupervised/scripts/filter_lexicon.py b/examples/wav2vec/unsupervised/scripts/filter_lexicon.py new file mode 100644 index 0000000000..5bf3e51e7a --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/filter_lexicon.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + +from fairseq.data import Dictionary + + +def get_parser(): + parser = argparse.ArgumentParser( + description="filters a lexicon given a unit dictionary" + ) + parser.add_argument("-d", "--unit-dict", help="unit dictionary", required=True) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + d = Dictionary.load(args.unit_dict) + symbols = set(d.symbols) + + for line in sys.stdin: + items = line.rstrip().split() + skip = len(items) < 2 + for x in items[1:]: + if x not in symbols: + skip = True + break + if not skip: + print(line, end="") + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/filter_tsv.py b/examples/wav2vec/unsupervised/scripts/filter_tsv.py new file mode 100644 index 0000000000..a09d79acf3 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/filter_tsv.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import argparse +import sys + + +parser = argparse.ArgumentParser() +parser.add_argument("--tsv", required=True, type=str) +parser.add_argument("--no-skip", action="store_true") +parser.add_argument("--keep", action="store_true") +params = parser.parse_args() + + +def get_fname(line): + p = os.path.basename(line.split("\t")[0]) + p = os.path.splitext(p)[0] + return p + + +# filenames to exclude +seen = set() +with open(params.tsv) as f: + if not params.no_skip: + root = next(f).rstrip() + for line in f: + seen.add(get_fname(line)) + +for i, line in enumerate(sys.stdin): + exists = get_fname(line) in seen + keep = (exists and params.keep) or (not exists and not params.keep) + if i == 0 or keep: + print(line, end="") diff --git a/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py b/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py new file mode 100644 index 0000000000..8c3138e55b --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + +from g2p_en import G2p + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("root_dirs", nargs="*") + parser.add_argument("--insert-silence", "-s", action="store_true") + args = parser.parse_args() + sil = "<s>" + + wrd_to_phn = {} + g2p = G2p() + for line in sys.stdin: + words = line.strip().split() + phones = [] + if args.insert_silence: + phones.append(sil) + for w in words: + if w not in wrd_to_phn: + wrd_to_phn[w] = g2p(w) + phones.extend(wrd_to_phn[w]) + if args.insert_silence: + phones.append(sil) + try: + print(" ".join(phones)) + except: + print(wrd_to_phn, w, phones, file=sys.stderr) + raise + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py b/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py new file mode 100644 index 0000000000..36c85d1e2f --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +def main(): + for line in sys.stdin: + print(line.replace(" ", "").replace("|", " ").strip()) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/mean_pool.py b/examples/wav2vec/unsupervised/scripts/mean_pool.py new file mode 100644 index 0000000000..1145e774eb --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/mean_pool.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import math +import numpy as np +import tqdm +import torch +import torch.nn.functional as F +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="mean pools representations by compressing uniform splits of the data" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to') + + parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + + print(f"data path: {source_path}") + + features = np.load(source_path + ".npy", mmap_mode="r") + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + copyfile(source_path + ".phn", save_path + ".phn") + copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + with open(source_path + ".lengths", "r") as lf: + lengths = lf.readlines() + + fsz = features.shape[-1] + start = 0 + with torch.no_grad(): + with open(save_path + ".lengths", "w") as lengths_out: + for length in tqdm.tqdm(lengths): + length = int(length) + end = start + length + feats = features[start:end] + start += length + x = torch.from_numpy(feats).cuda() + target_num = math.ceil(length * args.subsample_rate) + rem = length % target_num + + if rem > 0: + if args.remove_extra: + to_rem = target_num - rem + target_num -= 1 + x = x[:-to_rem] + else: + to_add = target_num - rem + x = F.pad(x, [0, 0, 0, to_add]) + x[-to_add:] = x[-to_add - 1] + + x = x.view(target_num, -1, fsz) + x = x.mean(dim=-2) + print(target_num, file=lengths_out) + npaa.append(x.cpu().numpy()) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/merge_clusters.py b/examples/wav2vec/unsupervised/scripts/merge_clusters.py new file mode 100644 index 0000000000..6502ed5718 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/merge_clusters.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import numpy as np +import tqdm +import torch +import random +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="transforms features via a given pca and stored them in target dir" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--cluster-dir', help='where the clusters are') + parser.add_argument('--pooling', type=str, default='mean', choices=['mean', 'sample'], help='how to pool') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + cluster_path = osp.join(args.cluster_dir, args.split + ".src") + print(f"data path: {source_path}") + + features = np.load(source_path + ".npy", mmap_mode="r") + sizes = [] + offsets = [] + offset = 0 + with open(source_path + ".lengths", "r") as len_f: + for line in len_f: + length = int(line.rstrip()) + sizes.append(length) + offsets.append(offset) + offset += length + + clusters = [] + with open(cluster_path, "r") as cf: + for line in cf: + line = line.rstrip() + items = line.split() + items = list(map(int, items)) + clusters.append(items) + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + copyfile(source_path + ".phn", save_path + ".phn") + if os.path.exists(source_path + ".phnsc"): + copyfile(source_path + ".phnsc", save_path + ".phnsc") + copyfile( + osp.join(args.source, "dict.phnsc.txt"), + osp.join(args.save_dir, "dict.phnsc.txt"), + ) + copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + def merge(feats, clust): + feats = torch.from_numpy(feats.copy()) + clust = torch.LongTensor(clust) + _, counts = clust.unique_consecutive(return_counts=True) + curr = 0 + + merged = [] + for c in counts: + c = c.item() + start = curr + end = curr + c + curr += c + if args.pooling == "mean": + new_x = feats[start:end].mean(dim=0) + elif args.pooling == "sample": + new_x = feats[start + int(random.random() * c)] + else: + raise NotImplementedError() + merged.append(new_x) + + return torch.stack(merged, dim=0).numpy() + + with open(save_path + ".lengths", "w") as l_f: + for size, offset, clust in tqdm.tqdm( + zip(sizes, offsets, clusters), total=len(sizes) + ): + end = size + offset + feats = features[offset:end] + feats = merge(feats, clust) + print(len(feats), file=l_f) + npaa.append(feats) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py b/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py new file mode 100644 index 0000000000..1284747795 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fasttext as ft +import regex +import sys + + +def get_parser(): + parser = argparse.ArgumentParser( + description="reads text from stdin and outputs normalized, lid-filtered version to stdout" + ) + parser.add_argument( + "--fasttext-model", + help="path to fasttext model", + default="lid.187.bin", + ) + parser.add_argument("--lang", help="language id", required=True) + parser.add_argument( + "--lid-threshold", + type=float, + help="threshold for this lang id probability", + default=0.4, + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") + + lg = args.lang.lower() + lg_label = f"__label__{lg}" + thresh = args.lid_threshold + + model = ft.load_model(args.fasttext_model) + for line in sys.stdin: + line = line.strip() + line = filter_r.sub(" ", line) + line = " ".join(line.split()) + lid, prob = model.predict(line, k=100) + try: + target_idx = lid.index(lg_label) + except ValueError: + continue + if target_idx == 0 or prob[target_idx] >= thresh: + print(line) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/normalize_text.py b/examples/wav2vec/unsupervised/scripts/normalize_text.py new file mode 100644 index 0000000000..9d0ffeb27d --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/normalize_text.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import regex +import sys + + +def main(): + filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") + + for line in sys.stdin: + line = line.strip() + line = filter_r.sub(" ", line) + line = " ".join(line.split()) + print(line) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/pca.py b/examples/wav2vec/unsupervised/scripts/pca.py new file mode 100644 index 0000000000..948cf5319f --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/pca.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import numpy as np + +import faiss + + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute a pca matrix given an array of numpy features" + ) + # fmt: off + parser.add_argument('data', help='numpy file containing features') + parser.add_argument('--output', help='where to save the pca matrix', required=True) + parser.add_argument('--dim', type=int, help='dim for pca reduction', required=True) + parser.add_argument('--eigen-power', type=float, default=0, help='eigen power, -0.5 for whitening') + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + print("Reading features") + x = np.load(args.data, mmap_mode="r") + + print("Computing PCA") + pca = faiss.PCAMatrix(x.shape[-1], args.dim, args.eigen_power) + pca.train(x) + b = faiss.vector_to_array(pca.b) + A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) + + os.makedirs(args.output, exist_ok=True) + + prefix = str(args.dim) + if args.eigen_power != 0: + prefix += f"_{args.eigen_power}" + + np.save(osp.join(args.output, f"{prefix}_pca_A"), A.T) + np.save(osp.join(args.output, f"{prefix}_pca_b"), b) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py b/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py new file mode 100644 index 0000000000..c6512d7322 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import numpy as np +import sys + + +def get_parser(): + parser = argparse.ArgumentParser( + description="converts words to phones adding optional silences around in between words" + ) + parser.add_argument( + "--sil-prob", + "-s", + type=float, + default=0, + help="probability of inserting silence between each word", + ) + parser.add_argument( + "--surround", + action="store_true", + help="if set, surrounds each example with silence", + ) + parser.add_argument( + "--lexicon", + help="lexicon to convert to phones", + required=True, + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + sil_prob = args.sil_prob + surround = args.surround + sil = "<SIL>" + + wrd_to_phn = {} + + with open(args.lexicon, "r") as lf: + for line in lf: + items = line.rstrip().split() + assert len(items) > 1, line + assert items[0] not in wrd_to_phn, items + wrd_to_phn[items[0]] = items[1:] + + for line in sys.stdin: + words = line.strip().split() + + if not all(w in wrd_to_phn for w in words): + continue + + phones = [] + if surround: + phones.append(sil) + + sample_sil_probs = None + if sil_prob > 0 and len(words) > 1: + sample_sil_probs = np.random.random(len(words) - 1) + + for i, w in enumerate(words): + phones.extend(wrd_to_phn[w]) + if ( + sample_sil_probs is not None + and i < len(sample_sil_probs) + and sample_sil_probs[i] < sil_prob + ): + phones.append(sil) + + if surround: + phones.append(sil) + print(" ".join(phones)) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/prepare_audio.sh b/examples/wav2vec/unsupervised/scripts/prepare_audio.sh new file mode 100644 index 0000000000..893c9fda1a --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/prepare_audio.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env zsh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +source_dir=$1 +tgt_dir=$2 +model=$3 + +if [ -z "$4" ] + then + dim=512 + else + dim=$4 +fi + +echo "using $dim dim for PCA" + +train_split=train +valid_split=valid +test_split=test + +mkdir -p $tgt_dir + +cp $source_dir/*.tsv $tgt_dir +cp $source_dir/*.wrd $tgt_dir +cp $source_dir/*.ltr $tgt_dir +cp $source_dir/*.phn $tgt_dir +cp $source_dir/dict* $tgt_dir + +setopt shwordsplit + +for split in $train_split $valid_split $test_split; do + python wav2vec_extract_features.py $source_dir --split $split \ + --save-dir $tgt_dir --checkpoint $model +done + +python wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \ +--checkpoint $model --save-dir $tgt_dir -f "CLUS128" --sample-pct 1.0 + +for split in $train_split $valid_split $test_split; do + python wav2vec_apply_cluster_faiss.py $tgt_dir \ + --checkpoint $model --path $tgt_dir/CLUS128 --split $split +done + +python pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim + +for split in $train_split $valid_split $test_split; do + python apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000 + + python merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \ + --split $split --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean --pooling mean + + python mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \ + --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean_pooled --split $split +done \ No newline at end of file diff --git a/examples/wav2vec/unsupervised/scripts/prepare_text.sh b/examples/wav2vec/unsupervised/scripts/prepare_text.sh new file mode 100644 index 0000000000..e9090a3d80 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/prepare_text.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env zsh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +lg=$1 +text_path=$2 +target_dir=$3 + +ph_lg=${lg:l} +if test "$lg" = 'fr'; then + ph_lg='fr-fr' +elif test "$lg" = 'en'; then + ph_lg='en-us' +elif test "$lg" = 'pt'; then + ph_lg='pt-br' +fi + +echo $lg +echo $ph_lg +echo $text_path +echo $target_dir + +mkdir -p $target_dir +python normalize_and_filter_text.py --lang $lg < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/lm.upper.lid.txt --only-source --destdir $target_dir --thresholdsrc 2 --padding-factor 1 --dict-only +cut -f1 -d' ' $target_dir/dict.txt | grep -v -x '[[:punct:]]*' | grep -Pv '\d\d\d\d\d+' >! $target_dir/words.txt + +one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$(which espeak) phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags) +sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$(which espeak) phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags + +echo "one is ${one}" + +sed -i "s/${one}$//" $target_dir/phones.txt +paste $target_dir/words.txt $target_dir/phones.txt >! $target_dir/lexicon.lst + +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc 1000 --padding-factor 1 --dict-only + +python filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst +python phonemize_with_sil.py -s 0.25 --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt +cp $target_dir/phones/dict.txt $target_dir/phones/dict.phn.txt +echo "<SIL> 0" >> $target_dir/phones/dict.phn.txt +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones/lm.phones.filtered.txt --workers 70 --only-source --destdir $target_dir/phones --srcdict $target_dir/phones/dict.phn.txt + +lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa +build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin +lg=$lg python examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones "blank_symbol='<SIL>'" +lg=$lg python examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones + +lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa +build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin +lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa +build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin + +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones "blank_symbol='<SIL>'" diff --git a/examples/wav2vec/unsupervised/scripts/remove_silence.py b/examples/wav2vec/unsupervised/scripts/remove_silence.py new file mode 100644 index 0000000000..417b703de8 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/remove_silence.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +get intervals from .vads file, specify output data, and this script removes silences and saves the audio data in out path folder +paths=shards/train.tsv +vads=shards/train.vads +python remove_silence.py --paths $paths --vads $vads +""" + +import os +import argparse +import torch +import torchaudio +import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument("--tsv", default="", type=str) +parser.add_argument("--vads", default="", type=str) +parser.add_argument("--out", type=str) +params = parser.parse_args() + +# load paths +paths = [] +with open(params.tsv) as f: + root = next(f).rstrip() + for line in f: + paths.append(os.path.join(root, line.rstrip().split("\t")[0])) + +# load vads +list_intervals = [] +with open(params.vads) as f: + for line in f: + interval = [ + [int(w.split(":")[0]), int(w.split(":")[1])] for w in line.rstrip().split() + ] + list_intervals.append(interval) + + +# load audio and keep only intervals (i.e. remove silences) +for i in tqdm.trange(len(paths)): + data, _ = torchaudio.load(paths[i]) + if len(list_intervals[i]) > 0: + data_filtered = torch.cat( + [data[0][int(it[0]) : int(it[1])] for it in list_intervals[i]] + ).unsqueeze(0) + else: + data_filtered = data + + # YOU MAY NEED TO MODIFY THIS TO GET THE RIGHT SUBPATH + # outpath = params.out + '/'.join(paths[i].split('/')[-1]) + outpath = params.out + "/" + "/".join(paths[i].split("/")[-2:]) + + if not os.path.isdir("/".join(outpath.split("/")[:-1])): + os.makedirs("/".join(outpath.split("/")[:-1])) + if not os.path.exists(outpath): + print(outpath) + torchaudio.save(outpath, data_filtered, sample_rate=16000) + else: + print(outpath, "exists!") diff --git a/examples/wav2vec/unsupervised/scripts/vads.py b/examples/wav2vec/unsupervised/scripts/vads.py new file mode 100644 index 0000000000..1acd95369c --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/vads.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +sys.path.append("/path/to/rVADfast_py_2.0") +import speechproc +from copy import deepcopy +from scipy.signal import lfilter + +import numpy as np +from tqdm import tqdm +import soundfile as sf +import os.path as osp + + +def rvad(path): + winlen, ovrlen, pre_coef, nfilter, nftt = 0.025, 0.01, 0.97, 20, 512 + ftThres = 0.5 + vadThres = 0.4 + opts = 1 + + data, fs = sf.read(path) + assert fs == 16_000, "sample rate must be 16khz" + ft, flen, fsh10, nfr10 = speechproc.sflux(data, fs, winlen, ovrlen, nftt) + + # --spectral flatness -- + pv01 = np.zeros(ft.shape[0]) + pv01[np.less_equal(ft, ftThres)] = 1 + pitch = deepcopy(ft) + + pvblk = speechproc.pitchblockdetect(pv01, pitch, nfr10, opts) + + # --filtering-- + ENERGYFLOOR = np.exp(-50) + b = np.array([0.9770, -0.9770]) + a = np.array([1.0000, -0.9540]) + fdata = lfilter(b, a, data, axis=0) + + # --pass 1-- + noise_samp, noise_seg, n_noise_samp = speechproc.snre_highenergy( + fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk + ) + + # sets noisy segments to zero + for j in range(n_noise_samp): + fdata[range(int(noise_samp[j, 0]), int(noise_samp[j, 1]) + 1)] = 0 + + vad_seg = speechproc.snre_vad( + fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk, vadThres + ) + return vad_seg, data + + +def main(): + stride = 160 + lines = sys.stdin.readlines() + root = lines[0].rstrip() + for fpath in tqdm(lines[1:]): + path = osp.join(root, fpath.split()[0]) + vads, wav = rvad(path) + + start = None + vad_segs = [] + for i, v in enumerate(vads): + if start is None and v == 1: + start = i * stride + elif start is not None and v == 0: + vad_segs.append((start, i * stride)) + start = None + if start is not None: + vad_segs.append((start, len(wav))) + + print(" ".join(f"{v[0]}:{v[1]}" for v in vad_segs)) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py b/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py new file mode 100644 index 0000000000..25bc4e41ac --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os.path as osp +import numpy as np +import tqdm +import torch +import sys + +import faiss +import torch.nn.functional as F + +from wav2vec_cluster_faiss import parse_faiss_specs, Wav2VecFeatureReader + + +def get_parser(): + parser = argparse.ArgumentParser(description="apply clusters") + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--split', help='split to process', required=True) + parser.add_argument('--labels', help='split to process', default="phn") + parser.add_argument('--path', help='path to pca and centroids', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) + parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) + parser.add_argument('--max-tsz', type=int, help='batch kmeans up to this much', default=14) + # fmt: on + + return parser + + +def get_iterator(args): + with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp, open( + osp.join(args.data, f"{args.split}.{args.labels}"), "r" + ) as lp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [line.rstrip() for line in lines if len(line) > 0] + lbls = [line.rstrip() for line in lp] + + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) + + def iterate(): + for fname, lbl in zip(files, lbls): + file = osp.join(root, fname.split("\t")[0]) + feats = reader.get_feats(file) + yield feats.data, fname, lbl + + return iterate, num, root + + +def main(): + parser = get_parser() + args = parser.parse_args() + + spec = osp.basename(args.path) + + try: + faiss_spec = parse_faiss_specs(spec.rstrip("/"))[0] + except: + print(spec) + raise + + print("Faiss Spec:", faiss_spec, file=sys.stderr) + + if faiss_spec.pca: + A = torch.from_numpy(np.load(osp.join(args.path, "pca_A.npy"))).cuda() + b = torch.from_numpy(np.load(osp.join(args.path, "pca_b.npy"))).cuda() + print("Loaded PCA", file=sys.stderr) + + centroids = np.load(osp.join(args.path, "centroids.npy")) + print("Loaded centroids", centroids.shape, file=sys.stderr) + + res = faiss.StandardGpuResources() + index_flat = ( + faiss.IndexFlatL2(centroids.shape[1]) + if not faiss_spec.sphere + else faiss.IndexFlatIP(centroids.shape[1]) + ) + faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat) + faiss_index.add(centroids) + + generator, num, root = get_iterator(args) + iterator = generator() + + with torch.no_grad(): + with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open( + osp.join(args.path, f"{args.split}.tsv"), "w" + ) as pp, open(osp.join(args.path, f"{args.split}.{args.labels}"), "w") as lp: + print(root, file=pp) + for f, fname, lbl in tqdm.tqdm(iterator, total=num): + if faiss_spec.pca: + f = torch.mm(f, A) + b + if faiss_spec.norm: + f = F.normalize(f, p=2, dim=-1) + + f = f.cpu().numpy() + + _, z = faiss_index.search(f, 1) + + print(" ".join(str(x.item()) for x in z), file=fp) + print(fname, file=pp) + print(lbl, file=lp) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py b/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py new file mode 100644 index 0000000000..632a69e9f4 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import os +import os.path as osp +import random +import numpy as np +import tqdm +import torch + +from collections import namedtuple + +import faiss + +import fairseq +import soundfile as sf + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute kmeans codebook from kaldi-computed feats" + ) + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) + parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0) + parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) + parser.add_argument('--faiss-specs', '-f', type=str, + help='faiss index specs; separated by space ' + 'format is: PCAx_NORM_CLUSx_SPHERICAL -> ' + 'PCAx if exists first apply PCA ' + 'NORM if exists, normalize the vector by L2 norm ' + 'CLUSx must exist, cluster to x clusters ' + 'SPEHRICAL if exists, apply spherical kmeans', + default='l2') + # fmt: on + + return parser + + +faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"]) + + +def parse_faiss_specs(specs_str): + specs = [] + for ss in specs_str.split(): + comps = ss.split("_") + pca = 0 + norm = False + n_clus = 0 + sphere = False + for c in comps: + if c.startswith("PCA"): + pca = int(c[3:]) + elif c == "NORM": + norm = True + elif c.startswith("CLUS"): + n_clus = int(c[4:]) + elif c == "SPHERICAL": + sphere = True + assert n_clus > 0 + specs.append( + faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss) + ) + return specs + + +class Wav2VecFeatureReader(object): + def __init__(self, cp_file, layer): + state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file) + + self.layer = layer + + if "cfg" in state: + w2v_args = state["cfg"] + task = fairseq.tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) + else: + w2v_args = state["args"] + task = fairseq.tasks.setup_task(w2v_args) + model = task.build_model(w2v_args) + model.load_state_dict(state["model"], strict=True) + model.eval() + model.cuda() + self.model = model + + def read_audio(self, fname): + """Load an audio file and return PCM along with the sample rate""" + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav + + def get_feats(self, loc): + x = self.read_audio(loc) + with torch.no_grad(): + source = torch.from_numpy(x).view(1, -1).float().cuda() + res = self.model( + source=source, mask=False, features_only=True, layer=self.layer + ) + return res["layer_results"][self.layer][0].squeeze(1) + + +def get_iterator(args): + with open(args.data, "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] + + if getattr(args, "sample_pct", 0) > 0: + files = random.sample(files, int(args.sample_pct * len(files))) + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) + + def iterate(): + for fname in files: + feats = reader.get_feats(fname) + yield feats.cpu().numpy() + + return iterate, num + + +def main(): + parser = get_parser() + args = parser.parse_args() + + faiss_specs = parse_faiss_specs(args.faiss_specs) + print("Faiss Specs:", faiss_specs) + + feat_path = osp.join(args.save_dir, "features") + if osp.exists(feat_path + ".npy"): + feats = np.load(feat_path + ".npy") + else: + generator, num = get_iterator(args) + iterator = generator() + + feats = [] + for f in tqdm.tqdm(iterator, total=num): + feats.append(f) + + del iterator + del generator + + feats = np.concatenate(feats) + + print(feats.shape) + + os.makedirs(args.save_dir, exist_ok=True) + # np.save(feat_path, feats) + + gc.collect() + torch.cuda.empty_cache() + + reload = False + for spec in faiss_specs: + print("Processing spec", spec) + + if reload: + print("Reloading...") + del feats + gc.collect() + feats = np.load(feat_path + ".npy") + + save_path = osp.join(args.save_dir, spec.spec_str) + os.makedirs(save_path, exist_ok=True) + d = feats.shape[-1] + x = feats + if spec.pca > 0: + print("Computing PCA") + pca = faiss.PCAMatrix(d, spec.pca) + pca.train(x) + d = spec.pca + b = faiss.vector_to_array(pca.b) + A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) + np.save(osp.join(save_path, "pca_A"), A.T) + np.save(osp.join(save_path, "pca_b"), b) + print("Applying PCA") + x = pca.apply_py(x) + + if spec.norm: + reload = spec.pca <= 0 + print("Normalizing") + faiss.normalize_L2(x) + + print("Computing kmeans") + kmeans = faiss.Kmeans( + d, + spec.n_clus, + niter=50, + verbose=True, + spherical=spec.sphere, + max_points_per_centroid=feats.shape[0], + gpu=True, + nredo=3, + ) + kmeans.train(x) + np.save(osp.join(save_path, "centroids"), kmeans.centroids) + del kmeans + del x + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py b/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py new file mode 100644 index 0000000000..023dd1aaa5 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import tqdm +import torch +import torch.nn.functional as F +from shutil import copyfile + +from npy_append_array import NpyAppendArray + +import fairseq +import soundfile as sf + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute kmeans codebook from kaldi-computed feats" + ) + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True) + # fmt: on + + return parser + + +class Wav2VecFeatureReader(object): + def __init__(self, cp_file): + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [cp_file] + ) + model = model[0] + model.eval() + model.cuda() + self.model = model + self.task = task + + def read_audio(self, fname): + """Load an audio file and return PCM along with the sample rate""" + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav + + def get_feats(self, loc): + x = self.read_audio(loc) + with torch.no_grad(): + source = torch.from_numpy(x).float().cuda() + if self.task.cfg.normalize: + assert source.dim() == 1, source.dim() + with torch.no_grad(): + source = F.layer_norm(source, source.shape) + source = source.view(1, -1) + + m_res = self.model(source=source, mask=False, features_only=True, layer=14) + return m_res["x"].squeeze(0).cpu() + + +def get_iterator(args): + with open(osp.join(args.data, args.split) + ".tsv", "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] + + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint) + + def iterate(): + for fname in files: + w2v_feats = reader.get_feats(fname) + yield w2v_feats + + return iterate, num + + +def main(): + parser = get_parser() + args = parser.parse_args() + + os.makedirs(args.save_dir, exist_ok=True) + + def create_files(dest): + copyfile(osp.join(args.data, args.split) + ".tsv", dest + ".tsv") + if osp.exists(osp.join(args.data, args.split) + ".wrd"): + copyfile(osp.join(args.data, args.split) + ".wrd", dest + ".wrd") + if osp.exists(osp.join(args.data, args.split) + ".phn"): + copyfile(osp.join(args.data, args.split) + ".phn", dest + ".phn") + + if osp.exists(dest + ".npy"): + os.remove(dest + ".npy") + npaa = NpyAppendArray(dest + ".npy") + return npaa + + save_path = osp.join(args.save_dir, args.split) + npaa = create_files(save_path) + + generator, num = get_iterator(args) + iterator = generator() + + with open(save_path + ".lengths", "w") as l_f: + for w2v_feats in tqdm.tqdm(iterator, total=num): + print(len(w2v_feats), file=l_f) + + if len(w2v_feats) > 0: + npaa.append(w2v_feats.numpy()) + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/scripts/wer.py b/examples/wav2vec/unsupervised/scripts/wer.py new file mode 100644 index 0000000000..613ab50d39 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/wer.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Implement unsupervised metric for decoding hyperparameter selection: + $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ +""" +import argparse +import logging +import sys + +import editdistance + +logging.root.setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-s", "--hypo", help="hypo transcription", required=True) + parser.add_argument( + "-r", "--reference", help="reference transcription", required=True + ) + return parser + + +def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p): + d_cnt = 0 + w_cnt = 0 + w_cnt_h = 0 + for uid in hyp_uid_to_tra: + ref = ref_uid_to_tra[uid].split() + if g2p is not None: + hyp = g2p(hyp_uid_to_tra[uid]) + hyp = [p for p in hyp if p != "'" and p != " "] + hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] + else: + hyp = hyp_uid_to_tra[uid].split() + d_cnt += editdistance.eval(ref, hyp) + w_cnt += len(ref) + w_cnt_h += len(hyp) + wer = float(d_cnt) / w_cnt + logger.debug( + ( + f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; " + f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" + ) + ) + return wer + + +def main(): + args = get_parser().parse_args() + + errs = 0 + count = 0 + with open(args.hypo, "r") as hf, open(args.reference, "r") as rf: + for h, r in zip(hf, rf): + h = h.rstrip().split() + r = r.rstrip().split() + errs += editdistance.eval(r, h) + count += len(r) + + logger.info(f"UER: {errs / count * 100:.2f}%") + + +if __name__ == "__main__": + main() + + +def load_tra(tra_path): + with open(tra_path, "r") as f: + uid_to_tra = {} + for line in f: + uid, tra = line.split(None, 1) + uid_to_tra[uid] = tra + logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") + return uid_to_tra diff --git a/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py b/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py new file mode 100644 index 0000000000..f83471409a --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +def main(): + for line in sys.stdin: + print(" ".join(list(line.strip().replace(" ", "|"))) + " |") + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec/unsupervised/tasks/__init__.py b/examples/wav2vec/unsupervised/tasks/__init__.py new file mode 100644 index 0000000000..6d7dd625e0 --- /dev/null +++ b/examples/wav2vec/unsupervised/tasks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .unpaired_audio_text import UnpairedAudioText + + +__all__ = [ + "UnpairedAudioText", +] diff --git a/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py b/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py new file mode 100644 index 0000000000..0b770a1509 --- /dev/null +++ b/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py @@ -0,0 +1,437 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from dataclasses import dataclass, field +import logging +import math +import os +from typing import Optional +import torch + +from fairseq.logging import metrics +from fairseq.tasks import FairseqTask, register_task +from ..data import ExtractedFeaturesDataset, RandomInputDataset + +from fairseq.data import ( + Dictionary, + data_utils, + StripTokenDataset, +) +from fairseq.dataclass import FairseqDataclass +from fairseq.distributed.utils import get_data_parallel_world_size +from omegaconf import MISSING + +from examples.speech_recognition.kaldi.kaldi_decoder import ( + KaldiDecoder, + KaldiDecoderConfig, +) + + +logger = logging.getLogger(__name__) + + +@dataclass +class DecodingConfig(FairseqDataclass): + kenlm_path: Optional[str] = None + lm_weight: float = 0 + blank_weight: float = 0 + + +@dataclass +class UnpairedAudioTextConfig(FairseqDataclass): + data: str = field( + default=MISSING, metadata={"help": "path to data directory containing audio"} + ) + text_data: str = field( + default=MISSING, metadata={"help": "path to data directory containing text"} + ) + max_length: Optional[int] = None + labels: Optional[str] = field( + default=None, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, + ) + unfiltered: bool = field( + default=False, metadata={"help": "load data with _unfiltered suffix"} + ) + ctc_eval: bool = field( + default=False, metadata={"help": "eval UER as if computed by CTC"} + ) + sort_by_length: bool = field( + default=True, metadata={"help": "sort examples by length of audio timesteps"} + ) + shuffle: bool = field(default=True, metadata={"help": "shuffle examples"}) + append_eos: bool = field(default=False, metadata={"help": "append eos"}) + uppercase: Optional[bool] = field( + default=False, metadata={"help": "uppercase for LM score computation"} + ) + skipwords: Optional[str] = field( + default="", + metadata={ + "help": "comma-separated words to be removed for LM score computation" + }, + ) + kenlm_path: Optional[str] = None + vocab_usage_power: float = 2 + + word_decoder_config: Optional[KaldiDecoderConfig] = None + word_kenlm_path: Optional[str] = None + + decoding_config: DecodingConfig = DecodingConfig() + + +@register_task("gan_audio_pretraining_feats", dataclass=UnpairedAudioTextConfig) +class UnpairedAudioText(FairseqTask): + """ """ + + cfg: UnpairedAudioTextConfig + + def __init__( + self, + cfg: UnpairedAudioTextConfig, + source_dictionary=None, + target_dictionary=None, + ): + super().__init__(cfg) + + self._target_dictionary = target_dictionary + self._source_dictionary = source_dictionary + self.num_symbols = ( + len([s for s in target_dictionary.symbols if not s.startswith("madeup")]) + - target_dictionary.nspecial + ) + self.sil_id = ( + target_dictionary.index("<SIL>") if "<SIL>" in target_dictionary else -1 + ) + self.kenlm = None + if cfg.kenlm_path is not None: + import kenlm + + self.kenlm = kenlm.Model(cfg.kenlm_path) + + self.word_kenlm = None + if cfg.word_kenlm_path is not None: + import kenlm + + self.word_kenlm = kenlm.Model(cfg.word_kenlm_path) + + self.uppercase = cfg.uppercase + self.skipwords = set(cfg.skipwords.split(",")) + + def str_postprocess(s): + s = " ".join(w for w in s.split() if w not in self.skipwords) + s = s.upper() if self.uppercase else s + return s + + self.str_postprocess = str_postprocess + self.compute_lm_score = lambda s: self.kenlm.score(self.str_postprocess(s)) + + self.compute_word_score = None + if cfg.word_decoder_config is not None: + self.kaldi_decoder = KaldiDecoder(cfg.word_decoder_config, beam=10) + + def compute_word_score(logits, padding): + res = self.kaldi_decoder.decode(logits, padding) + for r in res: + r = r.result() + assert len(r) == 1 + r = r[0] + yield r["score"], r["words"] + + self.compute_word_score = compute_word_score + + @classmethod + def setup_task(cls, cfg: UnpairedAudioTextConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + dict_path = os.path.join(cfg.text_data, "dict.txt") + if os.path.exists(dict_path): + target_dictionary = Dictionary.load(dict_path) + else: + dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") + target_dictionary = Dictionary.load(dict_path) + + return cls(cfg, target_dictionary=target_dictionary) + + def optimizer_step(self, optimizer, model, update_num): + if hasattr(model, "get_groups_for_update"): + groups = model.get_groups_for_update(update_num) + optimizer.step(groups={groups}) + else: + optimizer.step() + + def valid_step(self, sample, model, criterion): + res = model( + **sample["net_input"], + dense_x_only=True, + ) + + dense_x = res["logits"] + padding_mask = res["padding_mask"] + + word_scores = None + if self.compute_word_score is not None: + word_scores = self.compute_word_score(dense_x.cpu(), padding_mask.cpu()) + + z = dense_x.argmax(-1) + z[padding_mask] = self.target_dictionary.pad() + + vocab_seen = torch.zeros(self.num_symbols, dtype=torch.bool) + + import editdistance + + c_err = 0 + c_len = 0 + pred_c_len = 0 + lm_score_sum = 0 + for i, (x, t, id) in enumerate( + zip( + z, + sample["target"], + sample["id"], + ) + ): + + t = t[(t >= self.target_dictionary.nspecial)] + x = x[ + (x >= self.target_dictionary.nspecial) + & (x < (self.num_symbols + self.target_dictionary.nspecial)) + ] + if self.sil_id >= 0: + x = x[x != self.sil_id] + + vocab_seen[x - self.target_dictionary.nspecial] = True + + pred_units_arr = x + if self.cfg.ctc_eval: + pred_units_arr = pred_units_arr.unique_consecutive() + pred_units_arr = pred_units_arr[pred_units_arr != 0] + + if id == 0: + logger.info(f"REF: {self.target_dictionary.string(t)}") + logger.info(f"HYP: {self.target_dictionary.string(pred_units_arr)}") + + if self.kenlm is not None: + ref_lm_s = self.compute_lm_score(self.target_dictionary.string(t)) + hyp_lm_s = self.compute_lm_score( + self.target_dictionary.string(pred_units_arr) + ) + logger.info( + f"LM [REF]: {ref_lm_s}, {math.pow(10, ref_lm_s / (len(t) + 1))}" + ) + logger.info( + f"LM [HYP]: {hyp_lm_s}, {math.pow(10, hyp_lm_s / (len(pred_units_arr) + 1))}" + ) + + pred_units_arr = pred_units_arr.tolist() + + t = t.tolist() + c_err += editdistance.eval(pred_units_arr, t) + c_len += len(t) + pred_c_len += len(pred_units_arr) + + if self.kenlm is not None: + pred_str = self.target_dictionary.string(pred_units_arr) + lm_score = self.compute_lm_score(pred_str) + lm_score_sum += lm_score + + kaldi_score_sum = 0 + word_lm_sum = 0 + num_words = 0 + if word_scores is not None: + for score, words in word_scores: + kaldi_score_sum += score + num_words += len(words) + if self.word_kenlm is not None: + word_lm_sum += self.kenlm.score(" ".join(words)) + + try: + world_size = get_data_parallel_world_size() + except: + world_size = 1 + + logging_output = { + "loss": c_err, + "_num_char_errors": c_err, + "_num_chars": c_len, + "_num_pred_chars": pred_c_len, + "ntokens": c_len, + "nsentences": z.size(0), + "sample_size": c_len, + "_world_size": world_size, + "_lm_score_sum": lm_score_sum, + "_kaldi_score_sum": kaldi_score_sum, + "_word_lm_sum": word_lm_sum, + "_num_words": num_words, + "_vocab_seen": vocab_seen, + } + + return c_err, c_len, logging_output + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg + + has_unpaired_text = os.path.exists( + os.path.join(self.cfg.text_data, f"{split}.idx") + ) + + self.datasets[split] = ExtractedFeaturesDataset( + path=data_path, + split=split, + min_length=3, + max_length=task_cfg.max_length, + labels=None if has_unpaired_text else task_cfg.labels, + label_dict=self.target_dictionary, + shuffle=getattr(task_cfg, "shuffle", True), + sort_by_length=task_cfg.sort_by_length, + ) + + logger.info(f"split {split} has unpaired text? {has_unpaired_text}") + if has_unpaired_text: + text_dataset = data_utils.load_indexed_dataset( + os.path.join(self.cfg.text_data, split), self.target_dictionary + ) + text_dataset = StripTokenDataset(text_dataset, self.target_dictionary.eos()) + self.datasets[split] = RandomInputDataset( + self.datasets[split], + text_dataset, + ["random_label"], + add_to_input=True, + pad_idx=self.target_dictionary.pad(), + ) + + @property + def source_dictionary(self): + return self._source_dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self._target_dictionary + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + num_pred_chars = sum( + log.get("_num_pred_chars", zero) for log in logging_outputs + ) + + lm_score_sum = sum(log.get("_lm_score_sum", zero) for log in logging_outputs) + vocab_seen = ( + sum(log.get("_vocab_seen", zero) for log in logging_outputs) + .bool() + .sum() + .item() + ) + kaldi_score_sum = sum( + log.get("_kaldi_score_sum", zero) for log in logging_outputs + ) + word_lm_sum = sum(log.get("_word_lm_sum", zero) for log in logging_outputs) + + metrics.log_scalar_sum("_num_char_errors", num_char_errors) + metrics.log_scalar_sum("_num_chars", num_chars) + metrics.log_scalar_sum("_num_word_errors", num_word_errors) + metrics.log_scalar_sum("_num_words", num_words) + + metrics.log_scalar_sum("lm_score_sum", lm_score_sum) + metrics.log_scalar_sum("num_pred_chars", num_pred_chars) + + if self.cfg.word_kenlm_path is not None: + metrics.log_scalar_sum("kaldi_score_sum", kaldi_score_sum) + metrics.log_scalar_sum("word_lm_sum", word_lm_sum) + + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + + if lm_score_sum < 0 and vocab_seen > 0: + metrics.log_scalar("vocab_seen_pct", vocab_seen / self.num_symbols) + + metrics.log_derived( + "weighted_lm_ppl", + lambda meters: math.pow( + 10, + -meters["lm_score_sum"].sum + / ( + meters["num_pred_chars"].sum + meters["nsentences"].sum + ), # account for </s> + ) + / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power, + ) + + metrics.log_derived( + "lm_ppl", + lambda meters: math.pow( + 10, + -meters["lm_score_sum"].sum + / ( + meters["num_pred_chars"].sum + meters["nsentences"].sum + ), # account for </s> + ), + ) + else: + metrics.log_derived("weighted_lm_ppl", lambda meters: float("inf")) + + if num_words > 0: + if word_lm_sum != 0: + metrics.log_derived( + "word_lm_ppl", + lambda meters: math.pow( + 10, + -meters["word_lm_sum"].sum + / ( + meters["_num_words"].sum + meters["nsentences"].sum + ), # account for </s> + ), + ) + metrics.log_derived( + "weighted_word_lm_ppl", + lambda meters: math.pow( + 10, + -meters["word_lm_sum"].sum + / ( + meters["_num_words"].sum + meters["nsentences"].sum + ), # account for </s> + ) + / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power, + ) + + if self.cfg.word_kenlm_path is not None: + metrics.log_derived( + "kaldi_score", + lambda meters: meters["kaldi_score_sum"].sum + / meters["nsentences"].sum, + ) + + def build_model(self, cfg: FairseqDataclass): + model = super().build_model(cfg) + + return model diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py new file mode 100644 index 0000000000..a1bc0ec706 --- /dev/null +++ b/examples/wav2vec/unsupervised/w2vu_generate.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run inference for pre-processed data with a trained model. +""" + +import ast +from collections import namedtuple +from dataclasses import dataclass, field +from enum import Enum, auto +import hydra +from hydra.core.config_store import ConfigStore +import logging +import math +import os +from omegaconf import OmegaConf +from typing import Optional +import sys + +import editdistance +import torch + +from hydra.core.hydra_config import HydraConfig + +from fairseq import checkpoint_utils, progress_bar, tasks, utils +from fairseq.data.data_utils import post_process +from fairseq.dataclass.configs import FairseqDataclass, FairseqConfig +from fairseq.logging.meters import StopwatchMeter +from omegaconf import open_dict + +from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoderConfig + +logging.root.setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + + +class DecoderType(Enum): + VITERBI = auto() + KENLM = auto() + FAIRSEQ = auto() + KALDI = auto() + + +@dataclass +class UnsupGenerateConfig(FairseqDataclass): + fairseq: FairseqConfig = FairseqConfig() + lm_weight: float = field( + default=2.0, + metadata={"help": "language model weight"}, + ) + w2l_decoder: DecoderType = field( + default=DecoderType.VITERBI, + metadata={"help": "type of decoder to use"}, + ) + kaldi_decoder_config: Optional[KaldiDecoderConfig] = None + lexicon: Optional[str] = field( + default=None, + metadata={ + "help": "path to lexicon. This is also used to 'phonemize' for unsupvised param tuning" + }, + ) + lm_model: Optional[str] = field( + default=None, + metadata={"help": "path to language model (kenlm or fairseq)"}, + ) + unit_lm: bool = field( + default=False, + metadata={"help": "whether to use unit lm"}, + ) + beam_threshold: float = field( + default=50.0, + metadata={"help": "beam score threshold"}, + ) + beam_size_token: float = field( + default=100.0, + metadata={"help": "max tokens per beam"}, + ) + beam: int = field( + default=5, + metadata={"help": "decoder beam size"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of results to return"}, + ) + word_score: float = field( + default=1.0, + metadata={"help": "word score to add at end of word"}, + ) + unk_weight: float = field( + default=-math.inf, + metadata={"help": "unknown token weight"}, + ) + sil_weight: float = field( + default=0.0, + metadata={"help": "silence token weight"}, + ) + targets: Optional[str] = field( + default=None, + metadata={"help": "extension of ground truth labels to compute UER"}, + ) + results_path: Optional[str] = field( + default=None, + metadata={"help": "where to store results"}, + ) + post_process: Optional[str] = field( + default=None, + metadata={"help": "how to post process results"}, + ) + vocab_usage_power: float = field( + default=2, + metadata={"help": "for unsupervised param tuning"}, + ) + + viterbi_transcript: Optional[str] = field( + default=None, + metadata={"help": "for unsupervised param tuning"}, + ) + min_lm_ppl: float = field( + default=0, + metadata={"help": "for unsupervised param tuning"}, + ) + min_vt_uer: float = field( + default=0, + metadata={"help": "for unsupervised param tuning"}, + ) + + blank_weight: float = field( + default=0, + metadata={"help": "value to add or set for blank emission"}, + ) + blank_mode: str = field( + default="set", + metadata={ + "help": "can be add or set, how to modify blank emission with blank weight" + }, + ) + sil_is_blank: bool = field( + default=False, + metadata={"help": "if true, <SIL> token is same as blank token"}, + ) + + unsupervised_tuning: bool = field( + default=False, + metadata={ + "help": "if true, returns a score based on unsupervised param selection metric instead of UER" + }, + ) + is_ax: bool = field( + default=False, + metadata={ + "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" + }, + ) + + +def get_dataset_itr(cfg, task): + return task.get_batch_iterator( + dataset=task.dataset(cfg.fairseq.dataset.gen_subset), + max_tokens=cfg.fairseq.dataset.max_tokens, + max_sentences=cfg.fairseq.dataset.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=cfg.fairseq.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.fairseq.dataset.required_batch_size_multiple, + num_shards=cfg.fairseq.dataset.num_shards, + shard_id=cfg.fairseq.dataset.shard_id, + num_workers=cfg.fairseq.dataset.num_workers, + data_buffer_size=cfg.fairseq.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + +def process_predictions( + cfg: UnsupGenerateConfig, + hypos, + tgt_dict, + target_tokens, + res_files, +): + retval = [] + word_preds = [] + transcriptions = [] + dec_scores = [] + + for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]): + if torch.is_tensor(hypo["tokens"]): + tokens = hypo["tokens"].int().cpu() + tokens = tokens[tokens >= tgt_dict.nspecial] + hyp_pieces = tgt_dict.string(tokens) + else: + hyp_pieces = " ".join(hypo["tokens"]) + + if "words" in hypo and len(hypo["words"]) > 0: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, cfg.post_process) + + to_write = {} + if res_files is not None: + to_write[res_files["hypo.units"]] = hyp_pieces + to_write[res_files["hypo.words"]] = hyp_words + + tgt_words = "" + if target_tokens is not None: + if isinstance(target_tokens, str): + tgt_pieces = tgt_words = target_tokens + else: + tgt_pieces = tgt_dict.string(target_tokens) + tgt_words = post_process(tgt_pieces, cfg.post_process) + + if res_files is not None: + to_write[res_files["ref.units"]] = tgt_pieces + to_write[res_files["ref.words"]] = tgt_words + + if not cfg.fairseq.common_eval.quiet: + logger.info(f"HYPO {i}:" + hyp_words) + if tgt_words: + logger.info("TARGET:" + tgt_words) + + if "am_score" in hypo and "lm_score" in hypo: + logger.info( + f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}" + ) + elif "score" in hypo: + logger.info(f"DECODER SCORE: {hypo['score']}") + + logger.info("___________________") + + hyp_words_arr = hyp_words.split() + tgt_words_arr = tgt_words.split() + + retval.append( + ( + editdistance.eval(hyp_words_arr, tgt_words_arr), + len(hyp_words_arr), + len(tgt_words_arr), + hyp_pieces, + hyp_words, + ) + ) + word_preds.append(hyp_words_arr) + transcriptions.append(to_write) + dec_scores.append(-hypo.get("score", 0)) # negate cuz kaldi returns NLL + + if len(retval) > 1: + best = None + for r, t in zip(retval, transcriptions): + if best is None or r[0] < best[0][0]: + best = r, t + for dest, tran in best[1].items(): + print(tran, file=dest) + dest.flush() + return best[0] + + assert len(transcriptions) == 1 + for dest, tran in transcriptions[0].items(): + print(tran, file=dest) + + return retval[0] + + +def prepare_result_files(cfg: UnsupGenerateConfig): + def get_res_file(file_prefix): + if cfg.fairseq.dataset.num_shards > 1: + file_prefix = f"{cfg.fairseq.dataset.shard_id}_{file_prefix}" + path = os.path.join( + cfg.results_path, + "{}{}.txt".format( + cfg.fairseq.dataset.gen_subset, + file_prefix, + ), + ) + return open(path, "w", buffering=1) + + if not cfg.results_path: + return None + + return { + "hypo.words": get_res_file(""), + "hypo.units": get_res_file("_units"), + "ref.words": get_res_file("_ref"), + "ref.units": get_res_file("_ref_units"), + "hypo.nbest.words": get_res_file("_nbest_words"), + } + + +def optimize_models(cfg: UnsupGenerateConfig, use_cuda, models): + """Optimize ensemble for generation""" + for model in models: + model.eval() + if cfg.fairseq.common.fp16: + model.half() + if use_cuda: + model.cuda() + + +GenResult = namedtuple( + "GenResult", + [ + "count", + "errs_t", + "gen_timer", + "lengths_hyp_unit_t", + "lengths_hyp_t", + "lengths_t", + "lm_score_t", + "num_feats", + "num_sentences", + "num_symbols", + "vt_err_t", + "vt_length_t", + ], +) + + +def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda): + task = tasks.setup_task(cfg.fairseq.task) + saved_cfg.task.labels = cfg.fairseq.task.labels + task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task) + # Set dictionary + tgt_dict = task.target_dictionary + logger.info( + "| {} {} {} examples".format( + cfg.fairseq.task.data, + cfg.fairseq.dataset.gen_subset, + len(task.dataset(cfg.fairseq.dataset.gen_subset)), + ) + ) + # Load dataset (possibly sharded) + itr = get_dataset_itr(cfg, task) + # Initialize generator + gen_timer = StopwatchMeter() + + def build_generator(cfg: UnsupGenerateConfig): + w2l_decoder = cfg.w2l_decoder + if w2l_decoder == DecoderType.VITERBI: + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(cfg, task.target_dictionary) + elif w2l_decoder == DecoderType.KENLM: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(cfg, task.target_dictionary) + elif w2l_decoder == DecoderType.FAIRSEQ: + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(cfg, task.target_dictionary) + elif w2l_decoder == DecoderType.KALDI: + from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder + + assert cfg.kaldi_decoder_config is not None + + return KaldiDecoder( + cfg.kaldi_decoder_config, + cfg.beam, + ) + else: + raise NotImplementedError( + "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found " + + str(w2l_decoder) + ) + + generator = build_generator(cfg) + + kenlm = None + fairseq_lm = None + if cfg.lm_model is not None: + import kenlm + + kenlm = kenlm.Model(cfg.lm_model) + + num_sentences = 0 + if cfg.results_path is not None and not os.path.exists(cfg.results_path): + os.makedirs(cfg.results_path) + + res_files = prepare_result_files(cfg) + errs_t = 0 + lengths_hyp_t = 0 + lengths_hyp_unit_t = 0 + lengths_t = 0 + count = 0 + num_feats = 0 + all_hyp_pieces = [] + all_hyp_words = [] + + num_symbols = ( + len([s for s in tgt_dict.symbols if not s.startswith("madeup")]) + - tgt_dict.nspecial + ) + targets = None + if cfg.targets is not None: + tgt_path = os.path.join( + cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets + ) + if os.path.exists(tgt_path): + with open(tgt_path, "r") as f: + targets = f.read().splitlines() + viterbi_transcript = None + if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0: + logger.info(f"loading viterbi transcript from {cfg.viterbi_transcript}") + with open(cfg.viterbi_transcript, "r") as vf: + viterbi_transcript = vf.readlines() + viterbi_transcript = [v.rstrip().split() for v in viterbi_transcript] + + gen_timer.start() + + start = 0 + end = len(itr) + + hypo_futures = None + if cfg.w2l_decoder == DecoderType.KALDI: + logger.info("Extracting features") + hypo_futures = [] + samples = [] + with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: + for i, sample in enumerate(t): + if "net_input" not in sample or i < start or i >= end: + continue + if "padding_mask" not in sample["net_input"]: + sample["net_input"]["padding_mask"] = None + + hypos, num_feats = gen_hypos( + generator, models, num_feats, sample, task, use_cuda + ) + hypo_futures.append(hypos) + samples.append(sample) + if cfg.debug: + break + itr = list(zip(hypo_futures, samples)) + start = 0 + end = len(itr) + logger.info("Finished extracting features") + + with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: + for i, sample in enumerate(t): + if i < start or i >= end: + continue + + if hypo_futures is not None: + hypos, sample = sample + hypos = [h.result() for h in hypos] + else: + if "net_input" not in sample: + continue + + hypos, num_feats = gen_hypos( + generator, models, num_feats, sample, task, use_cuda + ) + + for i, sample_id in enumerate(sample["id"].tolist()): + if targets is not None: + target_tokens = targets[sample_id] + elif "target" in sample or "target_label" in sample: + toks = ( + sample["target"][i, :] + if "target_label" not in sample + else sample["target_label"][i, :] + ) + + target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() + else: + target_tokens = None + + # Process top predictions + ( + errs, + length_hyp, + length, + hyp_pieces, + hyp_words, + ) = process_predictions( + cfg, + hypos[i], + tgt_dict, + target_tokens, + res_files, + ) + errs_t += errs + lengths_hyp_t += length_hyp + lengths_hyp_unit_t += ( + len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words) + ) + lengths_t += length + count += 1 + all_hyp_pieces.append(hyp_pieces) + all_hyp_words.append(hyp_words) + + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + lm_score_sum = 0 + if kenlm is not None: + + if cfg.unit_lm: + lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces) + else: + lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words) + elif fairseq_lm is not None: + lm_score_sum = sum(fairseq_lm.score([h.split() for h in all_hyp_words])[0]) + + vt_err_t = 0 + vt_length_t = 0 + if viterbi_transcript is not None: + unit_hyps = [] + if cfg.targets is not None and cfg.lexicon is not None: + lex = {} + with open(cfg.lexicon, "r") as lf: + for line in lf: + items = line.rstrip().split() + lex[items[0]] = items[1:] + for h in all_hyp_pieces: + hyp_ws = [] + for w in h.split(): + assert w in lex, w + hyp_ws.extend(lex[w]) + unit_hyps.append(hyp_ws) + + else: + unit_hyps.extend([h.split() for h in all_hyp_words]) + + vt_err_t = sum( + editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps) + ) + + vt_length_t = sum(len(h) for h in viterbi_transcript) + + if res_files is not None: + for r in res_files.values(): + r.close() + + gen_timer.stop(lengths_hyp_t) + + return GenResult( + count, + errs_t, + gen_timer, + lengths_hyp_unit_t, + lengths_hyp_t, + lengths_t, + lm_score_sum, + num_feats, + num_sentences, + num_symbols, + vt_err_t, + vt_length_t, + ) + + +def gen_hypos(generator, models, num_feats, sample, task, use_cuda): + sample = utils.move_to_cuda(sample) if use_cuda else sample + + if "features" in sample["net_input"]: + sample["net_input"]["dense_x_only"] = True + num_feats += ( + sample["net_input"]["features"].shape[0] + * sample["net_input"]["features"].shape[1] + ) + hypos = task.inference_step(generator, models, sample, None) + return hypos, num_feats + + +def main(cfg: UnsupGenerateConfig, model=None): + if ( + cfg.fairseq.dataset.max_tokens is None + and cfg.fairseq.dataset.batch_size is None + ): + cfg.fairseq.dataset.max_tokens = 1024000 + + use_cuda = torch.cuda.is_available() and not cfg.fairseq.common.cpu + + task = tasks.setup_task(cfg.fairseq.task) + + overrides = ast.literal_eval(cfg.fairseq.common_eval.model_overrides) + + if cfg.fairseq.task._name == "gan_audio_pretraining_feats": + overrides["model"] = { + "blank_weight": cfg.blank_weight, + "blank_mode": cfg.blank_mode, + "blank_is_sil": cfg.sil_is_blank, + "no_softmax": True, + "segmentation": { + "type": "NONE", + }, + } + else: + overrides["model"] = { + "blank_weight": cfg.blank_weight, + "blank_mode": cfg.blank_mode, + } + + if model is None: + # Load ensemble + logger.info("| loading model(s) from {}".format(cfg.fairseq.common_eval.path)) + models, saved_cfg = checkpoint_utils.load_model_ensemble( + cfg.fairseq.common_eval.path.split("\\"), + arg_overrides=overrides, + task=task, + suffix=cfg.fairseq.checkpoint.checkpoint_suffix, + strict=(cfg.fairseq.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.fairseq.checkpoint.checkpoint_shard_count, + ) + optimize_models(cfg, use_cuda, models) + else: + models = [model] + saved_cfg = cfg.fairseq + + with open_dict(saved_cfg.task): + saved_cfg.task.shuffle = False + saved_cfg.task.sort_by_length = False + + gen_result = generate(cfg, models, saved_cfg, use_cuda) + + wer = None + if gen_result.lengths_t > 0: + wer = gen_result.errs_t * 100.0 / gen_result.lengths_t + logger.info(f"WER: {wer}") + + lm_ppl = float("inf") + + if gen_result.lm_score_t != 0 and gen_result.lengths_hyp_t > 0: + hyp_len = gen_result.lengths_hyp_t + lm_ppl = math.pow( + 10, -gen_result.lm_score_t / (hyp_len + gen_result.num_sentences) + ) + logger.info(f"LM PPL: {lm_ppl}") + + logger.info( + "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" + " sentences/s, {:.2f} tokens/s)".format( + gen_result.num_sentences, + gen_result.gen_timer.n, + gen_result.gen_timer.sum, + gen_result.num_sentences / gen_result.gen_timer.sum, + 1.0 / gen_result.gen_timer.avg, + ) + ) + + vt_diff = None + if gen_result.vt_length_t > 0: + vt_diff = gen_result.vt_err_t / gen_result.vt_length_t + vt_diff = max(cfg.min_vt_uer, vt_diff) + + lm_ppl = max(cfg.min_lm_ppl, lm_ppl) + + if not cfg.unsupervised_tuning == 0: + weighted_score = wer + else: + weighted_score = math.log(lm_ppl) * (vt_diff or 1.0) + + res = ( + f"| Generate {cfg.fairseq.dataset.gen_subset} with beam={cfg.beam}, " + f"lm_weight={cfg.kaldi_decoder_config.acoustic_scale if cfg.kaldi_decoder_config else cfg.lm_weight}, " + f"word_score={cfg.word_score}, sil_weight={cfg.sil_weight}, blank_weight={cfg.blank_weight}, " + f"WER: {wer}, LM_PPL: {lm_ppl}, num feats: {gen_result.num_feats}, " + f"length: {gen_result.lengths_hyp_t}, UER to viterbi: {(vt_diff or 0) * 100}, score: {weighted_score}" + ) + + logger.info(res) + # print(res) + + return task, weighted_score + + +@hydra.main( + config_path=os.path.join("../../..", "fairseq", "config"), config_name="config" +) +def hydra_main(cfg): + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container( + HydraConfig.get().job_logging, resolve=True + ) + + cfg = OmegaConf.create( + OmegaConf.to_container(cfg, resolve=False, enum_to_str=False) + ) + OmegaConf.set_struct(cfg, True) + logger.info(cfg) + _, score = main(cfg) + + if cfg.is_ax: + return score, None + return score + + +def cli_main(): + try: + from hydra._internal.utils import get_args + + cfg_name = get_args().config_name or "config" + except: + logger.warning("Failed to get config name from hydra args") + cfg_name = "config" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=UnsupGenerateConfig) + hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/examples/wav2vec/wav2vec_manifest.py b/examples/wav2vec/wav2vec_manifest.py index 5417084554..9b8aa180e8 100644 --- a/examples/wav2vec/wav2vec_manifest.py +++ b/examples/wav2vec/wav2vec_manifest.py @@ -54,11 +54,17 @@ def main(args): search_path = os.path.join(dir_path, "**/*." + args.ext) rand = random.Random(args.seed) - with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open( - os.path.join(args.dest, "valid.tsv"), "w" - ) as valid_f: + valid_f = ( + open(os.path.join(args.dest, "valid.tsv"), "w") + if args.valid_percent > 0 + else None + ) + + with open(os.path.join(args.dest, "train.tsv"), "w") as train_f: print(dir_path, file=train_f) - print(dir_path, file=valid_f) + + if valid_f is not None: + print(dir_path, file=valid_f) for fname in glob.iglob(search_path, recursive=True): file_path = os.path.realpath(fname) @@ -71,6 +77,8 @@ def main(args): print( "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest ) + if valid_f is not None: + valid_f.close() if __name__ == "__main__": diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 4cb5193bde..9ce3f7e39d 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -13,10 +13,12 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset, BaseWrapperDataset +from .. import FairseqDataset from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes from fairseq.data.audio.audio_utils import ( - parse_path, read_from_stored_zip, is_sf_audio_data + parse_path, + read_from_stored_zip, + is_sf_audio_data, ) @@ -212,11 +214,15 @@ def ordered_indices(self): if self.shuffle: order = [np.random.permutation(len(self))] + order.append( + np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + ) + return np.lexsort(order)[::-1] else: - order = [np.arange(len(self))] - - order.append(self.sizes) - return np.lexsort(order)[::-1] + return np.arange(len(self)) def set_bucket_info(self, num_buckets): self.num_buckets = num_buckets diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 8c5e5a490d..70a4086cd0 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -373,6 +373,10 @@ def post_process(sentence: str, symbol: str): sentence = sentence.replace(" ", "").replace("_", " ").strip() elif symbol == "letter": sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "silence": + import re + sentence = sentence.replace("<SIL>", "") + sentence = re.sub(' +', ' ', sentence).strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() elif symbol in {"subword_nmt", "@@ ", "@@"}: diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py index 6793ef54e6..2100b1fa0b 100644 --- a/fairseq/logging/meters.py +++ b/fairseq/logging/meters.py @@ -109,6 +109,38 @@ def smoothed_value(self) -> float: return val +class SumMeter(Meter): + """Computes and stores the sum""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.sum = 0 # sum from all updates + + def update(self, val): + if val is not None: + self.sum = type_as(self.sum, val) + val + + def state_dict(self): + return { + "sum": self.sum, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.round = state_dict.get("round", None) + + @property + def smoothed_value(self) -> float: + val = self.sum + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + class TimeMeter(Meter): """Computes the average occurrence of some event per second""" diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 2bb1da086f..58c2fb64e1 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -12,10 +12,9 @@ """ import contextlib -import time import uuid -from collections import OrderedDict, defaultdict -from typing import Callable, Dict, List, Optional +from collections import defaultdict +from typing import Callable, List, Optional from .meters import * @@ -131,6 +130,25 @@ def log_scalar( agg.add_meter(key, AverageMeter(round=round), priority) agg[key].update(value, weight) +def log_scalar_sum( + key: str, + value: float, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, SumMeter(round=round), priority) + agg[key].update(value) + def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): """Log a scalar value derived from other meters. diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 6af288b10e..06905455fd 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -18,7 +18,6 @@ ModuleProxyWrapper, TPUDistributedDataParallel, ) -from torch.distributed.algorithms.ddp_comm_hooks import register_ddp_comm_hook, DDPCommHookType logger = logging.getLogger(__name__) @@ -65,8 +64,19 @@ def DistributedFairseqModel(args, model, process_group, device): process_group=process_group, find_unused_parameters=args.find_unused_parameters, ) - if args.ddp_comm_hook == 'fp16': + if args.ddp_comm_hook == "fp16": logger.info("enable fp16 communication hook in DDP") + try: + from torch.distributed.algorithms.ddp_comm_hooks import ( + register_ddp_comm_hook, + DDPCommHookType, + ) + except: + logger.error( + "Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version" + ) + raise + register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 6999dca2d9..6002d28438 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -165,6 +165,7 @@ class Wav2Vec2Config(FairseqDataclass): mask_channel_prob: float = field( default=0.0, metadata={"help": "probability of replacing a feature with 0"} ) + mask_channel_before: bool = False mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( default="static", metadata={"help": "how to choose mask length for channel masking"}, @@ -249,6 +250,7 @@ def __init__(self, cfg: Wav2Vec2Config): self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_before = cfg.mask_channel_before self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length @@ -331,10 +333,33 @@ def build_model(cls, cfg: Wav2Vec2Config, task=None): return cls(cfg) def apply_mask( - self, x, padding_mask, - mask_indices=None, mask_channel_indices=None, + self, + x, + padding_mask, + mask_indices=None, + mask_channel_indices=None, ): B, T, C = x.shape + + if self.mask_channel_prob > 0 and self.mask_channel_before: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + if self.mask_prob > 0: if mask_indices is None: mask_indices = compute_mask_indices( @@ -353,7 +378,7 @@ def apply_mask( else: mask_indices = None - if self.mask_channel_prob > 0: + if self.mask_channel_prob > 0 and not self.mask_channel_before: if mask_channel_indices is None: mask_channel_indices = compute_mask_indices( (B, C), @@ -445,12 +470,12 @@ def compute_preds(self, x, y, negatives): logits = logits / self.logit_temp if is_xla_tensor(logits) or neg_is_pos.any(): - fillval = -float(2**30) - if not hasattr(self, '_inftensor'): + fillval = -float(2 ** 30) + if not hasattr(self, "_inftensor"): self._inftensor = ( torch.tensor(fillval).to(x.device) - if is_xla_tensor(logits) else - float("-inf") + if is_xla_tensor(logits) + else float("-inf") ) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) @@ -467,13 +492,21 @@ def _conv_out_length(input_length, kernel_size, stride): conv_cfg_list = eval(self.cfg.conv_feature_layers) for i in range(len(conv_cfg_list)): - input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) + input_lengths = _conv_out_length( + input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] + ) return input_lengths.to(torch.long) def forward( - self, source, padding_mask=None, mask=True, features_only=False, - mask_indices=None, mask_channel_indices=None, + self, + source, + padding_mask=None, + mask=True, + features_only=False, + layer=None, + mask_indices=None, + mask_channel_indices=None, padding_count=None, ): @@ -491,7 +524,7 @@ def forward( features = self.layer_norm(features) unmasked_features = features.clone() - if padding_mask is not None: + if padding_mask is not None and padding_mask.any(): input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = self._get_feat_extract_output_lengths(input_lengths) @@ -502,8 +535,15 @@ def forward( # these two operations makes sure that all values # before the output lengths indices are attended to - padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1 + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + else: + padding_mask = None if self.post_extract_proj is not None: features = self.post_extract_proj(features) @@ -527,7 +567,8 @@ def forward( if mask: x, mask_indices = self.apply_mask( - features, padding_mask, + features, + padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) @@ -544,10 +585,15 @@ def forward( y = unmasked_features mask_indices = None - x = self.encoder(x, padding_mask=padding_mask) + x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) if features_only: - return {"x": x, "padding_mask": padding_mask} + return { + "x": x, + "padding_mask": padding_mask, + "features": unmasked_features, + "layer_results": layer_results, + } if self.quantizer: q = self.quantizer(y, produce_targets=False) @@ -560,17 +606,21 @@ def forward( y = self.project_q(y) if self.negatives_from_everywhere: - neg_cands = self.quantizer( - unmasked_features, produce_targets=False - )["x"] + neg_cands = self.quantizer(unmasked_features, produce_targets=False)[ + "x" + ] negs, _ = self.sample_negatives( - neg_cands, y.size(1), padding_count=padding_count, + neg_cands, + y.size(1), + padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( - y, y.size(1), padding_count=padding_count, + y, + y.size(1), + padding_count=padding_count, ) if self.codebook_negatives > 0: @@ -587,13 +637,16 @@ def forward( if self.negatives_from_everywhere: negs, _ = self.sample_negatives( - unmasked_features, y.size(1), + unmasked_features, + y.size(1), padding_count=padding_count, ) negs = self.project_q(negs) else: negs, _ = self.sample_negatives( - y, y.size(1), padding_count=padding_count, + y, + y.size(1), + padding_count=padding_count, ) if not is_xla_tensor(x): @@ -609,7 +662,9 @@ def forward( x = self.compute_preds(x, y, negs) result = { - "x": x, "padding_mask": padding_mask, "features_pen": features_pen, + "x": x, + "padding_mask": padding_mask, + "features_pen": features_pen, } if prob_ppl is not None: @@ -627,9 +682,11 @@ def quantize(self, x): x = self.layer_norm(x) return self.quantizer.forward_idx(x) - def extract_features(self, source, padding_mask, mask=False): - res = self.forward(source, padding_mask, mask=mask, features_only=True) - return res["x"], res["padding_mask"] + def extract_features(self, source, padding_mask, mask=False, layer=None): + res = self.forward( + source, padding_mask, mask=mask, features_only=True, layer=layer + ) + return res def get_logits(self, net_output): logits = net_output["x"] @@ -787,15 +844,15 @@ def __init__(self, args): self.apply(init_bert_params) - def forward(self, x, padding_mask=None): - x = self.extract_features(x, padding_mask) + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) - if self.layer_norm_first: + if self.layer_norm_first and layer is None: x = self.layer_norm(x) - return x + return x, layer_results - def extract_features(self, x, padding_mask=None): + def extract_features(self, x, padding_mask=None, tgt_layer=None): if padding_mask is not None: x = index_put(x, padding_mask, 0) @@ -813,16 +870,24 @@ def extract_features(self, x, padding_mask=None): x = x.transpose(0, 1) layer_results = [] + r = None for i, layer in enumerate(self.layers): dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) - layer_results.append(x) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r # T x B x C -> B x T x C x = x.transpose(0, 1) - return x + return x, layer_results def max_positions(self): """Maximum output length supported by the encoder.""" @@ -901,7 +966,6 @@ def forward( key=x, value=x, key_padding_mask=self_attn_padding_mask, - need_weights=False, attn_mask=self_attn_mask, ) x = self.dropout1(x) @@ -920,7 +984,6 @@ def forward( key=x, value=x, key_padding_mask=self_attn_padding_mask, - need_weights=need_weights, ) x = self.dropout1(x) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index e8a1d03eb2..abae9d1ab3 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field from omegaconf import MISSING, II, open_dict -from typing import Optional, Any +from typing import Any, Optional from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass import FairseqDataclass @@ -27,7 +27,12 @@ register_model, ) from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES -from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer +from fairseq.modules import ( + LayerNorm, + PositionalEmbedding, + TransformerDecoderLayer, + SamePad, +) @dataclass @@ -119,6 +124,7 @@ class Wav2Vec2AsrConfig(FairseqDataclass): layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} ) + mask_channel_before: bool = False normalize: bool = II("task.normalize") data: str = II("task.data") # this holds the loaded wav2vec args @@ -127,6 +133,8 @@ class Wav2Vec2AsrConfig(FairseqDataclass): @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): + blank_weight: float = 0 + blank_mode: str = "add" mask_min_space: Optional[int] = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, @@ -156,6 +164,8 @@ def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): super().__init__() self.cfg = cfg self.w2v_encoder = w2v_encoder + self.blank_weight = cfg.blank_weight + self.blank_mode = cfg.blank_mode def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -164,28 +174,38 @@ def upgrade_state_dict_named(self, state_dict, name): @classmethod def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): """Build a new model instance.""" - w2v_encoder = Wav2VecEncoder(cfg, task.target_dictionary) + w2v_encoder = Wav2VecEncoder(cfg, len(task.target_dictionary)) return cls(cfg, w2v_encoder) + def get_logits(self, net_output, normalize=False): + logits = net_output["encoder_out"] + if self.blank_weight != 0: + if self.blank_mode == "add": + logits[..., 0] += self.blank_weight + elif self.blank_mode == "set": + logits[..., 0] = self.blank_weight + else: + raise Exception(f"invalid blank mode {self.blank_mode}") + + if net_output["padding_mask"] is not None and net_output["padding_mask"].any(): + logits[net_output["padding_mask"].T][..., 0] = float("inf") + logits[net_output["padding_mask"].T][..., 1:] = float("-inf") + + if normalize: + logits = utils.log_softmax(logits.float(), dim=-1) + + return logits + def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" - logits = net_output["encoder_out"] + logits = self.get_logits(net_output) + if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1) - def get_logits(self, net_output): - logits = net_output["encoder_out"] - padding = net_output["padding_mask"] - if padding is not None and padding.any(): - padding = padding.T - logits[padding][...,0] = 0 - logits[padding][...,1:] = float('-inf') - - return logits - def forward(self, **kwargs): x = self.w2v_encoder(**kwargs) return x @@ -237,7 +257,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): max_target_positions: int = field( default=2048, metadata={"help": "max target positions"} ) - share_decoder_input_output_embed: bool = field( + share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) autoregressive: bool = II("task.autoregressive") @@ -252,7 +272,9 @@ def __init__(self, encoder, decoder): def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" - assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" + assert ( + cfg.autoregressive + ), "Please set task.autoregressive=true for seq2seq asr models" src_dict, tgt_dict = task.source_dictionary, task.target_dictionary @@ -288,7 +310,7 @@ def upgrade_state_dict_named(self, state_dict, name): class Wav2VecEncoder(FairseqEncoder): - def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): + def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None): self.apply_mask = cfg.apply_mask arg_overrides = { @@ -303,6 +325,7 @@ def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_before": cfg.mask_channel_before, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, @@ -346,12 +369,16 @@ def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 - if tgt_dict is not None: - self.proj = Linear(d, len(tgt_dict)) + targ_d = None + self.proj = None + + if output_size is not None: + targ_d = output_size elif getattr(cfg, "decoder_embed_dim", d) != d: - self.proj = Linear(d, cfg.decoder_embed_dim) - else: - self.proj = None + targ_d = cfg.decoder_embed_dim + + if targ_d is not None: + self.proj = Linear(d, targ_d) def set_num_updates(self, num_updates): """Set the number of parameters updates.""" @@ -359,7 +386,6 @@ def set_num_updates(self, num_updates): self.num_updates = num_updates def forward(self, source, padding_mask, tbc=True, **kwargs): - w2v_args = { "source": source, "padding_mask": padding_mask, @@ -369,10 +395,13 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): - x, padding_mask = self.w2v_model.extract_features(**w2v_args) + res = self.w2v_model.extract_features(**w2v_args) + + x = res["x"] + padding_mask = res["padding_mask"] if tbc: - # B x T x C -> T x B x C + # BTC -> TBC x = x.transpose(0, 1) x = self.final_dropout(x) @@ -382,8 +411,11 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): return { "encoder_out": x, # T x B x C - "encoder_padding_mask": padding_mask.transpose(0, 1), # T x B + "encoder_padding_mask": padding_mask.transpose(0, 1) + if padding_mask is not None + else None, # T x B "padding_mask": padding_mask, + "layer_results": res["layer_results"], } def reorder_encoder_out(self, encoder_out, new_order): @@ -562,9 +594,7 @@ def extract_features( x, attn, _ = layer( x, encoder_out["encoder_out"] if encoder_out is not None else None, - encoder_out["padding_mask"] - if encoder_out is not None - else None, + encoder_out["padding_mask"] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index f73804718a..cfe948a194 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -7,7 +7,7 @@ import math from collections.abc import Collection from dataclasses import dataclass, field -from typing import List +from typing import Any, List import torch import torch.distributed as dist @@ -23,8 +23,8 @@ @dataclass class FairseqAdamConfig(FairseqDataclass): - adam_betas: str = field( - default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} + adam_betas: Any = field( + default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"} ) adam_eps: float = field( default=1e-8, metadata={"help": "epsilon for Adam optimizer"} @@ -47,7 +47,7 @@ class FairseqAdam(FairseqOptimizer): analogous to torch.optim.AdamW from PyTorch. """ - def __init__(self, cfg: DictConfig, params): + def __init__(self, cfg: FairseqAdamConfig, params): super().__init__(cfg) fused_adam_cls = get_fused_adam_class() use_fused_adam = ( @@ -77,7 +77,7 @@ def optimizer_config(self): "lr": self.cfg.lr[0] if isinstance(self.cfg.lr, Collection) else self.cfg.lr, - "betas": eval(self.cfg.adam_betas), + "betas": eval(self.cfg.adam_betas) if isinstance(self.cfg.adam_betas, str) else self.cfg.adam_betas, "eps": self.cfg.adam_eps, "weight_decay": self.cfg.weight_decay, } diff --git a/fairseq/options.py b/fairseq/options.py index 7558264fce..2d9f8381a7 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -286,6 +286,8 @@ def add_preprocess_args(parser): help="Pad dictionary size to be multiple of N") group.add_argument("--workers", metavar="N", default=1, type=int, help="number of parallel workers") + group.add_argument("--dict-only", action='store_true', + help="if true, only builds a dictionary and then exits") # fmt: on return parser diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ddef3d58d2..8a3858563e 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -5,6 +5,7 @@ import math from typing import Dict, List, Optional +import sys import torch import torch.nn as nn @@ -214,8 +215,15 @@ def _generate( if net_input["padding_mask"] is not None else torch.tensor(src_tokens.size(-1)).to(src_tokens) ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) else: - raise Exception("expected src_tokens or source in net input") + raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) # bsz: total number of sentences in beam # Note that src_tokens may have more than 2 dimensions (i.e. audio features) @@ -750,7 +758,7 @@ def has_incremental_states(self): return self.has_incremental def max_decoder_positions(self): - return min([m.max_decoder_positions() for m in self.models]) + return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) @torch.jit.export def forward_encoder(self, net_input: Dict[str, Tensor]): @@ -780,7 +788,10 @@ def forward_decoder( incremental_state=incremental_states[i], ) else: - decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) attn: Optional[Tensor] = None decoder_len = len(decoder_out) @@ -800,7 +811,6 @@ def forward_decoder( decoder_out[0][:, -1:, :].div_(temperature), None if decoder_len <= 1 else decoder_out[1], ) - probs = model.get_normalized_probs( decoder_out_tuple, log_probs=True, sample=None ) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 071331a10a..e0b001b667 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -137,7 +137,7 @@ class AudioPretrainingConfig(FairseqDataclass): @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) class AudioPretrainingTask(FairseqTask): - """""" + """ """ cfg: AudioPretrainingConfig @@ -199,7 +199,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == "ctc" - if getattr(task_cfg, 'binarized_dataset', False): + if getattr(task_cfg, "binarized_dataset", False): self.datasets[split] = BinarizedAudioDataset( data_path, split=split, @@ -236,13 +236,9 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") - skipped_indices = getattr(self.datasets[split], 'skipped_indices', set()) + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) with open(label_path, "r") as f: - labels = [ - line - for i, line in enumerate(f) - if i not in skipped_indices - ] + labels = [line for i, line in enumerate(f) if i not in skipped_indices] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " @@ -360,7 +356,7 @@ def reduce_metrics(self, logging_outputs, criterion): metrics.log_scalar("_num_chars", num_chars) metrics.log_scalar("_num_word_errors", num_word_errors) metrics.log_scalar("_num_words", num_words) - if num_words > 0: + if num_chars > 0: metrics.log_derived( "uer", lambda meters: meters["_num_char_errors"].sum @@ -369,6 +365,7 @@ def reduce_metrics(self, logging_outputs, criterion): if meters["_num_chars"].sum > 0 else float("nan"), ) + if num_words > 0: metrics.log_derived( "wer", lambda meters: meters["_num_word_errors"].sum diff --git a/fairseq/utils.py b/fairseq/utils.py index 03826d18d0..d0ce16ae6b 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -110,7 +110,6 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): - def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -124,6 +123,7 @@ def _move_to_cpu(tensor): def move_to_tpu(sample): import torch_xla.core.xla_model as xm + device = xm.xla_device() def _move_to_tpu(tensor): @@ -302,7 +302,7 @@ def convert_padding_direction( def item(tensor): # tpu-comment: making this a no-op for xla devices. - if torch.is_tensor(tensor) and tensor.device.type == 'xla': + if torch.is_tensor(tensor) and tensor.device.type == "xla": return tensor.detach() if hasattr(tensor, "item"): return tensor.item() @@ -341,11 +341,16 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: def grad_exists(p): return p is not None and getattr(p, "grad", None) is not None + if isinstance(params, torch.Tensor): params = [params] params = list(params) - grads = [p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, 'expert')] - expert_grads = [p.grad.detach() for p in params if grad_exists(p) and hasattr(p, 'expert')] + grads = [ + p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert") + ] + expert_grads = [ + p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert") + ] if len(grads) == 0: if len(params) > 0: @@ -454,7 +459,9 @@ def import_user_module(args): module_path = getattr(args, "user_dir", None) if module_path is not None: module_path = os.path.abspath(args.user_dir) - if not os.path.exists(module_path) and not os.path.isfile(os.path.dirname(module_path)): + if not os.path.exists(module_path) and not os.path.isfile( + os.path.dirname(module_path) + ): fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path @@ -515,7 +522,7 @@ def deprecation_warning(message, stacklevel=3): def get_activation_fn(activation: str) -> Callable: - """ Returns the activation function corresponding to `activation` """ + """Returns the activation function corresponding to `activation`""" from fairseq.modules import gelu, gelu_accurate if activation == "relu": @@ -653,18 +660,13 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): - tgt_valid = ( - ((tgt_sent != pad)).nonzero(as_tuple=False) - ) - src_valid = ( - ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) - ) + tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False) + src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) alignment = [] if len(tgt_valid) != 0 and len(src_valid) != 0: attn_valid = attn[tgt_valid, src_valid] alignment = [ - ["{:.6f}".format(p) for p in src_probs.tolist()] - for src_probs in attn_valid + ["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid ] return alignment @@ -699,7 +701,7 @@ def tpu_data_loader(itr): def is_xla_tensor(tensor): - return torch.is_tensor(tensor) and tensor.device.type == 'xla' + return torch.is_tensor(tensor) and tensor.device.type == "xla" def index_put(tensor, indices, value): @@ -716,6 +718,7 @@ def index_put(tensor, indices, value): def xla_device_to_cpu(dat): import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) @@ -778,3 +781,18 @@ def eval_bool(x, default=False): return bool(eval(x)) except TypeError: return default + + +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 4501cac67e..ab6e77029e 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -42,7 +42,7 @@ def eval_lm( output_word_probs: bool = False, output_word_stats: bool = False, target_dictionary: Optional[fairseq.data.Dictionary] = None, - softmax_batch: int = False, + softmax_batch: int = 0, remove_bos_token: bool = False, device: Optional[torch.device] = None, ): diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 180bd40717..9de01084ba 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -6,12 +6,12 @@ import logging import os -import sys from fairseq.dataclass.initialize import add_defaults, hydra_init from fairseq_cli.train import main as pre_main from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig +from fairseq.utils import reset_logging import hydra from hydra.core.hydra_config import HydraConfig @@ -63,21 +63,6 @@ def hydra_main(cfg: FairseqConfig) -> float: return best_val -def reset_logging(): - root = logging.getLogger() - for handler in root.handlers: - root.removeHandler(handler) - root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter( - logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - ) - root.addHandler(handler) - - def cli_main(): try: from hydra._internal.utils import get_args diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index fa77da8dba..b788900d30 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -117,6 +117,9 @@ def build_dictionary(filenames, src=False, tgt=False): if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) + if args.dict_only: + return + def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) n_seq_tok = [0, 0] diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index cb49915827..a1b7cb58e2 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -162,6 +162,7 @@ def main(cfg: FairseqConfig) -> None: max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() + train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: @@ -381,7 +382,7 @@ def validate_and_save( and num_updates > 0 and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not cfg.dataset.disable_validation + ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates # Validate valid_losses = [None] @@ -460,6 +461,10 @@ def validate( # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) + + if hasattr(task, "post_validate"): + task.post_validate(trainer.get_model(), stats, agg) + progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 90d7e4c6a9..f0d983ee6b 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -66,6 +66,7 @@ def main(cfg: DictConfig, override_args=None): # Move models to GPU for model in models: + model.eval() if use_fp16: model.half() if use_cuda: From a8fe9434adeb6caea65b34ca968f4e5cd1d75be6 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Fri, 21 May 2021 12:11:57 -0700 Subject: [PATCH 584/707] fix readme link (#1890) Summary: fix link to self training readme Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1890 Reviewed By: arbabu123, ArmenAg Differential Revision: D28609910 Pulled By: alexeib fbshipit-source-id: 4cbbb75cfec876938d80c8d7bc0df4ba93d4f54d --- examples/wav2vec/unsupervised/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md index 2277e65ffb..e9ec59f06d 100644 --- a/examples/wav2vec/unsupervised/README.md +++ b/examples/wav2vec/unsupervised/README.md @@ -70,6 +70,6 @@ fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words. -Please see [this README](http://github.com/pytorch/fairseq/tree/master/examples/wav2vec/unsupervised/kaldi_st) for more instructions on how to do iterative self-training + Kaldi LM-decoding. +Please see [this README](kaldi_self_train/README.md) for more instructions on how to do iterative self-training + Kaldi LM-decoding. *** Note: these instructions are a work in progress and will be updated over the next few days From 78e75fa3edf4a4ce02f9aa59ded01d814036dc81 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Fri, 21 May 2021 16:17:21 -0700 Subject: [PATCH 585/707] attempt to make non-sharded FSDP checkpoint behave like regular checkpoint Summary: overall just wondering if feature is desirable. if it is, the next diff which supports loading sharded checkpoint into a consolidated state dict cleaner. a couple advantages 1. allows resuming from other DDP trainers. 2. allows resuming into other DDP trainers. or FSDP of a different configuration. 3. none-sharded FSDP checkpoint can be loaded with regular load_model_ensemble_and_task() For old training workflow that's not using `--use-sharded-state`, please rename the checkpoint to remove the "-shard0" for resuming training. Reviewed By: sshleifer Differential Revision: D28563032 fbshipit-source-id: ced72bed969319ab6306059721f56e29b2c3d892 --- fairseq/checkpoint_utils.py | 2 + fairseq/trainer.py | 103 ++++++++++++++++++++++++------------ 2 files changed, 71 insertions(+), 34 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ac6c7339d4..23677be83d 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -45,6 +45,8 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() return write_timer = meters.StopwatchMeter() diff --git a/fairseq/trainer.py b/fairseq/trainer.py index dc06928dfc..d3da876c29 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -24,9 +24,7 @@ from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler - from omegaconf import OmegaConf -import re logger = logging.getLogger(__name__) @@ -185,8 +183,7 @@ def is_data_parallel_master(self): @property def use_distributed_wrapper(self) -> bool: return ( - self.data_parallel_world_size > 1 - and not self.cfg.optimization.use_bmuf + self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf ) or ( self.cfg.distributed_training.ddp_backend == "fully_sharded" and self.cfg.distributed_training.cpu_offload @@ -195,26 +192,42 @@ def use_distributed_wrapper(self) -> bool: @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded" or getattr(self.cfg.model, "base_layers", 0) > 0: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.use_sharded_state + ) or getattr(self.cfg.model, "base_layers", 0) > 0: return True else: return self.is_data_parallel_master + @property + def always_call_state_dict_during_save_checkpoint(self) -> bool: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and not self.cfg.distributed_training.use_sharded_state + ): + # FSDP calls communication collective when consolidating checkpoints + return True + else: + return False + @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded": - return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.use_sharded_state + ): + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format( + self.data_parallel_rank + ) else: return self.cfg.checkpoint.checkpoint_suffix or "" @property def criterion(self): if self._wrapped_criterion is None: - if ( - utils.has_parameters(self._criterion) - and self.use_distributed_wrapper - ): + if utils.has_parameters(self._criterion) and self.use_distributed_wrapper: self._wrapped_criterion = models.DistributedFairseqModel( self.cfg.distributed_training, self._criterion, @@ -293,8 +306,9 @@ def _build_optimizer(self): self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) if self.cfg.distributed_training.ddp_backend == "fully_sharded": - assert not self.cfg.optimization.use_bmuf, \ - "--ddp-backend=fully_sharded is not compatible with BMUF" + assert ( + not self.cfg.optimization.use_bmuf + ), "--ddp-backend=fully_sharded is not compatible with BMUF" assert self._optimizer.supports_flat_params, ( "--ddp-backend=fully_sharded is only compatible with pointwise " "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " @@ -337,10 +351,13 @@ def consolidate_optimizer(self): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() - elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state: - st = self.model.gather_full_optim_state_dict(self.optimizer) # only returns on rank 0 - if st is None: - st = -1 # sentinel so that workers do not save optimizer.state_dict() + elif ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and not self.model.use_sharded_state + ): + st = self.model.gather_full_optim_state_dict( + self.optimizer + ) # only returns on rank 0 self._gathered_optim_state = st def state_dict(self): @@ -348,12 +365,14 @@ def state_dict(self): "args": None, # legacy "cfg": ( OmegaConf.to_container(self.cfg) - if OmegaConf.is_config(self.cfg) else self.cfg + if OmegaConf.is_config(self.cfg) + else self.cfg ), "model": self.model.state_dict(), "criterion": ( self.criterion.state_dict() - if utils.has_parameters(self.criterion) else None + if utils.has_parameters(self.criterion) + else None ), "optimizer_history": (self._optim_history or []) + [ @@ -368,7 +387,7 @@ def state_dict(self): "extra_state": { "metrics": metrics.state_dict(), "previous_training_time": self.cumulative_training_time(), - } + }, } if not self.cfg.checkpoint.no_save_optimizer_state: if self._gathered_optim_state is not None: @@ -417,7 +436,10 @@ def load_checkpoint( # on every worker for now or self.tpu # FSDP requires loading checkpoint shards on all ranks - or self.cfg.distributed_training.ddp_backend == "fully_sharded" + or ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.use_sharded_state + ) or getattr(self.cfg.model, "base_layers", 0) > 0 ) @@ -426,10 +448,6 @@ def load_checkpoint( filename, load_on_all_ranks=load_on_all_ranks ) last_optim_state = state.get("last_optimizer_state", None) - if last_optim_state == -1: - master_path = re.sub("shard[0-9]+", "shard0", filename) - local_master_path = PathManager.get_local_path(master_path) - last_optim_state = torch.load(local_master_path, map_location='cpu')['last_optimizer_state'] # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank @@ -492,13 +510,18 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if not load_on_all_ranks and is_distributed: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and not self.model.use_sharded_state + ): + # if use_sharded_state, the last_optim_state is already sharded, skip this + last_optim_state = self.model.get_shard_from_optim_state_dict( + last_optim_state + ) + elif not load_on_all_ranks and is_distributed: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) - elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state: - # if use_sharded_state, the last_optim_state is already sharded, skip this - last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) @@ -514,7 +537,10 @@ def load_checkpoint( self.lr_step(epoch) - if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: + if ( + itr_state.get("version", 1) >= 2 + and itr_state["iterations_in_epoch"] == 0 + ): # reset meters at start of epoch reset_meters = True @@ -801,7 +827,9 @@ def maybe_no_sync(): raise except OverflowError as e: overflow = True - logger.info(f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}") + logger.info( + f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" + ) grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: @@ -846,7 +874,9 @@ def maybe_no_sync(): metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0 ) - logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) + logging_outputs = self._xla_markstep_and_send_to_cpu( + logging_outputs + ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm ) @@ -899,6 +929,7 @@ def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm + xm.rendezvous("valid_step") # wait for all workers with torch.no_grad(): @@ -1040,7 +1071,6 @@ def set_num_updates(self, num_updates): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - def agg_norm_fn(total_norm): total_norm = total_norm.cuda().float() ** 2 total_norm = distributed_utils.all_reduce( @@ -1190,7 +1220,10 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is @@ -1323,9 +1356,11 @@ def _check_xla_compilation(self): def _xla_markstep_and_send_to_cpu(self, data=None): import torch_xla.core.xla_model as xm + xm.mark_step() if data is not None: from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) From 4aef9036cef814a24193dff3a678ce0f5c27309f Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Fri, 21 May 2021 18:40:01 -0700 Subject: [PATCH 586/707] Merge Hubert to master (#1877) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ X] Did you make sure to update the docs? - [ X] Did you write any new necessary tests? ## What does this PR do? This PR adds relevant code for pre-training HuBERT and fine-tuning a pretrained HuBERT for ASR. It also shared trained models of different sizes. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1877 Reviewed By: wnhsu Differential Revision: D28513359 Pulled By: hikushalhere fbshipit-source-id: 8755862f236b7d840105b0fa8f5461ac053d79cc --- examples/hubert/README.md | 108 ++++ .../hubert/config/decode/ax_sweep/ngram.yaml | 33 + .../config/decode/ax_sweep/transformer.yaml | 33 + .../hubert/config/decode/infer_fsqlm.yaml | 41 ++ .../hubert/config/decode/infer_kenlm.yaml | 41 ++ .../hubert/config/decode/infer_viterbi.yaml | 34 ++ .../config/decode/run/submitit_slurm.yaml | 17 + .../decode/run/submitit_slurm_8gpu.yaml | 17 + examples/hubert/config/finetune/base_10h.yaml | 99 +++ examples/hubert/config/finetune/ckpt/it1.yaml | 7 + .../hubert/config/finetune/lm/ls_4gram.yaml | 7 + .../config/finetune/run/submitit_reg.yaml | 20 + .../hubert/config/pretrain/data/iter1.yaml | 8 + .../hubert/config/pretrain/data/iter2.yaml | 8 + .../pretrain/hubert_base_librispeech.yaml | 97 +++ .../pretrain/hubert_large_librivox.yaml | 101 ++++ .../pretrain/hubert_xlarge_librivox.yaml | 101 ++++ .../config/pretrain/run/submitit_reg.yaml | 20 + examples/hubert/measure_teacher_quality.py | 241 ++++++++ examples/hubert/simple_kmeans/README.md | 71 +++ .../simple_kmeans/dump_hubert_feature.py | 133 +++++ .../simple_kmeans/dump_hubert_feature_s2t.py | 126 ++++ .../hubert/simple_kmeans/dump_km_label.py | 98 +++ .../hubert/simple_kmeans/dump_mfcc_feature.py | 116 ++++ examples/hubert/simple_kmeans/learn_kmeans.py | 146 +++++ examples/hubert/update_ckpt.py | 22 + fairseq/criterions/hubert_criterion.py | 177 ++++++ fairseq/data/__init__.py | 9 +- fairseq/data/audio/hubert_dataset.py | 358 +++++++++++ fairseq/distributed/utils.py | 2 +- fairseq/models/hubert/__init__.py | 7 + fairseq/models/hubert/hubert.py | 563 ++++++++++++++++++ fairseq/models/hubert/hubert_asr.py | 373 ++++++++++++ fairseq/tasks/hubert_pretraining.py | 189 ++++++ 34 files changed, 3419 insertions(+), 4 deletions(-) create mode 100644 examples/hubert/README.md create mode 100644 examples/hubert/config/decode/ax_sweep/ngram.yaml create mode 100644 examples/hubert/config/decode/ax_sweep/transformer.yaml create mode 100644 examples/hubert/config/decode/infer_fsqlm.yaml create mode 100644 examples/hubert/config/decode/infer_kenlm.yaml create mode 100644 examples/hubert/config/decode/infer_viterbi.yaml create mode 100644 examples/hubert/config/decode/run/submitit_slurm.yaml create mode 100644 examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml create mode 100644 examples/hubert/config/finetune/base_10h.yaml create mode 100644 examples/hubert/config/finetune/ckpt/it1.yaml create mode 100644 examples/hubert/config/finetune/lm/ls_4gram.yaml create mode 100644 examples/hubert/config/finetune/run/submitit_reg.yaml create mode 100644 examples/hubert/config/pretrain/data/iter1.yaml create mode 100644 examples/hubert/config/pretrain/data/iter2.yaml create mode 100644 examples/hubert/config/pretrain/hubert_base_librispeech.yaml create mode 100644 examples/hubert/config/pretrain/hubert_large_librivox.yaml create mode 100644 examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml create mode 100644 examples/hubert/config/pretrain/run/submitit_reg.yaml create mode 100644 examples/hubert/measure_teacher_quality.py create mode 100644 examples/hubert/simple_kmeans/README.md create mode 100644 examples/hubert/simple_kmeans/dump_hubert_feature.py create mode 100644 examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py create mode 100644 examples/hubert/simple_kmeans/dump_km_label.py create mode 100644 examples/hubert/simple_kmeans/dump_mfcc_feature.py create mode 100644 examples/hubert/simple_kmeans/learn_kmeans.py create mode 100644 examples/hubert/update_ckpt.py create mode 100644 fairseq/criterions/hubert_criterion.py create mode 100644 fairseq/data/audio/hubert_dataset.py create mode 100644 fairseq/models/hubert/__init__.py create mode 100644 fairseq/models/hubert/hubert.py create mode 100644 fairseq/models/hubert/hubert_asr.py create mode 100644 fairseq/tasks/hubert_pretraining.py diff --git a/examples/hubert/README.md b/examples/hubert/README.md new file mode 100644 index 0000000000..88973c22f2 --- /dev/null +++ b/examples/hubert/README.md @@ -0,0 +1,108 @@ +# HuBERT + +## Pre-trained and fine-tuned (ASR) models +Model | Pretraining Data | Finetuning Dataset | Model +|---|---|---|--- +HuBERT Base | [Librispeech](http://www.openslr.org/12) 960 hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt) +HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt) +HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt) +HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt) + +## Train a new model + +### Data preparation + +Follow the steps in `./simple_kmeans` to create: +- `{train,valid}.tsv` waveform list files +- `{train,valid}.km` frame-aligned pseudo label files. +The `label_rate` is the same as the feature frame rate used for clustering, +which is 100Hz for MFCC features and 50Hz for HuBERT features by default. + +### Pre-train a HuBERT model + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km` +are saved at `/path/to/labels`, and the label rate is 100Hz. + +To train a base model (12 layer transformer), run: +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/hubert/config/pretrain \ + --config-name hubert_base_librispeech \ + task.data=/path/to/data task.label_dir=/path/to/labels model.label_rate=100 +``` + +### Fine-tune a HuBERT model with a CTC loss + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their +corresponding character transcripts `{train,valid}.ltr` are saved at +`/path/to/trans`. + +To fine-tune a pre-trained HuBERT model at `/path/to/checkpoint`, run +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/hubert/config/finetune \ + --config-name base_10h \ + task.data=/path/to/data task.label_dir=/path/to/trans \ + model.w2v_path=/path/to/checkpoint +``` + +### Decode a HuBERT model + +Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of +the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is +saved at `/path/to/checkpoint`. We support three decoding modes: +- Viterbi decoding: greedy decoding without a language model +- KenLM decoding: decoding with an arpa-format KenLM n-gram language model +- Fairseq-LM deocding: decoding with a Fairseq neural language model + + +#### Viterbi decoding + +`task.normalize` needs to be consistent with the value used during fine-tuning. +Decoding results will be saved at +`/path/to/experiment/directory/decode/viterbi/test`. + +```sh +$ python examples/speech_recognition/hydra/infer.py \ + --config-dir /path/to/fairseq-py/examples/hubert/config/decode \ + --config-name infer_viterbi \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ +``` + +#### KenLM / Fairseq-LM decoding + +Suppose the pronunciation lexicon and the n-gram LM are saved at +`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be +saved at `/path/to/experiment/directory/decode/kenlm/test`. + +```sh +$ python examples/speech_recognition/hydra/infer.py \ + --config-dir /path/to/fairseq-py/examples/hubert/config/decode \ + --config-name infer_kenlm \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ + decoding.decoder.lexicon=/path/to/lexicon \ + decoding.decoder.lmpath=/path/to/arpa +``` + +The command above uses the default decoding hyperparameter, which can be found +in `examples/speech_recognition/hydra/decoder.py`. These parameters can be +configured from the command line. For example, to search with a beam size of +500, we can append the command above with `decoding.decoder.beam=500`. +Important parameters include: +- decoding.decoder.beam +- decoding.decoder.beamthreshold +- decoding.decoder.lmweight +- decoding.decoder.wordscore +- decoding.decoder.silweight + +To decode with a Fairseq LM, use `--config-name infer_fsqlm` instead, and +change the path of lexicon and LM accordingly. diff --git a/examples/hubert/config/decode/ax_sweep/ngram.yaml b/examples/hubert/config/decode/ax_sweep/ngram.yaml new file mode 100644 index 0000000000..5a02df1f7d --- /dev/null +++ b/examples/hubert/config/decode/ax_sweep/ngram.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset} + +hydra: + sweeper: + ax_config: + max_trials: 60 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 0.025 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.decoder.lmweight: + type: range + bounds: [0.0, 8.0] + decoding.decoder.wordscore: + type: range + bounds: [-5.0, 5.0] + decoding.decoder.silweight: + type: range + bounds: [-10.0, 0.0] diff --git a/examples/hubert/config/decode/ax_sweep/transformer.yaml b/examples/hubert/config/decode/ax_sweep/transformer.yaml new file mode 100644 index 0000000000..85ed3bd1a5 --- /dev/null +++ b/examples/hubert/config/decode/ax_sweep/transformer.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset} + +hydra: + sweeper: + ax_config: + max_trials: 60 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 0.025 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.decoder.lmweight: + type: range + bounds: [0.0, 4.0] + decoding.decoder.wordscore: + type: range + bounds: [-5.0, 5.0] + decoding.decoder.silweight: + type: range + bounds: [-8.0, 0.0] diff --git a/examples/hubert/config/decode/infer_fsqlm.yaml b/examples/hubert/config/decode/infer_fsqlm.yaml new file mode 100644 index 0000000000..b9fb845066 --- /dev/null +++ b/examples/hubert/config/decode/infer_fsqlm.yaml @@ -0,0 +1,41 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + sweep: + dir: ${common_eval.results_path} + subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + +task: + _name: hubert_pretraining + single_target: true + data: ??? + normalize: ??? + +decoding: + exp_dir: ??? + decoder: + name: fairseqlm + lexicon: ??? + lmpath: ??? + beamthreshold: 25 # 100 + beam: 500 + lmweight: 2 + wordscore: -1 + silweight: 0 + write_sentences: true + unique_wer_file: true +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + path: ??? + post_process: letter +generation: + nbest: 1 + beam: 500 +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/examples/hubert/config/decode/infer_kenlm.yaml b/examples/hubert/config/decode/infer_kenlm.yaml new file mode 100644 index 0000000000..fe464eaae5 --- /dev/null +++ b/examples/hubert/config/decode/infer_kenlm.yaml @@ -0,0 +1,41 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + sweep: + dir: ${common_eval.results_path} + subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + +task: + _name: hubert_pretraining + single_target: true + data: ??? + normalize: ??? + +decoding: + exp_dir: ??? + decoder: + name: kenlm + lexicon: ??? + lmpath: ??? + beamthreshold: 100 + beam: 500 + lmweight: 2 + wordscore: -1 + silweight: 0 + write_sentences: true + unique_wer_file: true +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + path: ??? + post_process: letter +generation: + nbest: 1 + beam: 500 +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/examples/hubert/config/decode/infer_viterbi.yaml b/examples/hubert/config/decode/infer_viterbi.yaml new file mode 100644 index 0000000000..d0de9cfd26 --- /dev/null +++ b/examples/hubert/config/decode/infer_viterbi.yaml @@ -0,0 +1,34 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + sweep: + dir: ${common_eval.results_path} + subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + +task: + _name: hubert_pretraining + single_target: true + data: ??? + normalize: ??? + +decoding: + exp_dir: ??? + decoder: + name: viterbi + write_sentences: true + unique_wer_file: true +common_eval: + results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + path: ??? + post_process: letter +generation: + nbest: 1 + beam: 500 +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/examples/hubert/config/decode/run/submitit_slurm.yaml b/examples/hubert/config/decode/run/submitit_slurm.yaml new file mode 100644 index 0000000000..0b8065832e --- /dev/null +++ b/examples/hubert/config/decode/run/submitit_slurm.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 1 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml b/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml new file mode 100644 index 0000000000..2f669f3763 --- /dev/null +++ b/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 8 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/examples/hubert/config/finetune/base_10h.yaml b/examples/hubert/config/finetune/base_10h.yaml new file mode 100644 index 0000000000..844484d7fb --- /dev/null +++ b/examples/hubert/config/finetune/base_10h.yaml @@ -0,0 +1,99 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 5 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 1 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: hubert_pretraining + data: ??? + label_dir: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train + valid_subset: valid + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [2e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: hubert_ctc + w2v_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.w2v_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/hubert/config/finetune/ckpt/it1.yaml b/examples/hubert/config/finetune/ckpt/it1.yaml new file mode 100644 index 0000000000..2af96b3f72 --- /dev/null +++ b/examples/hubert/config/finetune/ckpt/it1.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +task: + normalize: false + +model: + w2v_path: /checkpoint/wnhsu/w2v/hubert_final/iter1/hubert.km.randcrop.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU400k.s1337.ngpu32/checkpoint_last.pt diff --git a/examples/hubert/config/finetune/lm/ls_4gram.yaml b/examples/hubert/config/finetune/lm/ls_4gram.yaml new file mode 100644 index 0000000000..8c7728ad29 --- /dev/null +++ b/examples/hubert/config/finetune/lm/ls_4gram.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +criterion: + wer_kenlm_model: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/4-gram.bin + wer_lexicon: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 diff --git a/examples/hubert/config/finetune/run/submitit_reg.yaml b/examples/hubert/config/finetune/run/submitit_reg.yaml new file mode 100644 index 0000000000..27509503e7 --- /dev/null +++ b/examples/hubert/config/finetune/run/submitit_reg.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +hydra: + launcher: + cpus_per_task: 8 + gpus_per_node: 8 + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + comment: null + mem_gb: 384 + timeout_min: 4320 + max_num_timeout: 100 + constraint: volta32gb + name: ${hydra.job.config_name}/${hydra.job.override_dirname} + submitit_folder: ${hydra.sweep.dir}/submitit/%j + +distributed_training: + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 diff --git a/examples/hubert/config/pretrain/data/iter1.yaml b/examples/hubert/config/pretrain/data/iter1.yaml new file mode 100644 index 0000000000..0a1b65d802 --- /dev/null +++ b/examples/hubert/config/pretrain/data/iter1.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +task: + label_dir: ??? + labels: ["km"] + +model: + label_rate: 100 diff --git a/examples/hubert/config/pretrain/data/iter2.yaml b/examples/hubert/config/pretrain/data/iter2.yaml new file mode 100644 index 0000000000..2d4bfe61cc --- /dev/null +++ b/examples/hubert/config/pretrain/data/iter2.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +task: + label_dir: ??? + labels: ["km"] + +model: + label_rate: 50 diff --git a/examples/hubert/config/pretrain/hubert_base_librispeech.yaml b/examples/hubert/config/pretrain/hubert_base_librispeech.yaml new file mode 100644 index 0000000000..bd84461a16 --- /dev/null +++ b/examples/hubert/config/pretrain/hubert_base_librispeech.yaml @@ -0,0 +1,97 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + +dataset: + num_workers: 6 + max_tokens: 1400000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: hubert + label_rate: ??? + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: default + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + final_dim: 256 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/hubert/config/pretrain/hubert_large_librivox.yaml b/examples/hubert/config/pretrain/hubert_large_librivox.yaml new file mode 100644 index 0000000000..a5192b5f29 --- /dev/null +++ b/examples/hubert/config/pretrain/hubert_large_librivox.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 128 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: true # must be consistent with extractor + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0015] + clip_norm: 1.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: hubert + label_rate: ??? + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + final_dim: 768 + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + layer_norm_first: true + feature_grad_mult: 1.0 + untie_final_proj: true + activation_dropout: 0.0 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + run: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + sweep: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml b/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml new file mode 100644 index 0000000000..34e8f2bfb9 --- /dev/null +++ b/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 256 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: true # must be consistent with extractor + +dataset: + num_workers: 6 + max_tokens: 360000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.003] + clip_norm: 1.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: hubert + label_rate: ??? + encoder_layers: 48 + encoder_embed_dim: 1280 + encoder_ffn_embed_dim: 5120 + encoder_attention_heads: 16 + final_dim: 1024 + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + layer_norm_first: true + feature_grad_mult: 1.0 + untie_final_proj: true + activation_dropout: 0.0 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + run: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + sweep: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/examples/hubert/config/pretrain/run/submitit_reg.yaml b/examples/hubert/config/pretrain/run/submitit_reg.yaml new file mode 100644 index 0000000000..46c979cd28 --- /dev/null +++ b/examples/hubert/config/pretrain/run/submitit_reg.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +hydra: + launcher: + cpus_per_task: 8 + gpus_per_node: 8 + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 4 + comment: null + mem_gb: 384 + timeout_min: 4320 + max_num_timeout: 100 + constraint: volta32gb + name: ${hydra.job.config_name}/${hydra.job.override_dirname} + submitit_folder: ${hydra.sweep.dir}/submitit/%j + +distributed_training: + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 diff --git a/examples/hubert/measure_teacher_quality.py b/examples/hubert/measure_teacher_quality.py new file mode 100644 index 0000000000..92279b2214 --- /dev/null +++ b/examples/hubert/measure_teacher_quality.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os.path as op +import re +from tabulate import tabulate +from collections import Counter + + +def comp_purity(p_xy, axis): + max_p = p_xy.max(axis=axis) + marg_p = p_xy.sum(axis=axis) + indv_pur = max_p / marg_p + aggr_pur = max_p.sum() + return indv_pur, aggr_pur + + +def comp_entropy(p): + return (-p * np.log(p + 1e-8)).sum() + + +def comp_norm_mutual_info(p_xy): + p_x = p_xy.sum(axis=1, keepdims=True) + p_y = p_xy.sum(axis=0, keepdims=True) + pmi = np.log(p_xy / np.matmul(p_x, p_y) + 1e-8) + mi = (p_xy * pmi).sum() + h_x = comp_entropy(p_x) + h_y = comp_entropy(p_y) + return mi, mi / h_x, mi / h_y, h_x, h_y + + +def pad(labs, n): + if n == 0: + return np.array(labs) + return np.concatenate([[labs[0]] * n, labs, [labs[-1]] * n]) + + +def comp_avg_seg_dur(labs_list): + n_frms = 0 + n_segs = 0 + for labs in labs_list: + labs = np.array(labs) + edges = np.zeros(len(labs)).astype(bool) + edges[0] = True + edges[1:] = labs[1:] != labs[:-1] + n_frms += len(edges) + n_segs += edges.astype(int).sum() + return n_frms / n_segs + + +def comp_joint_prob(uid2refs, uid2hyps): + """ + Args: + pad: padding for spliced-feature derived labels + """ + cnts = Counter() + skipped = [] + abs_frmdiff = 0 + for uid in uid2refs: + if uid not in uid2hyps: + skipped.append(uid) + continue + refs = uid2refs[uid] + hyps = uid2hyps[uid] + abs_frmdiff += abs(len(refs) - len(hyps)) + min_len = min(len(refs), len(hyps)) + refs = refs[:min_len] + hyps = hyps[:min_len] + cnts.update(zip(refs, hyps)) + tot = sum(cnts.values()) + + ref_set = sorted({ref for ref, _ in cnts.keys()}) + hyp_set = sorted({hyp for _, hyp in cnts.keys()}) + ref2pid = dict(zip(ref_set, range(len(ref_set)))) + hyp2lid = dict(zip(hyp_set, range(len(hyp_set)))) + # print(hyp_set) + p_xy = np.zeros((len(ref2pid), len(hyp2lid)), dtype=float) + for (ref, hyp), cnt in cnts.items(): + p_xy[ref2pid[ref], hyp2lid[hyp]] = cnt + p_xy /= p_xy.sum() + return p_xy, ref2pid, hyp2lid, tot, abs_frmdiff, skipped + + +def read_phn(tsv_path, rm_stress=True): + uid2phns = {} + with open(tsv_path) as f: + for line in f: + uid, phns = line.rstrip().split("\t") + phns = phns.split(",") + if rm_stress: + phns = [re.sub("[0-9]", "", phn) for phn in phns] + uid2phns[uid] = phns + return uid2phns + + +def read_lab(tsv_path, lab_path, pad_len=0, upsample=1): + """ + tsv is needed to retrieve the uids for the labels + """ + with open(tsv_path) as f: + f.readline() + uids = [op.splitext(op.basename(line.rstrip().split()[0]))[0] for line in f] + with open(lab_path) as f: + labs_list = [pad(line.rstrip().split(), pad_len).repeat(upsample) for line in f] + assert len(uids) == len(labs_list) + return dict(zip(uids, labs_list)) + + +def main_lab_lab( + tsv_dir, + lab_dir, + lab_name, + lab_sets, + ref_dir, + ref_name, + pad_len=0, + upsample=1, + verbose=False, +): + # assume tsv_dir is the same for both the reference and the hypotheses + tsv_dir = lab_dir if tsv_dir is None else tsv_dir + + uid2refs = {} + for s in lab_sets: + uid2refs.update(read_lab(f"{tsv_dir}/{s}.tsv", f"{ref_dir}/{s}.{ref_name}")) + + uid2hyps = {} + for s in lab_sets: + uid2hyps.update( + read_lab( + f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample + ) + ) + _main(uid2refs, uid2hyps, verbose) + + +def main_phn_lab( + tsv_dir, + lab_dir, + lab_name, + lab_sets, + phn_dir, + phn_sets, + pad_len=0, + upsample=1, + verbose=False, +): + uid2refs = {} + for s in phn_sets: + uid2refs.update(read_phn(f"{phn_dir}/{s}.tsv")) + + uid2hyps = {} + tsv_dir = lab_dir if tsv_dir is None else tsv_dir + for s in lab_sets: + uid2hyps.update( + read_lab( + f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample + ) + ) + _main(uid2refs, uid2hyps, verbose) + + +def _main(uid2refs, uid2hyps, verbose): + (p_xy, ref2pid, hyp2lid, tot, frmdiff, skipped) = comp_joint_prob( + uid2refs, uid2hyps + ) + ref_pur_by_hyp, ref_pur = comp_purity(p_xy, axis=0) + hyp_pur_by_ref, hyp_pur = comp_purity(p_xy, axis=1) + (mi, mi_norm_by_ref, mi_norm_by_hyp, h_ref, h_hyp) = comp_norm_mutual_info(p_xy) + outputs = { + "ref pur": ref_pur, + "hyp pur": hyp_pur, + "H(ref)": h_ref, + "H(hyp)": h_hyp, + "MI": mi, + "MI/H(ref)": mi_norm_by_ref, + "ref segL": comp_avg_seg_dur(uid2refs.values()), + "hyp segL": comp_avg_seg_dur(uid2hyps.values()), + "p_xy shape": p_xy.shape, + "frm tot": tot, + "frm diff": frmdiff, + "utt tot": len(uid2refs), + "utt miss": len(skipped), + } + print(tabulate([outputs.values()], outputs.keys(), floatfmt=".4f")) + + +if __name__ == "__main__": + """ + compute quality of labels with respect to phone or another labels if set + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("lab_dir") + parser.add_argument("lab_name") + parser.add_argument("--lab_sets", default=["valid"], type=str, nargs="+") + parser.add_argument( + "--phn_dir", + default="/checkpoint/wnhsu/data/librispeech/960h/fa/raw_phn/phone_frame_align_v1", + ) + parser.add_argument( + "--phn_sets", default=["dev-clean", "dev-other"], type=str, nargs="+" + ) + parser.add_argument("--pad_len", default=0, type=int, help="padding for hypotheses") + parser.add_argument( + "--upsample", default=1, type=int, help="upsample factor for hypotheses" + ) + parser.add_argument("--ref_lab_dir", default="") + parser.add_argument("--ref_lab_name", default="") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + if args.ref_lab_dir and args.ref_lab_name: + main_lab_lab( + args.tsv_dir, + args.lab_dir, + args.lab_name, + args.lab_sets, + args.ref_lab_dir, + args.ref_lab_name, + args.pad_len, + args.upsample, + args.verbose, + ) + else: + main_phn_lab( + args.tsv_dir, + args.lab_dir, + args.lab_name, + args.lab_sets, + args.phn_dir, + args.phn_sets, + args.pad_len, + args.upsample, + args.verbose, + ) diff --git a/examples/hubert/simple_kmeans/README.md b/examples/hubert/simple_kmeans/README.md new file mode 100644 index 0000000000..cd17da3b3e --- /dev/null +++ b/examples/hubert/simple_kmeans/README.md @@ -0,0 +1,71 @@ +# Sharded Feature Extraction and K-means Application + +This folder contains scripts for preparing HUBERT labels from tsv files, the +steps are: +1. feature extraction +2. k-means clustering +3. k-means application + + +## Data preparation + +`*.tsv` files contains a list of audio, where each line is the root, and +following lines are the subpath for each audio: +``` +<root-dir> +<audio-path-1> +<audio-path-2> +... +``` + + +## Feature extraction + +### MFCC feature +Suppose the tsv file is at `${tsv_dir}/${split}.tsv`. To extract 39-D +mfcc+delta+ddelta features for the 1st iteration HUBERT training, run: +```sh +python dump_mfcc_feature.py ${tsv_dir} ${split} ${nshard} ${rank} ${feat_dir} +``` +This would shard the tsv file into `${nshard}` and extract features for the +`${rank}`-th shard, where rank is an integer in `[0, nshard-1]`. Features would +be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`. + + +### HUBERT feature +To extract features from the `${layer}`-th transformer layer of a trained +HUBERT model saved at `${ckpt_path}`, run: +```sh +python dump_hubert_feature.py ${tsv_dir} ${split} ${ckpt_path} ${layer} ${nshard} ${rank} ${feat_dir} +``` +Features would also be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`. + +- if out-of-memory, decrease the chunk size with `--max_chunk` + + +## K-means clustering +To fit a k-means model with `${n_clusters}` clusters on 10% of the `${split}` data, run +```sh +python learn_kmeans.py ${feat_dir} ${split} ${nshard} ${km_path} ${n_cluster} --percent 0.1 +``` +This saves the k-means model to `${km_path}`. + +- set `--precent -1` to use all data +- more kmeans options can be found with `-h` flag + + +## K-means application +To apply a trained k-means model `${km_path}` to obtain labels for `${split}`, run +```sh +python dump_km_label.py ${feat_dir} ${split} ${km_path} ${nshard} ${rank} ${lab_dir} +``` +This would extract labels for the `${rank}`-th shard out of `${nshard}` shards +and dump them to `${lab_dir}/${split}_${rank}_${shard}.km` + + +Finally, merge shards for `${split}` by running +```sh +for rank in $(seq 0 $((nshard - 1))); do + cat $lab_dir/${split}_${rank}_${nshard}.km +done > $lab_dir/${split}.km +``` diff --git a/examples/hubert/simple_kmeans/dump_hubert_feature.py b/examples/hubert/simple_kmeans/dump_hubert_feature.py new file mode 100644 index 0000000000..cd242890e5 --- /dev/null +++ b/examples/hubert/simple_kmeans/dump_hubert_feature.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os +import sys + +import fairseq +import soundfile as sf +import torch +import torch.nn.functional as F +import tqdm +from npy_append_array import NpyAppendArray + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("dump_hubert_feature") + + +class HubertFeatureReader(object): + def __init__(self, ckpt_path, layer, max_chunk=1600000): + ( + model, + cfg, + task, + ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) + self.model = model[0].eval().cuda() + self.task = task + self.layer = layer + self.max_chunk = max_chunk + logger.info(f"TASK CONFIG:\n{self.task.cfg}") + logger.info(f" max_chunk = {self.max_chunk}") + + def read_audio(self, path, ref_len=None): + wav, sr = sf.read(path) + assert sr == self.task.cfg.sample_rate, sr + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, path, ref_len=None): + x = self.read_audio(path, ref_len) + with torch.no_grad(): + x = torch.from_numpy(x).float().cuda() + if self.task.cfg.normalize: + x = F.layer_norm(x, x.shape) + x = x.view(1, -1) + + feat = [] + for start in range(0, x.size(1), self.max_chunk): + x_chunk = x[:, start: start + self.max_chunk] + feat_chunk, _ = self.model.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + output_layer=self.layer, + ) + feat.append(feat_chunk) + return torch.cat(feat, 1).squeeze(0) + + +def get_path_iterator(tsv, nshard, rank): + with open(tsv, "r") as f: + root = f.readline().rstrip() + lines = [line.rstrip() for line in f] + tot = len(lines) + shard_size = math.ceil(tot / nshard) + start, end = rank * shard_size, min((rank + 1) * shard_size, tot) + assert start < end, "start={start}, end={end}" + logger.info( + f"rank {rank} of {nshard}, process {end-start} " + f"({start}-{end}) out of {tot}" + ) + + lines = lines[start:end] + + def iterate(): + for line in lines: + subpath, nsample = line.split("\t") + yield f"{root}/{subpath}", int(nsample) + + return iterate, len(lines) + + +def dump_feature( + tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk +): + reader = HubertFeatureReader(ckpt_path, layer, max_chunk) + generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) + iterator = generator() + + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + + os.makedirs(feat_dir, exist_ok=True) + if os.path.exists(feat_path): + os.remove(feat_path) + + feat_f = NpyAppendArray(feat_path) + with open(leng_path, "w") as leng_f: + for path, nsample in tqdm.tqdm(iterator, total=num): + feat = reader.get_feats(path, nsample) + feat_f.append(feat.cpu().numpy()) + leng_f.write(f"{len(feat)}\n") + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("split") + parser.add_argument("ckpt_path") + parser.add_argument("layer", type=int) + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("--max_chunk", type=int, default=1600000) + args = parser.parse_args() + logger.info(args) + + dump_feature(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py b/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py new file mode 100644 index 0000000000..7ec8a7311b --- /dev/null +++ b/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import io +import logging +import math +import os +import os.path as op +import sys + +import tqdm +from dump_hubert_feature import HubertFeatureReader +from fairseq.data.audio.audio_utils import get_waveform +from fairseq.data.audio.speech_to_text_dataset import ( + read_from_uncompressed_zip, +) +from npy_append_array import NpyAppendArray + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("dump_hubert_feature_s2t") + + +class HubertFeatureReaderS2T(HubertFeatureReader): + def read_audio(self, path, ref_len=None): + path, *extra = path.split(":") + assert len(extra) == 2 + assert path.endswith(".zip") + + data = read_from_uncompressed_zip(path, int(extra[0]), int(extra[1])) + f = io.BytesIO(data) + wav, sr = get_waveform(f) + assert sr == self.task.cfg.sample_rate, sr + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + +def get_path_iterator(root, tsv, nshard, rank): + with open(tsv) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + subpaths = [op.join(root, e["audio"]) for e in reader] + + tot = len(subpaths) + shard_size = math.ceil(tot / nshard) + start, end = rank * shard_size, min((rank + 1) * shard_size, tot) + assert start < end, "start={start}, end={end}" + logger.info( + f"rank {rank} of {nshard}, process {end-start} " + f"({start}-{end}) out of {tot}" + ) + + subpaths = subpaths[start:end] + + def iterate(): + for subpath in subpaths: + yield op.join(root, subpath) + + return iterate, len(subpaths) + + +def dump_feature( + root, + tsv_path, + ckpt_path, + layer, + nshard, + rank, + feat_dir, + feat_name, + max_chunk, +): + reader = HubertFeatureReaderS2T(ckpt_path, layer, max_chunk) + generator, num = get_path_iterator(root, tsv_path, nshard, rank) + iterator = generator() + + feat_path = f"{feat_dir}/{feat_name}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{feat_name}_{rank}_{nshard}.len" + + os.makedirs(feat_dir, exist_ok=True) + if op.exists(feat_path): + os.remove(feat_path) + + feat_f = NpyAppendArray(feat_path) + with open(leng_path, "w") as leng_f: + for path in tqdm.tqdm(iterator, total=num): + feat = reader.get_feats(path) + feat_f.append(feat.cpu().numpy()) + leng_f.write(f"{len(feat)}\n") + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("root") + parser.add_argument("tsv_path") + parser.add_argument("ckpt_path") + parser.add_argument("layer", type=int) + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("feat_name") + parser.add_argument("--max_chunk", type=int, default=1600000) + args = parser.parse_args() + logger.info(args) + + dump_feature(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_km_label.py b/examples/hubert/simple_kmeans/dump_km_label.py new file mode 100644 index 0000000000..8871307804 --- /dev/null +++ b/examples/hubert/simple_kmeans/dump_km_label.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +import numpy as np + +import joblib +import torch +import tqdm + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("dump_km_label") + + +class ApplyKmeans(object): + def __init__(self, km_path): + self.km_model = joblib.load(km_path) + self.C_np = self.km_model.cluster_centers_.transpose() + self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) + + self.C = torch.from_numpy(self.C_np) + self.Cnorm = torch.from_numpy(self.Cnorm_np) + if torch.cuda.is_available(): + self.C = self.C.cuda() + self.Cnorm = self.Cnorm.cuda() + + def __call__(self, x): + if isinstance(x, torch.Tensor): + dist = ( + x.pow(2).sum(1, keepdim=True) + - 2 * torch.matmul(x, self.C) + + self.Cnorm + ) + return dist.argmin(dim=1).cpu().numpy() + else: + dist = ( + (x ** 2).sum(1, keepdims=True) + - 2 * np.matmul(x, self.C_np) + + self.Cnorm_np + ) + return np.argmin(dist, axis=1) + + +def get_feat_iterator(feat_dir, split, nshard, rank): + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + with open(leng_path, "r") as f: + lengs = [int(line.rstrip()) for line in f] + offsets = [0] + np.cumsum(lengs[:-1]).tolist() + + def iterate(): + feat = np.load(feat_path, mmap_mode="r") + assert feat.shape[0] == (offsets[-1] + lengs[-1]) + for offset, leng in zip(offsets, lengs): + yield feat[offset: offset + leng] + + return iterate, len(lengs) + + +def dump_label(feat_dir, split, km_path, nshard, rank, lab_dir): + apply_kmeans = ApplyKmeans(km_path) + generator, num = get_feat_iterator(feat_dir, split, nshard, rank) + iterator = generator() + + lab_path = f"{lab_dir}/{split}_{rank}_{nshard}.km" + os.makedirs(lab_dir, exist_ok=True) + with open(lab_path, "w") as f: + for feat in tqdm.tqdm(iterator, total=num): + # feat = torch.from_numpy(feat).cuda() + lab = apply_kmeans(feat).tolist() + f.write(" ".join(map(str, lab)) + "\n") + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("feat_dir") + parser.add_argument("split") + parser.add_argument("km_path") + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("lab_dir") + args = parser.parse_args() + logging.info(str(args)) + + dump_label(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_mfcc_feature.py b/examples/hubert/simple_kmeans/dump_mfcc_feature.py new file mode 100644 index 0000000000..a36fa643bd --- /dev/null +++ b/examples/hubert/simple_kmeans/dump_mfcc_feature.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os +import sys + +import soundfile as sf +import torch +import torchaudio +import tqdm +from npy_append_array import NpyAppendArray + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("dump_mfcc_feature") + + +class MfccFeatureReader(object): + def __init__(self, sample_rate): + self.sample_rate = sample_rate + + def read_audio(self, path, ref_len=None): + wav, sr = sf.read(path) + assert sr == self.sample_rate, sr + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, path, ref_len=None): + x = self.read_audio(path, ref_len) + with torch.no_grad(): + x = torch.from_numpy(x).float() + x = x.view(1, -1) + + mfccs = torchaudio.compliance.kaldi.mfcc( + waveform=x, + sample_frequency=self.sample_rate, + use_energy=False, + ) # (time, freq) + mfccs = mfccs.transpose(0, 1) # (freq, time) + deltas = torchaudio.functional.compute_deltas(mfccs) + ddeltas = torchaudio.functional.compute_deltas(deltas) + concat = torch.cat([mfccs, deltas, ddeltas], dim=0) + concat = concat.transpose(0, 1).contiguous() # (freq, time) + return concat + + +def get_path_iterator(tsv, nshard, rank): + with open(tsv, "r") as f: + root = f.readline().rstrip() + lines = [line.rstrip() for line in f] + tot = len(lines) + shard_size = math.ceil(tot / nshard) + start, end = rank * shard_size, min((rank + 1) * shard_size, tot) + assert start < end, "start={start}, end={end}" + logger.info( + f"rank {rank} of {nshard}, process {end-start} " + f"({start}-{end}) out of {tot}" + ) + + lines = lines[start:end] + + def iterate(): + for line in lines: + subpath, nsample = line.split("\t") + yield f"{root}/{subpath}", int(nsample) + + return iterate, len(lines) + + +def dump_feature(tsv_dir, split, sample_rate, nshard, rank, feat_dir): + reader = MfccFeatureReader(sample_rate) + generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) + iterator = generator() + + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + + os.makedirs(feat_dir, exist_ok=True) + if os.path.exists(feat_path): + os.remove(feat_path) + + feat_f = NpyAppendArray(feat_path) + with open(leng_path, "w") as leng_f: + for path, nsample in tqdm.tqdm(iterator, total=num): + feat = reader.get_feats(path, nsample) + feat_f.append(feat.cpu().numpy()) + leng_f.write(f"{len(feat)}\n") + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("split") + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("--sample_rate", type=int, default=16000) + args = parser.parse_args() + logger.info(args) + + dump_feature(**vars(args)) diff --git a/examples/hubert/simple_kmeans/learn_kmeans.py b/examples/hubert/simple_kmeans/learn_kmeans.py new file mode 100644 index 0000000000..113ac655b8 --- /dev/null +++ b/examples/hubert/simple_kmeans/learn_kmeans.py @@ -0,0 +1,146 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +import numpy as np +from sklearn.cluster import MiniBatchKMeans + +import joblib + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("learn_kmeans") + + +def get_km_model( + n_clusters, + init, + max_iter, + batch_size, + tol, + max_no_improvement, + n_init, + reassignment_ratio, +): + return MiniBatchKMeans( + n_clusters=n_clusters, + init=init, + max_iter=max_iter, + batch_size=batch_size, + verbose=1, + compute_labels=False, + tol=tol, + max_no_improvement=max_no_improvement, + init_size=None, + n_init=n_init, + reassignment_ratio=reassignment_ratio, + ) + + +def load_feature_shard(feat_dir, split, nshard, rank, percent): + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + with open(leng_path, "r") as f: + lengs = [int(line.rstrip()) for line in f] + offsets = [0] + np.cumsum(lengs[:-1]).tolist() + + if percent < 0: + return np.load(feat_path, mmap_mode="r") + else: + nsample = int(np.ceil(len(lengs) * percent)) + indices = np.random.choice(len(lengs), nsample, replace=False) + feat = np.load(feat_path, mmap_mode="r") + sampled_feat = np.concatenate( + [feat[offsets[i]: offsets[i] + lengs[i]] for i in indices], axis=0 + ) + logger.info( + ( + f"sampled {nsample} utterances, {len(sampled_feat)} frames " + f"from shard {rank}/{nshard}" + ) + ) + return sampled_feat + + +def load_feature(feat_dir, split, nshard, seed, percent): + assert percent <= 1.0 + feat = np.concatenate( + [ + load_feature_shard(feat_dir, split, nshard, r, percent) + for r in range(nshard) + ], + axis=0, + ) + logging.info(f"loaded feature with dimension {feat.shape}") + return feat + + +def learn_kmeans( + feat_dir, + split, + nshard, + km_path, + n_clusters, + seed, + percent, + init, + max_iter, + batch_size, + tol, + n_init, + reassignment_ratio, + max_no_improvement, +): + np.random.seed(seed) + feat = load_feature(feat_dir, split, nshard, seed, percent) + km_model = get_km_model( + n_clusters, + init, + max_iter, + batch_size, + tol, + max_no_improvement, + n_init, + reassignment_ratio, + ) + km_model.fit(feat) + joblib.dump(km_model, km_path) + + inertia = -km_model.score(feat) / len(feat) + logger.info("total intertia: %.5f", inertia) + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("feat_dir", type=str) + parser.add_argument("split", type=str) + parser.add_argument("nshard", type=int) + parser.add_argument("km_path", type=str) + parser.add_argument("n_clusters", type=int) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--percent", default=-1, type=float, help="sample a subset; -1 for all" + ) + parser.add_argument("--init", default="k-means++") + parser.add_argument("--max_iter", default=100, type=int) + parser.add_argument("--batch_size", default=10000, type=int) + parser.add_argument("--tol", default=0.0, type=float) + parser.add_argument("--max_no_improvement", default=100, type=int) + parser.add_argument("--n_init", default=20, type=int) + parser.add_argument("--reassignment_ratio", default=0.0, type=float) + args = parser.parse_args() + logging.info(str(args)) + + learn_kmeans(**vars(args)) diff --git a/examples/hubert/update_ckpt.py b/examples/hubert/update_ckpt.py new file mode 100644 index 0000000000..53c9e74ea6 --- /dev/null +++ b/examples/hubert/update_ckpt.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +src_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2.pt" +ref_ckpt = "/checkpoint/wnhsu/w2v/hubert_icassp_oss_v3/iter2_km100-400k-grp-L6/oss.km500_p0_1_s334.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU100k.s1337.ngpu32/checkpoint_last.pt" +new_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2_updated.pt" + + +def update_state(state): + state["model"]["label_embs_concat"] = state["model"].pop("label_embs") + state["args"].task = "hubert_pretraining" + state["args"].labels = f"['{state['args'].labels}']" + return state + + +src_state = torch.load(src_ckpt) +src_state = update_state(src_state) +torch.save(src_state, new_ckpt) diff --git a/fairseq/criterions/hubert_criterion.py b/fairseq/criterions/hubert_criterion.py new file mode 100644 index 0000000000..68cb24e6f1 --- /dev/null +++ b/fairseq/criterions/hubert_criterion.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class HubertCriterionConfig(FairseqDataclass): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("hubert", dataclass=HubertCriterionConfig) +class HubertCriterion(FairseqCriterion): + def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0. + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = model.get_logits(net_output, True) + targ_m_list = model.get_targets(net_output, True) + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list = model.get_logits(net_output, False) + targ_u_list = model.get_targets(net_output, False) + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + **logging_output, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + corr_m, count_m = compute_correct(logp_m) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + corr_u, count_u = compute_correct(logp_u) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + if sample_size != ntokens: + metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) + else: + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 30af792185..8b7eb2ec4f 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -13,6 +13,7 @@ from .add_target_dataset import AddTargetDataset from .append_token_dataset import AppendTokenDataset from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset +from .audio.hubert_dataset import HubertDataset from .backtranslation_dataset import BacktranslationDataset from .bucket_pad_length_dataset import BucketPadLengthDataset from .colorize_dataset import ColorizeDataset @@ -82,7 +83,9 @@ "FairseqDataset", "FairseqIterableDataset", "FastaDataset", + "FileAudioDataset", "GroupedIterator", + "HubertDataset", "IdDataset", "IndexedCachedDataset", "IndexedDataset", @@ -104,12 +107,12 @@ "PadDataset", "PrependDataset", "PrependTokenDataset", - "ReplaceDataset", - "RollDataset", - "FileAudioDataset", + "RandomCropDataset", "RawLabelDataset", "ResamplingDataset", + "ReplaceDataset", "RightPadDataset", + "RollDataset", "RoundRobinZipDatasets", "SampledMultiDataset", "SampledMultiEpochDataset", diff --git a/fairseq/data/audio/hubert_dataset.py b/fairseq/data/audio/hubert_dataset.py new file mode 100644 index 0000000000..f00fe301a6 --- /dev/null +++ b/fairseq/data/audio/hubert_dataset.py @@ -0,0 +1,358 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import sys +from typing import Any, List, Optional, Union + +import numpy as np + +import torch +import torch.nn.functional as F +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset + +logger = logging.getLogger(__name__) + + +def load_audio(manifest_path, max_keep, min_keep): + n_long, n_short = 0, 0 + names, inds, sizes = [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + + +class HubertDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + random_crop: bool = False, + single_target: bool = False, + ): + self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.sample_rate = sample_rate + self.shuffle = shuffle + self.random_crop = random_crop + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, int) + else label_rates + ) + self.store_labels = store_labels + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert ( + label_processors is None + or len(label_processors) == self.num_labels + ) + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths( + self.sizes, sample_rate, label_path, label_rate, inds, tot + ) + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav, cur_sample_rate = sf.read(wav_path) + wav = torch.from_numpy(wav).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def __getitem__(self, index): + wav = self.get_audio(index) + labels = self.get_labels(index) + return {"id": index, "source": wav, "label_list": labels} + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater(self, samples): + # target = max(sizes) -> random_crop not used + # target = max_sample_size -> random_crop used for long + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + collated_audios, padding_mask, audio_starts = self.collater_audio( + audios, audio_size + ) + + targets_by_label = [ + [s["label_list"][i] for s in samples] + for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label( + targets_by_label, audio_size, audio_starts + ) + + net_input = {"source": collated_audios, "padding_mask": padding_mask} + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + } + + if self.single_target: + batch["target_lengths"] = lengths_list[0] + batch["ntokens"] = ntokens_list[0] + batch["target"] = targets_list[0] + else: + batch["target_lengths_list"] = lengths_list + batch["ntokens_list"] = ntokens_list + batch["target_list"] = targets_list + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat( + [audio, audio.new_full((-diff,), 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + return collated_audios, padding_mask, audio_starts + + def collater_frm_label( + self, targets, audio_size, audio_starts, label_rate, pad + ): + assert label_rate > 0 + s2f = label_rate / self.sample_rate + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1: + targets, lengths, ntokens = self.collater_seq_label( + targets, pad + ) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index b09e87fe09..b7736116f9 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -306,7 +306,7 @@ def distributed_init(cfg: FairseqConfig): model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) - if getattr(cfg.model, "base_layers", 0) > 0: + if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" return cfg.distributed_training.distributed_rank diff --git a/fairseq/models/hubert/__init__.py b/fairseq/models/hubert/__init__.py new file mode 100644 index 0000000000..a1b0eabbdb --- /dev/null +++ b/fairseq/models/hubert/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hubert import * # noqa +from .hubert_asr import * # noqa diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py new file mode 100644 index 0000000000..232a5e402a --- /dev/null +++ b/fairseq/models/hubert/hubert.py @@ -0,0 +1,563 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + ConvFeatureExtractionModel, + TransformerEncoder, +) +from fairseq.modules import GradMultiply, LayerNorm +from fairseq.tasks.hubert_pretraining import ( + HubertPretrainingConfig, + HubertPretrainingTask, +) +from omegaconf import II + +logger = logging.getLogger(__name__) + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum( + ["static", "uniform", "normal", "poisson"] +) + + +@dataclass +class HubertConfig(FairseqDataclass): + label_rate: int = II("task.label_rate") + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={ + "help": "dropout to apply to the features (after feat extr)" + }, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={ + "help": "number of filters for convolutional positional embeddings" + }, + ) + conv_pos_groups: int = field( + default=16, + metadata={ + "help": "number of groups for convolutional positional embedding" + }, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + +@register_model("hubert", dataclass=HubertConfig) +class HubertModel(BaseFairseqModel): + def __init__( + self, + cfg: HubertConfig, + task_cfg: HubertPretrainingConfig, + dictionaries: List[Dictionary], + ) -> None: + super().__init__() + logger.info(f"HubertModel Config: {cfg}") + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = ( + cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + final_dim = ( + cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + ) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info( + "cannot find dictionary. assume will be used for fine-tuning" + ) + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask): + """Build a new model instance.""" + + model = HubertModel(cfg, task.cfg, task.dictionaries) + return model + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity( + x.float(), targets.float(), dim=-1 + ).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, features: torch.Tensor, target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask, target_list + ) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate( + zip(proj_x_m_list, target_list) + ) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate( + zip(proj_x_u_list, target_list) + ) + ] + else: + logit_u_list = [None for _ in target_list] + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [ + x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list + ] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None diff --git a/fairseq/models/hubert/hubert_asr.py b/fairseq/models/hubert/hubert_asr.py new file mode 100644 index 0000000000..4cb3fb7153 --- /dev/null +++ b/fairseq/models/hubert/hubert_asr.py @@ -0,0 +1,373 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from argparse import Namespace +from typing import Any + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model +from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.tasks import FairseqTask +from omegaconf import II, MISSING + + +@dataclass +class HubertAsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to hubert model"} + ) + no_pretrained_weights: bool = field( + default=False, + metadata={"help": "if true, does not load pretrained weights"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout after transformer and before final projection" + }, + ) + dropout: float = field( + default=0.0, + metadata={"help": "dropout probability inside hubert model"}, + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside hubert model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside hubert model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask " + "(normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + freeze_finetune_updates: int = field( + default=0, + metadata={"help": "dont finetune hubert for this many updates"}, + ) + feature_grad_mult: float = field( + default=0.0, + metadata={"help": "reset feature grad mult in hubert to this"}, + ) + layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a layer in hubert"}, + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + + # this holds the loaded hubert args + w2v_args: Any = None + + +@dataclass +class HubertCtcConfig(HubertAsrConfig): + pass + + +@register_model("hubert_ctc", dataclass=HubertCtcConfig) +class HubertCtc(BaseFairseqModel): + def __init__(self, cfg: HubertCtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertCtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = HubertEncoder(cfg, task.target_dictionary) + return cls(cfg, w2v_encoder) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][..., 0] = 0 + logits[padding][..., 1:] = float("-inf") + + return logits + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +@dataclass +class HubertSeq2SeqConfig(HubertAsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field( + default=6, metadata={"help": "num of decoder layers"} + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " + "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + + +class HubertEncoder(FairseqEncoder): + def __init__(self, cfg: HubertAsrConfig, tgt_dict=None): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.w2v_path, arg_overrides + ) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( + w2v_args + ) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + task = tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) + + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + model.load_state_dict(state["model"], strict=False) + + model.remove_pretraining_modules() + + super().__init__(task.source_dictionary) + + d = w2v_args.model.encoder_embed_dim + + self.w2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + if tgt_dict is not None: + self.proj = Linear(d, len(tgt_dict)) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_features(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out[ + "encoder_out" + ].index_select(1, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py new file mode 100644 index 0000000000..aff4100bb8 --- /dev/null +++ b/fairseq/tasks/hubert_pretraining.py @@ -0,0 +1,189 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False, + ) + + +@dataclass +class HubertPretrainingConfig(FairseqDataclass): + data: str = field( + default=MISSING, metadata={"help": "path to data directory"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: int = field( + default=-1, + metadata={"help": "label frame rate. -1 for sequence label"}, + ) + + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={ + "help": "if set, normalizes input to have 0 mean and unit variance" + }, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " + "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig) +class HubertPretrainingTask(FairseqTask): + + cfg: HubertPretrainingConfig + + def __init__( + self, + cfg: HubertPretrainingConfig, + dictionaries: Dict[str, Dictionary], + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"HubertPretrainingTask Config {cfg}") + self._dictionaries = dictionaries + + self._source_dictionary = None + self._target_dictionary = None + + if len(self.dictionaries) == 1: + self._target_dictionary = self.dictionaries[0] + self.blank_symbol = "<s>" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return self._source_dictionary + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self._target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return [self._dictionaries[l] for l in self.cfg.labels] + + @classmethod + def setup_task( + cls, cfg: HubertPretrainingConfig, **kwargs + ) -> "HubertPretrainingTask": + label_dir = cfg.data if cfg.label_dir is None else cfg.label_dir + dictionaries = { + label: Dictionary.load(f"{label_dir}/dict.{label}.txt") + if os.path.exists(f"{label_dir}/dict.{label}.txt") + else None + for label in cfg.labels + } + return cls(cfg, dictionaries) + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + pad_list = [self._dictionaries[l].pad() for l in self.cfg.labels] + eos_list = [self._dictionaries[l].eos() for l in self.cfg.labels] + procs = [LabelEncoder(self._dictionaries[l]) for l in self.cfg.labels] + paths = [ + f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels + ] + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=None, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size( + self, indices: np.array, *args, **kwargs + ) -> np.array: + return indices From 49cf3e0bc389f0c704f7540c7d45100066ce7f7c Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Sat, 22 May 2021 00:21:06 -0700 Subject: [PATCH 587/707] fixing s2t transformer and N-best checkpoint saving Summary: - fixing the default value for `encoder_freezing_updates` in s2t transformer - fixing N-best checkpoint saving: the previous implementation compares the new checkpoint with only the previous best one but not the previous N best ones. This leads to suboptimal results on N-best checkpoint averaging. Reviewed By: jmp84 Differential Revision: D28546493 fbshipit-source-id: 44ec6d5ab49347f392d71269c5dcfd154b00c11e --- fairseq/checkpoint_utils.py | 20 +++++++++++++++---- .../models/speech_to_text/s2t_transformer.py | 3 +-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 23677be83d..b1a19d8515 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -12,6 +12,7 @@ import traceback from collections import OrderedDict from typing import Any, Dict, Optional, Union +from random import randint import torch from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig @@ -76,11 +77,22 @@ def is_better(a, b): or is_better(val_loss, save_checkpoint.best) ) if val_loss is not None and cfg.keep_best_checkpoints > 0: - checkpoint_conds[ - "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) - ] = not hasattr(save_checkpoint, "best") or is_better( - val_loss, save_checkpoint.best + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( + cfg.best_checkpoint_metric + ), ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace(".pt", "")) + # add random digits to resolve ties + rand_sfx = randint(0, cfg.keep_best_checkpoints) + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}.pt".format(cfg.best_checkpoint_metric, + val_loss, rand_sfx) + ] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index ff3d2100c7..5c935efaf5 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -203,7 +203,6 @@ def add_args(parser): ) parser.add_argument( '--encoder-freezing-updates', - default=None, type=int, metavar='N', help='freeze encoder for first N updates' @@ -279,7 +278,7 @@ class S2TTransformerEncoder(FairseqEncoder): def __init__(self, args): super().__init__(None) - self.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) + self.encoder_freezing_updates = args.encoder_freezing_updates self.num_updates = 0 self.dropout_module = FairseqDropout( From 366974d9817138d1618693f021ea1690f9e53f33 Mon Sep 17 00:00:00 2001 From: Patrick von Platen <patrick.v.platen@gmail.com> Date: Sun, 23 May 2021 16:18:55 -0700 Subject: [PATCH 588/707] HF Wav2Vec2 Example (#3502) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## What does this PR do? This PR updates some outdated code from the Hugging Face Transformers library to the new, better format. ## PR review alexeib ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3502 Reviewed By: arbabu123 Differential Revision: D28140574 Pulled By: alexeib fbshipit-source-id: f03643e7ebba04015d942a3aa9529f7f6600c734 --- examples/wav2vec/README.md | 41 +++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index bfed3913cf..238639a9ba 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -149,31 +149,54 @@ To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the t ## Use wav2vec 2.0 with 🤗Transformers: -Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.3. +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.4. -Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) +Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). Usage example: ```python # !pip install transformers +# !pip install datasets import soundfile as sf import torch -from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer +from datasets import load_dataset +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor # load pretrained model -tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") -model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") +processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") +model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + + +librispeech_samples_ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") # load audio -audio_input, _ = sf.read("path/to/audio/file") +audio_input, sample_rate = sf.read(librispeech_samples_ds[0]["file"]) -# transcribe -input_values = tokenizer(audio_input, return_tensors="pt").input_values +# pad input values and return pt tensor +input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values + +# INFERENCE + +# retrieve logits & take argmax logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) -transcription = tokenizer.batch_decode(predicted_ids)[0] + +# transcribe +transcription = processor.decode(predicted_ids[0]) + +# FINE-TUNE + +target_transcription = "A MAN SAID TO THE UNIVERSE I EXIST" + +# encode labels +with processor.as_target_processor(): + labels = processor(target_transcription, return_tensors="pt").input_ids + +# compute loss by passing labels +loss = model(input_values, labels=labels).loss +loss.backward() ``` # wav2vec From 342d5daf34acaf34c93b5a0a313f46bc4104c7fc Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sun, 23 May 2021 21:24:38 -0700 Subject: [PATCH 589/707] propagate quantizer depth and factor args through w2v (#1892) Summary: makes quantizer larger which helps accuracy in certain cases Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1892 Reviewed By: arbabu123 Differential Revision: D28630035 Pulled By: alexeib fbshipit-source-id: ba5a902ff1623025e7566e901aa81cdf377a7aa0 --- fairseq/models/wav2vec/wav2vec2.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 6002d28438..714fd3ab50 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -119,6 +119,16 @@ class Wav2Vec2Config(FairseqDataclass): feature_grad_mult: float = field( default=1.0, metadata={"help": "multiply feature extractor var grads by this"} ) + quantizer_depth: int = field( + default=1, + metadata={"help": "number of quantizer layers"}, + ) + quantizer_factor: int = field( + default=3, + metadata={ + "help": "dimensionality increase for inner quantizer layers (if depth > 1)" + }, + ) latent_vars: int = field( default=320, metadata={"help": "number of latent variables V in each group of the codebook"}, @@ -284,6 +294,8 @@ def __init__(self, cfg: Wav2Vec2Config): combine_groups=False, vq_dim=vq_dim, time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, ) self.project_q = nn.Linear(vq_dim, final_dim) else: @@ -303,6 +315,8 @@ def __init__(self, cfg: Wav2Vec2Config): combine_groups=False, vq_dim=vq_dim, time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, ) self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) From 2be2f3c7c1ba9ec3ee6ef929f7edce13052b6844 Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Mon, 24 May 2021 08:58:56 -0700 Subject: [PATCH 590/707] Plasma tests: ask for less disk (#1893) Summary: Old logs: ``` /arrow/cpp/src/plasma/store.cc:1274: Allowing the Plasma store to use up to 107.374GB of memory. ``` New logs: ``` ... up to 1e-05GB of memory. ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1893 Reviewed By: myleott Differential Revision: D28641488 Pulled By: sshleifer fbshipit-source-id: 3373526042cdcbf434c61790be62a09f15e6ad06 --- tests/test_plasma_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_plasma_utils.py b/tests/test_plasma_utils.py index a5cf386b86..e6344c2a5a 100644 --- a/tests/test_plasma_utils.py +++ b/tests/test_plasma_utils.py @@ -23,7 +23,7 @@ class TestPlasmaView(unittest.TestCase): def setUp(self) -> None: self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201 self.path = self.tmp_file.name - self.server = PlasmaStore.start(path=self.path) + self.server = PlasmaStore.start(path=self.path, nbytes=10000) self.client = plasma.connect(self.path, num_retries=10) def tearDown(self) -> None: From 30003ba4192f173d84cb0dbfee678296d8af6378 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Mon, 24 May 2021 19:09:23 -0700 Subject: [PATCH 591/707] fix serialization on python 3.6 (#1894) Summary: fixes serialization errors when using python 3.6 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1894 Reviewed By: arbabu123 Differential Revision: D28655932 Pulled By: alexeib fbshipit-source-id: df40f972966e828817a2861e6e907835fe1d9573 --- fairseq/optim/adam.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index cfe948a194..6a31e53a62 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -15,7 +15,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class -from omegaconf import II, DictConfig +from omegaconf import II, OmegaConf logger = logging.getLogger(__name__) @@ -77,7 +77,9 @@ def optimizer_config(self): "lr": self.cfg.lr[0] if isinstance(self.cfg.lr, Collection) else self.cfg.lr, - "betas": eval(self.cfg.adam_betas) if isinstance(self.cfg.adam_betas, str) else self.cfg.adam_betas, + "betas": eval(self.cfg.adam_betas) + if isinstance(self.cfg.adam_betas, str) + else OmegaConf.to_container(self.cfg.adam_betas), "eps": self.cfg.adam_eps, "weight_decay": self.cfg.weight_decay, } From 5a75b079bf8911a327940c28794608e003a9fa52 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Mon, 24 May 2021 19:56:14 -0700 Subject: [PATCH 592/707] fix saving w2v args in config (#1896) Summary: previous changes broke saving updating w2v_args in config as the model had a copy of the config. this change makes the task copy over the field to save. not the nicest approach, but it works for now Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1896 Reviewed By: arbabu123 Differential Revision: D28658802 Pulled By: alexeib fbshipit-source-id: a13866c42c3b88c48b8b91864c1bf1aeaeba4e8a --- fairseq/tasks/audio_pretraining.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index e0b001b667..71cefcfcaa 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -303,6 +303,12 @@ def build_model(self, model_cfg: FairseqDataclass): self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) else: self.tokenizer = None + + actualized_cfg = getattr(model, "cfg") + if actualized_cfg is not None: + if "w2v_args" in actualized_cfg: + model_cfg.w2v_args = actualized_cfg.w2v_args + return model def _inference_with_wer(self, generator, sample, model): From 95cf58056dc30a9fa653dd8adcbd5b76a180a63d Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Tue, 25 May 2021 13:43:42 -0700 Subject: [PATCH 593/707] Update model table in README (#1901) Summary: ## What does this PR do? Updated the models' table in README to show the model sizes and groups pretrained models followed by fine tuned models. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1901 Reviewed By: wnhsu Differential Revision: D28688952 Pulled By: hikushalhere fbshipit-source-id: 8621398a785caa3d7bdc68367789ad7f48499d0d --- examples/hubert/README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/hubert/README.md b/examples/hubert/README.md index 88973c22f2..ca714469c6 100644 --- a/examples/hubert/README.md +++ b/examples/hubert/README.md @@ -3,12 +3,13 @@ ## Pre-trained and fine-tuned (ASR) models Model | Pretraining Data | Finetuning Dataset | Model |---|---|---|--- -HuBERT Base | [Librispeech](http://www.openslr.org/12) 960 hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) -HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt) +HuBERT Base (~95M params) | [Librispeech](http://www.openslr.org/12) 960 hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +HuBERT Large (~316M params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt) +HuBERT Extra Large (~1B params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt) HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt) -HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt) HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt) + ## Train a new model ### Data preparation From 8df9e3a4a55bad55078967e97e8a8f31d90ec987 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng <wyz@fb.com> Date: Tue, 25 May 2021 17:45:04 -0700 Subject: [PATCH 594/707] support FSDP sharded_state checkpoint loading during inference Summary: using the very useful feature added by QuentinDuval https://github.com/facebookresearch/fairscale/pull/683/files , we can consolidate sharded states into a full regular states. this allows inferences on sharded state almost transparently. The main complexity comes from trying to be smart about what kind of checkpoint the user wants to load. not sure if this is over-engineering 1. if the file checkpoint-shard0.pt exists, and `--checkpoint-shard-count` is > 1, then we load sharded FSDP checkpoint 2. if checkpoint-shard0.pt exists but --checkpoint-shard-count=1, we load consolidated FSDP checkpoint 3. if checkpoint-shard0.pt does not exist, but --checkpoint-shard-count > 1, we load model parallel checkpoint 4. otherwise we are loading a single, plain checkpoint. In theory we could be even smarter and load shard0.pt to check how many more checkpoints are needed. this is not implemented, though it will save the user having to specify --checkpoint-shard-count. Reviewed By: sshleifer Differential Revision: D28563441 fbshipit-source-id: dcafcaa7c9eaf5c9ff94f55c16bb3424c98dfa59 --- fairseq/checkpoint_utils.py | 78 ++++++++++++++++++++++++++++------ fairseq/trainer.py | 3 ++ tests/gpu/test_binaries_gpu.py | 32 ++++++++++++++ 3 files changed, 100 insertions(+), 13 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index b1a19d8515..ecc45f4351 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -9,6 +9,7 @@ import logging import os import re +import time import traceback from collections import OrderedDict from typing import Any, Dict, Optional, Union @@ -20,6 +21,7 @@ convert_namespace_to_omegaconf, overwrite_args_by_name, ) +from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder from omegaconf import Container, DictConfig, open_dict, OmegaConf @@ -134,9 +136,15 @@ def is_better(a, b): ) else: checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), keep_match=True + cfg.save_dir, + pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + keep_match=True, ) - checkpoints = [x[0] for x in checkpoints if x[1] % cfg.keep_interval_updates_pattern != 0] + checkpoints = [ + x[0] + for x in checkpoints + if x[1] % cfg.keep_interval_updates_pattern != 0 + ] for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): @@ -146,7 +154,9 @@ def is_better(a, b): if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)) + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -351,6 +361,21 @@ def load_model_ensemble( return ensemble, args +def get_maybe_sharded_checkpoint_filename( + filename: str, suffix: str, shard_idx: int, num_shards: int +) -> str: + orig_filename = filename + filename = filename.replace(".pt", suffix + ".pt") + fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" + model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if PathManager.exists(fsdp_filename): + return fsdp_filename + elif num_shards > 1: + return model_parallel_filename + else: + return filename + + def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, @@ -371,12 +396,13 @@ def load_model_ensemble_and_task( cfg = None for filename in filenames: orig_filename = filename + model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 + st = time.time() for shard_idx in range(num_shards): - if num_shards == 1: - filename = filename.replace(".pt", suffix + ".pt") - else: - filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + filename = get_maybe_sharded_checkpoint_filename( + orig_filename, suffix, shard_idx, num_shards + ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) @@ -397,14 +423,38 @@ def load_model_ensemble_and_task( if "task_state" in state: task.load_state_dict(state["task_state"]) - # build model for ensemble - model = task.build_model(cfg.model) - - model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) + if "fsdp_metadata" in state and num_shards > 1: + model_shard_state["shard_weights"].append(state["model"]) + model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) + # check FSDP import before the code goes too far + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if shard_idx == num_shards - 1: + consolidated_model_state = FSDP.consolidate_shard_weights( + shard_weights=model_shard_state["shard_weights"], + shard_metadata=model_shard_state["shard_metadata"], + ) + model = task.build_model(cfg.model) + model.load_state_dict( + consolidated_model_state, strict=strict, model_cfg=cfg.model + ) + else: + # model parallel checkpoint or unsharded checkpoint + model = task.build_model(cfg.model) + model.load_state_dict( + state["model"], strict=strict, model_cfg=cfg.model + ) # reset state so it gets loaded for the next model in ensemble state = None + if shard_idx % 10 == 0 and shard_idx > 0: + elapsed = time.time() - st + logger.info(f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard") + # build model for ensemble ensemble.append(model) return ensemble, cfg, task @@ -500,8 +550,10 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( - state["args"], "max_source_positions" + if ( + "args" in state + and hasattr(state["args"], "max_positions") + and not hasattr(state["args"], "max_source_positions") ): state["args"].max_source_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d3da876c29..924b2b6b34 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -395,6 +395,9 @@ def state_dict(self): self._gathered_optim_state = None else: state_dict["last_optimizer_state"] = self.optimizer.state_dict() + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + # save meta data for recombining checkpoint upon loading + state_dict["fsdp_metadata"] = self.model.local_metadata_dict() return state_dict def save_checkpoint(self, filename, extra_state): diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index a0824c23ad..54dc96079f 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -184,6 +184,38 @@ def test_levenshtein_transformer(self): ), ) + def test_fsdp_checkpoint_generate(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir: + log = os.path.join(data_dir, "train.log") + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + world_size = min(torch.cuda.device_count(), 2) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + ["--log-file", log, "--ddp-backend", "fully_sharded"], + world_size=world_size, + ) + generate_main(data_dir) + assert os.path.exists(log) + + def test_fsdp_sharded_checkpoint_generate(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir: + log = os.path.join(data_dir, "train.log") + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + world_size = min(torch.cuda.device_count(), 2) + train_translation_model( + data_dir, + "fconv_iwslt_de_en", + ["--log-file", log, "--ddp-backend", "fully_sharded", "--use-sharded-state"], + world_size=world_size, + ) + generate_main(data_dir, ["--checkpoint-shard-count", str(world_size)]) + assert os.path.exists(log) + def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False): train_parser = options.get_training_parser() From 237184e5222b347475456f4a44f31a510c64ca35 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh <gagandeep.singh1@nuance.com> Date: Wed, 26 May 2021 14:38:16 -0700 Subject: [PATCH 595/707] Add torch.cuda.amp support (#3460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3282 Add support for `torch.cuda.amp` AMP can be enabled by `--amp`, instead of using `--fp16` for the already present full fp16 support. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3460 Reviewed By: sshleifer, msbaines Differential Revision: D27932253 Pulled By: myleott fbshipit-source-id: 21637aefb5e788c59bf4f3c5de6c4a80f7319543 --- .../pointer_generator_src/transformer_pg.py | 10 +- fairseq/dataclass/configs.py | 14 ++- fairseq/optim/__init__.py | 2 + fairseq/optim/amp_optimizer.py | 105 ++++++++++++++++++ fairseq/tasks/fairseq_task.py | 4 +- fairseq/trainer.py | 42 +++++-- tests/gpu/test_binaries_gpu.py | 34 ++++++ tests/test_amp_optimizer.py | 78 +++++++++++++ tests/test_reproducibility.py | 12 ++ 9 files changed, 285 insertions(+), 16 deletions(-) create mode 100644 fairseq/optim/amp_optimizer.py create mode 100644 tests/test_amp_optimizer.py diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index e109a8e269..4ccf30f4eb 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from fairseq import metrics, utils +from fairseq import utils from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import ( DEFAULT_MAX_SOURCE_POSITIONS, @@ -300,7 +300,7 @@ def forward( prev_output_embed *= self.embed_scale predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) - p_gens = torch.sigmoid(p_gens) + p_gens = torch.sigmoid(p_gens.float()) # Torchscript complains if encoder_out or attn are None because # `output_layer()` signature expects tensors instead attn: Optional[Tensor] = extra["attn"][0] @@ -351,18 +351,18 @@ def output_layer( # vocab_size]. Each attention weight will be written into a location # that is for other dimensions the same as in the index tensor, but for # the third dimension it's the value of the index tensor (the token ID). - attn = torch.mul(attn, 1 - p_gens) + attn = torch.mul(attn.float(), 1 - p_gens) index = src_tokens[:, None, :] index = index.expand(batch_size, output_length, src_length) attn_dists_size = (batch_size, output_length, self.num_types) attn_dists = attn.new_zeros(attn_dists_size) - attn_dists.scatter_add_(2, index, attn) + attn_dists.scatter_add_(2, index, attn.float()) # Final distributions, [batch_size, output_length, num_types]. return gen_dists + attn_dists def get_normalized_probs( - self, + self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f41cfcd94f..70d7476d31 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -150,11 +150,23 @@ class CommonConfig(FairseqDataclass): ) min_loss_scale: float = field( default=1e-4, - metadata={"help": "minimum FP16 loss scale, after which training is stopped"}, + metadata={"help": "minimum FP16/AMP loss scale, after which training is stopped"}, ) threshold_loss_scale: Optional[float] = field( default=None, metadata={"help": "threshold FP16 loss scale from below"} ) + amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"}) + amp_batch_retries: int = field( + default=2, + metadata={"help": "number of retries of same batch after reducing loss scale with AMP"}, + ) + amp_init_scale: int = field( + default=2 ** 7, metadata={"help": "default AMP loss scale"} + ) + amp_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing AMP loss scale"}, + ) user_dir: Optional[str] = field( default=None, metadata={ diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 01c08c98d2..be783be896 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -13,11 +13,13 @@ FairseqOptimizer, LegacyFairseqOptimizer, ) +from fairseq.optim.amp_optimizer import AMPOptimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.shard import shard_ from omegaconf import DictConfig __all__ = [ + "AMPOptimizer", "FairseqOptimizer", "FP16Optimizer", "MemoryEfficientFP16Optimizer", diff --git a/fairseq/optim/amp_optimizer.py b/fairseq/optim/amp_optimizer.py new file mode 100644 index 0000000000..3b7958e50c --- /dev/null +++ b/fairseq/optim/amp_optimizer.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from fairseq import optim +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +class AMPOptimizer(optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support AMP (automatic mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, **kwargs): + super().__init__(cfg.optimizer) + self.fp32_optimizer = fp32_optimizer + amp_kwargs = {"init_scale": cfg.common.fp16_init_scale} + if getattr(cfg.common, "amp_scale_window", None) is not None: + amp_kwargs["growth_interval"] = cfg.common.amp_init_scale + self._grad_scaler = torch.cuda.amp.GradScaler(**amp_kwargs) + self.min_loss_scale = cfg.common.min_loss_scale + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp32_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp32_optimizer, **kwargs) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + self._grad_scaler.scale(loss).backward() + + def step(self): + self.scaler.step(self.fp32_optimizer) + self.scaler.update() + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + self.scaler.unscale_(self.optimizer) + grad_norm = self.fp32_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + if not torch.isfinite(grad_norm).all(): + new_loss_scale = self.next_loss_scale + if new_loss_scale <= self.min_loss_scale: + raise FloatingPointError( + ( + "AMP: Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try restarting training or use fp32. {}" + ).format(self.min_loss_scale, new_loss_scale) + ) + else: + logger.info("AMP: overflow detected, setting scale to " + f"to {new_loss_scale}") + return grad_norm + + @property + def scaler(self): + return self._grad_scaler + + @property + def next_loss_scale(self): + return self.scaler.get_scale() * self.scaler.get_backoff_factor() + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 375b5277b9..e30b2cd985 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -14,6 +14,7 @@ from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.amp_optimizer import AMPOptimizer from omegaconf import DictConfig @@ -472,7 +473,8 @@ def train_step( model.train() model.set_num_updates(update_num) with torch.autograd.profiler.record_function("forward"): - loss, sample_size, logging_output = criterion(model, sample) + with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + loss, sample_size, logging_output = criterion(model, sample) if ignore_grad: loss *= 0 with torch.autograd.profiler.record_function("backward"): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 924b2b6b34..d1d08025f6 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -81,11 +81,14 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): self._model = model if cfg.distributed_training.ddp_backend != "fully_sharded": if cfg.common.fp16: + assert not cfg.common.amp, "Cannot use fp16 and AMP together" self._criterion = self._criterion.half() self._model = self._model.half() elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) + elif cfg.common.amp: + self._amp_retries = 0 if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, @@ -285,10 +288,10 @@ def _build_optimizer(self): self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( self.cfg, params, allow_unsupported=allow_unsupported ) - elif self.cfg.common.fp16 or self.cfg.common.bf16: + elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( - "NOTE: your device does NOT support faster training with --fp16, " + "NOTE: your device does NOT support faster training with --fp16 or --amp, " "please switch to FP32 which is likely to be faster" ) if ( @@ -298,11 +301,13 @@ def _build_optimizer(self): self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( self.cfg, params ) + elif self.cfg.common.amp: + self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params) else: self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) else: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: - logger.info("NOTE: your device may support faster training with --fp16") + logger.info("NOTE: your device may support faster training with --fp16 or --amp") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) if self.cfg.distributed_training.ddp_backend == "fully_sharded": @@ -803,14 +808,26 @@ def maybe_no_sync(): ): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): - # check local gradnorm single GPU case, trigger NanDetector - raise FloatingPointError("gradients are Nan/Inf") + # in case of AMP, if gradients are Nan/Inf then + # optimizer step is still required + if self.cfg.common.amp: + overflow = True + else: + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.task.optimizer_step( self.optimizer, model=self.model, update_num=self.get_num_updates() ) + if self.cfg.common.amp and overflow: + if self._amp_retries == self.cfg.common.amp_batch_retries: + logger.info("AMP: skipping this batch.") + self._amp_retries = 0 + else: + self._amp_retries += 1 + return self.train_step(samples, raise_oom) # recursion to feed in same batch except FloatingPointError: # re-run the forward and backward pass with hooks attached to print @@ -915,10 +932,14 @@ def maybe_no_sync(): ): torch.cuda.empty_cache() - if self.cfg.common.fp16: + if self.cfg.common.fp16 or self.cfg.common.amp: metrics.log_scalar( "loss_scale", - self.optimizer.scaler.loss_scale, + ( + self.optimizer.scaler.loss_scale + if self.cfg.common.fp16 + else self.optimizer.scaler.get_scale() + ), priority=700, round=4, weight=0, @@ -1274,8 +1295,11 @@ def _check_grad_norms(self, grad_norm): def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( - torch.isfinite(tensor).all() - and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + (torch.isfinite(tensor).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()) + or + (self.cfg.common.amp and not torch.isfinite(tensor).all()) + # in case of amp non-finite grads are fine ) if not is_consistent(self._grad_norm_buf): diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 54dc96079f..de8c242613 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -142,6 +142,40 @@ def test_transformer_fp16(self): ) generate_main(data_dir) + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_amp(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_amp") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, "fconv_iwslt_de_en", ["--amp"]) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_transformer_amp(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "64", + "--decoder-embed-dim", + "64", + "--amp", + ], + run_validation=True, + ) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_levenshtein_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( diff --git a/tests/test_amp_optimizer.py b/tests/test_amp_optimizer.py new file mode 100644 index 0000000000..3a785e1830 --- /dev/null +++ b/tests/test_amp_optimizer.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import unittest + +import torch +from torch.cuda.amp import autocast, GradScaler +from fairseq.optim import build_optimizer + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestGradientScalingAMP(unittest.TestCase): + def setUp(self): + self.x = torch.tensor([2.0]).cuda().half() + weight = 3.0 + bias = 5.0 + self.error = 1.0 + self.target = torch.tensor([self.x * weight + bias + self.error]).cuda() + self.loss_fn = torch.nn.L1Loss() + + self.model = torch.nn.Linear(1, 1) + self.model.weight.data = torch.tensor([[weight]]) + self.model.bias.data = torch.tensor([bias]) + self.model.cuda() + self.params = list(self.model.parameters()) + + self.namespace_dls = argparse.Namespace( + optimizer="adam", + lr=[0.1], + adam_betas="(0.9, 0.999)", + adam_eps=1e-8, + weight_decay=0.0, + threshold_loss_scale=1, + min_loss_scale=1e-4, + ) + self.scaler = GradScaler( + init_scale=1, + growth_interval=1, + ) + + def run_iter(self, model, params, optimizer): + optimizer.zero_grad() + with autocast(): + y = model(self.x) + loss = self.loss_fn(y, self.target) + self.scaler.scale(loss).backward() + self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) + + self.scaler.unscale_(optimizer) + grad_norm = optimizer.clip_grad_norm(0) + self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) + + self.scaler.step(optimizer) + self.scaler.update() + self.assertEqual( + model.weight, + torch.tensor( + [[3.1]], device="cuda:0", requires_grad=True + ), + ) + self.assertEqual( + model.bias, + torch.tensor( + [5.1], device="cuda:0", requires_grad=True + ), + ) + self.assertEqual(self.scaler.get_scale(), 2.0) + + def test_automatic_mixed_precision(self): + model = copy.deepcopy(self.model) + params = list(model.parameters()) + optimizer = build_optimizer(self.namespace_dls, params) + + self.run_iter(model, params, optimizer) diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 405d545593..94931b2a07 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -125,6 +125,18 @@ def test_reproducibility_memory_efficient_fp16(self): ], ) + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_reproducibility_amp(self): + self._test_reproducibility( + "test_reproducibility_amp", + [ + "--amp", + "--fp16-init-scale", + "4096", + ], + delta=0.011, + ) + def test_mid_epoch_reproducibility(self): self._test_reproducibility( "test_mid_epoch_reproducibility", From e6eddd805ebbc5c17bf5100c2fde6e0dfc946d2c Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Wed, 26 May 2021 16:27:59 -0700 Subject: [PATCH 596/707] =?UTF-8?q?make=20hydra/infer.py=20work;=20also=20?= =?UTF-8?q?dont=20break=20if=20something=20is=20removed=20fro=E2=80=A6=20(?= =?UTF-8?q?#1903)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: previously hydra/infer.py did not always work for several reasons which are addressed here new example usage: PYTHONPATH=. python examples/speech_recognition/new/infer.py --config-dir examples/speech_recognition/hydra/conf --config-name infer task=audio_pretraining task.data=/path/to/data task.labels=ltr decoding.type=kenlm decoding.lexicon=/path/to/lexicon decoding.lmpath=/path/to/lm dataset.gen_subset=dev_other common_eval.path=/path/to/model.pt decoding.beam=5 decoding.lmweight=2 decoding.wordscore=-1 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1903 Reviewed By: arbabu123 Differential Revision: D28700795 Pulled By: alexeib fbshipit-source-id: 66fe454de49c1bf511b3529ac683f1c8cb08e579 --- examples/hubert/README.md | 4 +- .../hubert/config/decode/infer_fsqlm.yaml | 25 +- .../hubert/config/decode/infer_kenlm.yaml | 25 +- .../hubert/config/decode/infer_viterbi.yaml | 7 +- examples/speech_recognition/infer.py | 2 +- .../{hydra => new}/README.md | 24 +- examples/speech_recognition/new/__init__.py | 0 .../{hydra => new}/conf/hydra/sweeper/ax.yaml | 8 +- .../{hydra => new}/conf/infer.yaml | 13 +- .../new/decoders/__init__.py | 0 .../new/decoders/base_decoder.py | 62 ++++ .../new/decoders/decoder.py | 32 ++ .../new/decoders/decoder_config.py | 70 ++++ .../decoders/flashlight_decoder.py} | 326 ++++-------------- .../new/decoders/viterbi_decoder.py | 24 ++ .../{hydra => new}/infer.py | 140 ++++---- fairseq/criterions/ctc.py | 15 +- fairseq/dataclass/utils.py | 16 +- fairseq/models/wav2vec/wav2vec2_asr.py | 1 - fairseq/tasks/audio_pretraining.py | 69 ++-- fairseq/utils.py | 4 +- fairseq_cli/validate.py | 3 + 22 files changed, 450 insertions(+), 420 deletions(-) rename examples/speech_recognition/{hydra => new}/README.md (54%) create mode 100644 examples/speech_recognition/new/__init__.py rename examples/speech_recognition/{hydra => new}/conf/hydra/sweeper/ax.yaml (80%) rename examples/speech_recognition/{hydra => new}/conf/infer.yaml (67%) create mode 100644 examples/speech_recognition/new/decoders/__init__.py create mode 100644 examples/speech_recognition/new/decoders/base_decoder.py create mode 100644 examples/speech_recognition/new/decoders/decoder.py create mode 100644 examples/speech_recognition/new/decoders/decoder_config.py rename examples/speech_recognition/{hydra/decoder.py => new/decoders/flashlight_decoder.py} (54%) create mode 100644 examples/speech_recognition/new/decoders/viterbi_decoder.py rename examples/speech_recognition/{hydra => new}/infer.py (80%) diff --git a/examples/hubert/README.md b/examples/hubert/README.md index ca714469c6..c0b1125cb5 100644 --- a/examples/hubert/README.md +++ b/examples/hubert/README.md @@ -65,7 +65,7 @@ Decoding results will be saved at `/path/to/experiment/directory/decode/viterbi/test`. ```sh -$ python examples/speech_recognition/hydra/infer.py \ +$ python examples/speech_recognition/new/infer.py \ --config-dir /path/to/fairseq-py/examples/hubert/config/decode \ --config-name infer_viterbi \ task.data=/path/to/data \ @@ -82,7 +82,7 @@ Suppose the pronunciation lexicon and the n-gram LM are saved at saved at `/path/to/experiment/directory/decode/kenlm/test`. ```sh -$ python examples/speech_recognition/hydra/infer.py \ +$ python examples/speech_recognition/new/infer.py \ --config-dir /path/to/fairseq-py/examples/hubert/config/decode \ --config-name infer_kenlm \ task.data=/path/to/data \ diff --git a/examples/hubert/config/decode/infer_fsqlm.yaml b/examples/hubert/config/decode/infer_fsqlm.yaml index b9fb845066..bc77cab32e 100644 --- a/examples/hubert/config/decode/infer_fsqlm.yaml +++ b/examples/hubert/config/decode/infer_fsqlm.yaml @@ -17,25 +17,20 @@ task: normalize: ??? decoding: - exp_dir: ??? - decoder: - name: fairseqlm - lexicon: ??? - lmpath: ??? - beamthreshold: 25 # 100 - beam: 500 - lmweight: 2 - wordscore: -1 - silweight: 0 - write_sentences: true + type: fairseqlm + lexicon: ??? + lmpath: ??? + beamthreshold: 25 # 100 + beam: 500 + lmweight: 2 + wordscore: -1 + silweight: 0 unique_wer_file: true + beam: 500 common_eval: - results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + results_path: ??? path: ??? post_process: letter -generation: - nbest: 1 - beam: 500 dataset: max_tokens: 1100000 gen_subset: ??? diff --git a/examples/hubert/config/decode/infer_kenlm.yaml b/examples/hubert/config/decode/infer_kenlm.yaml index fe464eaae5..26f5c48928 100644 --- a/examples/hubert/config/decode/infer_kenlm.yaml +++ b/examples/hubert/config/decode/infer_kenlm.yaml @@ -17,25 +17,20 @@ task: normalize: ??? decoding: - exp_dir: ??? - decoder: - name: kenlm - lexicon: ??? - lmpath: ??? - beamthreshold: 100 - beam: 500 - lmweight: 2 - wordscore: -1 - silweight: 0 - write_sentences: true + type: kenlm + lexicon: ??? + lmpath: ??? + beamthreshold: 100 + beam: 500 + lmweight: 2 + wordscore: -1 + silweight: 0 unique_wer_file: true + beam: 500 common_eval: - results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + results_path: ??? path: ??? post_process: letter -generation: - nbest: 1 - beam: 500 dataset: max_tokens: 1100000 gen_subset: ??? diff --git a/examples/hubert/config/decode/infer_viterbi.yaml b/examples/hubert/config/decode/infer_viterbi.yaml index d0de9cfd26..935d7d1d01 100644 --- a/examples/hubert/config/decode/infer_viterbi.yaml +++ b/examples/hubert/config/decode/infer_viterbi.yaml @@ -17,13 +17,10 @@ task: normalize: ??? decoding: - exp_dir: ??? - decoder: - name: viterbi - write_sentences: true + type: viterbi unique_wer_file: true common_eval: - results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}/${dataset.gen_subset} + results_path: ??? path: ??? post_process: letter generation: diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index f4efbf39c8..6e9a878af4 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -227,7 +227,7 @@ def main(args, task=None, model_state=None): else: logger.info("| loading model(s) from {}".format(args.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - utils.split_paths(args.path), + utils.split_paths(args.path, separator="\\"), arg_overrides=ast.literal_eval(args.model_overrides), task=task, suffix=args.checkpoint_suffix, diff --git a/examples/speech_recognition/hydra/README.md b/examples/speech_recognition/new/README.md similarity index 54% rename from examples/speech_recognition/hydra/README.md rename to examples/speech_recognition/new/README.md index 17d5946675..5fa0e97245 100644 --- a/examples/speech_recognition/hydra/README.md +++ b/examples/speech_recognition/new/README.md @@ -7,7 +7,7 @@ This script runs decoding for pre-trained speech recognition models. Assuming a few variables: ```bash -exp_dir=<path-to-experiment-directory> +checkpoint=<path-to-checkpoint> data=<path-to-data-directory> lm_model=<path-to-language-model> lexicon=<path-to-lexicon> @@ -16,30 +16,28 @@ lexicon=<path-to-lexicon> Example usage for decoding a fine-tuned Wav2Vec model: ```bash -python $FAIRSEQ_ROOT/examples/speech_recognition/hydra/infer.py --multirun \ +python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \ task=audio_pretraining \ task.data=$data \ task.labels=ltr \ - decoding.exp_dir=$exp_dir \ - decoding.decoder.name=kenlm \ - decoding.decoder.lexicon=$lexicon \ - decoding.decoder.lmpath=$lm_model \ + common_eval.path=$checkpoint \ + decoding.type=kenlm \ + decoding.lexicon=$lexicon \ + decoding.lmpath=$lm_model \ dataset.gen_subset=dev_clean,dev_other,test_clean,test_other ``` Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`): ```bash -python $FAIRSEQ_ROOT/examples/speech_recognition/hydra/infer.py --multirun \ +python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \ hydra/sweeper=ax \ task=audio_pretraining \ task.data=$data \ task.labels=ltr \ - decoding.exp_dir=$exp_dir \ - decoding.decoder.name=kenlm \ - decoding.decoder.lexicon=$lexicon \ - decoding.decoder.lmpath=$lm_model \ - decoding.write_sentences=false \ - decoding.unique_wer_file=true \ + common_eval.path=$checkpoint \ + decoding.type=kenlm \ + decoding.lexicon=$lexicon \ + decoding.lmpath=$lm_model \ dataset.gen_subset=dev_other ``` diff --git a/examples/speech_recognition/new/__init__.py b/examples/speech_recognition/new/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml b/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml similarity index 80% rename from examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml rename to examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml index 7700712ea0..fbeff17ca6 100644 --- a/examples/speech_recognition/hydra/conf/hydra/sweeper/ax.yaml +++ b/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml @@ -2,10 +2,10 @@ _target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper max_batch_size: null ax_config: - max_trials: 100 + max_trials: 128 early_stop: minimize: true - max_epochs_without_improvement: 10 + max_epochs_without_improvement: 32 epsilon: 1.0e-05 experiment: name: ${dataset.gen_subset} @@ -18,9 +18,9 @@ ax_config: verbose_logging: false random_seed: null params: - decoding.decoder.lmweight: + decoding.lmweight: type: range bounds: [0.0, 5.0] - decoding.decoder.wordscore: + decoding.wordscore: type: range bounds: [-5.0, 5.0] diff --git a/examples/speech_recognition/hydra/conf/infer.yaml b/examples/speech_recognition/new/conf/infer.yaml similarity index 67% rename from examples/speech_recognition/hydra/conf/infer.yaml rename to examples/speech_recognition/new/conf/infer.yaml index 1d78ba14cb..f176228082 100644 --- a/examples/speech_recognition/hydra/conf/infer.yaml +++ b/examples/speech_recognition/new/conf/infer.yaml @@ -11,12 +11,15 @@ hydra: dir: ${common_eval.results_path} subdir: ${dataset.gen_subset} common_eval: - results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name} - path: ${decoding.exp_dir}/checkpoint_best.pt + results_path: null + path: null post_process: letter -generation: - nbest: 1 - beam: 500 + quiet: true dataset: max_tokens: 1000000 gen_subset: test +distributed_training: + distributed_world_size: 1 +decoding: + beam: 5 + type: viterbi diff --git a/examples/speech_recognition/new/decoders/__init__.py b/examples/speech_recognition/new/decoders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/speech_recognition/new/decoders/base_decoder.py b/examples/speech_recognition/new/decoders/base_decoder.py new file mode 100644 index 0000000000..a097969b3c --- /dev/null +++ b/examples/speech_recognition/new/decoders/base_decoder.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools as it +from typing import Any, Dict, List + +import torch +from fairseq.data.dictionary import Dictionary +from fairseq.models.fairseq_model import FairseqModel + + +class BaseDecoder: + def __init__(self, tgt_dict: Dictionary) -> None: + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + + self.blank = ( + tgt_dict.index("<ctc_blank>") + if "<ctc_blank>" in tgt_dict.indices + else tgt_dict.bos() + ) + if "<sep>" in tgt_dict.indices: + self.silence = tgt_dict.index("<sep>") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") + else: + self.silence = tgt_dict.eos() + + def generate( + self, models: List[FairseqModel], sample: Dict[str, Any], **unused + ) -> List[List[Dict[str, torch.LongTensor]]]: + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions = self.get_emissions(models, encoder_input) + return self.decode(emissions) + + def get_emissions( + self, + models: List[FairseqModel], + encoder_input: Dict[str, Any], + ) -> torch.FloatTensor: + model = models[0] + encoder_out = model(**encoder_input) + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: + idxs = (g[0] for g in it.groupby(idxs)) + idxs = filter(lambda x: x != self.blank, idxs) + return torch.LongTensor(list(idxs)) + + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + raise NotImplementedError diff --git a/examples/speech_recognition/new/decoders/decoder.py b/examples/speech_recognition/new/decoders/decoder.py new file mode 100644 index 0000000000..b5bec8cf70 --- /dev/null +++ b/examples/speech_recognition/new/decoders/decoder.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +from fairseq.data.dictionary import Dictionary + +from .decoder_config import DecoderConfig, FlashlightDecoderConfig +from .base_decoder import BaseDecoder + + +def Decoder( + cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary +) -> BaseDecoder: + + if cfg.type == "viterbi": + from .viterbi_decoder import ViterbiDecoder + + return ViterbiDecoder(tgt_dict) + if cfg.type == "kenlm": + from .flashlight_decoder import KenLMDecoder + + return KenLMDecoder(cfg, tgt_dict) + if cfg.type == "fairseqlm": + from .flashlight_decoder import FairseqLMDecoder + + return FairseqLMDecoder(cfg, tgt_dict) + raise NotImplementedError(f"Invalid decoder name: {cfg.name}") diff --git a/examples/speech_recognition/new/decoders/decoder_config.py b/examples/speech_recognition/new/decoders/decoder_config.py new file mode 100644 index 0000000000..659eb94a9b --- /dev/null +++ b/examples/speech_recognition/new/decoders/decoder_config.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional + +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.dataclass.constants import ChoiceEnum +from omegaconf import MISSING + + +DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"]) + + +@dataclass +class DecoderConfig(FairseqDataclass): + type: DECODER_CHOICES = field( + default="viterbi", + metadata={"help": "The type of decoder to use"}, + ) + + +@dataclass +class FlashlightDecoderConfig(FairseqDataclass): + nbest: int = field( + default=1, + metadata={"help": "Number of decodings to return"}, + ) + unitlm: bool = field( + default=False, + metadata={"help": "If set, use unit language model"}, + ) + lmpath: str = field( + default=MISSING, + metadata={"help": "Language model for KenLM decoder"}, + ) + lexicon: Optional[str] = field( + default=None, + metadata={"help": "Lexicon for Flashlight decoder"}, + ) + beam: int = field( + default=50, + metadata={"help": "Number of beams to use for decoding"}, + ) + beamthreshold: float = field( + default=50.0, + metadata={"help": "Threshold for beam search decoding"}, + ) + beamsizetoken: Optional[int] = field( + default=None, metadata={"help": "Beam size to use"} + ) + wordscore: float = field( + default=-1, + metadata={"help": "Word score for KenLM decoder"}, + ) + unkweight: float = field( + default=-math.inf, + metadata={"help": "Unknown weight for KenLM decoder"}, + ) + silweight: float = field( + default=0, + metadata={"help": "Silence weight for KenLM decoder"}, + ) + lmweight: float = field( + default=2, + metadata={"help": "Weight for LM while interpolating score"}, + ) diff --git a/examples/speech_recognition/hydra/decoder.py b/examples/speech_recognition/new/decoders/flashlight_decoder.py similarity index 54% rename from examples/speech_recognition/hydra/decoder.py rename to examples/speech_recognition/new/decoders/flashlight_decoder.py index d182b95a32..8a548bdf66 100644 --- a/examples/speech_recognition/hydra/decoder.py +++ b/examples/speech_recognition/new/decoders/flashlight_decoder.py @@ -6,35 +6,39 @@ # LICENSE file in the root directory of this source tree. import gc -import itertools as it -import math import os.path as osp import warnings from collections import deque, namedtuple -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Tuple import numpy as np import torch -from examples.speech_recognition.data.replabels import unpack_replabels from fairseq import tasks from fairseq.data.dictionary import Dictionary -from fairseq.dataclass.configs import FairseqDataclass -from fairseq.dataclass.constants import ChoiceEnum from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models.fairseq_model import FairseqModel from fairseq.utils import apply_to_sample -from omegaconf import MISSING, open_dict +from omegaconf import open_dict, OmegaConf + +from typing import List + +from .decoder_config import FlashlightDecoderConfig +from .base_decoder import BaseDecoder try: - from flashlight.lib.sequence.criterion import (CpuViterbiPath, - get_data_ptr_as_bytes) - from flashlight.lib.text.decoder import (LM, CriterionType, DecodeResult, - KenLM, LexiconDecoder, - LexiconDecoderOptions, - LexiconFreeDecoder, - LexiconFreeDecoderOptions, - LMState, SmearingMode, Trie) + from flashlight.lib.text.decoder import ( + LM, + CriterionType, + DecodeResult, + KenLM, + LexiconDecoder, + LexiconDecoderOptions, + LexiconFreeDecoder, + LexiconFreeDecoderOptions, + LMState, + SmearingMode, + Trie, + ) from flashlight.lib.text.dictionary import create_word_dict, load_words except ImportError: warnings.warn( @@ -46,192 +50,13 @@ LMState = object -CRITERION_CHOICES = ChoiceEnum(["ctc", "asg"]) -DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"]) - - -@dataclass -class DecoderConfig(FairseqDataclass): - name: DECODER_CHOICES = field( - default="viterbi", - metadata={"help": "The type of decoder to use"}, - ) - nbest: int = field( - default=1, - metadata={"help": "Number of decodings to return"}, - ) - criterion: CRITERION_CHOICES = field( - default="ctc", - metadata={"help": "Criterion to use"}, - ) - asgtransitions: List[int] = field( - default=MISSING, - metadata={"help": "ASG transition indices"}, - ) - maxreplabel: int = field( - default=2, - metadata={"help": "Maximum repeated labels for ASG criterion"}, - ) - unitlm: bool = field( - default=False, - metadata={"help": "If set, use unit language model"}, - ) - lmpath: str = field( - default=MISSING, - metadata={"help": "Language model for KenLM decoder"}, - ) - lexicon: Optional[str] = field( - default=None, - metadata={"help": "Lexicon for Flashlight decoder"}, - ) - beam: int = field( - default=50, - metadata={"help": "Number of beams to use for decoding"}, - ) - beamthreshold: float = field( - default=15.0, - metadata={"help": "Threshold for beam search decoding"}, - ) - beamsizetoken: Optional[int] = field( - default=None, - metadata={"help": "Beam size to use"} - ) - wordscore: float = field( - default=1.5, - metadata={"help": "Word score for KenLM decoder"}, - ) - unkweight: float = field( - default=-math.inf, - metadata={"help": "Unknown weight for KenLM decoder"}, - ) - silweight: float = field( - default=-0.3, - metadata={"help": "Silence weight for KenLM decoder"}, - ) - lmweight: float = field( - default=1.5, - metadata={"help": "Weight for LM while interpolating score"}, - ) - unitlm: bool = field( - default=False, - metadata={"help": "If using a unit language model"}, - ) - +class KenLMDecoder(BaseDecoder): + def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(tgt_dict) -class BaseDecoder: - def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: - self.tgt_dict = tgt_dict - self.vocab_size = len(tgt_dict) self.nbest = cfg.nbest self.unitlm = cfg.unitlm - if cfg.criterion == "ctc": - self.criterion_type = CriterionType.CTC - self.blank = ( - tgt_dict.index("<ctc_blank>") - if "<ctc_blank>" in tgt_dict.indices - else tgt_dict.bos() - ) - if "<sep>" in tgt_dict.indices: - self.silence = tgt_dict.index("<sep>") - elif "|" in tgt_dict.indices: - self.silence = tgt_dict.index("|") - else: - self.silence = tgt_dict.eos() - self.asgtransitions = None - elif cfg.criterion == "asg_loss": - self.criterion_type = CriterionType.ASG - self.blank = -1 - self.silence = -1 - self.asgtransitions = cfg.asgtransitions - self.maxreplabel = cfg.maxreplabel - assert len(self.asgtransitions) == self.vocab_size ** 2 - else: - raise RuntimeError(f"unknown criterion: {cfg.criterion}") - - def generate( - self, - models: List[FairseqModel], - sample: Dict[str, Any], - **unused - ) -> List[List[Dict[str, torch.LongTensor]]]: - encoder_input = { - k: v - for k, v in sample["net_input"].items() - if k != "prev_output_tokens" - } - emissions = self.get_emissions(models, encoder_input) - return self.decode(emissions) - - def get_emissions( - self, - models: List[FairseqModel], - encoder_input: Dict[str, Any], - ) -> torch.FloatTensor: - model = models[0] - encoder_out = model(**encoder_input) - if self.criterion_type == CriterionType.CTC: - if hasattr(model, "get_logits"): - emissions = model.get_logits(encoder_out) - else: - emissions = model.get_normalized_probs( - encoder_out, log_probs=True) - elif self.criterion_type == CriterionType.ASG: - emissions = encoder_out["encoder_out"] - else: - raise ValueError("Criterion not implemented: " - f"{self.criterion_type}") - return emissions.transpose(0, 1).float().cpu().contiguous() - - def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: - idxs = (g[0] for g in it.groupby(idxs)) - if self.criterion_type == CriterionType.CTC: - idxs = filter(lambda x: x != self.blank, idxs) - elif self.criterion_type == CriterionType.ASG: - idxs = filter(lambda x: x >= 0, idxs) - idxs = unpack_replabels( - list(idxs), self.tgt_dict, self.maxreplabel) - return torch.LongTensor(list(idxs)) - - def decode( - self, - emissions: torch.FloatTensor, - ) -> List[List[Dict[str, torch.LongTensor]]]: - raise NotImplementedError - - -class ViterbiDecoder(BaseDecoder): - def decode( - self, - emissions: torch.FloatTensor, - ) -> List[List[Dict[str, torch.LongTensor]]]: - B, T, N = emissions.size() - if self.asgtransitions is None: - transitions = torch.FloatTensor(N, N).zero_() - else: - transitions = torch.FloatTensor(self.asgtransitions).view(N, N) - viterbi_path = torch.IntTensor(B, T) - workspace = torch.ByteTensor( - CpuViterbiPath.get_workspace_size(B, T, N)) - CpuViterbiPath.compute( - B, - T, - N, - get_data_ptr_as_bytes(emissions), - get_data_ptr_as_bytes(transitions), - get_data_ptr_as_bytes(viterbi_path), - get_data_ptr_as_bytes(workspace), - ) - return [ - [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] - for b in range(B) - ] - - -class KenLMDecoder(BaseDecoder): - def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: - super().__init__(cfg, tgt_dict) - if cfg.lexicon: self.lexicon = load_words(cfg.lexicon) self.word_dict = create_word_dict(self.lexicon) @@ -245,12 +70,10 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: - spelling_idxs = [ - tgt_dict.index(token) - for token in spelling - ] - assert tgt_dict.unk() not in spelling_idxs, \ - f"{word} {spelling} {spelling_idxs}" + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{word} {spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) @@ -263,12 +86,9 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, - criterion_type=self.criterion_type, + criterion_type=CriterionType.CTC, ) - if self.asgtransitions is None: - self.asgtransitions = [] - self.decoder = LexiconDecoder( self.decoder_opts, self.trie, @@ -276,7 +96,7 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: self.silence, self.blank, self.unk_word, - self.asgtransitions, + [], self.unitlm, ) else: @@ -292,7 +112,7 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, - criterion_type=self.criterion_type, + criterion_type=CriterionType.CTC, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] @@ -309,16 +129,18 @@ def decode( results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[: self.nbest] - hypos.append([ - { - "tokens": self.get_tokens(result.tokens), - "score": result.score, - "words": [ - self.word_dict.get_entry(x) - for x in result.words if x >= 0 - ], - } for result in nbest_results - ]) + hypos.append( + [ + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "words": [ + self.word_dict.get_entry(x) for x in result.words if x >= 0 + ], + } + for result in nbest_results + ] + ) return hypos @@ -328,7 +150,7 @@ def decode( "prefix", "incremental_state", "probs", - ] + ], ) @@ -343,7 +165,8 @@ def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None: self.save_incremental = False # this currently does not work properly self.max_cache = 20_000 - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() model.make_generation_fast_() @@ -355,14 +178,11 @@ def start(self, start_with_nothing: bool) -> LMState: prefix = torch.LongTensor([[self.dictionary.eos()]]) incremental_state = {} if self.save_incremental else None with torch.no_grad(): - res = self.model( - prefix.cuda(), incremental_state=incremental_state) - probs = self.model.get_normalized_probs( - res, log_probs=True, sample=None) + res = self.model(prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) if incremental_state is not None: - incremental_state = apply_to_sample( - lambda x: x.cpu(), incremental_state) + incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) self.states[state] = FairseqLMState( prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() ) @@ -425,8 +245,7 @@ def trim_cache(targ_size: int) -> None: ) curr_state = FairseqLMState( - curr_state.prefix, new_incremental_state, probs[0, -1].cpu( - ).numpy() + curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() ) if not no_cache: @@ -467,8 +286,11 @@ def empty_cache(self) -> None: class FairseqLMDecoder(BaseDecoder): - def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: - super().__init__(cfg, tgt_dict) + def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: + super().__init__(tgt_dict) + + self.nbest = cfg.nbest + self.unitlm = cfg.unitlm self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} @@ -480,6 +302,9 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) + if not OmegaConf.is_dict(lm_args): + lm_args = OmegaConf.create(lm_args) + with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) @@ -502,16 +327,13 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: score = 0 else: word_idx = self.word_dict.index(word) - _, score = self.lm.score( - start_state, word_idx, no_cache=True) + _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: - spelling_idxs = [ - tgt_dict.index(token) - for token in spelling - ] - assert tgt_dict.unk() not in spelling_idxs, \ - f"{spelling} {spelling_idxs}" + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) @@ -524,12 +346,9 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, - criterion_type=self.criterion_type, + criterion_type=CriterionType.CTC, ) - if self.asgtransitions is None: - self.asgtransitions = [] - self.decoder = LexiconDecoder( self.decoder_opts, self.trie, @@ -537,7 +356,7 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: self.silence, self.blank, self.unk_word, - self.asgtransitions, + [], self.unitlm, ) else: @@ -553,7 +372,7 @@ def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, - criterion_type=self.criterion_type, + criterion_type=CriterionType.CTC, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] @@ -574,7 +393,8 @@ def make_hypo(result: DecodeResult) -> Dict[str, Any]: if self.lexicon: hypo["words"] = [ self.idx_to_wrd[x] if self.unitlm else self.word_dict[x] - for x in result.words if x >= 0 + for x in result.words + if x >= 0 ] return hypo @@ -582,18 +402,8 @@ def make_hypo(result: DecodeResult) -> Dict[str, Any]: emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) - nbest_results = results[:self.nbest] + nbest_results = results[: self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos - - -def Decoder(cfg: DecoderConfig, tgt_dict: Dictionary) -> BaseDecoder: - if cfg.name == "viterbi": - return ViterbiDecoder(cfg, tgt_dict) - if cfg.name == "kenlm": - return KenLMDecoder(cfg, tgt_dict) - if cfg.name == "fairseqlm": - return FairseqLMDecoder(cfg, tgt_dict) - raise NotImplementedError(f"Invalid decoder name: {cfg.name}") diff --git a/examples/speech_recognition/new/decoders/viterbi_decoder.py b/examples/speech_recognition/new/decoders/viterbi_decoder.py new file mode 100644 index 0000000000..b1c47868fa --- /dev/null +++ b/examples/speech_recognition/new/decoders/viterbi_decoder.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from typing import List, Dict + +from .base_decoder import BaseDecoder + + +class ViterbiDecoder(BaseDecoder): + def decode( + self, + emissions: torch.FloatTensor, + ) -> List[List[Dict[str, torch.LongTensor]]]: + def get_pred(e): + toks = e.argmax(dim=-1).unique_consecutive() + return toks[toks != self.blank] + + return [[{"tokens": get_pred(x), "score": 0}] for x in emissions] diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/new/infer.py similarity index 80% rename from examples/speech_recognition/hydra/infer.py rename to examples/speech_recognition/new/infer.py index 1b49823553..79afbc426d 100644 --- a/examples/speech_recognition/hydra/infer.py +++ b/examples/speech_recognition/new/infer.py @@ -10,25 +10,32 @@ import os import shutil import sys -from dataclasses import dataclass, field +from dataclasses import dataclass, field, is_dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import editdistance import torch import torch.distributed as dist -from examples.speech_recognition.hydra.decoder import Decoder, DecoderConfig -from fairseq import (checkpoint_utils, distributed_utils, progress_bar, tasks, - utils) +from examples.speech_recognition.new.decoders.decoder_config import ( + DecoderConfig, + FlashlightDecoderConfig, +) +from examples.speech_recognition.new.decoders.decoder import Decoder +from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils from fairseq.data.data_utils import post_process -from fairseq.dataclass.configs import (CheckpointConfig, CommonConfig, - CommonEvalConfig, DatasetConfig, - DistributedTrainingConfig, - FairseqDataclass, GenerationConfig) +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + FairseqDataclass, +) from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.progress_bar import BaseProgressBar from fairseq.models.fairseq_model import FairseqModel -from omegaconf import MISSING, OmegaConf +from omegaconf import OmegaConf import hydra from hydra.core.config_store import ConfigStore @@ -41,20 +48,17 @@ @dataclass -class DecodingConfig(FairseqDataclass): - exp_dir: str = field( - default=MISSING, - metadata={"help": "Path to the experiment directory"}, - ) +class DecodingConfig(DecoderConfig, FlashlightDecoderConfig): unique_wer_file: bool = field( default=False, metadata={"help": "If set, use a unique file for storing WER"}, ) - write_sentences: bool = field( - default=True, - metadata={"help": "If set, write hypothesis and reference sentences"}, + results_path: Optional[str] = field( + default=None, + metadata={ + "help": "If set, write hypothesis and reference sentences into this directory" + }, ) - decoder: DecoderConfig = DecoderConfig() @dataclass @@ -64,9 +68,14 @@ class InferConfig(FairseqDataclass): common: CommonConfig = CommonConfig() common_eval: CommonEvalConfig = CommonEvalConfig() checkpoint: CheckpointConfig = CheckpointConfig() - generation: GenerationConfig = GenerationConfig() distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() dataset: DatasetConfig = DatasetConfig() + is_ax: bool = field( + default=False, + metadata={ + "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" + }, + ) def reset_logging(): @@ -85,6 +94,8 @@ def reset_logging(): class InferenceProcessor: + cfg: InferConfig + def __init__(self, cfg: InferConfig) -> None: self.cfg = cfg self.task = tasks.setup_task(cfg.task) @@ -98,7 +109,7 @@ def __init__(self, cfg: InferConfig) -> None: self.cfg.dataset.gen_subset, task_cfg=saved_cfg.task, ) - self.generator = Decoder(cfg.decoding.decoder, self.tgt_dict) + self.generator = Decoder(cfg.decoding, self.tgt_dict) self.gen_timer = StopwatchMeter() self.wps_meter = TimeMeter() self.num_sentences = 0 @@ -113,7 +124,7 @@ def __init__(self, cfg: InferConfig) -> None: self.progress_bar = self.build_progress_bar() def __enter__(self) -> "InferenceProcessor": - if self.cfg.decoding.write_sentences: + if self.cfg.decoding.results_path is not None: self.hypo_words_file = self.get_res_file("hypo.word") self.hypo_units_file = self.get_res_file("hypo.units") self.ref_words_file = self.get_res_file("ref.word") @@ -121,7 +132,7 @@ def __enter__(self) -> "InferenceProcessor": return self def __exit__(self, *exc) -> bool: - if self.cfg.decoding.write_sentences: + if self.cfg.decoding.results_path is not None: self.hypo_words_file.close() self.hypo_units_file.close() self.ref_words_file.close() @@ -145,6 +156,7 @@ def print(self, *args, **kwargs): self.progress_bar.print(*args, **kwargs) def get_res_file(self, fname: str) -> None: + fname = os.path.join(self.cfg.decoding.results_path, fname) if self.data_parallel_world_size > 1: fname = f"{fname}.{self.data_parallel_rank}" return open(fname, "w", buffering=1) @@ -156,7 +168,9 @@ def merge_shards(self) -> None: num_shards = self.data_parallel_world_size if self.data_parallel_world_size > 1: + def merge_shards_with_root(fname: str) -> None: + fname = os.path.join(self.cfg.decoding.results_path, fname) logger.info("Merging %s on shard %d", fname, shard_id) base_fpath = Path(f"{fname}.0") with open(base_fpath, "a") as out_file: @@ -180,11 +194,7 @@ def merge_shards_with_root(fname: str) -> None: dist.barrier() def optimize_model(self, model: FairseqModel) -> None: - gcfg = self.cfg.generation - model.make_generation_fast_( - beamable_mm_beam_size=None if gcfg.no_beamable_mm else gcfg.beam, - need_attn=gcfg.print_alignment, - ) + model.make_generation_fast_() if self.cfg.common.fp16: model.half() if not self.cfg.common.cpu: @@ -193,7 +203,7 @@ def optimize_model(self, model: FairseqModel) -> None: def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]: arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) models, saved_cfg = checkpoint_utils.load_model_ensemble( - utils.split_paths(self.cfg.common_eval.path), + utils.split_paths(self.cfg.common_eval.path, separator="\\"), arg_overrides=arg_overrides, task=self.task, suffix=self.cfg.checkpoint.checkpoint_suffix, @@ -268,21 +278,24 @@ def process_sentence( if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: - hyp_words = post_process(hyp_pieces, - self.cfg.common_eval.post_process) + hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process) # Processes target. target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) - tgt_words = post_process(tgt_pieces, - self.cfg.common_eval.post_process) + tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process) - if self.cfg.decoding.write_sentences: + if self.cfg.decoding.results_path is not None: print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) + if not self.cfg.common_eval.quiet: + logger.info(f"HYPO: {hyp_words}") + logger.info(f"REF: {tgt_words}") + logger.info("---------------------") + hyp_words, tgt_words = hyp_words.split(), tgt_words.split() return editdistance.eval(hyp_words, tgt_words), len(tgt_words) @@ -315,11 +328,15 @@ def process_sample(self, sample: Dict[str, Any]) -> None: self.num_sentences += sample["id"].numel() def log_generation_time(self) -> None: - logger.info("Processed %d sentences (%d tokens) in %.1fs %.2f " - "sentences per second, %.2f tokens per second)", - self.num_sentences, self.gen_timer.n, self.gen_timer.sum, - self.num_sentences / self.gen_timer.sum, - 1.0 / self.gen_timer.avg) + logger.info( + "Processed %d sentences (%d tokens) in %.1fs %.2f " + "sentences per second, %.2f tokens per second)", + self.num_sentences, + self.gen_timer.n, + self.gen_timer.sum, + self.num_sentences / self.gen_timer.sum, + 1.0 / self.gen_timer.avg, + ) def parse_wer(wer_file: Path) -> float: @@ -329,12 +346,16 @@ def parse_wer(wer_file: Path) -> float: def get_wer_file(cfg: InferConfig) -> Path: """Hashes the decoding parameters to a unique file ID.""" + base_path = "wer" + if cfg.decoding.results_path is not None: + base_path = os.path.join(cfg.decoding.results_path, base_path) + if cfg.decoding.unique_wer_file: yaml_str = OmegaConf.to_yaml(cfg.decoding) fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) - return Path(f"wer.{fid % 1000000}") + return Path(f"{base_path}.{fid % 1000000}") else: - return Path("wer") + return Path(base_path) def main(cfg: InferConfig) -> float: @@ -356,8 +377,6 @@ def main(cfg: InferConfig) -> float: cfg.dataset.max_tokens = 4000000 if not cfg.common.cpu and not torch.cuda.is_available(): raise ValueError("CUDA not found; set `cpu=True` to run without CUDA") - if cfg.generation.nbest > 1: - raise ValueError("`nbest > 1` not implemented yet") with InferenceProcessor(cfg) as processor: for sample in processor: @@ -365,7 +384,7 @@ def main(cfg: InferConfig) -> float: processor.log_generation_time() - if cfg.decoding.write_sentences: + if cfg.decoding.results_path is not None: processor.merge_shards() errs_t, leng_t = processor.total_errors, processor.total_length @@ -381,17 +400,19 @@ def main(cfg: InferConfig) -> float: if distributed_utils.is_master(cfg.distributed_training): with open(wer_file, "w") as f: - f.write(( - f"WER: {wer}\n" - f"err / num_ref_words = {errs_t} / {leng_t}\n\n" - f"{yaml_str}" - )) + f.write( + ( + f"WER: {wer}\n" + f"err / num_ref_words = {errs_t} / {leng_t}\n\n" + f"{yaml_str}" + ) + ) return wer @hydra.main(config_path=config_path, config_name="infer") -def hydra_main(cfg: InferConfig) -> None: +def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) cfg = OmegaConf.create(container) OmegaConf.set_struct(cfg, True) @@ -399,8 +420,7 @@ def hydra_main(cfg: InferConfig) -> None: if cfg.common.reset_logging: reset_logging() - logger.info("Config:\n%s", OmegaConf.to_yaml(cfg)) - logger.info("Working directory: %s", Path.cwd()) + # logger.info("Config:\n%s", OmegaConf.to_yaml(cfg)) wer = float("inf") try: @@ -419,13 +439,18 @@ def hydra_main(cfg: InferConfig) -> None: logger.error("Crashed! %s", str(e)) logger.info("Word error rate: %.4f", wer) + if cfg.is_ax: + return wer, None + return wer def cli_main() -> None: try: - from hydra._internal.utils import \ - get_args # pylint: disable=import-outside-toplevel + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + cfg_name = get_args().config_name or "infer" except ImportError: logger.warning("Failed to get config name from hydra args") @@ -435,12 +460,9 @@ def cli_main() -> None: cs.store(name=cfg_name, node=InferConfig) for k in InferConfig.__dataclass_fields__: - v = InferConfig.__dataclass_fields__[k].default - try: + if is_dataclass(InferConfig.__dataclass_fields__[k].type): + v = InferConfig.__dataclass_fields__[k].default cs.store(name=k, node=v) - except BaseException: - logger.error(f"{k} - {v}") - raise hydra_main() # pylint: disable=no-value-for-parameter diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 543e796da3..10e3618382 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -66,7 +66,11 @@ class CtcCriterionConfig(FairseqDataclass): class CtcCriterion(FairseqCriterion): def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): super().__init__(task) - self.blank_idx = task.target_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0 + self.blank_idx = ( + task.target_dictionary.index(task.blank_symbol) + if hasattr(task, "blank_symbol") + else 0 + ) self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() self.post_process = cfg.post_process @@ -111,8 +115,13 @@ def forward(self, model, sample, reduce=True): if "src_lengths" in sample["net_input"]: input_lengths = sample["net_input"]["src_lengths"] else: - non_padding_mask = ~net_output["padding_mask"] - input_lengths = non_padding_mask.long().sum(-1) + if net_output["padding_mask"] is not None: + non_padding_mask = ~net_output["padding_mask"] + input_lengths = non_padding_mask.long().sum(-1) + else: + input_lengths = lprobs.new_full( + (lprobs.size(1),), lprobs.size(0), dtype=torch.long + ) pad_mask = (sample["target"] != self.pad_idx) & ( sample["target"] != self.eos_idx diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 27c9006fdb..89206125d1 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -9,7 +9,7 @@ import os import re from argparse import ArgumentError, ArgumentParser, Namespace -from dataclasses import _MISSING_TYPE, MISSING +from dataclasses import _MISSING_TYPE, MISSING, is_dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type @@ -457,7 +457,19 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): cfg[k] = overrides[k] -def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): +def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=True): + if remove_missing: + + if is_dataclass(dc): + target_keys = set(dc.__dataclass_fields__.keys()) + else: + target_keys = set(dc.keys()) + + with open_dict(cfg): + for k in list(cfg.keys()): + if k not in target_keys: + del cfg[k] + merged_cfg = OmegaConf.merge(dc, cfg) merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index abae9d1ab3..405d1e613a 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -31,7 +31,6 @@ LayerNorm, PositionalEmbedding, TransformerDecoderLayer, - SamePad, ) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 71cefcfcaa..ce454e12b7 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -12,9 +12,8 @@ from argparse import Namespace from dataclasses import dataclass, field -import numpy as np from typing import Optional, Any -from omegaconf import MISSING, II +from omegaconf import MISSING, II, OmegaConf from fairseq.data import ( AddTargetDataset, @@ -44,6 +43,27 @@ def __call__(self, label): ) +@dataclass +class InferredW2vConfig: + # The following are needed to precompute mask and mask channel indices + # before model's forward. + mask_length: Optional[int] = II("model.mask_length") + mask_prob: Optional[float] = II("model.mask_prob") + mask_selection: Optional[str] = II("model.mask_selection") + mask_other: Optional[float] = II("model.mask_other") + no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") + mask_min_space: Optional[int] = II("model.mask_min_space") + mask_channel_length: Optional[int] = II("model.mask_channel_length") + mask_channel_prob: Optional[float] = II("model.mask_channel_prob") + mask_channel_selection: Optional[str] = II("model.mask_channel_selection") + mask_channel_other: Optional[float] = II("model.mask_channel_other") + no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") + mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") + + conv_feature_layers: Optional[str] = II("model.conv_feature_layers") + encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") + + @dataclass class AudioPretrainingConfig(FairseqDataclass): data: str = field(default=MISSING, metadata={"help": "path to data directory"}) @@ -114,23 +134,13 @@ class AudioPretrainingConfig(FairseqDataclass): "help": "flag to compute mask indices in data preparation.", }, ) - # The following are needed to precompute mask and mask channel indices - # before model's forward. - mask_length: Optional[int] = II("model.mask_length") - mask_prob: Optional[float] = II("model.mask_prob") - mask_selection: Optional[str] = II("model.mask_selection") - mask_other: Optional[float] = II("model.mask_other") - no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") - mask_min_space: Optional[int] = II("model.mask_min_space") - mask_channel_length: Optional[int] = II("model.mask_channel_length") - mask_channel_prob: Optional[float] = II("model.mask_channel_prob") - mask_channel_selection: Optional[str] = II("model.mask_channel_selection") - mask_channel_other: Optional[float] = II("model.mask_channel_other") - no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") - mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") - conv_feature_layers: Optional[str] = II("model.conv_feature_layers") - encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") + inferred_w2v_config: Optional[InferredW2vConfig] = field( + default=None, + metadata={ + "help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)", + }, + ) tpu: bool = II("common.tpu") @@ -170,23 +180,12 @@ def load_target_dictionary(self): def _get_mask_precompute_kwargs(self, cfg): if self.cfg.precompute_mask_indices or self.cfg.tpu: - args = [ - "mask_length", - "mask_prob", - "mask_selection", - "mask_other", - "no_mask_overlap", - "mask_min_space", - "mask_channel_length", - "mask_channel_prob", - "mask_channel_selection", - "mask_channel_other", - "no_mask_channel_overlap", - "mask_channel_min_space", - "encoder_embed_dim", - "conv_feature_layers", - ] - return {arg: cfg[arg] for arg in args} + assert ( + cfg.inferred_w2v_config is not None + ), "inferred_w2v_config must be set" + return OmegaConf.to_container( + cfg.inferred_w2v_config, resolve=True, enum_to_str=True + ) else: return {} diff --git a/fairseq/utils.py b/fairseq/utils.py index d0ce16ae6b..bf5727edfd 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -57,9 +57,9 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, argument) -def split_paths(paths: str) -> List[str]: +def split_paths(paths: str, separator=os.pathsep) -> List[str]: return ( - paths.split(os.pathsep) + paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) ) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index f0d983ee6b..22b93e9a6a 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -14,6 +14,7 @@ from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar +from fairseq.utils import reset_logging from omegaconf import DictConfig @@ -32,6 +33,8 @@ def main(cfg: DictConfig, override_args=None): utils.import_user_module(cfg.common) + reset_logging() + assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" From c8223e350cfc616bb47196151d1223683e483b6d Mon Sep 17 00:00:00 2001 From: Nicola De Cao <nicola.decao@uva.nl> Date: Wed, 26 May 2021 18:20:04 -0700 Subject: [PATCH 597/707] fixing prefix_allowed_tokens_fn (#3276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes the use of `prefix_allowed_tokens_fn` in generation. It was working for `fairseq==0.9.0` (see https://github.com/facebookresearch/GENRE) but with the current version is broken. ## PR review Anyone in the community is free to review the PR once the tests have passed. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3276 Reviewed By: alexeib Differential Revision: D26725494 Pulled By: myleott fbshipit-source-id: ce3da725f36352687e5cb5d62a59b4c89ce0b0bc --- fairseq/hub_utils.py | 7 ++++++- fairseq/tasks/fairseq_task.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 7de2e2b0d4..d74470d2ec 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -151,6 +151,7 @@ def generate( verbose: bool = False, skip_invalid_size_inputs=False, inference_step_args=None, + prefix_allowed_tokens_fn=None, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: @@ -164,7 +165,11 @@ def generate( gen_args.beam = beam for k, v in kwargs.items(): setattr(gen_args, k, v) - generator = self.task.build_generator(self.models, gen_args) + generator = self.task.build_generator( + self.models, + gen_args, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + ) inference_step_args = inference_step_args or {} results = [] diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index e30b2cd985..fbec9bb2a5 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -341,8 +341,32 @@ def build_criterion(self, cfg: DictConfig): return criterions.build_criterion(cfg, self) def build_generator( - self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer @@ -369,7 +393,8 @@ def build_generator( match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) - prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if ( sum( int(cond) From 9497ae3cfb04bb6ec4735758bbe8dc767276932c Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines <mandeep.baines@gmail.com> Date: Thu, 27 May 2021 12:14:40 -0700 Subject: [PATCH 598/707] disable raise_if_valid_subsets_unintentionally_ignored check for dummy tasks (#3552) Summary: Fixes the following crash: ```python Traceback (most recent call last): File "/private/home/msb/.conda/envs/fairseq-20210102-pt181/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, *args) File "/private/home/msb/code/fairseq/fairseq/distributed/utils.py", line 328, in distributed_main main(cfg, **kwargs) File "/private/home/msb/code/fairseq/fairseq_cli/train.py", line 117, in main data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) File "/private/home/msb/code/fairseq/fairseq/data/data_utils.py", line 584, in raise_if_valid_subsets_unintentionally_ignored other_paths = _find_extra_valid_paths(train_cfg.task.data) AttributeError: 'Namespace' object has no attribute 'data' ``` Pull Request resolved: https://github.com/pytorch/fairseq/pull/3552 Reviewed By: sshleifer Differential Revision: D28667773 Pulled By: msbaines fbshipit-source-id: bc9a633184105dbae0cce58756bb1d379b03980a --- fairseq/data/data_utils.py | 1 + tests/test_valid_subset_checks.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 70a4086cd0..b3de57681e 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -583,6 +583,7 @@ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: train_cfg.dataset.ignore_unused_valid_subsets or train_cfg.dataset.combine_valid_subsets or train_cfg.dataset.disable_validation + or not hasattr(train_cfg.task, "data") ): return other_paths = _find_extra_valid_paths(train_cfg.task.data) diff --git a/tests/test_valid_subset_checks.py b/tests/test_valid_subset_checks.py index ab778fb3fa..8da79cfb82 100644 --- a/tests/test_valid_subset_checks.py +++ b/tests/test_valid_subset_checks.py @@ -10,18 +10,20 @@ def make_lm_config( - data_dir, + data_dir=None, extra_flags=None, task="language_modeling", arch="transformer_lm_gpt2_tiny", ): + task_args = [task] + if data_dir is not None: + task_args += [data_dir] train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", - task, - data_dir, + *task_args, "--arch", arch, "--optimizer", @@ -97,6 +99,10 @@ def test_disable_validation(self): self._test_case([], ["--disable-validation"]) self._test_case(["valid", "valid1"], ["--disable-validation"]) + def test_dummy_task(self): + cfg = make_lm_config(task="dummy_lm") + raise_if_valid_subsets_unintentionally_ignored(cfg) + class TestCombineValidSubsets(unittest.TestCase): def _train(self, extra_flags): From 19793a78e5cd9aa0de427065f125c93100a942ea Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Thu, 27 May 2021 14:28:23 -0700 Subject: [PATCH 599/707] Remove duplicate registration of ManifoldPathHandler Summary: `ManifoldPathHandler` is automatically registered with `IOPathManager` upon importing the latter (see D27960781). Therefore it is no longer necessary to register `ManifoldPathManager` in fairseq, as introduced by D27809504 (https://github.com/pytorch/fairseq/commit/3a90a859d4dfdbf13f15399be12a1928aa2c54ff). Reviewed By: sujitoc Differential Revision: D28735316 fbshipit-source-id: 03e246dd17ba9f2a9a81dd4e741cce88f26feedd --- fairseq/file_io.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 6266e6a1d8..dba663d4aa 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -29,14 +29,6 @@ "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." ) - try: - # [FB only] Add extra FB only PathHandlers for PathManager - import fairseq.fb_file_io as fb_file_io - - fb_file_io.update_path_manager(IOPathManager) - except ImportError: - pass - except ImportError: IOPathManager = None From 62ccebaf70cd8d392e178b67e0af661da0431a20 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Mon, 31 May 2021 01:13:41 -0700 Subject: [PATCH 600/707] =?UTF-8?q?fix=20'=5Fpickle.PicklingError:=20Can't?= =?UTF-8?q?=20pickle=20<enum=20'Choices'>:=20attribute=20=E2=80=A6=20(#191?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: for whatever reason, checkpoints are failing to save because choiceenum can't be pickled again (could be env specific). this should permanently resolve it by converting choice enum to string in the config before saving Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1915 Reviewed By: arbabu123 Differential Revision: D28784506 Pulled By: alexeib fbshipit-source-id: 17843cfa00e8e624eb06262df8e1b71b062a237b --- fairseq/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d1d08025f6..64c6fabed6 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -369,7 +369,7 @@ def state_dict(self): state_dict = { "args": None, # legacy "cfg": ( - OmegaConf.to_container(self.cfg) + OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True) if OmegaConf.is_config(self.cfg) else self.cfg ), From c47a9b2eef0f41b0564c8daf52cb82ea97fc6548 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Tue, 1 Jun 2021 16:42:48 -0700 Subject: [PATCH 601/707] fix #3574 (#1921) Summary: support pre-hydra w2v models Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1921 Reviewed By: arbabu123 Differential Revision: D28807630 Pulled By: alexeib fbshipit-source-id: 0fc8bcda12cf677e909d88678f235bfdeb50e726 --- fairseq/tasks/audio_pretraining.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index ce454e12b7..c642ff5226 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -303,7 +303,7 @@ def build_model(self, model_cfg: FairseqDataclass): else: self.tokenizer = None - actualized_cfg = getattr(model, "cfg") + actualized_cfg = getattr(model, "cfg", None) if actualized_cfg is not None: if "w2v_args" in actualized_cfg: model_cfg.w2v_args = actualized_cfg.w2v_args From 4950c56f461db1646872159b2c470fe57ae72c69 Mon Sep 17 00:00:00 2001 From: Henry Hu <henryhu6@fb.com> Date: Thu, 3 Jun 2021 16:19:27 -0700 Subject: [PATCH 602/707] Add export flag to transform, so LayerNorm can be TorchScripted. Summary: Previously on cuda, LayerNorm would always default to FusedLayerNorm, which could not be exported. Add export flag, so torch.nn.LayerNorm would be used. Reviewed By: myleott, mikekgfb, kpuatfb Differential Revision: D28858633 fbshipit-source-id: 58dd4945f596b2bcc94a6b74356bd9fd3c73ca1a --- fairseq/models/transformer.py | 12 ++++++------ fairseq/modules/transformer_layer.py | 10 ++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index b7b8783fa2..f4f6bea27b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -378,9 +378,9 @@ def __init__(self, args, dictionary, embed_tokens): if not args.no_token_positional_embeddings else None ) - + export = getattr(args, "export", False) if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim) + self.layernorm_embedding = LayerNorm(embed_dim, export=export) else: self.layernorm_embedding = None @@ -403,7 +403,7 @@ def __init__(self, args, dictionary, embed_tokens): self.num_layers = len(self.layers) if args.encoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim) + self.layer_norm = LayerNorm(embed_dim, export=export) else: self.layer_norm = None @@ -702,9 +702,9 @@ def __init__( if not args.no_token_positional_embeddings else None ) - + export = getattr(args, "export", False) if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim) + self.layernorm_embedding = LayerNorm(embed_dim, export=export) else: self.layernorm_embedding = None @@ -725,7 +725,7 @@ def __init__( if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False ): - self.layer_norm = LayerNorm(embed_dim) + self.layer_norm = LayerNorm(embed_dim, export=export) else: self.layer_norm = None diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index f9ada37bde..4f9ea22a9b 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -36,7 +36,8 @@ def __init__(self, args): self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) - self.self_attn_layer_norm = LayerNorm(self.embed_dim) + export = getattr(args, "export", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) @@ -64,7 +65,7 @@ def __init__(self, args): self.quant_noise_block_size, ) - self.final_layer_norm = LayerNorm(self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( @@ -207,10 +208,7 @@ def __init__( ) self.normalize_before = args.decoder_normalize_before - # use layerNorm rather than FusedLayerNorm for exporting. - # char_inputs can be used to determint this. - # TODO remove this once we update apex with the fix - export = getattr(args, "char_inputs", False) + export = getattr(args, "export", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: From 50f3766a9d8413a5875a9b2b599b233b5690c7f2 Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Thu, 3 Jun 2021 17:48:17 -0700 Subject: [PATCH 603/707] TALNet: Use batch size as sample_size Summary: `Wav2VecCriterion` uses a sample_size for two purposes: 1. It weights the extra loss by multiplying it by sample_size; 2. It divides the total loss by sample_size before reporting them in the learning curves. By default, when using the binary cross-entropy loss (`infonce = False`), `Wav2VecCriterion` uses the number of 1's in the label matrix as sample_size. For TALNet, because each recording may have multiple labels, this sample_size is not a constant across batches. TALNet also uses a consistency loss between the predictions on two different copies of augmented data as an extra loss, and it is undesirable for the weight of the extra loss to vary from batch to batch. This diff adds a field "sample_size" to the batch in the `AcousticEventCollater`, and makes it equal to the batch size (number or recordings in a batch). Because the extra loss is multiplied by sample_size in `Wav2VecCriterion`, this diff also divides the consistency loss by the batch size in the `forward` method of `TALNetModel`. This diff also adds a unit test for the consistency loss. Reviewed By: alexeib Differential Revision: D28728699 fbshipit-source-id: dda1f2a1b02e49b894842c8990218b5fe92d0330 --- fairseq/criterions/wav2vec_criterion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 521d0cf1ad..a5048fdb4a 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -82,7 +82,7 @@ def forward(self, model, sample, reduce=True): ) loss = (loss * mi).sum() if reduce else (loss * mi) - if 'sample_size' in sample and self.infonce: + if 'sample_size' in sample: sample_size = sample['sample_size'] elif 'mask_indices' in sample['net_input']: sample_size = sample['net_input']['mask_indices'].sum() From 3084b812beb72a880b6ddb5d9076ead60b7232d6 Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Thu, 3 Jun 2021 17:48:17 -0700 Subject: [PATCH 604/707] Teacher-student learning for TALNet Summary: This diff implements teacher-student learning for TALNet. Three classes take part in the teacher-student learning: * The task loads the teacher models; * The model generates predictions using the teacher models, and mixes them with the original targets; * The `Wav2VecCriterion` reads the mixed targets to compute the loss. However, it still uses the original targets to compute the MAP and MAUC metrics. There are two types of teachers: * Static teachers: a file that stores predictions on training data which have been produced by running a model offline; * Dynamic teachers: model files that are loaded at the beginning of training and executed on the fly to produce predictions. We actually no longer use static teachers. The code about static teachers are copied over from the `KnowledgeDistillationBinaryCrossEntropyCriterion` class. This class will be cleaned up in D28728718. The teacher models are stored in the task object, and will not be saved into checkpoints. Reviewed By: alexeib Differential Revision: D28728707 fbshipit-source-id: 0fcfc00db2e7194a6f7ee687cad9fa72e82a028b --- fairseq/checkpoint_utils.py | 9 +++++++++ fairseq/criterions/wav2vec_criterion.py | 8 +++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ecc45f4351..402921744d 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -611,6 +611,15 @@ def _upgrade_state_dict(state): and len(state["args"].data) > 0 ): state["args"].data = state["args"].data[0] + # remove keys in state["args"] related to teacher-student learning + for key in [ + "static_teachers", + "static_teacher_weights", + "dynamic_teachers", + "dynamic_teacher_weights", + ]: + if key in state["args"]: + delattr(state["args"], key) state["cfg"] = convert_namespace_to_omegaconf(state["args"]) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index a5048fdb4a..e04786cc3b 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -121,7 +121,13 @@ def forward(self, model, sample, reduce=True): logging_output["logits"] = logits.cpu().numpy() elif lk == "target": if not self.training: - logging_output["target"] = target.cpu().numpy() + # If the targets have been mixed with the predictions of + # teacher models, find the original targets + if hasattr(model, "get_original_targets"): + original_target = model.get_original_targets(sample, net_output) + else: + original_target = target + logging_output["target"] = original_target.cpu().numpy() elif lk in net_output: value = net_output[lk] if not is_xla_tensor(value): From 45d8fefaa6871afbb747e5e65ba58b8f9fda37fe Mon Sep 17 00:00:00 2001 From: Mandeep Singh Baines <mandeep.baines@gmail.com> Date: Fri, 4 Jun 2021 11:23:32 -0700 Subject: [PATCH 605/707] fix logging when running single-process (#3592) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In file_io.py, there is a logging message that happens in the global scope. This logging message can be invoked before calling logging.basicConfig() in fairseq_cli/train.py resulting in that call becoming a no-op. This was causing the loglevel to remain at WARNING. Fix is to call logging.basicConfig() before import-ing any fairseq libraries that may do logging in global scope. Verified that I logging.info messages are now visible after applying this PR. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3592 Reviewed By: sujitoc Differential Revision: D28900871 Pulled By: msbaines fbshipit-source-id: ff5393aa7c5e4cbec168ff0b846da048de76cdbc --- fairseq_cli/train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index a1b7cb58e2..8347587313 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -14,6 +14,15 @@ import sys from typing import Dict, Optional, Any, List, Tuple, Callable +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.train") + import numpy as np import torch from fairseq import ( @@ -35,13 +44,6 @@ from omegaconf import DictConfig, OmegaConf -logging.basicConfig( - format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=os.environ.get("LOGLEVEL", "INFO").upper(), - stream=sys.stdout, -) -logger = logging.getLogger("fairseq_cli.train") def main(cfg: FairseqConfig) -> None: From fc391ff6974969649ec94dc18eb7795de66bda4f Mon Sep 17 00:00:00 2001 From: Yun Wang <yunwang@fb.com> Date: Fri, 4 Jun 2021 16:19:04 -0700 Subject: [PATCH 606/707] Fix loading some TALNet models Summary: D28728718 cleaned up the "kd_binary_cross_entropy" criterion, but this caused loading old models trained with this criterion to fail. This diff replaces the "kd_binary_cross_entropy" criterion with the "wav2vec" criterion when loading models, and fixes this error. It also removes the "log_keys" argument if it's `None`. Some criteria (e.g. wav2vec) require this argument to be a list, and will supply a default value of `[]` when it's absent. The presence of the `None` value prevents the use of this default value and causes an error. Differential Revision: D28901263 fbshipit-source-id: 9b33aed35e76d2c734d1d4e2cbca1ff193a8c920 --- fairseq/checkpoint_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 402921744d..80c797bcdd 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -16,7 +16,7 @@ from random import randint import torch -from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig +from fairseq.dataclass.configs import CheckpointConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, overwrite_args_by_name, @@ -24,7 +24,7 @@ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import Container, DictConfig, open_dict, OmegaConf +from omegaconf import DictConfig, open_dict, OmegaConf logger = logging.getLogger(__name__) @@ -512,7 +512,6 @@ def _torch_persistent_save(obj, f): def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" - from fairseq import models, registry, tasks # add optimizer_history if "optimizer_history" not in state: @@ -586,12 +585,18 @@ def _upgrade_state_dict(state): if hasattr(state["args"], "min_lr"): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr - # binary_cross_entropy => wav2vec criterion + # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion if ( hasattr(state["args"], "criterion") - and state["args"].criterion == "binary_cross_entropy" + and state["args"].criterion in [ + "binary_cross_entropy", + "kd_binary_cross_entropy", + ] ): state["args"].criterion = "wav2vec" + # remove log_keys if it's None (criteria will supply a default value of []) + if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: + delattr(state["args"], "log_keys") # speech_pretraining => audio pretraining if ( hasattr(state["args"], "task") From 2fd9d8a972794ba919174baf0d1828a5a4c626f3 Mon Sep 17 00:00:00 2001 From: Naman Goyal <namangoyal@learnfair1299.h2.fair> Date: Mon, 7 Jun 2021 15:04:31 -0700 Subject: [PATCH 607/707] released xlmr xl and xxl model weights (#1944) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1944 Reviewed By: jingfeidu Differential Revision: D28944206 fbshipit-source-id: 583837f7dd387341574d27dd9acc145455d640a8 --- README.md | 1 + examples/xlmr/README.md | 31 +++++++++++++++++++++++----- fairseq/models/roberta/model_xlmr.py | 2 ++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 839dd8e1de..82b6ba7cd8 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) * February 2021 [Added LASER training code](examples/laser/README.md) * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) diff --git a/examples/xlmr/README.md b/examples/xlmr/README.md index 65d4be13de..b95bfe15d3 100644 --- a/examples/xlmr/README.md +++ b/examples/xlmr/README.md @@ -1,9 +1,16 @@ # Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa) https://arxiv.org/pdf/1911.02116.pdf +# Larger-Scale Transformers for Multilingual Masked Language Modeling +https://arxiv.org/pdf/2105.00572.pdf + + +## What's New: +- June 2021: `XLMR-XL` AND `XLMR-XXL` models released. + ## Introduction -XLM-R (XLM-RoBERTa) is a generic cross lingual sentence encoder that obtains state-of-the-art results on many cross-lingual understanding (XLU) benchmarks. It is trained on 2.5T of filtered CommonCrawl data in 100 languages (list below). +`XLM-R` (`XLM-RoBERTa`) is a generic cross lingual sentence encoder that obtains state-of-the-art results on many cross-lingual understanding (XLU) benchmarks. It is trained on `2.5T` of filtered CommonCrawl data in 100 languages (list below). Language | Language|Language |Language | Language ---|---|---|---|--- @@ -34,8 +41,8 @@ Model | Description | #params | vocab size | Download ---|---|---|---|--- `xlmr.base` | XLM-R using the BERT-base architecture | 250M | 250k | [xlm.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz) `xlmr.large` | XLM-R using the BERT-large architecture | 560M | 250k | [xlm.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz) - -(Note: Above are final model checkpoints. If you were using previously released `v0` version, we recommend using above. They have same architecture and dictionary.) +`xlmr.xl` | XLM-R (`layers=36, model_dim=2560`) | 3.5B | 250k | [xlm.xl.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz) +`xlmr.xxl` | XLM-R (`layers=48, model_dim=4096`) | 10.7B | 250k | [xlm.xxl.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz) ## Results @@ -44,7 +51,9 @@ Model | Description | #params | vocab size | Download Model | average | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur ---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--- `roberta.large.mnli` _(TRANSLATE-TEST)_ | 77.8 | 91.3 | 82.9 | 84.3 | 81.2 | 81.7 | 83.1 | 78.3 | 76.8 | 76.6 | 74.2 | 74.1 | 77.5 | 70.9 | 66.7 | 66.8 -`xlmr.large` _(TRANSLATE-TRAIN-ALL)_ | **83.6** | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | 83.7 | 81.6 | 78.0 | 78.1 +`xlmr.large` _(TRANSLATE-TRAIN-ALL)_ | 83.6 | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | 83.7 | 81.6 | 78.0 | 78.1 +`xlmr.xl` _(TRANSLATE-TRAIN-ALL)_ | 85.4 | 91.1 | 87.2 | 88.1 | 87.0 | 87.4 | 87.8 | 85.3 | 85.2 | 85.3 | 86.2 | 83.8 | 85.3 | 83.1 | 79.8 | 78.2 | 85.4 +`xlmr.xxl` _(TRANSLATE-TRAIN-ALL)_ | 86.0 | 91.5 | 87.6 | 88.7 | 87.8 | 87.4 | 88.2 | 85.6 | 85.1 | 85.8 | 86.3 | 83.9 | 85.6 | 84.6 | 81.7 | 80.6 **[MLQA (Lewis et al., 2018)](https://arxiv.org/abs/1910.07475)** @@ -52,7 +61,9 @@ Model | average | en | es | de | ar | hi | vi | zh ---|---|---|---|---|---|---|---|--- `BERT-large` | - | 80.2/67.4 | - | - | - | - | - | - `mBERT` | 57.7 / 41.6 | 77.7 / 65.2 | 64.3 / 46.6 | 57.9 / 44.3 | 45.7 / 29.8| 43.8 / 29.7 | 57.1 / 38.6 | 57.5 / 37.3 -`xlmr.large` | **70.7 / 52.7** | 80.6 / 67.8 | 74.1 / 56.0 | 68.5 / 53.6 | 63.1 / 43.5 | 69.2 / 51.6 | 71.3 / 50.9 | 68.0 / 45.4 +`xlmr.large` | 70.7 / 52.7 | 80.6 / 67.8 | 74.1 / 56.0 | 68.5 / 53.6 | 63.1 / 43.5 | 69.2 / 51.6 | 71.3 / 50.9 | 68.0 / 45.4 +`xlmr.xl` | 73.4 / 55.3 | 85.1 / 72.6 | 66.7 / 46.2 | 70.5 / 55.5 | 74.3 / 56.9 | 72.2 / 54.7 | 74.4 / 52.9 | 70.9 / 48.5 +`xlmr.xxl` | 74.8 / 56.6 | 85.5 / 72.4 | 68.6 / 48.4 | 72.7 / 57.8 | 75.4 / 57.6 | 73.7 / 55.8 | 76.0 / 55.0 | 71.7 / 48.9 ## Example usage @@ -121,3 +132,13 @@ assert torch.all(all_layers[-1] == last_layer_features) year={2019} } ``` + + +```bibtex +@article{goyal2021larger, + title={Larger-Scale Transformers for Multilingual Masked Language Modeling}, + author={Goyal, Naman and Du, Jingfei and Ott, Myle and Anantharaman, Giri and Conneau, Alexis}, + journal={arXiv preprint arXiv:2105.00572}, + year={2021} +} +``` diff --git a/fairseq/models/roberta/model_xlmr.py b/fairseq/models/roberta/model_xlmr.py index 5886880f73..cf6e354d53 100644 --- a/fairseq/models/roberta/model_xlmr.py +++ b/fairseq/models/roberta/model_xlmr.py @@ -19,6 +19,8 @@ def hub_models(cls): return { "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz", "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz", + "xlmr.xl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz", + "xlmr.xxl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz", } @classmethod From 50158da3a7b293f2d2fa06a23e90c160b92f54ce Mon Sep 17 00:00:00 2001 From: Diana Liskovich <dianaml@devfair0471.h2.fair> Date: Thu, 10 Jun 2021 09:42:18 -0700 Subject: [PATCH 608/707] Migrate DummyMaskedLMTask to FairseqTask (#3593) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3593 Reviewed By: msbaines Differential Revision: D28992614 Pulled By: dianaml0 fbshipit-source-id: b2dfcab472a65c41536e78600a0e6b3745dc3a08 --- fairseq/benchmark/__init__.py | 2 +- fairseq/benchmark/dummy_dataset.py | 36 ++++++++ fairseq/benchmark/dummy_lm.py | 39 +-------- fairseq/benchmark/dummy_masked_lm.py | 119 ++++++++++----------------- tests/test_valid_subset_checks.py | 4 + 5 files changed, 86 insertions(+), 114 deletions(-) create mode 100644 fairseq/benchmark/dummy_dataset.py diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py index f6584661bd..0317d5c623 100644 --- a/fairseq/benchmark/__init__.py +++ b/fairseq/benchmark/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. # import models/tasks to register them -from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa +from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa diff --git a/fairseq/benchmark/dummy_dataset.py b/fairseq/benchmark/dummy_dataset.py new file mode 100644 index 0000000000..2f051754af --- /dev/null +++ b/fairseq/benchmark/dummy_dataset.py @@ -0,0 +1,36 @@ +import numpy as np +from fairseq.data import FairseqDataset + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index d917e28837..c6246a0c0e 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -7,9 +7,9 @@ from dataclasses import dataclass, field from typing import Optional -import numpy as np import torch -from fairseq.data import Dictionary, FairseqDataset +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary from fairseq.dataclass import FairseqDataclass from fairseq.tasks import FairseqTask, register_task from omegaconf import II @@ -33,7 +33,6 @@ class DummyLMConfig(FairseqDataclass): @register_task("dummy_lm", dataclass=DummyLMConfig) class DummyLMTask(FairseqTask): - def __init__(self, cfg: DummyLMConfig): super().__init__(cfg) @@ -82,37 +81,3 @@ def source_dictionary(self): @property def target_dictionary(self): return self.dictionary - - -class DummyDataset(FairseqDataset): - def __init__(self, batch, num_items, item_size): - super().__init__() - self.batch = batch - self.num_items = num_items - self.item_size = item_size - - def __getitem__(self, index): - return index - - def __len__(self): - return self.num_items - - def collater(self, samples): - return self.batch - - @property - def sizes(self): - return np.array([self.item_size] * self.num_items) - - def num_tokens(self, index): - return self.item_size - - def size(self, index): - return self.item_size - - def ordered_indices(self): - return np.arange(self.num_items) - - @property - def supports_prefetch(self): - return False diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index ab506fe1d5..12b9c5d0f5 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -4,43 +4,53 @@ # LICENSE file in the root directory of this source tree. import logging +from dataclasses import dataclass, field +from typing import Optional -import numpy as np import torch -from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import LegacyFairseqTask, register_task +from omegaconf import II +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task logger = logging.getLogger(__name__) -@register_task("dummy_masked_lm") -class DummyMaskedLMTask(LegacyFairseqTask): - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("--dict-size", default=49995, type=int) - parser.add_argument("--dataset-size", default=100000, type=int) - parser.add_argument( - "--tokens-per-sample", - default=512, - type=int, - help="max number of total tokens over all segments " - "per sample for BERT dataset", - ) - - def __init__(self, args, dictionary): - super().__init__(args) - self.dictionary = dictionary - +@dataclass +class DummyMaskedLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all" + " segments per sample for BERT dataset" + }, + ) + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + + +@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig) +class DummyMaskedLMTask(FairseqTask): + def __init__(self, cfg: DummyMaskedLMConfig): + super().__init__(cfg) + + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(self.dictionary))) # add mask token - self.mask_idx = dictionary.add_symbol("<mask>") - dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + self.mask_idx = self.dictionary.add_symbol("<mask>") + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 mask_idx = 0 pad_idx = 1 - seq = torch.arange(args.tokens_per_sample) + pad_idx + 1 - mask = torch.arange(2, args.tokens_per_sample, 7) # ~15% + seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1 + mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15% src = seq.clone() src[mask] = mask_idx tgt = torch.full_like(seq, pad_idx) @@ -49,39 +59,30 @@ def __init__(self, args, dictionary): self.dummy_src = src self.dummy_tgt = tgt - @classmethod - def setup_task(cls, args, **kwargs): - """Setup the task. """ - dictionary = Dictionary() - for i in range(args.dict_size): - dictionary.add_symbol("word{}".format(i)) - logger.info("dictionary: {} types".format(len(dictionary))) - return cls(args, dictionary) - def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ - if self.args.batch_size is not None: - bsz = self.args.batch_size + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size else: - bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) self.datasets[split] = DummyDataset( { "id": 1, "net_input": { "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), "src_lengths": torch.full( - (bsz,), self.args.tokens_per_sample, dtype=torch.long + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long ), }, "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), "nsentences": bsz, - "ntokens": bsz * self.args.tokens_per_sample, + "ntokens": bsz * self.cfg.tokens_per_sample, }, - num_items=self.args.dataset_size, - item_size=self.args.tokens_per_sample, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, ) @property @@ -91,37 +92,3 @@ def source_dictionary(self): @property def target_dictionary(self): return self.dictionary - - -class DummyDataset(FairseqDataset): - def __init__(self, batch, num_items, item_size): - super().__init__() - self.batch = batch - self.num_items = num_items - self.item_size = item_size - - def __getitem__(self, index): - return index - - def __len__(self): - return self.num_items - - def collater(self, samples): - return self.batch - - @property - def sizes(self): - return np.array([self.item_size] * self.num_items) - - def num_tokens(self, index): - return self.item_size - - def size(self, index): - return self.item_size - - def ordered_indices(self): - return np.arange(self.num_items) - - @property - def supports_prefetch(self): - return False diff --git a/tests/test_valid_subset_checks.py b/tests/test_valid_subset_checks.py index 8da79cfb82..3e9191bda6 100644 --- a/tests/test_valid_subset_checks.py +++ b/tests/test_valid_subset_checks.py @@ -103,6 +103,10 @@ def test_dummy_task(self): cfg = make_lm_config(task="dummy_lm") raise_if_valid_subsets_unintentionally_ignored(cfg) + def test_masked_dummy_task(self): + cfg = make_lm_config(task="dummy_masked_lm") + raise_if_valid_subsets_unintentionally_ignored(cfg) + class TestCombineValidSubsets(unittest.TestCase): def _train(self, extra_flags): From f8a7c93440cd925f70979a6082c18f830b39e44b Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 10 Jun 2021 21:57:48 -0700 Subject: [PATCH 609/707] W2v u update (#1954) Summary: updating the scripts and examples to be easier to follow Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1954 Reviewed By: wnhsu Differential Revision: D29041166 Pulled By: alexeib fbshipit-source-id: d9410c6e925337b810e92b393e226869ef9e1733 --- .../kaldi/config/kaldi_initializer.yaml | 8 +++ examples/speech_recognition/w2l_decoder.py | 51 +++++---------- examples/wav2vec/unsupervised/README.md | 60 +++++++++++------ .../wav2vec/unsupervised/config/gan/w2vu.yaml | 7 ++ .../unsupervised/config/generate/viterbi.yaml | 1 - .../wav2vec/unsupervised/models/wav2vec_u.py | 65 +++++++------------ .../wav2vec/unsupervised/scripts/apply_pca.py | 8 ++- .../unsupervised/scripts/g2p_wrd_to_phn.py | 20 +++--- .../wav2vec/unsupervised/scripts/mean_pool.py | 13 +++- .../unsupervised/scripts/merge_clusters.py | 14 ++-- .../scripts/normalize_and_filter_text.py | 29 +++++++-- .../unsupervised/scripts/prepare_audio.sh | 45 +++++++++---- .../unsupervised/scripts/prepare_text.sh | 60 ++++++++++++----- .../unsupervised/scripts/remove_silence.py | 1 - examples/wav2vec/unsupervised/scripts/vads.py | 25 +++++-- .../scripts/wav2vec_apply_cluster_faiss.py | 29 +++++++-- .../scripts/wav2vec_extract_features.py | 8 ++- .../unsupervised/tasks/unpaired_audio_text.py | 34 ++++++---- .../wav2vec/unsupervised/w2vu_generate.py | 3 + fairseq/models/__init__.py | 46 +++++++------ fairseq/models/wav2vec/wav2vec.py | 2 +- fairseq/tasks/__init__.py | 52 ++++++++------- fairseq/utils.py | 17 +++-- 23 files changed, 372 insertions(+), 226 deletions(-) create mode 100644 examples/speech_recognition/kaldi/config/kaldi_initializer.yaml diff --git a/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml b/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml new file mode 100644 index 0000000000..be9ba98f55 --- /dev/null +++ b/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +data_dir: ??? +fst_dir: ??? +in_labels: ??? +kaldi_root: ??? +lm_arpa: ??? +blank_symbol: <s> diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 8b158293a0..aef4481593 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -52,29 +52,19 @@ def __init__(self, args, tgt_dict): self.nbest = args.nbest # criterion-specific init - if args.criterion == "ctc": - self.criterion_type = CriterionType.CTC - self.blank = ( - tgt_dict.index("<ctc_blank>") - if "<ctc_blank>" in tgt_dict.indices - else tgt_dict.bos() - ) - if "<sep>" in tgt_dict.indices: - self.silence = tgt_dict.index("<sep>") - elif "|" in tgt_dict.indices: - self.silence = tgt_dict.index("|") - else: - self.silence = tgt_dict.eos() - self.asg_transitions = None - elif args.criterion == "asg_loss": - self.criterion_type = CriterionType.ASG - self.blank = -1 - self.silence = -1 - self.asg_transitions = args.asg_transitions - self.max_replabel = args.max_replabel - assert len(self.asg_transitions) == self.vocab_size ** 2 + self.criterion_type = CriterionType.CTC + self.blank = ( + tgt_dict.index("<ctc_blank>") + if "<ctc_blank>" in tgt_dict.indices + else tgt_dict.bos() + ) + if "<sep>" in tgt_dict.indices: + self.silence = tgt_dict.index("<sep>") + elif "|" in tgt_dict.indices: + self.silence = tgt_dict.index("|") else: - raise RuntimeError(f"unknown criterion: {args.criterion}") + self.silence = tgt_dict.eos() + self.asg_transitions = None def generate(self, models, sample, **unused): """Generate a batch of inferences.""" @@ -90,23 +80,16 @@ def get_emissions(self, models, encoder_input): """Run encoder and normalize emissions""" model = models[0] encoder_out = model(**encoder_input) - if self.criterion_type == CriterionType.CTC: - if hasattr(model, "get_logits"): - emissions = model.get_logits(encoder_out) # no need to normalize emissions - else: - emissions = model.get_normalized_probs(encoder_out, log_probs=True) - elif self.criterion_type == CriterionType.ASG: - emissions = encoder_out["encoder_out"] + if hasattr(model, "get_logits"): + emissions = model.get_logits(encoder_out) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) return emissions.transpose(0, 1).float().cpu().contiguous() def get_tokens(self, idxs): """Normalize tokens by handling CTC blank, ASG replabels, etc.""" idxs = (g[0] for g in it.groupby(idxs)) - if self.criterion_type == CriterionType.CTC: - idxs = filter(lambda x: x != self.blank, idxs) - elif self.criterion_type == CriterionType.ASG: - idxs = filter(lambda x: x >= 0, idxs) - idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel) + idxs = filter(lambda x: x != self.blank, idxs) return torch.LongTensor(list(idxs)) diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md index e9ec59f06d..c2a935d414 100644 --- a/examples/wav2vec/unsupervised/README.md +++ b/examples/wav2vec/unsupervised/README.md @@ -12,31 +12,41 @@ Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. -Here is how you can create new audio files without silences from a list of input audio files: +Pre-requisites: +* set FAIRSEQ_ROOT environmental variable to your fairseq installation +* set RVAD_ROOT environmental variable to a checkout of [rVADfast](https://github.com/zhenghuatan/rVADfast) +* set KENLM_ROOT environmental variable to the location of [KenLM](https://github.com/kpu/kenlm) binaries +* install [PyKaldi](https://github.com/pykaldi/pykaldi) and set KALDI_ROOT environmental variable to the location of your kaldi installation. To use the version bundled with PyKaldi, you can use /path/to/pykaldi/tools/kaldi + +Create new audio files without silences: ```shell -python scripts/vads.py < /path/to/train.tsv > train.vads +# create a manifest file for the set original of audio files +python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0 + +python scripts/vads.py -r $RVAD_ROOT < /path/to/train.tsv > train.vads python scripts/remove_silence.py --tsv /path/to/train.tsv --vads train.vads --out /dir/to/save/audio/files -python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0 +python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0.01 ``` -You will need to add the path to rVAD directory to vads.py. - Next, we need to preprocess the audio data to better match phonemized text data: ```shell -zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt +zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt 512 14 ``` -Note that if you have splits different than train/valid/test, you will need to modify this script. +Note that if you have splits different than train/valid/test, you will need to modify this script. The last two arguments are the PCA dimensionality and the 0-based index of the layer from which to extract representations. Now we need to prepare text data: ```shell -zsh scripts/prepare_text.sh language /path/to/text/file /output/dir +zsh scripts/prepare_text.sh language /path/to/text/file /output/dir 1000 espeak /path/to/fasttext/lid/model ``` -Note that if you want to use a different phonemizer, such as G2P, you will need to modify this script. +The fourth argument is minimum number observations of phones to keep. If your text corpus is small, you might want to reduce this number. +The fifth argument is which phonemizer to use. Supported values are [espeak](http://espeak.sourceforge.net/), [espeak-ng](https://github.com/espeak-ng/espeak-ng), and [G2P](https://github.com/Kyubyong/g2p) (english only). + +Pre-trained fasttext LID models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html). ## Generative adversarial training (GAN) @@ -46,26 +56,34 @@ Launching GAN training on top of preprocessed features, with default hyperparame ``` PREFIX=w2v_unsup_gan_xp -TASK_DATA=/path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean_pooled -TEXT_DATA=/path/to/data # path to fairseq-preprocessed GAN data -KENLM_PATH=/path/to/data/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here) - -PREFIX=$PREFIX fairseq-hydra-train \ - -m --config-dir configs/gan \ - --config-name w2vu \ - task.data=${TASK_DATA} \ - task.text_data=${TEXT_DATA} \ - task.kenlm_path=${KENLM_PATH} \ - 'common.seed=range(0,5)' & +TASK_DATA=/path/to/features/precompute_unfiltered_pca512_cls128_mean_pooled +TEXT_DATA=/path/to/data/phones # path to fairseq-preprocessed GAN data (phones dir) +KENLM_PATH=/path/to/data/phones/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here) + +PYTHONPATH=$FAIRSEQ_ROOT PREFIX=$PREFIX fairseq-hydra-train \ + -m --config-dir config/gan \ + --config-name w2vu \ + task.data=${TASK_DATA} \ + task.text_data=${TEXT_DATA} \ + task.kenlm_path=${KENLM_PATH} \ + common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \ + model.code_penalty=2,4 model.gradient_penalty=1.5,2.0 \ + model.smoothness_weight=0.5,0.75,1.0 'common.seed=range(0,5)' ``` + Once we find the best checkpoint (chosen using unsupervised metric that combined language model perplexity and vocabulary usage), we can use it to generate phone labels (or word labels with an appropriate kaldi WFST): ```shell python w2vu_generate.py --config-dir config/generate --config-name viterbi \ -fairseq.task.data=/path/to/dir/with/tsvs fairseq.common_eval.path=/path/to/gan/checkpoint \ +fairseq.common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \ +fairseq.task.data=/path/to/dir/with/features \ +fairseq.common_eval.path=/path/to/gan/checkpoint \ fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions ``` + +The decoding without LM works best on the same adjacent-mean-pooled features that the gan was trained on, while decoding with LM works better on features before the adjacent timestep mean-pooling step (without the "_pooled" suffix). + ## Iterative self-training + Kaldi LM-decoding After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words. diff --git a/examples/wav2vec/unsupervised/config/gan/w2vu.yaml b/examples/wav2vec/unsupervised/config/gan/w2vu.yaml index d168a11e19..74f1829d14 100644 --- a/examples/wav2vec/unsupervised/config/gan/w2vu.yaml +++ b/examples/wav2vec/unsupervised/config/gan/w2vu.yaml @@ -10,10 +10,15 @@ common: suppress_crashes: false checkpoint: + save_interval: 1000 + save_interval_updates: 1000 no_epoch_checkpoints: true best_checkpoint_metric: weighted_lm_ppl save_dir: . +distributed_training: + distributed_world_size: 1 + task: _name: unpaired_audio_text data: ??? @@ -30,6 +35,8 @@ dataset: batch_size: 160 skip_invalid_size_inputs_valid_test: true valid_subset: valid + validate_interval: 1000 + validate_interval_updates: 1000 criterion: _name: model diff --git a/examples/wav2vec/unsupervised/config/generate/viterbi.yaml b/examples/wav2vec/unsupervised/config/generate/viterbi.yaml index 0f850bb3e7..9c88beebcb 100644 --- a/examples/wav2vec/unsupervised/config/generate/viterbi.yaml +++ b/examples/wav2vec/unsupervised/config/generate/viterbi.yaml @@ -18,5 +18,4 @@ fairseq: batch_size: 1 w2l_decoder: VITERBI -lm_model: ??? post_process: silence diff --git a/examples/wav2vec/unsupervised/models/wav2vec_u.py b/examples/wav2vec/unsupervised/models/wav2vec_u.py index d3f195b94e..27792ebda8 100644 --- a/examples/wav2vec/unsupervised/models/wav2vec_u.py +++ b/examples/wav2vec/unsupervised/models/wav2vec_u.py @@ -76,7 +76,6 @@ class Wav2vec_UConfig(FairseqDataclass): hard_gumbel: bool = True temp: Tuple[float, float, float] = (2, 0.1, 0.99995) input_dim: int = 128 - wgan_loss: bool = False segmentation: SegmentationConfig = SegmentationConfig() @@ -397,12 +396,7 @@ def set_num_updates(self, num_updates): ) def discrim_step(self, num_updates): - if num_updates < self.zero_pretrain_updates: - return False - if self.dynamic_step_thresh <= 0 or self.last_acc is None: - return num_updates % 2 == 1 - else: - return self.last_acc < self.dynamic_step_thresh + return num_updates % 2 == 1 def get_groups_for_update(self, num_updates): return "discriminator" if self.discrim_step(num_updates) else "generator" @@ -413,7 +407,6 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict): self.cfg = cfg self.zero_index = target_dict.index("<SIL>") if "<SIL>" in target_dict else 0 self.smoothness_weight = cfg.smoothness_weight - self.wgan_loss = cfg.wgan_loss output_size = len(target_dict) self.pad = target_dict.pad() @@ -432,7 +425,7 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict): self.blank_index = target_dict.index("<SIL>") if cfg.blank_is_sil else 0 assert self.blank_index != target_dict.unk() - self.discriminator = self.Discriminator(output_size, cfg) + self.discriminator = Discriminator(output_size, cfg) for p in self.discriminator.parameters(): p.param_group = "discriminator" @@ -441,9 +434,7 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict): self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation) - self.generator = self.Generator( - d, output_size, cfg, lambda x: self.normalize(x)[0] - ) + self.generator = Generator(d, output_size, cfg) for p in self.generator.parameters(): p.param_group = "generator" @@ -589,20 +580,16 @@ def forward( code_pen = None if d_step: - if self.wgan_loss: - loss_dense = dense_y.sum() - loss_token = -1 * token_y.sum() - else: - loss_dense = F.binary_cross_entropy_with_logits( - dense_y, - dense_y.new_ones(dense_y.shape) - fake_smooth, - reduction="sum", - ) - loss_token = F.binary_cross_entropy_with_logits( - token_y, - token_y.new_zeros(token_y.shape) + real_smooth, - reduction="sum", - ) + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_ones(dense_y.shape) - fake_smooth, + reduction="sum", + ) + loss_token = F.binary_cross_entropy_with_logits( + token_y, + token_y.new_zeros(token_y.shape) + real_smooth, + reduction="sum", + ) if self.training and self.gradient_penalty > 0: grad_pen = self.calc_gradient_penalty(token_x, dense_x) grad_pen = grad_pen.sum() * self.gradient_penalty @@ -611,23 +598,15 @@ def forward( else: grad_pen = None loss_token = None - if self.update_num >= self.zero_pretrain_updates: - if self.wgan_loss: - loss_dense = -1 * dense_y.sum() - else: - loss_dense = F.binary_cross_entropy_with_logits( - dense_y, - dense_y.new_zeros(dense_y.shape) + fake_smooth, - reduction="sum", - ) - num_vars = dense_x.size(-1) - if prob_perplexity is not None: - code_pen = (num_vars - prob_perplexity) / num_vars - if self.exponential_code_pen: - code_pen = (1 - 1 / code_pen ** 2).exp() - code_pen = code_pen * sample_size * self.code_penalty - else: - loss_dense = None + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_zeros(dense_y.shape) + fake_smooth, + reduction="sum", + ) + num_vars = dense_x.size(-1) + if prob_perplexity is not None: + code_pen = (num_vars - prob_perplexity) / num_vars + code_pen = code_pen * sample_size * self.code_penalty if self.smoothness_weight > 0: smoothness_loss = F.mse_loss( diff --git a/examples/wav2vec/unsupervised/scripts/apply_pca.py b/examples/wav2vec/unsupervised/scripts/apply_pca.py index 0cddd20001..10ad6ce47c 100644 --- a/examples/wav2vec/unsupervised/scripts/apply_pca.py +++ b/examples/wav2vec/unsupervised/scripts/apply_pca.py @@ -50,8 +50,12 @@ def main(): copyfile(source_path + ".tsv", save_path + ".tsv") copyfile(data_poth + ".lengths", save_path + ".lengths") - copyfile(source_path + ".phn", save_path + ".phn") - copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + + if osp.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") if osp.exists(save_path + ".npy"): os.remove(save_path + ".npy") diff --git a/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py b/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py index 8c3138e55b..2e31c307bd 100644 --- a/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py +++ b/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py @@ -12,28 +12,32 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument("root_dirs", nargs="*") - parser.add_argument("--insert-silence", "-s", action="store_true") + parser.add_argument( + "--compact", + action="store_true", + help="if set, compacts phones", + ) args = parser.parse_args() - sil = "<s>" + + compact = args.compact wrd_to_phn = {} g2p = G2p() for line in sys.stdin: words = line.strip().split() phones = [] - if args.insert_silence: - phones.append(sil) for w in words: if w not in wrd_to_phn: wrd_to_phn[w] = g2p(w) + if compact: + wrd_to_phn[w] = [ + p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w] + ] phones.extend(wrd_to_phn[w]) - if args.insert_silence: - phones.append(sil) try: print(" ".join(phones)) except: - print(wrd_to_phn, w, phones, file=sys.stderr) + print(wrd_to_phn, words, phones, file=sys.stderr) raise diff --git a/examples/wav2vec/unsupervised/scripts/mean_pool.py b/examples/wav2vec/unsupervised/scripts/mean_pool.py index 1145e774eb..4eea048ef3 100644 --- a/examples/wav2vec/unsupervised/scripts/mean_pool.py +++ b/examples/wav2vec/unsupervised/scripts/mean_pool.py @@ -47,8 +47,17 @@ def main(): save_path = osp.join(args.save_dir, args.split) copyfile(source_path + ".tsv", save_path + ".tsv") - copyfile(source_path + ".phn", save_path + ".phn") - copyfile(source_path + ".wrd", save_path + ".wrd") + + if os.path.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + if os.path.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") + + if os.path.exists(osp.join(args.source, "dict.phn.txt")): + copyfile( + osp.join(args.source, "dict.phn.txt"), + osp.join(args.save_dir, "dict.phn.txt"), + ) if osp.exists(save_path + ".npy"): os.remove(save_path + ".npy") diff --git a/examples/wav2vec/unsupervised/scripts/merge_clusters.py b/examples/wav2vec/unsupervised/scripts/merge_clusters.py index 6502ed5718..2780f9d971 100644 --- a/examples/wav2vec/unsupervised/scripts/merge_clusters.py +++ b/examples/wav2vec/unsupervised/scripts/merge_clusters.py @@ -62,14 +62,16 @@ def main(): save_path = osp.join(args.save_dir, args.split) copyfile(source_path + ".tsv", save_path + ".tsv") - copyfile(source_path + ".phn", save_path + ".phn") - if os.path.exists(source_path + ".phnsc"): - copyfile(source_path + ".phnsc", save_path + ".phnsc") + + if os.path.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + if os.path.exists(osp.join(args.source, "dict.phn.txt")): copyfile( - osp.join(args.source, "dict.phnsc.txt"), - osp.join(args.save_dir, "dict.phnsc.txt"), + osp.join(args.source, "dict.phn.txt"), + osp.join(args.save_dir, "dict.phn.txt"), ) - copyfile(source_path + ".wrd", save_path + ".wrd") + if os.path.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") if osp.exists(save_path + ".npy"): os.remove(save_path + ".npy") diff --git a/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py b/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py index 1284747795..c2bd16efb5 100644 --- a/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py +++ b/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py @@ -6,6 +6,7 @@ import argparse import fasttext as ft +import os import regex import sys @@ -39,17 +40,31 @@ def main(): lg_label = f"__label__{lg}" thresh = args.lid_threshold - model = ft.load_model(args.fasttext_model) + if os.path.exists(args.fasttext_model): + model = ft.load_model(args.fasttext_model) + else: + print( + f"fasttext language id model {args.fasttext_model} not found. Proceeding without language filtering. " + f"To enable language filtering, please download the latest language id model " + f"from https://fasttext.cc/docs/en/language-identification.html", + file=sys.stderr, + ) + model = None + for line in sys.stdin: line = line.strip() line = filter_r.sub(" ", line) line = " ".join(line.split()) - lid, prob = model.predict(line, k=100) - try: - target_idx = lid.index(lg_label) - except ValueError: - continue - if target_idx == 0 or prob[target_idx] >= thresh: + + if model is not None: + lid, prob = model.predict(line, k=100) + try: + target_idx = lid.index(lg_label) + except ValueError: + continue + if target_idx == 0 or prob[target_idx] >= thresh: + print(line) + else: print(line) diff --git a/examples/wav2vec/unsupervised/scripts/prepare_audio.sh b/examples/wav2vec/unsupervised/scripts/prepare_audio.sh index 893c9fda1a..013f7a9b05 100644 --- a/examples/wav2vec/unsupervised/scripts/prepare_audio.sh +++ b/examples/wav2vec/unsupervised/scripts/prepare_audio.sh @@ -17,10 +17,31 @@ fi echo "using $dim dim for PCA" +if [ -z "$5" ] + then + layer=14 + else + layer=$5 +fi + +echo "extracting from layer $layer" + train_split=train valid_split=valid test_split=test +all_splits=($train_split) + +if [[ -f "$source_dir/valid.tsv" ]]; then + all_splits+=('valid') +fi + +if [[ -f "$source_dir/test.tsv" ]]; then + all_splits+=('test') +fi + +echo "processing splits: $all_splits" + mkdir -p $tgt_dir cp $source_dir/*.tsv $tgt_dir @@ -31,27 +52,27 @@ cp $source_dir/dict* $tgt_dir setopt shwordsplit -for split in $train_split $valid_split $test_split; do - python wav2vec_extract_features.py $source_dir --split $split \ - --save-dir $tgt_dir --checkpoint $model +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py $source_dir --split $split \ + --save-dir $tgt_dir --checkpoint $model --layer $layer done -python wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \ +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \ --checkpoint $model --save-dir $tgt_dir -f "CLUS128" --sample-pct 1.0 -for split in $train_split $valid_split $test_split; do - python wav2vec_apply_cluster_faiss.py $tgt_dir \ +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py $tgt_dir \ --checkpoint $model --path $tgt_dir/CLUS128 --split $split done -python pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim -for split in $train_split $valid_split $test_split; do - python apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000 +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000 - python merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \ + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \ --split $split --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean --pooling mean - python mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \ + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \ --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean_pooled --split $split -done \ No newline at end of file +done diff --git a/examples/wav2vec/unsupervised/scripts/prepare_text.sh b/examples/wav2vec/unsupervised/scripts/prepare_text.sh index e9090a3d80..1caf13cb6a 100644 --- a/examples/wav2vec/unsupervised/scripts/prepare_text.sh +++ b/examples/wav2vec/unsupervised/scripts/prepare_text.sh @@ -7,6 +7,13 @@ lg=$1 text_path=$2 target_dir=$3 +min_phones=$4 +phonemizer=$5 +lid_path=$6 + +if [ -z "$lid_path" ]; then + lid_path="lid.187.bin" +fi ph_lg=${lg:l} if test "$lg" = 'fr'; then @@ -17,40 +24,59 @@ elif test "$lg" = 'pt'; then ph_lg='pt-br' fi +ESPEAK_PATH='' +if test "$phonemizer" = 'espeak'; then + ESPEAK_PATH=$(which espeak) +elif test "$phonemizer" = 'espeak-ng'; then + ESPEAK_PATH=$(which espeak-ng) +elif test "$phonemizer" = 'G2P'; then + ESPEAK_PATH='' +else + echo "Unknown phonemizer $phonemizer. Valid options are espeak, espean-ng and G2P" + exit 1 +fi + echo $lg echo $ph_lg echo $text_path echo $target_dir +echo "min phone seen threshold is $min_phones" mkdir -p $target_dir -python normalize_and_filter_text.py --lang $lg < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py --lang $lg --fasttext-model $lid_path < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/lm.upper.lid.txt --only-source --destdir $target_dir --thresholdsrc 2 --padding-factor 1 --dict-only cut -f1 -d' ' $target_dir/dict.txt | grep -v -x '[[:punct:]]*' | grep -Pv '\d\d\d\d\d+' >! $target_dir/words.txt -one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$(which espeak) phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags) -sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$(which espeak) phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags -echo "one is ${one}" +if [ -z "$ESPEAK_PATH" ]; then + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py --compact < $target_dir/words.txt > $target_dir/phones.txt +else + # echoing 1 into corpus will prevent the mismatch lines between lexicon and phones in case the phonemizer fails + one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags) + sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags + echo "one is ${one}" + sed -i "s/${one}$//" $target_dir/phones.txt +fi -sed -i "s/${one}$//" $target_dir/phones.txt paste $target_dir/words.txt $target_dir/phones.txt >! $target_dir/lexicon.lst -python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc 1000 --padding-factor 1 --dict-only +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc $min_phones --padding-factor 1 --dict-only -python filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst -python phonemize_with_sil.py -s 0.25 --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py -s 0.25 --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt cp $target_dir/phones/dict.txt $target_dir/phones/dict.phn.txt echo "<SIL> 0" >> $target_dir/phones/dict.phn.txt python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones/lm.phones.filtered.txt --workers 70 --only-source --destdir $target_dir/phones --srcdict $target_dir/phones/dict.phn.txt -lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa -build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin -lg=$lg python examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones "blank_symbol='<SIL>'" -lg=$lg python examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones +$KENLM_ROOT/lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa +$KENLM_ROOT/build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin + +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn "blank_symbol='<SIL>'" +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn -lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa -build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin -lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa -build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin +$KENLM_ROOT/lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa +$KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin +$KENLM_ROOT/lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa +$KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin -lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones "blank_symbol='<SIL>'" +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones in_labels=phn "blank_symbol='<SIL>'" diff --git a/examples/wav2vec/unsupervised/scripts/remove_silence.py b/examples/wav2vec/unsupervised/scripts/remove_silence.py index 417b703de8..fac88b9897 100644 --- a/examples/wav2vec/unsupervised/scripts/remove_silence.py +++ b/examples/wav2vec/unsupervised/scripts/remove_silence.py @@ -58,7 +58,6 @@ if not os.path.isdir("/".join(outpath.split("/")[:-1])): os.makedirs("/".join(outpath.split("/")[:-1])) if not os.path.exists(outpath): - print(outpath) torchaudio.save(outpath, data_filtered, sample_rate=16000) else: print(outpath, "exists!") diff --git a/examples/wav2vec/unsupervised/scripts/vads.py b/examples/wav2vec/unsupervised/scripts/vads.py index 1acd95369c..2398da97d8 100644 --- a/examples/wav2vec/unsupervised/scripts/vads.py +++ b/examples/wav2vec/unsupervised/scripts/vads.py @@ -4,10 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import sys -sys.path.append("/path/to/rVADfast_py_2.0") -import speechproc from copy import deepcopy from scipy.signal import lfilter @@ -17,7 +16,19 @@ import os.path as osp -def rvad(path): +def get_parser(): + parser = argparse.ArgumentParser(description="compute vad segments") + parser.add_argument( + "--rvad-home", + "-r", + help="path to rvad home (see https://github.com/zhenghuatan/rVADfast)", + required=True, + ) + + return parser + + +def rvad(speechproc, path): winlen, ovrlen, pre_coef, nfilter, nftt = 0.025, 0.01, 0.97, 20, 512 ftThres = 0.5 vadThres = 0.4 @@ -56,12 +67,18 @@ def rvad(path): def main(): + parser = get_parser() + args = parser.parse_args() + + sys.path.append(args.rvad_home) + import speechproc + stride = 160 lines = sys.stdin.readlines() root = lines[0].rstrip() for fpath in tqdm(lines[1:]): path = osp.join(root, fpath.split()[0]) - vads, wav = rvad(path) + vads, wav = rvad(speechproc, path) start = None vad_segs = [] diff --git a/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py b/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py index 25bc4e41ac..a5dd7ae6c1 100644 --- a/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py +++ b/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import os import os.path as osp import numpy as np import tqdm @@ -33,13 +34,21 @@ def get_parser(): def get_iterator(args): - with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp, open( - osp.join(args.data, f"{args.split}.{args.labels}"), "r" - ) as lp: + label_path = osp.join(args.data, f"{args.split}.{args.labels}") + if osp.exists(label_path): + lp = open(label_path, "r") + else: + lp = None + + with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp: lines = fp.read().split("\n") root = lines.pop(0).strip() files = [line.rstrip() for line in lines if len(line) > 0] - lbls = [line.rstrip() for line in lp] + + if lp is not None: + lbls = [line.rstrip() for line in lp] + else: + lbls = [None] * len(files) num = len(files) reader = Wav2VecFeatureReader(args.checkpoint, args.layer) @@ -87,10 +96,13 @@ def main(): generator, num, root = get_iterator(args) iterator = generator() + had_labels = False + label_path = osp.join(args.path, f"{args.split}.{args.labels}") + with torch.no_grad(): with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open( osp.join(args.path, f"{args.split}.tsv"), "w" - ) as pp, open(osp.join(args.path, f"{args.split}.{args.labels}"), "w") as lp: + ) as pp, open(label_path, "w") as lp: print(root, file=pp) for f, fname, lbl in tqdm.tqdm(iterator, total=num): if faiss_spec.pca: @@ -104,7 +116,12 @@ def main(): print(" ".join(str(x.item()) for x in z), file=fp) print(fname, file=pp) - print(lbl, file=lp) + + if lbl is not None: + print(lbl, file=lp) + had_labels = True + if not had_labels: + os.remove(label_path) if __name__ == "__main__": diff --git a/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py b/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py index 023dd1aaa5..b07e274d20 100644 --- a/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py +++ b/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py @@ -27,13 +27,14 @@ def get_parser(): parser.add_argument('--split', help='which split to read', required=True) parser.add_argument('--save-dir', help='where to save the output', required=True) parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True) + parser.add_argument('--layer', type=int, default=14, help='which layer to use') # fmt: on return parser class Wav2VecFeatureReader(object): - def __init__(self, cp_file): + def __init__(self, cp_file, layer): model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [cp_file] ) @@ -42,6 +43,7 @@ def __init__(self, cp_file): model.cuda() self.model = model self.task = task + self.layer = layer def read_audio(self, fname): """Load an audio file and return PCM along with the sample rate""" @@ -60,7 +62,7 @@ def get_feats(self, loc): source = F.layer_norm(source, source.shape) source = source.view(1, -1) - m_res = self.model(source=source, mask=False, features_only=True, layer=14) + m_res = self.model(source=source, mask=False, features_only=True, layer=self.layer) return m_res["x"].squeeze(0).cpu() @@ -71,7 +73,7 @@ def get_iterator(args): files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] num = len(files) - reader = Wav2VecFeatureReader(args.checkpoint) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) def iterate(): for fname in files: diff --git a/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py b/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py index 0b770a1509..5f292528f8 100644 --- a/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py +++ b/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py @@ -83,7 +83,7 @@ class UnpairedAudioTextConfig(FairseqDataclass): decoding_config: DecodingConfig = DecodingConfig() -@register_task("gan_audio_pretraining_feats", dataclass=UnpairedAudioTextConfig) +@register_task("unpaired_audio_text", dataclass=UnpairedAudioTextConfig) class UnpairedAudioText(FairseqTask): """ """ @@ -194,12 +194,13 @@ def valid_step(self, sample, model, criterion): for i, (x, t, id) in enumerate( zip( z, - sample["target"], + sample["target"] if "target" in sample else [None] * len(z), sample["id"], ) ): - t = t[(t >= self.target_dictionary.nspecial)] + if t is not None: + t = t[(t >= self.target_dictionary.nspecial)] x = x[ (x >= self.target_dictionary.nspecial) & (x < (self.num_symbols + self.target_dictionary.nspecial)) @@ -215,28 +216,37 @@ def valid_step(self, sample, model, criterion): pred_units_arr = pred_units_arr[pred_units_arr != 0] if id == 0: - logger.info(f"REF: {self.target_dictionary.string(t)}") + if t is not None: + logger.info(f"REF: {self.target_dictionary.string(t)}") logger.info(f"HYP: {self.target_dictionary.string(pred_units_arr)}") if self.kenlm is not None: - ref_lm_s = self.compute_lm_score(self.target_dictionary.string(t)) + if t is not None: + ref_lm_s = self.compute_lm_score( + self.target_dictionary.string(t) + ) + logger.info( + f"LM [REF]: {ref_lm_s}, {math.pow(10, -ref_lm_s / (len(t) + 1))}" + ) + hyp_lm_s = self.compute_lm_score( self.target_dictionary.string(pred_units_arr) ) logger.info( - f"LM [REF]: {ref_lm_s}, {math.pow(10, ref_lm_s / (len(t) + 1))}" - ) - logger.info( - f"LM [HYP]: {hyp_lm_s}, {math.pow(10, hyp_lm_s / (len(pred_units_arr) + 1))}" + f"LM [HYP]: {hyp_lm_s}, {math.pow(10, -hyp_lm_s / (len(pred_units_arr) + 1))}" ) pred_units_arr = pred_units_arr.tolist() - t = t.tolist() - c_err += editdistance.eval(pred_units_arr, t) - c_len += len(t) pred_c_len += len(pred_units_arr) + if t is not None: + t = t.tolist() + c_err += editdistance.eval(pred_units_arr, t) + c_len += len(t) + else: + c_len = pred_c_len + if self.kenlm is not None: pred_str = self.target_dictionary.string(pred_units_arr) lm_score = self.compute_lm_score(pred_str) diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py index a1bc0ec706..b1e126665f 100644 --- a/examples/wav2vec/unsupervised/w2vu_generate.py +++ b/examples/wav2vec/unsupervised/w2vu_generate.py @@ -681,6 +681,9 @@ def hydra_main(cfg): ) OmegaConf.set_struct(cfg, True) logger.info(cfg) + + utils.import_user_module(cfg.fairseq.common) + _, score = main(cfg) if cfg.is_ax: diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 135530d5c0..61425c8ef5 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -195,25 +195,31 @@ def register_model_arch_fn(fn): return register_model_arch_fn +def import_models(models_dir, namespace): + for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser + + # automatically import any Python files in the models/ directory models_dir = os.path.dirname(__file__) -for file in os.listdir(models_dir): - path = os.path.join(models_dir, file) - if ( - not file.startswith("_") - and not file.startswith(".") - and (file.endswith(".py") or os.path.isdir(path)) - ): - model_name = file[: file.find(".py")] if file.endswith(".py") else file - module = importlib.import_module("fairseq.models." + model_name) - - # extra `model_parser` for sphinx - if model_name in MODEL_REGISTRY: - parser = argparse.ArgumentParser(add_help=False) - group_archs = parser.add_argument_group("Named architectures") - group_archs.add_argument( - "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] - ) - group_args = parser.add_argument_group("Additional command-line arguments") - MODEL_REGISTRY[model_name].add_args(group_args) - globals()[model_name + "_parser"] = parser +import_models(models_dir, "fairseq.models") diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py index 83b6461129..af6604da10 100644 --- a/fairseq/models/wav2vec/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -50,7 +50,7 @@ class Wav2VecConfig(FairseqDataclass): default=0, metadata={"help": "num of cross sampled negatives"} ) num_negatives: int = field( - default=10, metadata={"help": "num of cross sampled negatives"} + default=10, metadata={"help": "num of sampled negatives"} ) conv_feature_layers: str = field( default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 95b4a9647f..79dde74057 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -39,7 +39,9 @@ def setup_task(cfg: FairseqDataclass, **kwargs): cfg = merge_with_parent(dc(), cfg) task = TASK_REGISTRY[task_name] - assert task is not None, f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" + assert ( + task is not None + ), f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" return task.setup_task(cfg, **kwargs) @@ -103,26 +105,32 @@ def get_task(name): return TASK_REGISTRY[name] +def import_tasks(tasks_dir, namespace): + for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser + + # automatically import any Python files in the tasks/ directory tasks_dir = os.path.dirname(__file__) -for file in os.listdir(tasks_dir): - path = os.path.join(tasks_dir, file) - if ( - not file.startswith("_") - and not file.startswith(".") - and (file.endswith(".py") or os.path.isdir(path)) - ): - task_name = file[: file.find(".py")] if file.endswith(".py") else file - module = importlib.import_module("fairseq.tasks." + task_name) - - # expose `task_parser` for sphinx - if task_name in TASK_REGISTRY: - parser = argparse.ArgumentParser(add_help=False) - group_task = parser.add_argument_group("Task name") - # fmt: off - group_task.add_argument('--task', metavar=task_name, - help='Enable this task with: ``--task=' + task_name + '``') - # fmt: on - group_args = parser.add_argument_group("Additional command-line arguments") - TASK_REGISTRY[task_name].add_args(group_args) - globals()[task_name + "_parser"] = parser +import_tasks(tasks_dir, "fairseq.tasks") diff --git a/fairseq/utils.py b/fairseq/utils.py index bf5727edfd..4fe95b9e8b 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -10,7 +10,6 @@ import logging import os import sys -import tempfile import warnings from itertools import accumulate from typing import Callable, Dict, List, Optional @@ -59,9 +58,7 @@ def __call__(self, parser, namespace, values, option_string=None): def split_paths(paths: str, separator=os.pathsep) -> List[str]: return ( - paths.split(separator) - if "://" not in paths - else paths.split(MANIFOLD_PATH_SEP) + paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) ) @@ -483,6 +480,18 @@ def import_user_module(args): if module_name not in sys.modules: sys.path.insert(0, module_parent) importlib.import_module(module_name) + + tasks_path = os.path.join(module_path, "tasks") + if os.path.exists(tasks_path): + from fairseq.tasks import import_tasks + + import_tasks(tasks_path, f"{module_name}.tasks") + + models_path = os.path.join(module_path, "models") + if os.path.exists(models_path): + from fairseq.models import import_models + + import_models(models_path, f"{module_name}.models") else: raise ImportError( "Failed to import --user-dir={} because the corresponding module name " From c36294ea4fd35eac757f417de9668b32c57d4b3d Mon Sep 17 00:00:00 2001 From: Valentin Andrei <vandrei@fb.com> Date: Fri, 11 Jun 2021 17:59:24 -0700 Subject: [PATCH 610/707] Do FP16/BF16 conversions on the host to transfer less through PCIe Summary: If we do the FP16/BF32 conversion on the host, we do it at DRAM speed but transfer 2X smaller buffer to the GPU through PCIe. PCIe bandwidth is an order of magnitude lower so we actually gain about 50% of execution time compared to when performing the quantization on the GPU. Also, by transfering an already FP16 buffer, we save memory capacity. Reviewed By: zhengwy888 Differential Revision: D24146486 fbshipit-source-id: b897e7a32835aa1b571b0fae5f3d72a131ad16a1 --- fairseq/dataclass/configs.py | 7 ++++++ fairseq/trainer.py | 49 ++++++++++++++++++++++-------------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 70d7476d31..f2dc5b4fd6 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -148,6 +148,13 @@ class CommonConfig(FairseqDataclass): "help": "pct of updates that can overflow before decreasing the loss scale" }, ) + on_cpu_convert_precision: bool = field( + default=False, + metadata={ + "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. " + "This reduces bus transfer time and GPU memory usage." + } + ) min_loss_scale: float = field( default=1e-4, metadata={"help": "minimum FP16/AMP loss scale, after which training is stopped"}, diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 64c6fabed6..a55eb0ba3e 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1124,6 +1124,25 @@ def _local_cumulative_training_time(self): """Aggregate training time in seconds.""" return time.time() - self._start_time + self._previous_training_time + def _fp_convert_sample(self, sample): + def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + def apply_bfloat16(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.bfloat16) + return t + + if self.cfg.common.fp16: + sample = utils.apply_to_sample(apply_half, sample) + + if self.cfg.common.bf16: + sample = utils.apply_to_sample(apply_bfloat16, sample) + + return sample + def _prepare_sample(self, sample, is_dummy=False): if sample == "DUMMY": raise Exception( @@ -1139,33 +1158,25 @@ def _prepare_sample(self, sample, is_dummy=False): sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) return sample, True + # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth + # it makes sense to do the format conversion on the CPU and then transfer + # a smaller buffer to the device. This also saves GPU memory capacity. + + if self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) + if self.cuda: if self.pipeline_model_parallel: - if "target" in sample: - sample["target"] = utils.move_to_cuda( - sample["target"], device=self.last_device - ) + if 'target' in sample: + sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device) else: sample = utils.move_to_cuda(sample) elif self.tpu and is_dummy: # the dummy batch may not be on the appropriate device sample = utils.move_to_cuda(sample, device=self.device) - def apply_half(t): - if t.dtype is torch.float32: - return t.half() - return t - - def apply_bfloat16(t): - if t.dtype is torch.float32: - return t.to(dtype=torch.bfloat16) - return t - - if self.cfg.common.fp16: - sample = utils.apply_to_sample(apply_half, sample) - - if self.cfg.common.bf16: - sample = utils.apply_to_sample(apply_bfloat16, sample) + if not self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) if self._dummy_batch == "DUMMY": self._dummy_batch = sample From 176b2e4e766c94e64f945b9c1323e612d8887d83 Mon Sep 17 00:00:00 2001 From: Henry Hu <henryhu6@fb.com> Date: Mon, 14 Jun 2021 09:09:25 -0700 Subject: [PATCH 611/707] Fix warning for empty tensor without type Summary: Fairseq create an empty tensor without type. It will create warning for torchscript model. Warning: Creating a tensor from an empty intlist will create a tensor of default floating point type (currently Float) in python but a tensor of type int in torchscript. This diff adds definition of the type. Reviewed By: myleott Differential Revision: D29081170 fbshipit-source-id: 5c32aae65c9998b245eac43bfedc820bea509338 --- fairseq/ngram_repeat_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/ngram_repeat_block.py b/fairseq/ngram_repeat_block.py index ed2d744635..8541251494 100644 --- a/fairseq/ngram_repeat_block.py +++ b/fairseq/ngram_repeat_block.py @@ -123,7 +123,7 @@ def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): ] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx][ - torch.tensor(banned_tokens[bbsz_idx]).long() + torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) ] = torch.tensor(-math.inf).to(lprobs) return lprobs From cd5775f30184baa414a354d9f06b747344a8ba74 Mon Sep 17 00:00:00 2001 From: msbaines <35972327+msbaines@users.noreply.github.com> Date: Mon, 14 Jun 2021 16:37:53 -0700 Subject: [PATCH 612/707] avoid freezing batches unnecessarily (#3610) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In EpochBatchIterator, first_batch() freezes batches in order to generate the dummy_batch. We then freeze batches again in the call to next_epoch_itr(). We can avoid the second freeze and reduce time to first iteration by about 50% in cases where we have a callable batch_sampler. Before: ![Screen Shot 2021-06-10 at 5 08 22 PM](https://user-images.githubusercontent.com/35972327/121613200-d2366600-ca10-11eb-9d1d-bafc2403766a.png) After: ![Screen Shot 2021-06-10 at 5 07 54 PM](https://user-images.githubusercontent.com/35972327/121613224-dfebeb80-ca10-11eb-9d5a-07be9440db77.png) # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3610 Reviewed By: myleott Differential Revision: D29105845 Pulled By: msbaines fbshipit-source-id: 9795d46d70a99ad1218ce225092cc22ee3192bbc --- fairseq/data/iterators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 8321a49b54..86f6d05533 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -356,6 +356,7 @@ def next_epoch_itr( """ if self.disable_shuffling: shuffle = False + prev_epoch = self.epoch self.epoch = self.next_epoch_idx if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(self.epoch) @@ -363,7 +364,7 @@ def next_epoch_itr( self._cur_epoch_itr = self._next_epoch_itr self._next_epoch_itr = None else: - if callable(self.batch_sampler): + if callable(self.batch_sampler) and prev_epoch != self.epoch: # reset _frozen_batches to refresh the next epoch self._frozen_batches = None self._cur_epoch_itr = self._get_iterator_for_epoch( From 8320f6708ff27537f51a8d3a4c4a0bfbceea71d9 Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Tue, 15 Jun 2021 10:49:33 -0700 Subject: [PATCH 613/707] Instructions for loading HuBERT model (#1966) Summary: ## What does this PR do? Fixes the HuBERT README to contain instructions to load pretrained checkpoints. ## PR review Tested in a fresh environment that doesn't have access to FAIR's dev env. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1966 Reviewed By: wnhsu Differential Revision: D29117906 Pulled By: hikushalhere fbshipit-source-id: 89b0407ecf8cdbeddcab80f55e6b2f1fed24c967 --- examples/hubert/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/hubert/README.md b/examples/hubert/README.md index c0b1125cb5..3254b754f0 100644 --- a/examples/hubert/README.md +++ b/examples/hubert/README.md @@ -9,6 +9,13 @@ HuBERT Extra Large (~1B params) | [Libri-Light](https://github.com/facebookresea HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt) HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt) +## Load a pretrained model +``` +ckpt_path = "/path/to/the/checkpoint.pt" +models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], strict=False) +model = models[0] +``` +** We will follow-up with a patch such that you wouldn't need to pass `strict=False` for loading the checkpoint in future. ## Train a new model From 128b4fc3789338d782c1ae4a27c5f0d6fa6dfed0 Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Tue, 15 Jun 2021 14:09:01 -0700 Subject: [PATCH 614/707] Check attributes in trainer and checkpoint loading before using them (#1970) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## What does this PR do? Fixes None exception when some attributes in don't exist in cfg. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1970 Reviewed By: alexeib Differential Revision: D29140036 Pulled By: hikushalhere fbshipit-source-id: 7d941bcae6bb000c281a43ca2cd0876a49912ab9 --- fairseq/checkpoint_utils.py | 2 ++ fairseq/dataclass/configs.py | 6 ++++++ fairseq/trainer.py | 5 ++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 80c797bcdd..627f14160d 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -647,6 +647,8 @@ def _upgrade_state_dict(state): and ( hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args ) + and hasattr(cfg.model.w2v_args.task, "eval_wer_config") + and cfg.model.w2v_args.task.eval_wer_config is not None and isinstance( cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool ) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f2dc5b4fd6..b0146fa4c7 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -229,6 +229,12 @@ class DistributedTrainingConfig(FairseqDataclass): "help": "total number of GPUs across all nodes (default: all visible GPUs)" }, ) + distributed_num_procs: Optional[int] = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of processes to fork (default: all visible GPUs)" + }, + ) distributed_rank: Optional[int] = field( default=0, metadata={"help": "rank of the current worker"} ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a55eb0ba3e..1deb14326f 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -73,7 +73,10 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "option (it's already built in)" ) else: - if self.cfg.distributed_training.cpu_offload: + if ( + hasattr(self.cfg.distributed_training, "cpu_offload") + and self.cfg.distributed_training.cpu_offload + ): raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") # copy model and criterion to current device/dtype From afc77bdf4bb51453ce76f1572ef2ee6ddcda8eeb Mon Sep 17 00:00:00 2001 From: Neeyanth Kopparapu <neeyanth@fb.com> Date: Tue, 15 Jun 2021 22:01:55 -0700 Subject: [PATCH 615/707] Enabled storing of dictionaries (#3601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## What does this PR do? For HubertPretrainingTask, added dictionaries to the task state to enable the serialization of the dictionaries (thus removing the need to load from the disk after training) ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3601 Test Plan: To verify the success, run the Hubert Pretraining pipeline, load a checkpoint model, and verify that the "dictionaries" key is present in the state within the model. Specifically, ``` PYTHONPATH=. python /path/to/fairseq/fairseq_cli/hydra_train.py -m \ --config-dir ${fairseq_dir}/examples/hubert/config/pretrain \ --config-name hubert_base_librispeech \ hydra/launcher=submitit_local \ hydra.launcher.gpus_per_node=2 \ hydra.launcher.cpus_per_task=8 \ hydra.launcher.mem_gb=384 \ task.data=${tsv_dir} \ task.label_dir=${km_dir} \ task.labels=["km"] \ +data=iter1 \ optimization.max_update=250 \ hydra.sweep.dir=${exp_dir} \ hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err & ``` Then, at the location of the model, load the model using `pytorch.load`, and verifying that "dictionaries" is a key under the `task_state` key of the model. ## Did you have fun? Make sure you had fun coding � Reviewed By: wnhsu Differential Revision: D28995537 Pulled By: neeyanthkvk fbshipit-source-id: e10c5163c367285518961b3ce1e719a29da06aa6 --- examples/hubert/config/finetune/base_10h.yaml | 1 + fairseq/tasks/hubert_pretraining.py | 42 ++++++++++--------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/examples/hubert/config/finetune/base_10h.yaml b/examples/hubert/config/finetune/base_10h.yaml index 844484d7fb..a22c7c0347 100644 --- a/examples/hubert/config/finetune/base_10h.yaml +++ b/examples/hubert/config/finetune/base_10h.yaml @@ -23,6 +23,7 @@ distributed_training: task: _name: hubert_pretraining data: ??? + fine_tuning: true label_dir: ??? normalize: false # must be consistent with pre-training labels: ["ltr"] diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py index aff4100bb8..a63f2f6ef8 100644 --- a/fairseq/tasks/hubert_pretraining.py +++ b/fairseq/tasks/hubert_pretraining.py @@ -37,6 +37,9 @@ class HubertPretrainingConfig(FairseqDataclass): data: str = field( default=MISSING, metadata={"help": "path to data directory"} ) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) labels: List[str] = field( default_factory=lambda: ["ltr"], metadata={ @@ -56,7 +59,6 @@ class HubertPretrainingConfig(FairseqDataclass): default=-1, metadata={"help": "label frame rate. -1 for sequence label"}, ) - sample_rate: int = field( default=16_000, metadata={ @@ -107,19 +109,22 @@ class HubertPretrainingTask(FairseqTask): def __init__( self, cfg: HubertPretrainingConfig, - dictionaries: Dict[str, Dictionary], ) -> None: super().__init__(cfg) logger.info(f"current directory is {os.getcwd()}") logger.info(f"HubertPretrainingTask Config {cfg}") - self._dictionaries = dictionaries + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", lambda: self.load_dictionaries) + else: + self.state.add_factory("dictionaries", lambda: self.load_dictionaries) self._source_dictionary = None - self._target_dictionary = None - if len(self.dictionaries) == 1: - self._target_dictionary = self.dictionaries[0] self.blank_symbol = "<s>" @property @@ -128,24 +133,22 @@ def source_dictionary(self) -> Optional[Dictionary]: @property def target_dictionary(self) -> Optional[Dictionary]: - return self._target_dictionary + return self.state.target_dictionary @property def dictionaries(self) -> List[Dictionary]: - return [self._dictionaries[l] for l in self.cfg.labels] + return self.state.dictionaries @classmethod def setup_task( cls, cfg: HubertPretrainingConfig, **kwargs ) -> "HubertPretrainingTask": - label_dir = cfg.data if cfg.label_dir is None else cfg.label_dir - dictionaries = { - label: Dictionary.load(f"{label_dir}/dict.{label}.txt") - if os.path.exists(f"{label_dir}/dict.{label}.txt") - else None - for label in cfg.labels - } - return cls(cfg, dictionaries) + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries def get_label_dir(self) -> str: if self.cfg.label_dir is None: @@ -154,9 +157,10 @@ def get_label_dir(self) -> str: def load_dataset(self, split: str, **kwargs) -> None: manifest = f"{self.cfg.data}/{split}.tsv" - pad_list = [self._dictionaries[l].pad() for l in self.cfg.labels] - eos_list = [self._dictionaries[l].eos() for l in self.cfg.labels] - procs = [LabelEncoder(self._dictionaries[l]) for l in self.cfg.labels] + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] paths = [ f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels ] From 67138ceb08b02c3148034e445cf5e76a616fc859 Mon Sep 17 00:00:00 2001 From: Vimal Manohar <vimalmanohar@fb.com> Date: Thu, 17 Jun 2021 13:19:15 -0700 Subject: [PATCH 616/707] Fix lr for reduce_lr_on_plateau when there is no warmup Summary: warmup_init_lr should not be used if there is no warmup i.e. warmup_updates = 0 Created from Diffusion's 'Open in Editor' feature. Reviewed By: myleott Differential Revision: D29174059 fbshipit-source-id: c2e4cf998aebcff090584e689f692a0abe082e65 --- fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 6e29ba79b6..5ee9c1be4a 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -101,7 +101,7 @@ def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer): # initial learning rate # this self.lr is used only during init and/or warm up period - self.lr = cfg.warmup_init_lr + self.lr = warmup_end_lr if self.warmup_end else cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def state_dict(self): From b3491ae9d4c3eaa24292512cee7c21def713c535 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Thu, 17 Jun 2021 13:58:56 -0700 Subject: [PATCH 617/707] Add latency metrics to simulate tuna inference script and some other minor updates Summary: - Add average lagging latency metrics for online model. Offline models by default return 0 - Pad smaller input chunks with 0. - Enable export option in layer norm in transformer.py to avoid errors in scripted model inference. - Warm up prediction for scripted online model - Add additional args like force_read_cnt, data_split Reviewed By: xutaima Differential Revision: D28881594 fbshipit-source-id: fd4cce017539b5d8f6e39f9af9651341e47d6db0 --- fairseq/models/transformer.py | 4 +++- fairseq/modules/transformer_layer.py | 15 ++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f4f6bea27b..2562752af9 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -725,7 +725,9 @@ def __init__( if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False ): - self.layer_norm = LayerNorm(embed_dim, export=export) + self.layer_norm = LayerNorm( + embed_dim, export=getattr(args, "char_inputs", False) + ) else: self.layer_norm = None diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 4f9ea22a9b..79af17fe30 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -33,8 +33,8 @@ def __init__(self, args): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim - self.quant_noise = getattr(args, 'quant_noise_pq', 0) - self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) export = getattr(args, "export", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) @@ -42,7 +42,7 @@ def __init__(self, args): args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( - activation=getattr(args, 'activation_fn', 'relu') or "relu" + activation=getattr(args, "activation_fn", "relu") or "relu" ) activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: @@ -104,7 +104,12 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] - def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None): + def forward( + self, + x, + encoder_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor] = None, + ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` @@ -208,7 +213,7 @@ def __init__( ) self.normalize_before = args.decoder_normalize_before - export = getattr(args, "export", False) + export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: From fc77eeb550e8ae74a14b36e6b9babc129f896725 Mon Sep 17 00:00:00 2001 From: Sravya Popuri <spopuri@fb.com> Date: Fri, 18 Jun 2021 13:47:20 -0700 Subject: [PATCH 618/707] Change char_inputs to export as recommended in Fairseq Summary: TSIA Reviewed By: jmp84, henryhu6 Differential Revision: D29232406 fbshipit-source-id: 557006705faf28d723dc9f0ed9e92b0abe68e895 --- fairseq/models/transformer.py | 4 +--- fairseq/modules/transformer_layer.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 2562752af9..f4f6bea27b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -725,9 +725,7 @@ def __init__( if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False ): - self.layer_norm = LayerNorm( - embed_dim, export=getattr(args, "char_inputs", False) - ) + self.layer_norm = LayerNorm(embed_dim, export=export) else: self.layer_norm = None diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 79af17fe30..aa06a42935 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -213,7 +213,7 @@ def __init__( ) self.normalize_before = args.decoder_normalize_before - export = getattr(args, "char_inputs", False) + export = getattr(args, "export", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: From 822442e42a020bac1ddaaaf4a99124db8b0cfab0 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Mon, 21 Jun 2021 11:05:15 -0700 Subject: [PATCH 619/707] fix task name in w2v-u generate (#1989) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1989 Reviewed By: arbabu123 Differential Revision: D29267899 Pulled By: alexeib fbshipit-source-id: b89b804c14dbf8779b5cb56657d33bb03530f303 --- examples/wav2vec/unsupervised/w2vu_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py index b1e126665f..2bad873616 100644 --- a/examples/wav2vec/unsupervised/w2vu_generate.py +++ b/examples/wav2vec/unsupervised/w2vu_generate.py @@ -577,7 +577,7 @@ def main(cfg: UnsupGenerateConfig, model=None): overrides = ast.literal_eval(cfg.fairseq.common_eval.model_overrides) - if cfg.fairseq.task._name == "gan_audio_pretraining_feats": + if cfg.fairseq.task._name == "unpaired_audio_text": overrides["model"] = { "blank_weight": cfg.blank_weight, "blank_mode": cfg.blank_mode, From e47a4c84da877382c4d37941be4e84449f99d186 Mon Sep 17 00:00:00 2001 From: Neeyanth Kopparapu <neeyanth@fb.com> Date: Mon, 21 Jun 2021 12:32:19 -0700 Subject: [PATCH 620/707] hotfix to change factory creation for dictionaries (#1987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## What does this PR do? Fixes issue of creating factories causing errors because the lambda function is not proper. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1987 Test Plan: Completed a small pretraining+finetuning procedure: Pretraining: ``` PYTHONPATH=. python /private/home/neeyanth/project/fairseq/fairseq_cli/hydra_train.py -m \ --config-dir ${fairseq_dir}/examples/hubert/config/pretrain \ --config-name hubert_base_librispeech \ hydra/launcher=submitit_local \ hydra.launcher.gpus_per_node=1 \ hydra.launcher.cpus_per_task=8 \ hydra.launcher.mem_gb=384 \ task.data=${tsv_dir} \ task.label_dir=${km_dir} \ task.labels=["km"] \ +data=iter1 \ optimization.max_update=250 \ hydra.sweep.dir=${exp_dir} \ hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err & ``` Finetuning: ``` PYTHONPATH=. python /private/home/neeyanth/project/fairseq/fairseq_cli/hydra_train.py -m \ --config-dir ${fairseq_dir}/examples/hubert/config/finetune \ --config-name base_10h \ hydra/launcher=submitit_local \ hydra.launcher.gpus_per_node=1 \ hydra.launcher.cpus_per_task=8 \ hydra.launcher.mem_gb=384 \ task.data=${tsv_dir} \ task.label_dir=${tsv_dir} \ model.w2v_path=${model_dir} \ +data=iter1 \ optimization.max_update=250 \ hydra.sweep.dir=${exp_dir} \ hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err & ``` Reviewed By: hikushalhere Differential Revision: D29266136 Pulled By: neeyanthkvk fbshipit-source-id: d36c668ae38a7761b4c44f4dcb0c4cc8e15e42ce --- fairseq/tasks/hubert_pretraining.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py index a63f2f6ef8..ee3fedce3f 100644 --- a/fairseq/tasks/hubert_pretraining.py +++ b/fairseq/tasks/hubert_pretraining.py @@ -119,9 +119,9 @@ def __init__( self.fine_tuning = cfg.fine_tuning if cfg.fine_tuning: - self.state.add_factory("target_dictionary", lambda: self.load_dictionaries) + self.state.add_factory("target_dictionary", self.load_dictionaries) else: - self.state.add_factory("dictionaries", lambda: self.load_dictionaries) + self.state.add_factory("dictionaries", self.load_dictionaries) self._source_dictionary = None From 900a607ea3226e0cfd966a894b8b1effe25faa5e Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <wnhsu@csail.mit.edu> Date: Mon, 21 Jun 2021 19:40:02 -0700 Subject: [PATCH 621/707] add timit w2vu recipe (#1991) Summary: ## What does this PR do? Add TIMIT data preparation scripts for wav2vec-U Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1991 Reviewed By: alexeib Differential Revision: D29284481 Pulled By: wnhsu fbshipit-source-id: dccd75159a9de4f3cd95f9e4a90ce4bdf9264f2b --- examples/wav2vec/unsupervised/README.md | 10 + .../config/timit_matched/test.uid | 192 + .../config/timit_matched/train.uid | 3696 +++++++++++++++++ .../config/timit_matched/train_text.uid | 3696 +++++++++++++++++ .../config/timit_matched/valid.uid | 400 ++ .../config/timit_unmatched/test.uid | 1680 ++++++++ .../config/timit_unmatched/train.uid | 3000 +++++++++++++ .../config/timit_unmatched/train_text.uid | 1000 +++++ .../config/timit_unmatched/valid.uid | 620 +++ .../unsupervised/scripts/prepare_timit.sh | 79 + 10 files changed, 14373 insertions(+) create mode 100644 examples/wav2vec/unsupervised/config/timit_matched/test.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_matched/train.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_matched/train_text.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_matched/valid.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_unmatched/test.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_unmatched/train.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid create mode 100644 examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid create mode 100644 examples/wav2vec/unsupervised/scripts/prepare_timit.sh diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md index c2a935d414..046202e01c 100644 --- a/examples/wav2vec/unsupervised/README.md +++ b/examples/wav2vec/unsupervised/README.md @@ -48,6 +48,16 @@ The fifth argument is which phonemizer to use. Supported values are [espeak](htt Pre-trained fasttext LID models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html). +### Prepare TIMIT data +TIMIT transcripts include silence. Therefore VAD is not used for audio preprocessing, and we do not wrap transcripts with silences or insert random silence in between words. + +To prepare TIMIT data for both the matched an unmatched setup: +```shell +bash scripts/prepare_timit.sh /dir/to/timit/raw/data /output/dir /path/to/wav2vec2/model.pt +``` + +Note that we assume the TIMIT distribution with capitalized directories and filenames are used (e.g., `TRAIN/DR1/FCJF0/SA1.PHN`). + ## Generative adversarial training (GAN) We then use a GAN model to build a first unsupervised ASR model. The data preparation above of both speech features and text data is a necessary procedure that enables the generator to match speech to text in an unsupervised way. diff --git a/examples/wav2vec/unsupervised/config/timit_matched/test.uid b/examples/wav2vec/unsupervised/config/timit_matched/test.uid new file mode 100644 index 0000000000..401008246a --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_matched/test.uid @@ -0,0 +1,192 @@ +FDHC0_SI1559 +FDHC0_SI2189 +FDHC0_SI929 +FDHC0_SX119 +FDHC0_SX209 +FDHC0_SX29 +FDHC0_SX299 +FDHC0_SX389 +FELC0_SI1386 +FELC0_SI2016 +FELC0_SI756 +FELC0_SX126 +FELC0_SX216 +FELC0_SX306 +FELC0_SX36 +FELC0_SX396 +FJLM0_SI1043 +FJLM0_SI1673 +FJLM0_SI2303 +FJLM0_SX143 +FJLM0_SX233 +FJLM0_SX323 +FJLM0_SX413 +FJLM0_SX53 +FMGD0_SI1564 +FMGD0_SI2194 +FMGD0_SI934 +FMGD0_SX124 +FMGD0_SX214 +FMGD0_SX304 +FMGD0_SX34 +FMGD0_SX394 +FMLD0_SI2185 +FMLD0_SI822 +FMLD0_SI925 +FMLD0_SX115 +FMLD0_SX205 +FMLD0_SX25 +FMLD0_SX295 +FMLD0_SX385 +FNLP0_SI1308 +FNLP0_SI1938 +FNLP0_SI678 +FNLP0_SX138 +FNLP0_SX228 +FNLP0_SX318 +FNLP0_SX408 +FNLP0_SX48 +FPAS0_SI1272 +FPAS0_SI2204 +FPAS0_SI944 +FPAS0_SX134 +FPAS0_SX224 +FPAS0_SX314 +FPAS0_SX404 +FPAS0_SX44 +FPKT0_SI1538 +FPKT0_SI2168 +FPKT0_SI908 +FPKT0_SX188 +FPKT0_SX278 +FPKT0_SX368 +FPKT0_SX8 +FPKT0_SX98 +MBPM0_SI1577 +MBPM0_SI1584 +MBPM0_SI947 +MBPM0_SX137 +MBPM0_SX227 +MBPM0_SX317 +MBPM0_SX407 +MBPM0_SX47 +MCMJ0_SI1094 +MCMJ0_SI464 +MCMJ0_SI602 +MCMJ0_SX104 +MCMJ0_SX14 +MCMJ0_SX194 +MCMJ0_SX284 +MCMJ0_SX374 +MDAB0_SI1039 +MDAB0_SI1669 +MDAB0_SI2299 +MDAB0_SX139 +MDAB0_SX229 +MDAB0_SX319 +MDAB0_SX409 +MDAB0_SX49 +MGRT0_SI1450 +MGRT0_SI2080 +MGRT0_SI820 +MGRT0_SX10 +MGRT0_SX100 +MGRT0_SX190 +MGRT0_SX280 +MGRT0_SX370 +MJDH0_SI1354 +MJDH0_SI1984 +MJDH0_SI724 +MJDH0_SX184 +MJDH0_SX274 +MJDH0_SX364 +MJDH0_SX4 +MJDH0_SX94 +MJLN0_SI1449 +MJLN0_SI2079 +MJLN0_SI819 +MJLN0_SX189 +MJLN0_SX279 +MJLN0_SX369 +MJLN0_SX9 +MJLN0_SX99 +MJMP0_SI1535 +MJMP0_SI1791 +MJMP0_SI905 +MJMP0_SX185 +MJMP0_SX275 +MJMP0_SX365 +MJMP0_SX5 +MJMP0_SX95 +MKLT0_SI1213 +MKLT0_SI1843 +MKLT0_SI583 +MKLT0_SX133 +MKLT0_SX223 +MKLT0_SX313 +MKLT0_SX403 +MKLT0_SX43 +MLLL0_SI1363 +MLLL0_SI1993 +MLLL0_SI733 +MLLL0_SX103 +MLLL0_SX13 +MLLL0_SX193 +MLLL0_SX283 +MLLL0_SX373 +MLNT0_SI1574 +MLNT0_SI1902 +MLNT0_SI642 +MLNT0_SX102 +MLNT0_SX12 +MLNT0_SX192 +MLNT0_SX282 +MLNT0_SX372 +MNJM0_SI1580 +MNJM0_SI2210 +MNJM0_SI950 +MNJM0_SX140 +MNJM0_SX230 +MNJM0_SX320 +MNJM0_SX410 +MNJM0_SX50 +MPAM0_SI1189 +MPAM0_SI1819 +MPAM0_SI1961 +MPAM0_SX109 +MPAM0_SX19 +MPAM0_SX199 +MPAM0_SX289 +MPAM0_SX379 +MTAS1_SI1473 +MTAS1_SI2098 +MTAS1_SI838 +MTAS1_SX118 +MTAS1_SX208 +MTAS1_SX28 +MTAS1_SX298 +MTAS1_SX388 +MTLS0_SI1370 +MTLS0_SI2000 +MTLS0_SI740 +MTLS0_SX110 +MTLS0_SX20 +MTLS0_SX200 +MTLS0_SX290 +MTLS0_SX380 +MWBT0_SI1553 +MWBT0_SI2183 +MWBT0_SI923 +MWBT0_SX113 +MWBT0_SX203 +MWBT0_SX23 +MWBT0_SX293 +MWBT0_SX383 +MWEW0_SI1361 +MWEW0_SI1991 +MWEW0_SI731 +MWEW0_SX101 +MWEW0_SX11 +MWEW0_SX191 +MWEW0_SX281 +MWEW0_SX371 diff --git a/examples/wav2vec/unsupervised/config/timit_matched/train.uid b/examples/wav2vec/unsupervised/config/timit_matched/train.uid new file mode 100644 index 0000000000..c39fd0b91d --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_matched/train.uid @@ -0,0 +1,3696 @@ +FAEM0_SI1392 +FAEM0_SI2022 +FAEM0_SI762 +FAEM0_SX132 +FAEM0_SX222 +FAEM0_SX312 +FAEM0_SX402 +FAEM0_SX42 +FAJW0_SI1263 +FAJW0_SI1893 +FAJW0_SI633 +FAJW0_SX183 +FAJW0_SX273 +FAJW0_SX3 +FAJW0_SX363 +FAJW0_SX93 +FALK0_SI1086 +FALK0_SI456 +FALK0_SI658 +FALK0_SX186 +FALK0_SX276 +FALK0_SX366 +FALK0_SX6 +FALK0_SX96 +FALR0_SI1325 +FALR0_SI1955 +FALR0_SI695 +FALR0_SX155 +FALR0_SX245 +FALR0_SX335 +FALR0_SX425 +FALR0_SX65 +FAPB0_SI1063 +FAPB0_SI1693 +FAPB0_SI2323 +FAPB0_SX163 +FAPB0_SX253 +FAPB0_SX343 +FAPB0_SX433 +FAPB0_SX73 +FBAS0_SI1387 +FBAS0_SI1472 +FBAS0_SI2066 +FBAS0_SX127 +FBAS0_SX217 +FBAS0_SX307 +FBAS0_SX37 +FBAS0_SX397 +FBCG1_SI1612 +FBCG1_SI2242 +FBCG1_SI982 +FBCG1_SX172 +FBCG1_SX262 +FBCG1_SX352 +FBCG1_SX442 +FBCG1_SX82 +FBCH0_SI1586 +FBCH0_SI956 +FBCH0_SI959 +FBCH0_SX146 +FBCH0_SX236 +FBCH0_SX326 +FBCH0_SX416 +FBCH0_SX56 +FBJL0_SI1552 +FBJL0_SI2182 +FBJL0_SI922 +FBJL0_SX112 +FBJL0_SX202 +FBJL0_SX22 +FBJL0_SX292 +FBJL0_SX382 +FBLV0_SI1058 +FBLV0_SI1688 +FBLV0_SI2318 +FBLV0_SX158 +FBLV0_SX248 +FBLV0_SX338 +FBLV0_SX428 +FBLV0_SX68 +FBMH0_SI1136 +FBMH0_SI1766 +FBMH0_SI970 +FBMH0_SX146 +FBMH0_SX236 +FBMH0_SX326 +FBMH0_SX416 +FBMH0_SX56 +FBMJ0_SI1776 +FBMJ0_SI516 +FBMJ0_SI815 +FBMJ0_SX156 +FBMJ0_SX246 +FBMJ0_SX336 +FBMJ0_SX426 +FBMJ0_SX66 +FCAG0_SI1503 +FCAG0_SI1641 +FCAG0_SI2133 +FCAG0_SX153 +FCAG0_SX243 +FCAG0_SX333 +FCAG0_SX423 +FCAG0_SX63 +FCAJ0_SI1479 +FCAJ0_SI1804 +FCAJ0_SI849 +FCAJ0_SX129 +FCAJ0_SX219 +FCAJ0_SX309 +FCAJ0_SX39 +FCAJ0_SX399 +FCDR1_SI1186 +FCDR1_SI1816 +FCDR1_SI556 +FCDR1_SX106 +FCDR1_SX16 +FCDR1_SX196 +FCDR1_SX286 +FCDR1_SX376 +FCEG0_SI1248 +FCEG0_SI1878 +FCEG0_SI618 +FCEG0_SX168 +FCEG0_SX258 +FCEG0_SX348 +FCEG0_SX438 +FCEG0_SX78 +FCJF0_SI1027 +FCJF0_SI1657 +FCJF0_SI648 +FCJF0_SX127 +FCJF0_SX217 +FCJF0_SX307 +FCJF0_SX37 +FCJF0_SX397 +FCJS0_SI1607 +FCJS0_SI2237 +FCJS0_SI977 +FCJS0_SX167 +FCJS0_SX257 +FCJS0_SX347 +FCJS0_SX437 +FCJS0_SX77 +FCKE0_SI1111 +FCKE0_SI1741 +FCKE0_SI481 +FCKE0_SX121 +FCKE0_SX211 +FCKE0_SX301 +FCKE0_SX31 +FCKE0_SX391 +FCLT0_SI1438 +FCLT0_SI2068 +FCLT0_SI808 +FCLT0_SX178 +FCLT0_SX268 +FCLT0_SX358 +FCLT0_SX448 +FCLT0_SX88 +FCMG0_SI1142 +FCMG0_SI1242 +FCMG0_SI1872 +FCMG0_SX162 +FCMG0_SX252 +FCMG0_SX342 +FCMG0_SX432 +FCMG0_SX72 +FCMM0_SI1083 +FCMM0_SI1957 +FCMM0_SI453 +FCMM0_SX183 +FCMM0_SX273 +FCMM0_SX363 +FCMM0_SX420 +FCMM0_SX93 +FCRZ0_SI1913 +FCRZ0_SI2053 +FCRZ0_SI793 +FCRZ0_SX163 +FCRZ0_SX253 +FCRZ0_SX343 +FCRZ0_SX433 +FCRZ0_SX73 +FCYL0_SI1297 +FCYL0_SI1927 +FCYL0_SI667 +FCYL0_SX127 +FCYL0_SX217 +FCYL0_SX349 +FCYL0_SX37 +FCYL0_SX397 +FDAS1_SI1461 +FDAS1_SI2091 +FDAS1_SI831 +FDAS1_SX111 +FDAS1_SX201 +FDAS1_SX21 +FDAS1_SX291 +FDAS1_SX381 +FDAW0_SI1271 +FDAW0_SI1406 +FDAW0_SI2036 +FDAW0_SX146 +FDAW0_SX236 +FDAW0_SX326 +FDAW0_SX416 +FDAW0_SX56 +FDFB0_SI1318 +FDFB0_SI1948 +FDFB0_SI2010 +FDFB0_SX148 +FDFB0_SX238 +FDFB0_SX328 +FDFB0_SX418 +FDFB0_SX58 +FDJH0_SI1565 +FDJH0_SI2195 +FDJH0_SI935 +FDJH0_SX125 +FDJH0_SX215 +FDJH0_SX305 +FDJH0_SX35 +FDJH0_SX395 +FDKN0_SI1081 +FDKN0_SI1202 +FDKN0_SI1711 +FDKN0_SX181 +FDKN0_SX271 +FDKN0_SX361 +FDKN0_SX451 +FDKN0_SX91 +FDML0_SI1149 +FDML0_SI1779 +FDML0_SI2075 +FDML0_SX159 +FDML0_SX249 +FDML0_SX339 +FDML0_SX429 +FDML0_SX69 +FDMY0_SI1197 +FDMY0_SI567 +FDMY0_SI714 +FDMY0_SX117 +FDMY0_SX207 +FDMY0_SX27 +FDMY0_SX297 +FDMY0_SX387 +FDNC0_SI1278 +FDNC0_SI1908 +FDNC0_SI2287 +FDNC0_SX108 +FDNC0_SX18 +FDNC0_SX198 +FDNC0_SX288 +FDNC0_SX378 +FDTD0_SI1561 +FDTD0_SI2191 +FDTD0_SI931 +FDTD0_SX121 +FDTD0_SX211 +FDTD0_SX301 +FDTD0_SX321 +FDTD0_SX391 +FDXW0_SI1511 +FDXW0_SI2141 +FDXW0_SI881 +FDXW0_SX161 +FDXW0_SX251 +FDXW0_SX341 +FDXW0_SX431 +FDXW0_SX71 +FEAC0_SI1245 +FEAC0_SI1875 +FEAC0_SI615 +FEAC0_SX165 +FEAC0_SX255 +FEAC0_SX345 +FEAC0_SX435 +FEAC0_SX75 +FEAR0_SI1252 +FEAR0_SI1882 +FEAR0_SI622 +FEAR0_SX172 +FEAR0_SX262 +FEAR0_SX352 +FEAR0_SX442 +FEAR0_SX82 +FECD0_SI1418 +FECD0_SI2048 +FECD0_SI788 +FECD0_SX158 +FECD0_SX248 +FECD0_SX338 +FECD0_SX428 +FECD0_SX68 +FEEH0_SI1112 +FEEH0_SI1742 +FEEH0_SI471 +FEEH0_SX122 +FEEH0_SX212 +FEEH0_SX302 +FEEH0_SX32 +FEEH0_SX392 +FEME0_SI1505 +FEME0_SI2135 +FEME0_SI875 +FEME0_SX155 +FEME0_SX245 +FEME0_SX335 +FEME0_SX425 +FEME0_SX65 +FETB0_SI1148 +FETB0_SI1778 +FETB0_SI518 +FETB0_SX158 +FETB0_SX248 +FETB0_SX338 +FETB0_SX428 +FETB0_SX68 +FEXM0_SI1101 +FEXM0_SI1731 +FEXM0_SI482 +FEXM0_SX111 +FEXM0_SX201 +FEXM0_SX291 +FEXM0_SX366 +FEXM0_SX381 +FGCS0_SI1486 +FGCS0_SI2116 +FGCS0_SI856 +FGCS0_SX136 +FGCS0_SX226 +FGCS0_SX316 +FGCS0_SX406 +FGCS0_SX46 +FGDP0_SI1618 +FGDP0_SI2248 +FGDP0_SI988 +FGDP0_SX178 +FGDP0_SX268 +FGDP0_SX358 +FGDP0_SX448 +FGDP0_SX88 +FGMB0_SI1145 +FGMB0_SI1775 +FGMB0_SI515 +FGMB0_SX155 +FGMB0_SX245 +FGMB0_SX335 +FGMB0_SX425 +FGMB0_SX65 +FGRW0_SI1152 +FGRW0_SI1782 +FGRW0_SI1990 +FGRW0_SX162 +FGRW0_SX252 +FGRW0_SX342 +FGRW0_SX432 +FGRW0_SX72 +FHLM0_SI1560 +FHLM0_SI2190 +FHLM0_SI930 +FHLM0_SX120 +FHLM0_SX210 +FHLM0_SX300 +FHLM0_SX349 +FHLM0_SX390 +FHXS0_SI1075 +FHXS0_SI2302 +FHXS0_SI2335 +FHXS0_SX175 +FHXS0_SX265 +FHXS0_SX355 +FHXS0_SX445 +FHXS0_SX85 +FJDM2_SI1582 +FJDM2_SI1964 +FJDM2_SI2212 +FJDM2_SX142 +FJDM2_SX232 +FJDM2_SX322 +FJDM2_SX412 +FJDM2_SX52 +FJEN0_SI1047 +FJEN0_SI1677 +FJEN0_SI2307 +FJEN0_SX147 +FJEN0_SX237 +FJEN0_SX327 +FJEN0_SX417 +FJEN0_SX57 +FJHK0_SI1022 +FJHK0_SI1652 +FJHK0_SI2282 +FJHK0_SX122 +FJHK0_SX212 +FJHK0_SX302 +FJHK0_SX32 +FJHK0_SX392 +FJKL0_SI1562 +FJKL0_SI2192 +FJKL0_SI932 +FJKL0_SX122 +FJKL0_SX212 +FJKL0_SX302 +FJKL0_SX32 +FJKL0_SX392 +FJLG0_SI1506 +FJLG0_SI1889 +FJLG0_SI2306 +FJLG0_SX179 +FJLG0_SX269 +FJLG0_SX359 +FJLG0_SX449 +FJLG0_SX89 +FJLR0_SI1231 +FJLR0_SI1861 +FJLR0_SI601 +FJLR0_SX151 +FJLR0_SX241 +FJLR0_SX331 +FJLR0_SX421 +FJLR0_SX61 +FJRB0_SI1302 +FJRB0_SI1932 +FJRB0_SI672 +FJRB0_SX132 +FJRB0_SX222 +FJRB0_SX312 +FJRB0_SX402 +FJRB0_SX42 +FJRP1_SI1432 +FJRP1_SI2062 +FJRP1_SI802 +FJRP1_SX172 +FJRP1_SX262 +FJRP1_SX352 +FJRP1_SX442 +FJRP1_SX82 +FJSK0_SI1052 +FJSK0_SI1682 +FJSK0_SI2312 +FJSK0_SX152 +FJSK0_SX242 +FJSK0_SX332 +FJSK0_SX422 +FJSK0_SX62 +FJSP0_SI1434 +FJSP0_SI1763 +FJSP0_SI804 +FJSP0_SX174 +FJSP0_SX264 +FJSP0_SX354 +FJSP0_SX444 +FJSP0_SX84 +FJWB1_SI2055 +FJWB1_SI748 +FJWB1_SI795 +FJWB1_SX165 +FJWB1_SX255 +FJWB1_SX345 +FJWB1_SX435 +FJWB1_SX75 +FJXM0_SI1211 +FJXM0_SI1971 +FJXM0_SI581 +FJXM0_SX131 +FJXM0_SX221 +FJXM0_SX311 +FJXM0_SX401 +FJXM0_SX41 +FJXP0_SI1122 +FJXP0_SI1752 +FJXP0_SI492 +FJXP0_SX132 +FJXP0_SX222 +FJXP0_SX312 +FJXP0_SX402 +FJXP0_SX42 +FKAA0_SI1208 +FKAA0_SI1838 +FKAA0_SI578 +FKAA0_SX128 +FKAA0_SX218 +FKAA0_SX308 +FKAA0_SX38 +FKAA0_SX398 +FKDE0_SI1141 +FKDE0_SI1771 +FKDE0_SI2221 +FKDE0_SX151 +FKDE0_SX241 +FKDE0_SX331 +FKDE0_SX421 +FKDE0_SX61 +FKDW0_SI1207 +FKDW0_SI1891 +FKDW0_SI577 +FKDW0_SX127 +FKDW0_SX217 +FKDW0_SX307 +FKDW0_SX37 +FKDW0_SX397 +FKFB0_SI1608 +FKFB0_SI2238 +FKFB0_SI978 +FKFB0_SX168 +FKFB0_SX258 +FKFB0_SX348 +FKFB0_SX438 +FKFB0_SX78 +FKKH0_SI1290 +FKKH0_SI1920 +FKKH0_SI660 +FKKH0_SX120 +FKKH0_SX210 +FKKH0_SX30 +FKKH0_SX300 +FKKH0_SX390 +FKLC0_SI1615 +FKLC0_SI2245 +FKLC0_SI985 +FKLC0_SX175 +FKLC0_SX265 +FKLC0_SX355 +FKLC0_SX445 +FKLC0_SX85 +FKLC1_SI1048 +FKLC1_SI1678 +FKLC1_SI2308 +FKLC1_SX148 +FKLC1_SX238 +FKLC1_SX328 +FKLC1_SX418 +FKLC1_SX58 +FKLH0_SI1257 +FKLH0_SI1887 +FKLH0_SI627 +FKLH0_SX177 +FKLH0_SX267 +FKLH0_SX357 +FKLH0_SX447 +FKLH0_SX87 +FKSR0_SI1117 +FKSR0_SI1747 +FKSR0_SI487 +FKSR0_SX161 +FKSR0_SX217 +FKSR0_SX366 +FKSR0_SX37 +FKSR0_SX397 +FLAC0_SI1339 +FLAC0_SI2161 +FLAC0_SI901 +FLAC0_SX181 +FLAC0_SX271 +FLAC0_SX361 +FLAC0_SX451 +FLAC0_SX91 +FLAG0_SI1464 +FLAG0_SI2094 +FLAG0_SI834 +FLAG0_SX114 +FLAG0_SX204 +FLAG0_SX24 +FLAG0_SX294 +FLAG0_SX384 +FLEH0_SI1051 +FLEH0_SI1681 +FLEH0_SI2311 +FLEH0_SX151 +FLEH0_SX241 +FLEH0_SX331 +FLEH0_SX421 +FLEH0_SX61 +FLET0_SI1137 +FLET0_SI1767 +FLET0_SI507 +FLET0_SX147 +FLET0_SX237 +FLET0_SX277 +FLET0_SX417 +FLET0_SX57 +FLHD0_SI1344 +FLHD0_SI1827 +FLHD0_SI1974 +FLHD0_SX174 +FLHD0_SX264 +FLHD0_SX354 +FLHD0_SX444 +FLHD0_SX84 +FLJA0_SI1078 +FLJA0_SI1708 +FLJA0_SI2338 +FLJA0_SX178 +FLJA0_SX268 +FLJA0_SX358 +FLJA0_SX448 +FLJA0_SX88 +FLJD0_SI1516 +FLJD0_SI2146 +FLJD0_SI886 +FLJD0_SX166 +FLJD0_SX256 +FLJD0_SX346 +FLJD0_SX436 +FLJD0_SX76 +FLJG0_SI1611 +FLJG0_SI2241 +FLJG0_SI981 +FLJG0_SX171 +FLJG0_SX261 +FLJG0_SX351 +FLJG0_SX441 +FLJG0_SX81 +FLKM0_SI1880 +FLKM0_SI620 +FLKM0_SI686 +FLKM0_SX116 +FLKM0_SX260 +FLKM0_SX350 +FLKM0_SX440 +FLKM0_SX80 +FLMA0_SI1243 +FLMA0_SI1873 +FLMA0_SI613 +FLMA0_SX163 +FLMA0_SX253 +FLMA0_SX343 +FLMA0_SX433 +FLMA0_SX73 +FLMC0_SI1372 +FLMC0_SI2002 +FLMC0_SI742 +FLMC0_SX112 +FLMC0_SX22 +FLMC0_SX292 +FLMC0_SX336 +FLMC0_SX382 +FLMK0_SI1035 +FLMK0_SI1229 +FLMK0_SI2295 +FLMK0_SX135 +FLMK0_SX225 +FLMK0_SX315 +FLMK0_SX405 +FLMK0_SX45 +FLOD0_SI1287 +FLOD0_SI1917 +FLOD0_SI657 +FLOD0_SX117 +FLOD0_SX171 +FLOD0_SX207 +FLOD0_SX297 +FLOD0_SX387 +FLTM0_SI1070 +FLTM0_SI1700 +FLTM0_SI2330 +FLTM0_SX170 +FLTM0_SX260 +FLTM0_SX350 +FLTM0_SX440 +FLTM0_SX80 +FMAH1_SI1509 +FMAH1_SI2139 +FMAH1_SI879 +FMAH1_SX159 +FMAH1_SX249 +FMAH1_SX339 +FMAH1_SX429 +FMAH1_SX69 +FMBG0_SI1160 +FMBG0_SI1790 +FMBG0_SI2264 +FMBG0_SX260 +FMBG0_SX3 +FMBG0_SX350 +FMBG0_SX440 +FMBG0_SX80 +FMEM0_SI1377 +FMEM0_SI2007 +FMEM0_SI747 +FMEM0_SX117 +FMEM0_SX207 +FMEM0_SX297 +FMEM0_SX333 +FMEM0_SX387 +FMJB0_SI1177 +FMJB0_SI1807 +FMJB0_SI547 +FMJB0_SX187 +FMJB0_SX277 +FMJB0_SX367 +FMJB0_SX7 +FMJB0_SX97 +FMJF0_SI1254 +FMJF0_SI1884 +FMJF0_SI624 +FMJF0_SX174 +FMJF0_SX264 +FMJF0_SX354 +FMJF0_SX444 +FMJF0_SX84 +FMJU0_SI1389 +FMJU0_SI2019 +FMJU0_SI759 +FMJU0_SX129 +FMJU0_SX219 +FMJU0_SX309 +FMJU0_SX39 +FMJU0_SX399 +FMKC0_SI1041 +FMKC0_SI1072 +FMKC0_SI1702 +FMKC0_SX172 +FMKC0_SX262 +FMKC0_SX352 +FMKC0_SX442 +FMKC0_SX82 +FMKF0_SI1018 +FMKF0_SI1536 +FMKF0_SI906 +FMKF0_SX186 +FMKF0_SX276 +FMKF0_SX366 +FMKF0_SX6 +FMKF0_SX96 +FMMH0_SI1537 +FMMH0_SI2167 +FMMH0_SI907 +FMMH0_SX187 +FMMH0_SX367 +FMMH0_SX420 +FMMH0_SX7 +FMMH0_SX97 +FMPG0_SI1602 +FMPG0_SI2232 +FMPG0_SI972 +FMPG0_SX162 +FMPG0_SX252 +FMPG0_SX342 +FMPG0_SX432 +FMPG0_SX72 +FNKL0_SI1522 +FNKL0_SI2152 +FNKL0_SI892 +FNKL0_SX172 +FNKL0_SX196 +FNKL0_SX262 +FNKL0_SX442 +FNKL0_SX82 +FNTB0_SI1203 +FNTB0_SI573 +FNTB0_SI679 +FNTB0_SX123 +FNTB0_SX213 +FNTB0_SX303 +FNTB0_SX33 +FNTB0_SX393 +FPAB1_SI1471 +FPAB1_SI2101 +FPAB1_SI841 +FPAB1_SX121 +FPAB1_SX211 +FPAB1_SX301 +FPAB1_SX31 +FPAB1_SX391 +FPAC0_SI1921 +FPAC0_SI2011 +FPAC0_SI661 +FPAC0_SX121 +FPAC0_SX211 +FPAC0_SX301 +FPAC0_SX31 +FPAC0_SX391 +FPAD0_SI1346 +FPAD0_SI1976 +FPAD0_SI716 +FPAD0_SX176 +FPAD0_SX266 +FPAD0_SX356 +FPAD0_SX446 +FPAD0_SX86 +FPAF0_SI1054 +FPAF0_SI1684 +FPAF0_SI2314 +FPAF0_SX154 +FPAF0_SX244 +FPAF0_SX334 +FPAF0_SX424 +FPAF0_SX64 +FPAZ0_SI1593 +FPAZ0_SI2223 +FPAZ0_SI963 +FPAZ0_SX153 +FPAZ0_SX243 +FPAZ0_SX27 +FPAZ0_SX423 +FPAZ0_SX63 +FPJF0_SI1046 +FPJF0_SI1259 +FPJF0_SI1676 +FPJF0_SX146 +FPJF0_SX236 +FPJF0_SX326 +FPJF0_SX352 +FPJF0_SX56 +FPLS0_SI1590 +FPLS0_SI2220 +FPLS0_SI960 +FPLS0_SX150 +FPLS0_SX240 +FPLS0_SX3 +FPLS0_SX330 +FPLS0_SX60 +FPMY0_SI1153 +FPMY0_SI1783 +FPMY0_SI523 +FPMY0_SX163 +FPMY0_SX196 +FPMY0_SX253 +FPMY0_SX343 +FPMY0_SX73 +FREH0_SI1315 +FREH0_SI1945 +FREH0_SI685 +FREH0_SX145 +FREH0_SX235 +FREH0_SX325 +FREH0_SX415 +FREH0_SX55 +FRJB0_SI1427 +FRJB0_SI1470 +FRJB0_SI1794 +FRJB0_SX167 +FRJB0_SX257 +FRJB0_SX347 +FRJB0_SX437 +FRJB0_SX77 +FRLL0_SI1514 +FRLL0_SI805 +FRLL0_SI884 +FRLL0_SX164 +FRLL0_SX254 +FRLL0_SX344 +FRLL0_SX434 +FRLL0_SX74 +FSAG0_SI1323 +FSAG0_SI1953 +FSAG0_SI693 +FSAG0_SX153 +FSAG0_SX243 +FSAG0_SX333 +FSAG0_SX423 +FSAG0_SX63 +FSAH0_SI1244 +FSAH0_SI1874 +FSAH0_SI614 +FSAH0_SX164 +FSAH0_SX327 +FSAH0_SX344 +FSAH0_SX434 +FSAH0_SX74 +FSAK0_SI1300 +FSAK0_SI1930 +FSAK0_SI670 +FSAK0_SX130 +FSAK0_SX220 +FSAK0_SX310 +FSAK0_SX40 +FSAK0_SX400 +FSBK0_SI1069 +FSBK0_SI1699 +FSBK0_SI2329 +FSBK0_SX169 +FSBK0_SX259 +FSBK0_SX349 +FSBK0_SX439 +FSBK0_SX79 +FSCN0_SI1886 +FSCN0_SI626 +FSCN0_SI705 +FSCN0_SX176 +FSCN0_SX266 +FSCN0_SX356 +FSCN0_SX446 +FSCN0_SX86 +FSDC0_SI1312 +FSDC0_SI1942 +FSDC0_SI2234 +FSDC0_SX142 +FSDC0_SX232 +FSDC0_SX322 +FSDC0_SX412 +FSDC0_SX52 +FSDJ0_SI1115 +FSDJ0_SI1745 +FSDJ0_SI485 +FSDJ0_SX125 +FSDJ0_SX215 +FSDJ0_SX305 +FSDJ0_SX35 +FSDJ0_SX395 +FSGF0_SI1557 +FSGF0_SI2187 +FSGF0_SI927 +FSGF0_SX117 +FSGF0_SX207 +FSGF0_SX27 +FSGF0_SX297 +FSGF0_SX387 +FSJG0_SI1570 +FSJG0_SI2200 +FSJG0_SI940 +FSJG0_SX130 +FSJG0_SX220 +FSJG0_SX310 +FSJG0_SX40 +FSJG0_SX400 +FSJK1_SI1025 +FSJK1_SI2285 +FSJK1_SI696 +FSJK1_SX125 +FSJK1_SX215 +FSJK1_SX305 +FSJK1_SX35 +FSJK1_SX395 +FSJS0_SI1171 +FSJS0_SI1801 +FSJS0_SI541 +FSJS0_SX181 +FSJS0_SX271 +FSJS0_SX361 +FSJS0_SX451 +FSJS0_SX91 +FSJW0_SI1333 +FSJW0_SI1963 +FSJW0_SI703 +FSJW0_SX163 +FSJW0_SX253 +FSJW0_SX343 +FSJW0_SX433 +FSJW0_SX73 +FSKC0_SI1416 +FSKC0_SI2046 +FSKC0_SI786 +FSKC0_SX156 +FSKC0_SX246 +FSKC0_SX336 +FSKC0_SX426 +FSKC0_SX66 +FSKL0_SI1529 +FSKL0_SI2159 +FSKL0_SI899 +FSKL0_SX179 +FSKL0_SX269 +FSKL0_SX359 +FSKL0_SX449 +FSKL0_SX89 +FSKP0_SI1098 +FSKP0_SI1728 +FSKP0_SI468 +FSKP0_SX108 +FSKP0_SX18 +FSKP0_SX198 +FSKP0_SX288 +FSKP0_SX378 +FSLS0_SI1056 +FSLS0_SI1686 +FSLS0_SI2316 +FSLS0_SX156 +FSLS0_SX202 +FSLS0_SX246 +FSLS0_SX426 +FSLS0_SX66 +FSMA0_SI1621 +FSMA0_SI2251 +FSMA0_SI991 +FSMA0_SX181 +FSMA0_SX271 +FSMA0_SX361 +FSMA0_SX451 +FSMA0_SX91 +FSMM0_SI1314 +FSMM0_SI1944 +FSMM0_SI684 +FSMM0_SX144 +FSMM0_SX234 +FSMM0_SX324 +FSMM0_SX414 +FSMM0_SX54 +FSMS1_SI1504 +FSMS1_SI2134 +FSMS1_SI874 +FSMS1_SX154 +FSMS1_SX244 +FSMS1_SX334 +FSMS1_SX347 +FSMS1_SX64 +FSPM0_SI1241 +FSPM0_SI1871 +FSPM0_SI611 +FSPM0_SX161 +FSPM0_SX251 +FSPM0_SX341 +FSPM0_SX431 +FSPM0_SX71 +FSRH0_SI1719 +FSRH0_SI1931 +FSRH0_SI671 +FSRH0_SX131 +FSRH0_SX221 +FSRH0_SX311 +FSRH0_SX401 +FSRH0_SX41 +FSSB0_SI1082 +FSSB0_SI1712 +FSSB0_SI2342 +FSSB0_SX182 +FSSB0_SX272 +FSSB0_SX362 +FSSB0_SX452 +FSSB0_SX92 +FTAJ0_SI1329 +FTAJ0_SI474 +FTAJ0_SI699 +FTAJ0_SX159 +FTAJ0_SX249 +FTAJ0_SX339 +FTAJ0_SX429 +FTAJ0_SX69 +FTBR0_SI1402 +FTBR0_SI2181 +FTBR0_SI921 +FTBR0_SX111 +FTBR0_SX201 +FTBR0_SX21 +FTBR0_SX291 +FTBR0_SX381 +FTBW0_SI1345 +FTBW0_SI1975 +FTBW0_SI715 +FTBW0_SX175 +FTBW0_SX265 +FTBW0_SX355 +FTBW0_SX445 +FTBW0_SX85 +FTLG0_SI1743 +FTLG0_SI483 +FTLG0_SI840 +FTLG0_SX123 +FTLG0_SX213 +FTLG0_SX303 +FTLG0_SX33 +FTLG0_SX393 +FTMG0_SI1532 +FTMG0_SI2162 +FTMG0_SI902 +FTMG0_SX182 +FTMG0_SX272 +FTMG0_SX362 +FTMG0_SX452 +FTMG0_SX92 +FVFB0_SI1032 +FVFB0_SI1510 +FVFB0_SI2292 +FVFB0_SX132 +FVFB0_SX222 +FVFB0_SX312 +FVFB0_SX402 +FVFB0_SX42 +FVKB0_SI1159 +FVKB0_SI1789 +FVKB0_SI529 +FVKB0_SX169 +FVKB0_SX259 +FVKB0_SX349 +FVKB0_SX439 +FVKB0_SX79 +FVMH0_SI1466 +FVMH0_SI2096 +FVMH0_SI836 +FVMH0_SX116 +FVMH0_SX206 +FVMH0_SX26 +FVMH0_SX296 +FVMH0_SX386 +MABC0_SI1620 +MABC0_SI2041 +MABC0_SI781 +MABC0_SX151 +MABC0_SX241 +MABC0_SX331 +MABC0_SX421 +MABC0_SX61 +MADC0_SI1367 +MADC0_SI1997 +MADC0_SI737 +MADC0_SX107 +MADC0_SX17 +MADC0_SX197 +MADC0_SX287 +MADC0_SX377 +MADD0_SI1295 +MADD0_SI1798 +MADD0_SI538 +MADD0_SX178 +MADD0_SX268 +MADD0_SX358 +MADD0_SX448 +MADD0_SX88 +MAEB0_SI1411 +MAEB0_SI2250 +MAEB0_SI990 +MAEB0_SX180 +MAEB0_SX270 +MAEB0_SX360 +MAEB0_SX450 +MAEB0_SX90 +MAEO0_SI1326 +MAEO0_SI1655 +MAEO0_SI1956 +MAEO0_SX156 +MAEO0_SX246 +MAEO0_SX336 +MAEO0_SX426 +MAEO0_SX66 +MAFM0_SI1569 +MAFM0_SI2199 +MAFM0_SI939 +MAFM0_SX129 +MAFM0_SX219 +MAFM0_SX309 +MAFM0_SX39 +MAFM0_SX399 +MAJP0_SI1074 +MAJP0_SI1704 +MAJP0_SI2334 +MAJP0_SX174 +MAJP0_SX264 +MAJP0_SX354 +MAJP0_SX444 +MAJP0_SX84 +MAKB0_SI1016 +MAKB0_SI1646 +MAKB0_SI2276 +MAKB0_SX116 +MAKB0_SX206 +MAKB0_SX26 +MAKB0_SX296 +MAKB0_SX386 +MAKR0_SI1352 +MAKR0_SI1982 +MAKR0_SI722 +MAKR0_SX182 +MAKR0_SX272 +MAKR0_SX362 +MAKR0_SX452 +MAKR0_SX92 +MAPV0_SI1293 +MAPV0_SI1923 +MAPV0_SI663 +MAPV0_SX123 +MAPV0_SX213 +MAPV0_SX303 +MAPV0_SX33 +MAPV0_SX393 +MARC0_SI1188 +MARC0_SI1818 +MARC0_SI558 +MARC0_SX108 +MARC0_SX18 +MARC0_SX198 +MARC0_SX288 +MARC0_SX378 +MARW0_SI1276 +MARW0_SI1906 +MARW0_SI646 +MARW0_SX106 +MARW0_SX16 +MARW0_SX286 +MARW0_SX349 +MARW0_SX376 +MBAR0_SI1319 +MBAR0_SI1949 +MBAR0_SI689 +MBAR0_SX149 +MBAR0_SX239 +MBAR0_SX329 +MBAR0_SX419 +MBAR0_SX59 +MBBR0_SI1055 +MBBR0_SI1685 +MBBR0_SI2315 +MBBR0_SX155 +MBBR0_SX245 +MBBR0_SX335 +MBBR0_SX425 +MBBR0_SX65 +MBCG0_SI2217 +MBCG0_SI486 +MBCG0_SI957 +MBCG0_SX147 +MBCG0_SX237 +MBCG0_SX327 +MBCG0_SX417 +MBCG0_SX57 +MBEF0_SI1281 +MBEF0_SI1911 +MBEF0_SI651 +MBEF0_SX111 +MBEF0_SX201 +MBEF0_SX21 +MBEF0_SX291 +MBEF0_SX381 +MBGT0_SI1341 +MBGT0_SI1841 +MBGT0_SI711 +MBGT0_SX171 +MBGT0_SX261 +MBGT0_SX351 +MBGT0_SX441 +MBGT0_SX81 +MBJV0_SI1247 +MBJV0_SI1877 +MBJV0_SI617 +MBJV0_SX167 +MBJV0_SX257 +MBJV0_SX347 +MBJV0_SX437 +MBJV0_SX77 +MBMA0_SI1222 +MBMA0_SI1852 +MBMA0_SI592 +MBMA0_SX142 +MBMA0_SX232 +MBMA0_SX322 +MBMA0_SX412 +MBMA0_SX52 +MBMA1_SI2207 +MBMA1_SI2214 +MBMA1_SI954 +MBMA1_SX144 +MBMA1_SX234 +MBMA1_SX324 +MBMA1_SX414 +MBMA1_SX54 +MBML0_SI1169 +MBML0_SI1799 +MBML0_SI539 +MBML0_SX179 +MBML0_SX269 +MBML0_SX359 +MBML0_SX449 +MBML0_SX89 +MBOM0_SI1014 +MBOM0_SI1644 +MBOM0_SI2274 +MBOM0_SX114 +MBOM0_SX204 +MBOM0_SX294 +MBOM0_SX311 +MBOM0_SX384 +MBSB0_SI1353 +MBSB0_SI1983 +MBSB0_SI723 +MBSB0_SX183 +MBSB0_SX273 +MBSB0_SX3 +MBSB0_SX363 +MBSB0_SX93 +MBTH0_SI2102 +MBTH0_SI505 +MBTH0_SI757 +MBTH0_SX122 +MBTH0_SX212 +MBTH0_SX302 +MBTH0_SX32 +MBTH0_SX392 +MBWP0_SI1531 +MBWP0_SI1969 +MBWP0_SI709 +MBWP0_SX169 +MBWP0_SX259 +MBWP0_SX349 +MBWP0_SX439 +MBWP0_SX79 +MCAE0_SI1447 +MCAE0_SI2077 +MCAE0_SI817 +MCAE0_SX187 +MCAE0_SX277 +MCAE0_SX367 +MCAE0_SX7 +MCAE0_SX97 +MCAL0_SI1138 +MCAL0_SI1768 +MCAL0_SI508 +MCAL0_SX148 +MCAL0_SX238 +MCAL0_SX328 +MCAL0_SX418 +MCAL0_SX58 +MCDC0_SI1292 +MCDC0_SI1922 +MCDC0_SI662 +MCDC0_SX122 +MCDC0_SX212 +MCDC0_SX302 +MCDC0_SX32 +MCDC0_SX392 +MCDD0_SI1513 +MCDD0_SI2143 +MCDD0_SI883 +MCDD0_SX163 +MCDD0_SX253 +MCDD0_SX343 +MCDD0_SX433 +MCDD0_SX73 +MCDR0_SI1154 +MCDR0_SI1784 +MCDR0_SI524 +MCDR0_SX164 +MCDR0_SX254 +MCDR0_SX344 +MCDR0_SX434 +MCDR0_SX74 +MCEF0_SI1135 +MCEF0_SI1765 +MCEF0_SI842 +MCEF0_SX145 +MCEF0_SX235 +MCEF0_SX325 +MCEF0_SX415 +MCEF0_SX55 +MCEW0_SI1442 +MCEW0_SI2072 +MCEW0_SI812 +MCEW0_SX182 +MCEW0_SX272 +MCEW0_SX362 +MCEW0_SX452 +MCEW0_SX92 +MCHL0_SI1347 +MCHL0_SI1404 +MCHL0_SI1977 +MCHL0_SX177 +MCHL0_SX267 +MCHL0_SX357 +MCHL0_SX447 +MCHL0_SX87 +MCLK0_SI1660 +MCLK0_SI2290 +MCLK0_SI650 +MCLK0_SX130 +MCLK0_SX220 +MCLK0_SX310 +MCLK0_SX40 +MCLK0_SX400 +MCLM0_SI1456 +MCLM0_SI2086 +MCLM0_SI826 +MCLM0_SX106 +MCLM0_SX16 +MCLM0_SX196 +MCLM0_SX286 +MCLM0_SX376 +MCPM0_SI1194 +MCPM0_SI1824 +MCPM0_SI564 +MCPM0_SX114 +MCPM0_SX204 +MCPM0_SX24 +MCPM0_SX294 +MCPM0_SX384 +MCRE0_SI1121 +MCRE0_SI1725 +MCRE0_SI1751 +MCRE0_SX131 +MCRE0_SX221 +MCRE0_SX24 +MCRE0_SX401 +MCRE0_SX41 +MCSS0_SI1380 +MCSS0_SI688 +MCSS0_SI750 +MCSS0_SX120 +MCSS0_SX210 +MCSS0_SX30 +MCSS0_SX300 +MCSS0_SX390 +MCTH0_SI1209 +MCTH0_SI1839 +MCTH0_SI579 +MCTH0_SX129 +MCTH0_SX219 +MCTH0_SX309 +MCTH0_SX39 +MCTH0_SX399 +MCTM0_SI1350 +MCTM0_SI1980 +MCTM0_SI720 +MCTM0_SX180 +MCTM0_SX270 +MCTM0_SX360 +MCTM0_SX450 +MCTM0_SX90 +MCXM0_SI1351 +MCXM0_SI1981 +MCXM0_SI721 +MCXM0_SX181 +MCXM0_SX271 +MCXM0_SX361 +MCXM0_SX451 +MCXM0_SX91 +MDAC0_SI1261 +MDAC0_SI1837 +MDAC0_SI631 +MDAC0_SX181 +MDAC0_SX271 +MDAC0_SX361 +MDAC0_SX451 +MDAC0_SX91 +MDAS0_SI1266 +MDAS0_SI1896 +MDAS0_SI636 +MDAS0_SX186 +MDAS0_SX21 +MDAS0_SX276 +MDAS0_SX6 +MDAS0_SX96 +MDBB1_SI1006 +MDBB1_SI1636 +MDBB1_SI2056 +MDBB1_SX106 +MDBB1_SX16 +MDBB1_SX196 +MDBB1_SX286 +MDBB1_SX376 +MDBP0_SI1158 +MDBP0_SI1788 +MDBP0_SI528 +MDBP0_SX168 +MDBP0_SX258 +MDBP0_SX348 +MDBP0_SX438 +MDBP0_SX78 +MDCD0_SI1415 +MDCD0_SI2045 +MDCD0_SI785 +MDCD0_SX155 +MDCD0_SX245 +MDCD0_SX335 +MDCD0_SX425 +MDCD0_SX65 +MDCM0_SI1480 +MDCM0_SI2110 +MDCM0_SI850 +MDCM0_SX130 +MDCM0_SX220 +MDCM0_SX310 +MDCM0_SX40 +MDCM0_SX400 +MDDC0_SI1419 +MDDC0_SI2049 +MDDC0_SI789 +MDDC0_SX159 +MDDC0_SX249 +MDDC0_SX339 +MDDC0_SX429 +MDDC0_SX69 +MDED0_SI1170 +MDED0_SI1800 +MDED0_SI540 +MDED0_SX180 +MDED0_SX270 +MDED0_SX360 +MDED0_SX450 +MDED0_SX90 +MDEF0_SI1123 +MDEF0_SI1563 +MDEF0_SI2193 +MDEF0_SX123 +MDEF0_SX213 +MDEF0_SX303 +MDEF0_SX33 +MDEF0_SX393 +MDEM0_SI1868 +MDEM0_SI608 +MDEM0_SI800 +MDEM0_SX158 +MDEM0_SX248 +MDEM0_SX338 +MDEM0_SX428 +MDEM0_SX68 +MDHL0_SI1439 +MDHL0_SI2069 +MDHL0_SI809 +MDHL0_SX179 +MDHL0_SX269 +MDHL0_SX359 +MDHL0_SX449 +MDHL0_SX89 +MDHS0_SI1530 +MDHS0_SI2160 +MDHS0_SI900 +MDHS0_SX180 +MDHS0_SX270 +MDHS0_SX360 +MDHS0_SX450 +MDHS0_SX90 +MDJM0_SI1455 +MDJM0_SI2085 +MDJM0_SI825 +MDJM0_SX105 +MDJM0_SX15 +MDJM0_SX195 +MDJM0_SX285 +MDJM0_SX375 +MDKS0_SI1066 +MDKS0_SI1696 +MDKS0_SI2326 +MDKS0_SX166 +MDKS0_SX256 +MDKS0_SX346 +MDKS0_SX436 +MDKS0_SX76 +MDLB0_SI1306 +MDLB0_SI1936 +MDLB0_SI676 +MDLB0_SX136 +MDLB0_SX226 +MDLB0_SX316 +MDLB0_SX406 +MDLB0_SX46 +MDLC0_SI1395 +MDLC0_SI2025 +MDLC0_SI765 +MDLC0_SX135 +MDLC0_SX225 +MDLC0_SX315 +MDLC0_SX405 +MDLC0_SX45 +MDLC1_SI1435 +MDLC1_SI2065 +MDLC1_SI2144 +MDLC1_SX175 +MDLC1_SX265 +MDLC1_SX355 +MDLC1_SX445 +MDLC1_SX85 +MDLC2_SI1614 +MDLC2_SI2244 +MDLC2_SI984 +MDLC2_SX174 +MDLC2_SX264 +MDLC2_SX354 +MDLC2_SX444 +MDLC2_SX84 +MDLH0_SI1960 +MDLH0_SI574 +MDLH0_SI700 +MDLH0_SX160 +MDLH0_SX250 +MDLH0_SX340 +MDLH0_SX430 +MDLH0_SX70 +MDLM0_SI1234 +MDLM0_SI1864 +MDLM0_SI604 +MDLM0_SX154 +MDLM0_SX244 +MDLM0_SX334 +MDLM0_SX424 +MDLM0_SX64 +MDLR0_SI1233 +MDLR0_SI1863 +MDLR0_SI603 +MDLR0_SX153 +MDLR0_SX243 +MDLR0_SX333 +MDLR0_SX423 +MDLR0_SX63 +MDLR1_SI1299 +MDLR1_SI1929 +MDLR1_SI669 +MDLR1_SX129 +MDLR1_SX219 +MDLR1_SX309 +MDLR1_SX39 +MDLR1_SX399 +MDMA0_SI1238 +MDMA0_SI1430 +MDMA0_SI2060 +MDMA0_SX170 +MDMA0_SX260 +MDMA0_SX350 +MDMA0_SX440 +MDMA0_SX80 +MDMT0_SI1832 +MDMT0_SI2341 +MDMT0_SI572 +MDMT0_SX122 +MDMT0_SX212 +MDMT0_SX302 +MDMT0_SX32 +MDMT0_SX392 +MDNS0_SI1011 +MDNS0_SI2271 +MDNS0_SI873 +MDNS0_SX111 +MDNS0_SX201 +MDNS0_SX21 +MDNS0_SX291 +MDNS0_SX381 +MDPB0_SI1760 +MDPB0_SI2126 +MDPB0_SI866 +MDPB0_SX146 +MDPB0_SX236 +MDPB0_SX326 +MDPB0_SX416 +MDPB0_SX56 +MDPK0_SI1053 +MDPK0_SI1683 +MDPK0_SI552 +MDPK0_SX153 +MDPK0_SX243 +MDPK0_SX333 +MDPK0_SX423 +MDPK0_SX63 +MDPS0_SI1651 +MDPS0_SI1979 +MDPS0_SI719 +MDPS0_SX179 +MDPS0_SX269 +MDPS0_SX359 +MDPS0_SX449 +MDPS0_SX89 +MDRD0_SI1382 +MDRD0_SI2012 +MDRD0_SI752 +MDRD0_SX122 +MDRD0_SX212 +MDRD0_SX302 +MDRD0_SX32 +MDRD0_SX392 +MDSJ0_SI1462 +MDSJ0_SI2092 +MDSJ0_SI832 +MDSJ0_SX112 +MDSJ0_SX22 +MDSJ0_SX292 +MDSJ0_SX382 +MDSJ0_SX438 +MDSS0_SI1881 +MDSS0_SI2087 +MDSS0_SI621 +MDSS0_SX171 +MDSS0_SX261 +MDSS0_SX351 +MDSS0_SX441 +MDSS0_SX81 +MDSS1_SI1327 +MDSS1_SI1713 +MDSS1_SI697 +MDSS1_SX157 +MDSS1_SX247 +MDSS1_SX337 +MDSS1_SX427 +MDSS1_SX67 +MDTB0_SI1200 +MDTB0_SI1830 +MDTB0_SI570 +MDTB0_SX120 +MDTB0_SX210 +MDTB0_SX300 +MDTB0_SX321 +MDTB0_SX390 +MDWD0_SI1260 +MDWD0_SI1890 +MDWD0_SI557 +MDWD0_SX180 +MDWD0_SX270 +MDWD0_SX360 +MDWD0_SX450 +MDWD0_SX90 +MDWH0_SI1168 +MDWH0_SI1925 +MDWH0_SI665 +MDWH0_SX125 +MDWH0_SX215 +MDWH0_SX305 +MDWH0_SX35 +MDWH0_SX395 +MDWM0_SI1546 +MDWM0_SI2176 +MDWM0_SI916 +MDWM0_SX106 +MDWM0_SX16 +MDWM0_SX286 +MDWM0_SX376 +MDWM0_SX433 +MEAL0_SI1547 +MEAL0_SI2177 +MEAL0_SI917 +MEAL0_SX107 +MEAL0_SX197 +MEAL0_SX287 +MEAL0_SX347 +MEAL0_SX377 +MEDR0_SI1374 +MEDR0_SI2004 +MEDR0_SI744 +MEDR0_SX114 +MEDR0_SX204 +MEDR0_SX24 +MEDR0_SX294 +MEDR0_SX384 +MEFG0_SI465 +MEFG0_SI491 +MEFG0_SI598 +MEFG0_SX105 +MEFG0_SX15 +MEFG0_SX195 +MEFG0_SX285 +MEFG0_SX375 +MEGJ0_SI1337 +MEGJ0_SI1967 +MEGJ0_SI707 +MEGJ0_SX167 +MEGJ0_SX257 +MEGJ0_SX3 +MEGJ0_SX437 +MEGJ0_SX77 +MEJL0_SI1592 +MEJL0_SI1654 +MEJL0_SI962 +MEJL0_SX152 +MEJL0_SX242 +MEJL0_SX332 +MEJL0_SX422 +MEJL0_SX62 +MEJS0_SI1240 +MEJS0_SI1870 +MEJS0_SI610 +MEJS0_SX160 +MEJS0_SX250 +MEJS0_SX340 +MEJS0_SX430 +MEJS0_SX70 +MESG0_SI1332 +MESG0_SI1962 +MESG0_SI702 +MESG0_SX162 +MESG0_SX252 +MESG0_SX342 +MESG0_SX432 +MESG0_SX72 +MESJ0_SI2039 +MESJ0_SI2257 +MESJ0_SI997 +MESJ0_SX187 +MESJ0_SX277 +MESJ0_SX367 +MESJ0_SX7 +MESJ0_SX97 +MEWM0_SI1348 +MEWM0_SI1978 +MEWM0_SI718 +MEWM0_SX178 +MEWM0_SX268 +MEWM0_SX358 +MEWM0_SX448 +MEWM0_SX88 +MFER0_SI1492 +MFER0_SI2122 +MFER0_SI862 +MFER0_SX142 +MFER0_SX232 +MFER0_SX322 +MFER0_SX412 +MFER0_SX52 +MFMC0_SI1132 +MFMC0_SI1762 +MFMC0_SI502 +MFMC0_SX142 +MFMC0_SX232 +MFMC0_SX322 +MFMC0_SX412 +MFMC0_SX52 +MFRM0_SI1155 +MFRM0_SI1717 +MFRM0_SI1785 +MFRM0_SX165 +MFRM0_SX255 +MFRM0_SX345 +MFRM0_SX435 +MFRM0_SX75 +MFWK0_SI1249 +MFWK0_SI1879 +MFWK0_SI619 +MFWK0_SX169 +MFWK0_SX259 +MFWK0_SX349 +MFWK0_SX439 +MFWK0_SX79 +MFXS0_SI1674 +MFXS0_SI2225 +MFXS0_SI2304 +MFXS0_SX144 +MFXS0_SX234 +MFXS0_SX324 +MFXS0_SX414 +MFXS0_SX54 +MFXV0_SI1005 +MFXV0_SI1342 +MFXV0_SI1635 +MFXV0_SX105 +MFXV0_SX15 +MFXV0_SX195 +MFXV0_SX285 +MFXV0_SX375 +MGAF0_SI1282 +MGAF0_SI1912 +MGAF0_SI652 +MGAF0_SX112 +MGAF0_SX202 +MGAF0_SX22 +MGAF0_SX292 +MGAF0_SX382 +MGAG0_SI1321 +MGAG0_SI645 +MGAG0_SI691 +MGAG0_SX151 +MGAG0_SX241 +MGAG0_SX331 +MGAG0_SX421 +MGAG0_SX61 +MGAK0_SI1036 +MGAK0_SI1666 +MGAK0_SI2296 +MGAK0_SX136 +MGAK0_SX226 +MGAK0_SX316 +MGAK0_SX406 +MGAK0_SX46 +MGAR0_SI1212 +MGAR0_SI1694 +MGAR0_SI1842 +MGAR0_SX132 +MGAR0_SX222 +MGAR0_SX312 +MGAR0_SX402 +MGAR0_SX42 +MGAW0_SI1165 +MGAW0_SI1802 +MGAW0_SI535 +MGAW0_SX175 +MGAW0_SX265 +MGAW0_SX355 +MGAW0_SX445 +MGAW0_SX85 +MGES0_SI1481 +MGES0_SI2111 +MGES0_SI851 +MGES0_SX131 +MGES0_SX221 +MGES0_SX311 +MGES0_SX401 +MGES0_SX41 +MGJC0_SI1256 +MGJC0_SI1335 +MGJC0_SI1965 +MGJC0_SX165 +MGJC0_SX255 +MGJC0_SX345 +MGJC0_SX435 +MGJC0_SX75 +MGRL0_SI1497 +MGRL0_SI2127 +MGRL0_SI867 +MGRL0_SX147 +MGRL0_SX237 +MGRL0_SX327 +MGRL0_SX417 +MGRL0_SX57 +MGRP0_SI1317 +MGRP0_SI1947 +MGRP0_SI687 +MGRP0_SX147 +MGRP0_SX237 +MGRP0_SX327 +MGRP0_SX417 +MGRP0_SX57 +MGSH0_SI1176 +MGSH0_SI1806 +MGSH0_SI546 +MGSH0_SX127 +MGSH0_SX186 +MGSH0_SX276 +MGSH0_SX6 +MGSH0_SX96 +MGSL0_SI1164 +MGSL0_SI534 +MGSL0_SI797 +MGSL0_SX174 +MGSL0_SX264 +MGSL0_SX354 +MGSL0_SX444 +MGSL0_SX84 +MGXP0_SI1087 +MGXP0_SI457 +MGXP0_SI525 +MGXP0_SX187 +MGXP0_SX277 +MGXP0_SX367 +MGXP0_SX7 +MGXP0_SX97 +MHBS0_SI1575 +MHBS0_SI2205 +MHBS0_SI945 +MHBS0_SX135 +MHBS0_SX225 +MHBS0_SX315 +MHBS0_SX405 +MHBS0_SX45 +MHIT0_SI1613 +MHIT0_SI2243 +MHIT0_SI983 +MHIT0_SX173 +MHIT0_SX263 +MHIT0_SX353 +MHIT0_SX443 +MHIT0_SX83 +MHJB0_SI1017 +MHJB0_SI1647 +MHJB0_SI2277 +MHJB0_SX117 +MHJB0_SX207 +MHJB0_SX27 +MHJB0_SX297 +MHJB0_SX387 +MHMG0_SI1365 +MHMG0_SI1995 +MHMG0_SI735 +MHMG0_SX105 +MHMG0_SX15 +MHMG0_SX195 +MHMG0_SX285 +MHMG0_SX375 +MHMR0_SI1119 +MHMR0_SI1692 +MHMR0_SI489 +MHMR0_SX129 +MHMR0_SX219 +MHMR0_SX309 +MHMR0_SX39 +MHMR0_SX399 +MHRM0_SI1475 +MHRM0_SI2218 +MHRM0_SI958 +MHRM0_SX148 +MHRM0_SX238 +MHRM0_SX328 +MHRM0_SX418 +MHRM0_SX58 +MHXL0_SI1772 +MHXL0_SI512 +MHXL0_SI612 +MHXL0_SX152 +MHXL0_SX242 +MHXL0_SX332 +MHXL0_SX422 +MHXL0_SX62 +MILB0_SI2163 +MILB0_SI807 +MILB0_SI903 +MILB0_SX183 +MILB0_SX273 +MILB0_SX3 +MILB0_SX363 +MILB0_SX93 +MJAC0_SI1331 +MJAC0_SI2148 +MJAC0_SI701 +MJAC0_SX251 +MJAC0_SX307 +MJAC0_SX341 +MJAC0_SX431 +MJAC0_SX71 +MJAE0_SI1524 +MJAE0_SI1999 +MJAE0_SI2154 +MJAE0_SX174 +MJAE0_SX264 +MJAE0_SX354 +MJAE0_SX444 +MJAE0_SX84 +MJAI0_SI1604 +MJAI0_SI682 +MJAI0_SI710 +MJAI0_SX164 +MJAI0_SX254 +MJAI0_SX344 +MJAI0_SX434 +MJAI0_SX74 +MJBG0_SI1232 +MJBG0_SI1724 +MJBG0_SI1862 +MJBG0_SX152 +MJBG0_SX242 +MJBG0_SX332 +MJBG0_SX422 +MJBG0_SX62 +MJDA0_SI1031 +MJDA0_SI1661 +MJDA0_SI2291 +MJDA0_SX131 +MJDA0_SX221 +MJDA0_SX311 +MJDA0_SX401 +MJDA0_SX41 +MJDC0_SI1161 +MJDC0_SI2165 +MJDC0_SI531 +MJDC0_SX171 +MJDC0_SX261 +MJDC0_SX351 +MJDC0_SX441 +MJDC0_SX81 +MJDE0_SI1120 +MJDE0_SI463 +MJDE0_SI490 +MJDE0_SX130 +MJDE0_SX220 +MJDE0_SX310 +MJDE0_SX40 +MJDE0_SX400 +MJDG0_SI1042 +MJDG0_SI1672 +MJDG0_SI1705 +MJDG0_SX142 +MJDG0_SX232 +MJDG0_SX322 +MJDG0_SX412 +MJDG0_SX52 +MJDM0_SI1340 +MJDM0_SI1937 +MJDM0_SI974 +MJDM0_SX170 +MJDM0_SX260 +MJDM0_SX350 +MJDM0_SX440 +MJDM0_SX80 +MJEB0_SI1286 +MJEB0_SI1916 +MJEB0_SI656 +MJEB0_SX170 +MJEB0_SX206 +MJEB0_SX26 +MJEB0_SX296 +MJEB0_SX386 +MJEB1_SI1467 +MJEB1_SI2097 +MJEB1_SI837 +MJEB1_SX117 +MJEB1_SX207 +MJEB1_SX27 +MJEB1_SX297 +MJEB1_SX387 +MJEE0_SI1237 +MJEE0_SI1867 +MJEE0_SI607 +MJEE0_SX157 +MJEE0_SX247 +MJEE0_SX337 +MJEE0_SX427 +MJEE0_SX67 +MJFH0_SI1107 +MJFH0_SI1737 +MJFH0_SI477 +MJFH0_SX117 +MJFH0_SX207 +MJFH0_SX27 +MJFH0_SX297 +MJFH0_SX387 +MJFR0_SI1605 +MJFR0_SI2235 +MJFR0_SI975 +MJFR0_SX165 +MJFR0_SX255 +MJFR0_SX345 +MJFR0_SX435 +MJFR0_SX75 +MJHI0_SI1328 +MJHI0_SI555 +MJHI0_SI698 +MJHI0_SX158 +MJHI0_SX248 +MJHI0_SX338 +MJHI0_SX428 +MJHI0_SX68 +MJJB0_SI1139 +MJJB0_SI1277 +MJJB0_SI1769 +MJJB0_SX149 +MJJB0_SX239 +MJJB0_SX329 +MJJB0_SX419 +MJJB0_SX59 +MJJJ0_SI1163 +MJJJ0_SI1793 +MJJJ0_SI533 +MJJJ0_SX173 +MJJJ0_SX263 +MJJJ0_SX353 +MJJJ0_SX443 +MJJJ0_SX83 +MJJM0_SI1251 +MJJM0_SI1457 +MJJM0_SI827 +MJJM0_SX107 +MJJM0_SX17 +MJJM0_SX197 +MJJM0_SX287 +MJJM0_SX377 +MJKR0_SI1201 +MJKR0_SI1831 +MJKR0_SI571 +MJKR0_SX121 +MJKR0_SX211 +MJKR0_SX301 +MJKR0_SX31 +MJKR0_SX391 +MJLB0_SI1616 +MJLB0_SI2246 +MJLB0_SI986 +MJLB0_SX176 +MJLB0_SX266 +MJLB0_SX356 +MJLB0_SX446 +MJLB0_SX86 +MJLG1_SI1012 +MJLG1_SI1642 +MJLG1_SI2272 +MJLG1_SX112 +MJLG1_SX202 +MJLG1_SX22 +MJLG1_SX292 +MJLG1_SX382 +MJLS0_SI1096 +MJLS0_SI1726 +MJLS0_SI466 +MJLS0_SX106 +MJLS0_SX16 +MJLS0_SX196 +MJLS0_SX286 +MJLS0_SX376 +MJMA0_SI1495 +MJMA0_SI2125 +MJMA0_SI865 +MJMA0_SX145 +MJMA0_SX235 +MJMA0_SX325 +MJMA0_SX415 +MJMA0_SX55 +MJMD0_SI1028 +MJMD0_SI1658 +MJMD0_SI2288 +MJMD0_SX128 +MJMD0_SX218 +MJMD0_SX308 +MJMD0_SX38 +MJMD0_SX398 +MJMM0_SI1255 +MJMM0_SI1885 +MJMM0_SI625 +MJMM0_SX175 +MJMM0_SX265 +MJMM0_SX355 +MJMM0_SX445 +MJMM0_SX85 +MJPG0_SI1191 +MJPG0_SI1821 +MJPG0_SI561 +MJPG0_SX111 +MJPG0_SX201 +MJPG0_SX21 +MJPG0_SX291 +MJPG0_SX381 +MJPM0_SI1368 +MJPM0_SI1998 +MJPM0_SI738 +MJPM0_SX108 +MJPM0_SX18 +MJPM0_SX198 +MJPM0_SX288 +MJPM0_SX378 +MJPM1_SI1897 +MJPM1_SI2280 +MJPM1_SI761 +MJPM1_SX131 +MJPM1_SX221 +MJPM1_SX311 +MJPM1_SX401 +MJPM1_SX41 +MJRA0_SI1236 +MJRA0_SI1866 +MJRA0_SI606 +MJRA0_SX156 +MJRA0_SX246 +MJRA0_SX336 +MJRA0_SX426 +MJRA0_SX66 +MJRG0_SI1366 +MJRG0_SI1996 +MJRG0_SI736 +MJRG0_SX106 +MJRG0_SX16 +MJRG0_SX286 +MJRG0_SX352 +MJRG0_SX376 +MJRH0_SI1125 +MJRH0_SI1755 +MJRH0_SI1840 +MJRH0_SX135 +MJRH0_SX225 +MJRH0_SX315 +MJRH0_SX405 +MJRH0_SX45 +MJRH1_SI1558 +MJRH1_SI1774 +MJRH1_SI514 +MJRH1_SX154 +MJRH1_SX244 +MJRH1_SX334 +MJRH1_SX424 +MJRH1_SX64 +MJRK0_SI1662 +MJRK0_SI2103 +MJRK0_SI880 +MJRK0_SX160 +MJRK0_SX250 +MJRK0_SX340 +MJRK0_SX430 +MJRK0_SX70 +MJRP0_SI1835 +MJRP0_SI1845 +MJRP0_SI585 +MJRP0_SX135 +MJRP0_SX225 +MJRP0_SX315 +MJRP0_SX405 +MJRP0_SX45 +MJSR0_SI1424 +MJSR0_SI2054 +MJSR0_SI794 +MJSR0_SX164 +MJSR0_SX254 +MJSR0_SX344 +MJSR0_SX434 +MJSR0_SX74 +MJWG0_SI2155 +MJWG0_SI813 +MJWG0_SI895 +MJWG0_SX175 +MJWG0_SX265 +MJWG0_SX355 +MJWG0_SX445 +MJWG0_SX85 +MJWS0_SI1143 +MJWS0_SI1773 +MJWS0_SI513 +MJWS0_SX153 +MJWS0_SX243 +MJWS0_SX333 +MJWS0_SX423 +MJWS0_SX63 +MJWT0_SI1291 +MJWT0_SI1381 +MJWT0_SI751 +MJWT0_SX121 +MJWT0_SX211 +MJWT0_SX301 +MJWT0_SX31 +MJWT0_SX391 +MJXA0_SI1507 +MJXA0_SI2137 +MJXA0_SI877 +MJXA0_SX157 +MJXA0_SX247 +MJXA0_SX337 +MJXA0_SX427 +MJXA0_SX67 +MJXL0_SI1172 +MJXL0_SI1795 +MJXL0_SI542 +MJXL0_SX182 +MJXL0_SX272 +MJXL0_SX362 +MJXL0_SX452 +MJXL0_SX92 +MKAG0_SI1609 +MKAG0_SI2239 +MKAG0_SI979 +MKAG0_SX169 +MKAG0_SX259 +MKAG0_SX30 +MKAG0_SX439 +MKAG0_SX79 +MKAH0_SI1528 +MKAH0_SI2158 +MKAH0_SI898 +MKAH0_SX178 +MKAH0_SX268 +MKAH0_SX358 +MKAH0_SX448 +MKAH0_SX88 +MKAJ0_SI1414 +MKAJ0_SI2044 +MKAJ0_SI784 +MKAJ0_SX154 +MKAJ0_SX244 +MKAJ0_SX334 +MKAJ0_SX424 +MKAJ0_SX64 +MKAM0_SI1250 +MKAM0_SI1316 +MKAM0_SI1465 +MKAM0_SX146 +MKAM0_SX236 +MKAM0_SX326 +MKAM0_SX416 +MKAM0_SX56 +MKDB0_SI2132 +MKDB0_SI588 +MKDB0_SI872 +MKDB0_SX152 +MKDB0_SX242 +MKDB0_SX332 +MKDB0_SX422 +MKDB0_SX62 +MKDD0_SI1567 +MKDD0_SI2197 +MKDD0_SI937 +MKDD0_SX127 +MKDD0_SX217 +MKDD0_SX307 +MKDD0_SX37 +MKDD0_SX397 +MKDT0_SI2153 +MKDT0_SI814 +MKDT0_SI893 +MKDT0_SX173 +MKDT0_SX263 +MKDT0_SX353 +MKDT0_SX443 +MKDT0_SX83 +MKES0_SI1253 +MKES0_SI1883 +MKES0_SI623 +MKES0_SX173 +MKES0_SX263 +MKES0_SX353 +MKES0_SX443 +MKES0_SX83 +MKJO0_SI1517 +MKJO0_SI2147 +MKJO0_SI887 +MKJO0_SX167 +MKJO0_SX257 +MKJO0_SX424 +MKJO0_SX437 +MKJO0_SX77 +MKLN0_SI1598 +MKLN0_SI2228 +MKLN0_SI968 +MKLN0_SX158 +MKLN0_SX248 +MKLN0_SX338 +MKLN0_SX428 +MKLN0_SX68 +MKLR0_SI1059 +MKLR0_SI1689 +MKLR0_SI2319 +MKLR0_SX159 +MKLR0_SX249 +MKLR0_SX339 +MKLR0_SX429 +MKLR0_SX69 +MKLS0_SI1437 +MKLS0_SI1533 +MKLS0_SI2067 +MKLS0_SX177 +MKLS0_SX267 +MKLS0_SX357 +MKLS0_SX447 +MKLS0_SX87 +MKLS1_SI1545 +MKLS1_SI2175 +MKLS1_SI915 +MKLS1_SX105 +MKLS1_SX15 +MKLS1_SX195 +MKLS1_SX285 +MKLS1_SX375 +MKLW0_SI1571 +MKLW0_SI1844 +MKLW0_SI2201 +MKLW0_SX131 +MKLW0_SX221 +MKLW0_SX311 +MKLW0_SX401 +MKLW0_SX41 +MKRG0_SI1491 +MKRG0_SI2121 +MKRG0_SI861 +MKRG0_SX141 +MKRG0_SX231 +MKRG0_SX31 +MKRG0_SX411 +MKRG0_SX51 +MKXL0_SI1185 +MKXL0_SI1815 +MKXL0_SI1958 +MKXL0_SX105 +MKXL0_SX15 +MKXL0_SX195 +MKXL0_SX285 +MKXL0_SX375 +MLBC0_SI1239 +MLBC0_SI1869 +MLBC0_SI609 +MLBC0_SX159 +MLBC0_SX249 +MLBC0_SX339 +MLBC0_SX429 +MLBC0_SX69 +MLEL0_SI1246 +MLEL0_SI1876 +MLEL0_SI616 +MLEL0_SX166 +MLEL0_SX256 +MLEL0_SX346 +MLEL0_SX436 +MLEL0_SX76 +MLJC0_SI1225 +MLJC0_SI1855 +MLJC0_SI595 +MLJC0_SX145 +MLJC0_SX235 +MLJC0_SX325 +MLJC0_SX415 +MLJC0_SX55 +MLJH0_SI1324 +MLJH0_SI1422 +MLJH0_SI694 +MLJH0_SX154 +MLJH0_SX244 +MLJH0_SX334 +MLJH0_SX424 +MLJH0_SX64 +MLNS0_SI1407 +MLNS0_SI2037 +MLNS0_SI777 +MLNS0_SX147 +MLNS0_SX237 +MLNS0_SX327 +MLNS0_SX417 +MLNS0_SX57 +MLSH0_SI1417 +MLSH0_SI2047 +MLSH0_SI787 +MLSH0_SX157 +MLSH0_SX247 +MLSH0_SX337 +MLSH0_SX427 +MLSH0_SX67 +MMAA0_SI1588 +MMAA0_SI2105 +MMAA0_SI845 +MMAA0_SX125 +MMAA0_SX215 +MMAA0_SX305 +MMAA0_SX35 +MMAA0_SX395 +MMAB1_SI1494 +MMAB1_SI2124 +MMAB1_SI864 +MMAB1_SX144 +MMAB1_SX234 +MMAB1_SX324 +MMAB1_SX414 +MMAB1_SX54 +MMAG0_SI1126 +MMAG0_SI1756 +MMAG0_SI496 +MMAG0_SX136 +MMAG0_SX226 +MMAG0_SX316 +MMAG0_SX406 +MMAG0_SX46 +MMAM0_SI1597 +MMAM0_SI1668 +MMAM0_SI2227 +MMAM0_SX157 +MMAM0_SX247 +MMAM0_SX337 +MMAM0_SX427 +MMAM0_SX67 +MMAR0_SI1336 +MMAR0_SI1966 +MMAR0_SI706 +MMAR0_SX166 +MMAR0_SX256 +MMAR0_SX346 +MMAR0_SX436 +MMAR0_SX76 +MMBS0_SI1151 +MMBS0_SI1781 +MMBS0_SI521 +MMBS0_SX161 +MMBS0_SX251 +MMBS0_SX341 +MMBS0_SX431 +MMBS0_SX71 +MMCC0_SI1338 +MMCC0_SI1968 +MMCC0_SI708 +MMCC0_SX168 +MMCC0_SX258 +MMCC0_SX348 +MMCC0_SX438 +MMCC0_SX78 +MMDB0_SI1358 +MMDB0_SI1617 +MMDB0_SI987 +MMDB0_SX177 +MMDB0_SX267 +MMDB0_SX357 +MMDB0_SX447 +MMDB0_SX87 +MMDG0_SI1780 +MMDG0_SI2035 +MMDG0_SI520 +MMDG0_SX160 +MMDG0_SX250 +MMDG0_SX340 +MMDG0_SX430 +MMDG0_SX70 +MMDM0_SI1311 +MMDM0_SI1941 +MMDM0_SI681 +MMDM0_SX141 +MMDM0_SX231 +MMDM0_SX321 +MMDM0_SX411 +MMDM0_SX51 +MMDM1_SI1650 +MMDM1_SI2043 +MMDM1_SI783 +MMDM1_SX153 +MMDM1_SX243 +MMDM1_SX333 +MMDM1_SX423 +MMDM1_SX63 +MMDS0_SI1343 +MMDS0_SI1973 +MMDS0_SI713 +MMDS0_SX173 +MMDS0_SX263 +MMDS0_SX353 +MMDS0_SX443 +MMDS0_SX83 +MMEA0_SI1388 +MMEA0_SI2018 +MMEA0_SI758 +MMEA0_SX128 +MMEA0_SX218 +MMEA0_SX308 +MMEA0_SX38 +MMEA0_SX398 +MMEB0_SI1357 +MMEB0_SI1987 +MMEB0_SI727 +MMEB0_SX187 +MMEB0_SX327 +MMEB0_SX367 +MMEB0_SX7 +MMEB0_SX97 +MMGC0_SI1305 +MMGC0_SI1935 +MMGC0_SI2184 +MMGC0_SX135 +MMGC0_SX225 +MMGC0_SX315 +MMGC0_SX405 +MMGC0_SX45 +MMGG0_SI1079 +MMGG0_SI1709 +MMGG0_SI2339 +MMGG0_SX179 +MMGG0_SX269 +MMGG0_SX359 +MMGG0_SX449 +MMGG0_SX89 +MMGK0_SI1322 +MMGK0_SI1952 +MMGK0_SI692 +MMGK0_SX152 +MMGK0_SX242 +MMGK0_SX332 +MMGK0_SX422 +MMGK0_SX62 +MMJB1_SI1408 +MMJB1_SI2038 +MMJB1_SI778 +MMJB1_SX148 +MMJB1_SX238 +MMJB1_SX328 +MMJB1_SX418 +MMJB1_SX58 +MMLM0_SI1527 +MMLM0_SI2150 +MMLM0_SI897 +MMLM0_SX177 +MMLM0_SX267 +MMLM0_SX357 +MMLM0_SX447 +MMLM0_SX87 +MMPM0_SI1061 +MMPM0_SI1691 +MMPM0_SI2321 +MMPM0_SX161 +MMPM0_SX251 +MMPM0_SX341 +MMPM0_SX431 +MMPM0_SX71 +MMRP0_SI2034 +MMRP0_SI717 +MMRP0_SI774 +MMRP0_SX144 +MMRP0_SX234 +MMRP0_SX324 +MMRP0_SX414 +MMRP0_SX54 +MMSM0_SI1106 +MMSM0_SI1736 +MMSM0_SI476 +MMSM0_SX116 +MMSM0_SX206 +MMSM0_SX26 +MMSM0_SX296 +MMSM0_SX386 +MMVP0_SI1284 +MMVP0_SI1914 +MMVP0_SI654 +MMVP0_SX114 +MMVP0_SX204 +MMVP0_SX294 +MMVP0_SX347 +MMVP0_SX384 +MMWB0_SI1619 +MMWB0_SI2249 +MMWB0_SI989 +MMWB0_SX179 +MMWB0_SX269 +MMWB0_SX359 +MMWB0_SX449 +MMWB0_SX89 +MMWS0_SI1518 +MMWS0_SI559 +MMWS0_SI888 +MMWS0_SX168 +MMWS0_SX258 +MMWS0_SX348 +MMWS0_SX438 +MMWS0_SX78 +MMWS1_SI1071 +MMWS1_SI1701 +MMWS1_SI2331 +MMWS1_SX261 +MMWS1_SX27 +MMWS1_SX351 +MMWS1_SX441 +MMWS1_SX81 +MMXS0_SI2136 +MMXS0_SI629 +MMXS0_SI876 +MMXS0_SX156 +MMXS0_SX246 +MMXS0_SX336 +MMXS0_SX426 +MMXS0_SX66 +MNET0_SI1446 +MNET0_SI2076 +MNET0_SI816 +MNET0_SX186 +MNET0_SX276 +MNET0_SX366 +MNET0_SX6 +MNET0_SX96 +MNTW0_SI1068 +MNTW0_SI1698 +MNTW0_SI2328 +MNTW0_SX168 +MNTW0_SX202 +MNTW0_SX258 +MNTW0_SX348 +MNTW0_SX78 +MPAR0_SI1576 +MPAR0_SI2206 +MPAR0_SI946 +MPAR0_SX136 +MPAR0_SX226 +MPAR0_SX316 +MPAR0_SX406 +MPAR0_SX46 +MPEB0_SI1034 +MPEB0_SI1860 +MPEB0_SI600 +MPEB0_SX150 +MPEB0_SX240 +MPEB0_SX330 +MPEB0_SX420 +MPEB0_SX60 +MPFU0_SI1258 +MPFU0_SI1888 +MPFU0_SI628 +MPFU0_SX178 +MPFU0_SX268 +MPFU0_SX358 +MPFU0_SX448 +MPFU0_SX88 +MPGH0_SI1554 +MPGH0_SI675 +MPGH0_SI924 +MPGH0_SX114 +MPGH0_SX204 +MPGH0_SX24 +MPGH0_SX294 +MPGH0_SX384 +MPGR0_SI1410 +MPGR0_SI2040 +MPGR0_SI780 +MPGR0_SX150 +MPGR0_SX240 +MPGR0_SX330 +MPGR0_SX420 +MPGR0_SX60 +MPGR1_SI1269 +MPGR1_SI1499 +MPGR1_SI2129 +MPGR1_SX149 +MPGR1_SX239 +MPGR1_SX329 +MPGR1_SX419 +MPGR1_SX59 +MPMB0_SI1501 +MPMB0_SI2131 +MPMB0_SI871 +MPMB0_SX151 +MPMB0_SX241 +MPMB0_SX331 +MPMB0_SX421 +MPMB0_SX61 +MPPC0_SI1412 +MPPC0_SI2042 +MPPC0_SI782 +MPPC0_SX152 +MPPC0_SX242 +MPPC0_SX332 +MPPC0_SX422 +MPPC0_SX62 +MPRB0_SI1205 +MPRB0_SI1215 +MPRB0_SI575 +MPRB0_SX125 +MPRB0_SX215 +MPRB0_SX305 +MPRB0_SX35 +MPRB0_SX395 +MPRD0_SI1431 +MPRD0_SI2061 +MPRD0_SI801 +MPRD0_SX171 +MPRD0_SX261 +MPRD0_SX351 +MPRD0_SX441 +MPRD0_SX81 +MPRK0_SI1097 +MPRK0_SI1727 +MPRK0_SI467 +MPRK0_SX107 +MPRK0_SX17 +MPRK0_SX197 +MPRK0_SX287 +MPRK0_SX377 +MPRT0_SI1210 +MPRT0_SI495 +MPRT0_SI580 +MPRT0_SX130 +MPRT0_SX220 +MPRT0_SX310 +MPRT0_SX40 +MPRT0_SX400 +MPSW0_SI1067 +MPSW0_SI1697 +MPSW0_SI2327 +MPSW0_SX167 +MPSW0_SX24 +MPSW0_SX257 +MPSW0_SX437 +MPSW0_SX77 +MRAB0_SI1224 +MRAB0_SI1854 +MRAB0_SI594 +MRAB0_SX144 +MRAB0_SX234 +MRAB0_SX324 +MRAB0_SX414 +MRAB0_SX54 +MRAB1_SI1478 +MRAB1_SI2108 +MRAB1_SI848 +MRAB1_SX128 +MRAB1_SX218 +MRAB1_SX308 +MRAB1_SX38 +MRAB1_SX398 +MRAI0_SI1954 +MRAI0_SI2052 +MRAI0_SI792 +MRAI0_SX162 +MRAI0_SX252 +MRAI0_SX342 +MRAI0_SX432 +MRAI0_SX72 +MRAM0_SI1275 +MRAM0_SI1905 +MRAM0_SI1951 +MRAM0_SX105 +MRAM0_SX15 +MRAM0_SX195 +MRAM0_SX285 +MRAM0_SX375 +MRAV0_SI1008 +MRAV0_SI1638 +MRAV0_SI2268 +MRAV0_SX108 +MRAV0_SX18 +MRAV0_SX198 +MRAV0_SX288 +MRAV0_SX378 +MRBC0_SI1665 +MRBC0_SI1859 +MRBC0_SI599 +MRBC0_SX149 +MRBC0_SX239 +MRBC0_SX329 +MRBC0_SX419 +MRBC0_SX59 +MRCG0_SI1428 +MRCG0_SI2058 +MRCG0_SI798 +MRCG0_SX168 +MRCG0_SX258 +MRCG0_SX348 +MRCG0_SX438 +MRCG0_SX78 +MRCW0_SI1371 +MRCW0_SI2001 +MRCW0_SI741 +MRCW0_SX111 +MRCW0_SX201 +MRCW0_SX21 +MRCW0_SX291 +MRCW0_SX381 +MRDD0_SI1050 +MRDD0_SI1680 +MRDD0_SI2310 +MRDD0_SX150 +MRDD0_SX240 +MRDD0_SX277 +MRDD0_SX330 +MRDD0_SX60 +MRDM0_SI1044 +MRDM0_SI1595 +MRDM0_SI965 +MRDM0_SX155 +MRDM0_SX245 +MRDM0_SX335 +MRDM0_SX425 +MRDM0_SX65 +MRDS0_SI1167 +MRDS0_SI1797 +MRDS0_SI537 +MRDS0_SX177 +MRDS0_SX267 +MRDS0_SX357 +MRDS0_SX447 +MRDS0_SX87 +MREE0_SI1104 +MREE0_SI1734 +MREE0_SI1959 +MREE0_SX114 +MREE0_SX204 +MREE0_SX24 +MREE0_SX294 +MREE0_SX384 +MREH1_SI1599 +MREH1_SI2229 +MREH1_SI969 +MREH1_SX159 +MREH1_SX249 +MREH1_SX339 +MREH1_SX429 +MREH1_SX69 +MREM0_SI1591 +MREM0_SI511 +MREM0_SI961 +MREM0_SX151 +MREM0_SX241 +MREM0_SX331 +MREM0_SX421 +MREM0_SX61 +MREW1_SI1500 +MREW1_SI2130 +MREW1_SI870 +MREW1_SX150 +MREW1_SX240 +MREW1_SX330 +MREW1_SX420 +MREW1_SX60 +MRFK0_SI1076 +MRFK0_SI1706 +MRFK0_SI2336 +MRFK0_SX176 +MRFK0_SX266 +MRFK0_SX356 +MRFK0_SX446 +MRFK0_SX86 +MRFL0_SI1156 +MRFL0_SI1786 +MRFL0_SI526 +MRFL0_SX166 +MRFL0_SX256 +MRFL0_SX346 +MRFL0_SX436 +MRFL0_SX76 +MRGM0_SI1162 +MRGM0_SI1792 +MRGM0_SI532 +MRGM0_SX172 +MRGM0_SX262 +MRGM0_SX416 +MRGM0_SX442 +MRGM0_SX82 +MRGS0_SI1356 +MRGS0_SI1986 +MRGS0_SI726 +MRGS0_SX186 +MRGS0_SX276 +MRGS0_SX366 +MRGS0_SX6 +MRGS0_SX96 +MRHL0_SI1515 +MRHL0_SI2145 +MRHL0_SI885 +MRHL0_SX165 +MRHL0_SX255 +MRHL0_SX345 +MRHL0_SX435 +MRHL0_SX75 +MRJB1_SI1020 +MRJB1_SI1413 +MRJB1_SI2021 +MRJB1_SX120 +MRJB1_SX210 +MRJB1_SX30 +MRJB1_SX300 +MRJB1_SX390 +MRJH0_SI1519 +MRJH0_SI889 +MRJH0_SI914 +MRJH0_SX169 +MRJH0_SX259 +MRJH0_SX307 +MRJH0_SX439 +MRJH0_SX79 +MRJM0_SI1095 +MRJM0_SI1228 +MRJM0_SI1858 +MRJM0_SX148 +MRJM0_SX238 +MRJM0_SX328 +MRJM0_SX418 +MRJM0_SX58 +MRJM1_SI1298 +MRJM1_SI1928 +MRJM1_SI668 +MRJM1_SX128 +MRJM1_SX218 +MRJM1_SX308 +MRJM1_SX38 +MRJM1_SX398 +MRJT0_SI1498 +MRJT0_SI1805 +MRJT0_SI868 +MRJT0_SX148 +MRJT0_SX238 +MRJT0_SX328 +MRJT0_SX418 +MRJT0_SX58 +MRKM0_SI1267 +MRKM0_SI1391 +MRKM0_SI637 +MRKM0_SX187 +MRKM0_SX277 +MRKM0_SX367 +MRKM0_SX7 +MRKM0_SX97 +MRLD0_SI1594 +MRLD0_SI2224 +MRLD0_SI964 +MRLD0_SX154 +MRLD0_SX244 +MRLD0_SX334 +MRLD0_SX424 +MRLD0_SX64 +MRLJ0_SI1420 +MRLJ0_SI2050 +MRLJ0_SI790 +MRLJ0_SX160 +MRLJ0_SX250 +MRLJ0_SX340 +MRLJ0_SX430 +MRLJ0_SX70 +MRLJ1_SI1671 +MRLJ1_SI2301 +MRLJ1_SI2332 +MRLJ1_SX141 +MRLJ1_SX231 +MRLJ1_SX321 +MRLJ1_SX411 +MRLJ1_SX51 +MRLK0_SI1468 +MRLK0_SI2140 +MRLK0_SI843 +MRLK0_SX123 +MRLK0_SX213 +MRLK0_SX303 +MRLK0_SX33 +MRLK0_SX393 +MRLR0_SI1196 +MRLR0_SI1826 +MRLR0_SI566 +MRLR0_SX116 +MRLR0_SX206 +MRLR0_SX26 +MRLR0_SX296 +MRLR0_SX386 +MRMB0_SI1581 +MRMB0_SI2211 +MRMB0_SI951 +MRMB0_SX141 +MRMB0_SX231 +MRMB0_SX321 +MRMB0_SX411 +MRMB0_SX51 +MRMG0_SI1080 +MRMG0_SI1710 +MRMG0_SI2340 +MRMG0_SX180 +MRMG0_SX270 +MRMG0_SX360 +MRMG0_SX450 +MRMG0_SX90 +MRMH0_SI1021 +MRMH0_SI1349 +MRMH0_SI2281 +MRMH0_SX121 +MRMH0_SX211 +MRMH0_SX301 +MRMH0_SX31 +MRMH0_SX391 +MRML0_SI1421 +MRML0_SI2051 +MRML0_SI791 +MRML0_SX161 +MRML0_SX251 +MRML0_SX341 +MRML0_SX431 +MRML0_SX71 +MRMS0_SI1113 +MRMS0_SI2057 +MRMS0_SI2100 +MRMS0_SX120 +MRMS0_SX210 +MRMS0_SX30 +MRMS0_SX300 +MRMS0_SX390 +MRPC1_SI1482 +MRPC1_SI2026 +MRPC1_SI2112 +MRPC1_SX132 +MRPC1_SX222 +MRPC1_SX312 +MRPC1_SX402 +MRPC1_SX42 +MRRE0_SI1334 +MRRE0_SI704 +MRRE0_SI952 +MRRE0_SX164 +MRRE0_SX254 +MRRE0_SX344 +MRRE0_SX434 +MRRE0_SX74 +MRSO0_SI1206 +MRSO0_SI1659 +MRSO0_SI2289 +MRSO0_SX129 +MRSO0_SX219 +MRSO0_SX309 +MRSO0_SX39 +MRSO0_SX399 +MRSP0_SI1429 +MRSP0_SI2059 +MRSP0_SI799 +MRSP0_SX169 +MRSP0_SX196 +MRSP0_SX259 +MRSP0_SX439 +MRSP0_SX79 +MRTC0_SI1458 +MRTC0_SI2088 +MRTC0_SI828 +MRTC0_SX108 +MRTC0_SX18 +MRTC0_SX198 +MRTC0_SX288 +MRTC0_SX378 +MRTJ0_SI1551 +MRTJ0_SI2032 +MRTJ0_SI772 +MRTJ0_SX142 +MRTJ0_SX232 +MRTJ0_SX322 +MRTJ0_SX412 +MRTJ0_SX52 +MRVG0_SI1140 +MRVG0_SI1770 +MRVG0_SI510 +MRVG0_SX150 +MRVG0_SX240 +MRVG0_SX330 +MRVG0_SX420 +MRVG0_SX60 +MRWA0_SI1603 +MRWA0_SI2233 +MRWA0_SI973 +MRWA0_SX163 +MRWA0_SX253 +MRWA0_SX343 +MRWA0_SX433 +MRWA0_SX73 +MRWS0_SI1102 +MRWS0_SI1732 +MRWS0_SI472 +MRWS0_SX112 +MRWS0_SX202 +MRWS0_SX22 +MRWS0_SX292 +MRWS0_SX382 +MRXB0_SI1585 +MRXB0_SI2215 +MRXB0_SI955 +MRXB0_SX145 +MRXB0_SX235 +MRXB0_SX325 +MRXB0_SX415 +MRXB0_SX55 +MSAH1_SI1049 +MSAH1_SI1679 +MSAH1_SI2309 +MSAH1_SX149 +MSAH1_SX239 +MSAH1_SX329 +MSAH1_SX419 +MSAH1_SX59 +MSAS0_SI1376 +MSAS0_SI2006 +MSAS0_SI746 +MSAS0_SX116 +MSAS0_SX206 +MSAS0_SX26 +MSAS0_SX296 +MSAS0_SX386 +MSAT0_SI1526 +MSAT0_SI2156 +MSAT0_SI896 +MSAT0_SX176 +MSAT0_SX266 +MSAT0_SX356 +MSAT0_SX446 +MSAT0_SX86 +MSAT1_SI1073 +MSAT1_SI1703 +MSAT1_SI2333 +MSAT1_SX173 +MSAT1_SX263 +MSAT1_SX353 +MSAT1_SX443 +MSAT1_SX83 +MSDB0_SI1007 +MSDB0_SI1637 +MSDB0_SI2267 +MSDB0_SX107 +MSDB0_SX17 +MSDB0_SX197 +MSDB0_SX287 +MSDB0_SX377 +MSDH0_SI2113 +MSDH0_SI2240 +MSDH0_SI980 +MSDH0_SX170 +MSDH0_SX260 +MSDH0_SX350 +MSDH0_SX440 +MSDH0_SX80 +MSDS0_SI1077 +MSDS0_SI1707 +MSDS0_SI2337 +MSDS0_SX177 +MSDS0_SX267 +MSDS0_SX357 +MSDS0_SX447 +MSDS0_SX87 +MSEM1_SI1440 +MSEM1_SI2070 +MSEM1_SI810 +MSEM1_SX180 +MSEM1_SX270 +MSEM1_SX360 +MSEM1_SX450 +MSEM1_SX90 +MSES0_SI1589 +MSES0_SI2216 +MSES0_SI2219 +MSES0_SX149 +MSES0_SX239 +MSES0_SX329 +MSES0_SX419 +MSES0_SX59 +MSFH0_SI1216 +MSFH0_SI1738 +MSFH0_SI586 +MSFH0_SX136 +MSFH0_SX226 +MSFH0_SX316 +MSFH0_SX406 +MSFH0_SX46 +MSFV0_SI1262 +MSFV0_SI1892 +MSFV0_SI632 +MSFV0_SX182 +MSFV0_SX272 +MSFV0_SX362 +MSFV0_SX452 +MSFV0_SX92 +MSJK0_SI1596 +MSJK0_SI2226 +MSJK0_SI966 +MSJK0_SX156 +MSJK0_SX246 +MSJK0_SX336 +MSJK0_SX426 +MSJK0_SX66 +MSMC0_SI1907 +MSMC0_SI509 +MSMC0_SI647 +MSMC0_SX107 +MSMC0_SX17 +MSMC0_SX197 +MSMC0_SX287 +MSMC0_SX377 +MSMR0_SI1150 +MSMR0_SI1405 +MSMR0_SI775 +MSMR0_SX145 +MSMR0_SX235 +MSMR0_SX325 +MSMR0_SX415 +MSMR0_SX55 +MSMS0_SI1433 +MSMS0_SI2063 +MSMS0_SI803 +MSMS0_SX173 +MSMS0_SX263 +MSMS0_SX353 +MSMS0_SX443 +MSMS0_SX83 +MSRG0_SI1221 +MSRG0_SI1851 +MSRG0_SI591 +MSRG0_SX141 +MSRG0_SX231 +MSRG0_SX321 +MSRG0_SX411 +MSRG0_SX51 +MSRR0_SI1131 +MSRR0_SI1761 +MSRR0_SI501 +MSRR0_SX141 +MSRR0_SX231 +MSRR0_SX30 +MSRR0_SX411 +MSRR0_SX51 +MSTF0_SI1396 +MSTF0_SI766 +MSTF0_SI852 +MSTF0_SX136 +MSTF0_SX226 +MSTF0_SX316 +MSTF0_SX406 +MSTF0_SX46 +MSVS0_SI1568 +MSVS0_SI2198 +MSVS0_SI938 +MSVS0_SX128 +MSVS0_SX218 +MSVS0_SX308 +MSVS0_SX38 +MSVS0_SX398 +MTAB0_SI1572 +MTAB0_SI2202 +MTAB0_SI942 +MTAB0_SX132 +MTAB0_SX222 +MTAB0_SX312 +MTAB0_SX402 +MTAB0_SX42 +MTAS0_SI1385 +MTAS0_SI2015 +MTAS0_SI755 +MTAS0_SX125 +MTAS0_SX215 +MTAS0_SX305 +MTAS0_SX35 +MTAS0_SX395 +MTAT0_SI1110 +MTAT0_SI1740 +MTAT0_SI811 +MTAT0_SX120 +MTAT0_SX210 +MTAT0_SX30 +MTAT0_SX300 +MTAT0_SX390 +MTAT1_SI1409 +MTAT1_SI1627 +MTAT1_SI779 +MTAT1_SX149 +MTAT1_SX239 +MTAT1_SX329 +MTAT1_SX419 +MTAT1_SX59 +MTBC0_SI1173 +MTBC0_SI1803 +MTBC0_SI543 +MTBC0_SX183 +MTBC0_SX273 +MTBC0_SX347 +MTBC0_SX363 +MTBC0_SX93 +MTCS0_SI1972 +MTCS0_SI2265 +MTCS0_SI712 +MTCS0_SX172 +MTCS0_SX262 +MTCS0_SX352 +MTCS0_SX442 +MTCS0_SX82 +MTDB0_SI1401 +MTDB0_SI2031 +MTDB0_SI771 +MTDB0_SX141 +MTDB0_SX231 +MTDB0_SX321 +MTDB0_SX411 +MTDB0_SX51 +MTDP0_SI1274 +MTDP0_SI1521 +MTDP0_SI2151 +MTDP0_SX171 +MTDP0_SX261 +MTDP0_SX351 +MTDP0_SX441 +MTDP0_SX81 +MTER0_SI1157 +MTER0_SI1787 +MTER0_SI527 +MTER0_SX167 +MTER0_SX17 +MTER0_SX257 +MTER0_SX437 +MTER0_SX77 +MTJG0_SI1520 +MTJG0_SI2157 +MTJG0_SI890 +MTJG0_SX170 +MTJG0_SX260 +MTJG0_SX350 +MTJG0_SX440 +MTJG0_SX80 +MTJM0_SI1226 +MTJM0_SI1856 +MTJM0_SI655 +MTJM0_SX146 +MTJM0_SX236 +MTJM0_SX326 +MTJM0_SX416 +MTJM0_SX56 +MTJS0_SI1192 +MTJS0_SI1822 +MTJS0_SI562 +MTJS0_SX112 +MTJS0_SX202 +MTJS0_SX22 +MTJS0_SX292 +MTJS0_SX382 +MTJU0_SI2020 +MTJU0_SI2269 +MTJU0_SI760 +MTJU0_SX130 +MTJU0_SX220 +MTJU0_SX310 +MTJU0_SX40 +MTJU0_SX400 +MTKD0_SI1187 +MTKD0_SI1817 +MTKD0_SI630 +MTKD0_SX107 +MTKD0_SX17 +MTKD0_SX197 +MTKD0_SX287 +MTKD0_SX377 +MTKP0_SI1023 +MTKP0_SI2283 +MTKP0_SI454 +MTKP0_SX123 +MTKP0_SX213 +MTKP0_SX303 +MTKP0_SX33 +MTKP0_SX393 +MTLB0_SI1134 +MTLB0_SI1764 +MTLB0_SI504 +MTLB0_SX144 +MTLB0_SX234 +MTLB0_SX324 +MTLB0_SX414 +MTLB0_SX54 +MTLC0_SI1313 +MTLC0_SI1477 +MTLC0_SI847 +MTLC0_SX127 +MTLC0_SX217 +MTLC0_SX307 +MTLC0_SX37 +MTLC0_SX397 +MTML0_SI1065 +MTML0_SI1695 +MTML0_SI2325 +MTML0_SX165 +MTML0_SX255 +MTML0_SX345 +MTML0_SX435 +MTML0_SX75 +MTMN0_SI1064 +MTMN0_SI2324 +MTMN0_SI582 +MTMN0_SX164 +MTMN0_SX254 +MTMN0_SX344 +MTMN0_SX434 +MTMN0_SX74 +MTMT0_SI1118 +MTMT0_SI1748 +MTMT0_SI488 +MTMT0_SX128 +MTMT0_SX218 +MTMT0_SX308 +MTMT0_SX38 +MTMT0_SX398 +MTPF0_SI1235 +MTPF0_SI1865 +MTPF0_SI605 +MTPF0_SX155 +MTPF0_SX245 +MTPF0_SX335 +MTPF0_SX425 +MTPF0_SX65 +MTPG0_SI1383 +MTPG0_SI2013 +MTPG0_SI753 +MTPG0_SX123 +MTPG0_SX213 +MTPG0_SX303 +MTPG0_SX33 +MTPG0_SX393 +MTPP0_SI1508 +MTPP0_SI2138 +MTPP0_SI878 +MTPP0_SX158 +MTPP0_SX248 +MTPP0_SX338 +MTPP0_SX428 +MTPP0_SX68 +MTPR0_SI1600 +MTPR0_SI2230 +MTPR0_SI506 +MTPR0_SX160 +MTPR0_SX250 +MTPR0_SX340 +MTPR0_SX430 +MTPR0_SX70 +MTQC0_SI1441 +MTQC0_SI2071 +MTQC0_SI480 +MTQC0_SX181 +MTQC0_SX271 +MTQC0_SX361 +MTQC0_SX451 +MTQC0_SX91 +MTRC0_SI1623 +MTRC0_SI589 +MTRC0_SI993 +MTRC0_SX170 +MTRC0_SX183 +MTRC0_SX273 +MTRC0_SX363 +MTRC0_SX93 +MTRR0_SI1548 +MTRR0_SI2178 +MTRR0_SI918 +MTRR0_SX108 +MTRR0_SX18 +MTRR0_SX198 +MTRR0_SX288 +MTRR0_SX378 +MTRT0_SI1227 +MTRT0_SI1857 +MTRT0_SI597 +MTRT0_SX147 +MTRT0_SX237 +MTRT0_SX254 +MTRT0_SX417 +MTRT0_SX57 +MTWH1_SI1512 +MTWH1_SI2142 +MTWH1_SI882 +MTWH1_SX162 +MTWH1_SX252 +MTWH1_SX342 +MTWH1_SX432 +MTWH1_SX72 +MTXS0_SI1060 +MTXS0_SI1690 +MTXS0_SI2320 +MTXS0_SX160 +MTXS0_SX250 +MTXS0_SX340 +MTXS0_SX430 +MTXS0_SX70 +MVJH0_SI1556 +MVJH0_SI2186 +MVJH0_SI926 +MVJH0_SX116 +MVJH0_SX206 +MVJH0_SX26 +MVJH0_SX296 +MVJH0_SX386 +MVLO0_SI1147 +MVLO0_SI1777 +MVLO0_SI517 +MVLO0_SX157 +MVLO0_SX247 +MVLO0_SX337 +MVLO0_SX427 +MVLO0_SX67 +MVRW0_SI1485 +MVRW0_SI2115 +MVRW0_SI855 +MVRW0_SX135 +MVRW0_SX225 +MVRW0_SX315 +MVRW0_SX405 +MVRW0_SX45 +MWAC0_SI1601 +MWAC0_SI2231 +MWAC0_SI971 +MWAC0_SX161 +MWAC0_SX251 +MWAC0_SX341 +MWAC0_SX431 +MWAC0_SX71 +MWAD0_SI1062 +MWAD0_SI1749 +MWAD0_SI2322 +MWAD0_SX162 +MWAD0_SX252 +MWAD0_SX342 +MWAD0_SX432 +MWAD0_SX72 +MWAR0_SI1045 +MWAR0_SI1675 +MWAR0_SI2305 +MWAR0_SX145 +MWAR0_SX235 +MWAR0_SX325 +MWAR0_SX415 +MWAR0_SX55 +MWCH0_SI1622 +MWCH0_SI1895 +MWCH0_SI2252 +MWCH0_SX182 +MWCH0_SX272 +MWCH0_SX362 +MWCH0_SX452 +MWCH0_SX92 +MWDK0_SI1436 +MWDK0_SI2017 +MWDK0_SI806 +MWDK0_SX176 +MWDK0_SX266 +MWDK0_SX356 +MWDK0_SX446 +MWDK0_SX86 +MWEM0_SI1320 +MWEM0_SI1393 +MWEM0_SI1950 +MWEM0_SX150 +MWEM0_SX240 +MWEM0_SX330 +MWEM0_SX420 +MWEM0_SX60 +MWGR0_SI1606 +MWGR0_SI2236 +MWGR0_SI976 +MWGR0_SX166 +MWGR0_SX256 +MWGR0_SX346 +MWGR0_SX436 +MWGR0_SX76 +MWRE0_SI1057 +MWRE0_SI1687 +MWRE0_SI2317 +MWRE0_SX157 +MWRE0_SX247 +MWRE0_SX337 +MWRE0_SX427 +MWRE0_SX67 +MWRP0_SI1443 +MWRP0_SI1525 +MWRP0_SI2073 +MWRP0_SX183 +MWRP0_SX273 +MWRP0_SX3 +MWRP0_SX363 +MWRP0_SX93 +MWSB0_SI1626 +MWSB0_SI2256 +MWSB0_SI996 +MWSB0_SX186 +MWSB0_SX276 +MWSB0_SX366 +MWSB0_SX6 +MWSB0_SX96 +MWSH0_SI1426 +MWSH0_SI2266 +MWSH0_SI796 +MWSH0_SX166 +MWSH0_SX256 +MWSH0_SX346 +MWSH0_SX436 +MWSH0_SX76 +MZMB0_SI1166 +MZMB0_SI1796 +MZMB0_SI536 +MZMB0_SX176 +MZMB0_SX266 +MZMB0_SX356 +MZMB0_SX446 +MZMB0_SX86 diff --git a/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid b/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid new file mode 100644 index 0000000000..c39fd0b91d --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid @@ -0,0 +1,3696 @@ +FAEM0_SI1392 +FAEM0_SI2022 +FAEM0_SI762 +FAEM0_SX132 +FAEM0_SX222 +FAEM0_SX312 +FAEM0_SX402 +FAEM0_SX42 +FAJW0_SI1263 +FAJW0_SI1893 +FAJW0_SI633 +FAJW0_SX183 +FAJW0_SX273 +FAJW0_SX3 +FAJW0_SX363 +FAJW0_SX93 +FALK0_SI1086 +FALK0_SI456 +FALK0_SI658 +FALK0_SX186 +FALK0_SX276 +FALK0_SX366 +FALK0_SX6 +FALK0_SX96 +FALR0_SI1325 +FALR0_SI1955 +FALR0_SI695 +FALR0_SX155 +FALR0_SX245 +FALR0_SX335 +FALR0_SX425 +FALR0_SX65 +FAPB0_SI1063 +FAPB0_SI1693 +FAPB0_SI2323 +FAPB0_SX163 +FAPB0_SX253 +FAPB0_SX343 +FAPB0_SX433 +FAPB0_SX73 +FBAS0_SI1387 +FBAS0_SI1472 +FBAS0_SI2066 +FBAS0_SX127 +FBAS0_SX217 +FBAS0_SX307 +FBAS0_SX37 +FBAS0_SX397 +FBCG1_SI1612 +FBCG1_SI2242 +FBCG1_SI982 +FBCG1_SX172 +FBCG1_SX262 +FBCG1_SX352 +FBCG1_SX442 +FBCG1_SX82 +FBCH0_SI1586 +FBCH0_SI956 +FBCH0_SI959 +FBCH0_SX146 +FBCH0_SX236 +FBCH0_SX326 +FBCH0_SX416 +FBCH0_SX56 +FBJL0_SI1552 +FBJL0_SI2182 +FBJL0_SI922 +FBJL0_SX112 +FBJL0_SX202 +FBJL0_SX22 +FBJL0_SX292 +FBJL0_SX382 +FBLV0_SI1058 +FBLV0_SI1688 +FBLV0_SI2318 +FBLV0_SX158 +FBLV0_SX248 +FBLV0_SX338 +FBLV0_SX428 +FBLV0_SX68 +FBMH0_SI1136 +FBMH0_SI1766 +FBMH0_SI970 +FBMH0_SX146 +FBMH0_SX236 +FBMH0_SX326 +FBMH0_SX416 +FBMH0_SX56 +FBMJ0_SI1776 +FBMJ0_SI516 +FBMJ0_SI815 +FBMJ0_SX156 +FBMJ0_SX246 +FBMJ0_SX336 +FBMJ0_SX426 +FBMJ0_SX66 +FCAG0_SI1503 +FCAG0_SI1641 +FCAG0_SI2133 +FCAG0_SX153 +FCAG0_SX243 +FCAG0_SX333 +FCAG0_SX423 +FCAG0_SX63 +FCAJ0_SI1479 +FCAJ0_SI1804 +FCAJ0_SI849 +FCAJ0_SX129 +FCAJ0_SX219 +FCAJ0_SX309 +FCAJ0_SX39 +FCAJ0_SX399 +FCDR1_SI1186 +FCDR1_SI1816 +FCDR1_SI556 +FCDR1_SX106 +FCDR1_SX16 +FCDR1_SX196 +FCDR1_SX286 +FCDR1_SX376 +FCEG0_SI1248 +FCEG0_SI1878 +FCEG0_SI618 +FCEG0_SX168 +FCEG0_SX258 +FCEG0_SX348 +FCEG0_SX438 +FCEG0_SX78 +FCJF0_SI1027 +FCJF0_SI1657 +FCJF0_SI648 +FCJF0_SX127 +FCJF0_SX217 +FCJF0_SX307 +FCJF0_SX37 +FCJF0_SX397 +FCJS0_SI1607 +FCJS0_SI2237 +FCJS0_SI977 +FCJS0_SX167 +FCJS0_SX257 +FCJS0_SX347 +FCJS0_SX437 +FCJS0_SX77 +FCKE0_SI1111 +FCKE0_SI1741 +FCKE0_SI481 +FCKE0_SX121 +FCKE0_SX211 +FCKE0_SX301 +FCKE0_SX31 +FCKE0_SX391 +FCLT0_SI1438 +FCLT0_SI2068 +FCLT0_SI808 +FCLT0_SX178 +FCLT0_SX268 +FCLT0_SX358 +FCLT0_SX448 +FCLT0_SX88 +FCMG0_SI1142 +FCMG0_SI1242 +FCMG0_SI1872 +FCMG0_SX162 +FCMG0_SX252 +FCMG0_SX342 +FCMG0_SX432 +FCMG0_SX72 +FCMM0_SI1083 +FCMM0_SI1957 +FCMM0_SI453 +FCMM0_SX183 +FCMM0_SX273 +FCMM0_SX363 +FCMM0_SX420 +FCMM0_SX93 +FCRZ0_SI1913 +FCRZ0_SI2053 +FCRZ0_SI793 +FCRZ0_SX163 +FCRZ0_SX253 +FCRZ0_SX343 +FCRZ0_SX433 +FCRZ0_SX73 +FCYL0_SI1297 +FCYL0_SI1927 +FCYL0_SI667 +FCYL0_SX127 +FCYL0_SX217 +FCYL0_SX349 +FCYL0_SX37 +FCYL0_SX397 +FDAS1_SI1461 +FDAS1_SI2091 +FDAS1_SI831 +FDAS1_SX111 +FDAS1_SX201 +FDAS1_SX21 +FDAS1_SX291 +FDAS1_SX381 +FDAW0_SI1271 +FDAW0_SI1406 +FDAW0_SI2036 +FDAW0_SX146 +FDAW0_SX236 +FDAW0_SX326 +FDAW0_SX416 +FDAW0_SX56 +FDFB0_SI1318 +FDFB0_SI1948 +FDFB0_SI2010 +FDFB0_SX148 +FDFB0_SX238 +FDFB0_SX328 +FDFB0_SX418 +FDFB0_SX58 +FDJH0_SI1565 +FDJH0_SI2195 +FDJH0_SI935 +FDJH0_SX125 +FDJH0_SX215 +FDJH0_SX305 +FDJH0_SX35 +FDJH0_SX395 +FDKN0_SI1081 +FDKN0_SI1202 +FDKN0_SI1711 +FDKN0_SX181 +FDKN0_SX271 +FDKN0_SX361 +FDKN0_SX451 +FDKN0_SX91 +FDML0_SI1149 +FDML0_SI1779 +FDML0_SI2075 +FDML0_SX159 +FDML0_SX249 +FDML0_SX339 +FDML0_SX429 +FDML0_SX69 +FDMY0_SI1197 +FDMY0_SI567 +FDMY0_SI714 +FDMY0_SX117 +FDMY0_SX207 +FDMY0_SX27 +FDMY0_SX297 +FDMY0_SX387 +FDNC0_SI1278 +FDNC0_SI1908 +FDNC0_SI2287 +FDNC0_SX108 +FDNC0_SX18 +FDNC0_SX198 +FDNC0_SX288 +FDNC0_SX378 +FDTD0_SI1561 +FDTD0_SI2191 +FDTD0_SI931 +FDTD0_SX121 +FDTD0_SX211 +FDTD0_SX301 +FDTD0_SX321 +FDTD0_SX391 +FDXW0_SI1511 +FDXW0_SI2141 +FDXW0_SI881 +FDXW0_SX161 +FDXW0_SX251 +FDXW0_SX341 +FDXW0_SX431 +FDXW0_SX71 +FEAC0_SI1245 +FEAC0_SI1875 +FEAC0_SI615 +FEAC0_SX165 +FEAC0_SX255 +FEAC0_SX345 +FEAC0_SX435 +FEAC0_SX75 +FEAR0_SI1252 +FEAR0_SI1882 +FEAR0_SI622 +FEAR0_SX172 +FEAR0_SX262 +FEAR0_SX352 +FEAR0_SX442 +FEAR0_SX82 +FECD0_SI1418 +FECD0_SI2048 +FECD0_SI788 +FECD0_SX158 +FECD0_SX248 +FECD0_SX338 +FECD0_SX428 +FECD0_SX68 +FEEH0_SI1112 +FEEH0_SI1742 +FEEH0_SI471 +FEEH0_SX122 +FEEH0_SX212 +FEEH0_SX302 +FEEH0_SX32 +FEEH0_SX392 +FEME0_SI1505 +FEME0_SI2135 +FEME0_SI875 +FEME0_SX155 +FEME0_SX245 +FEME0_SX335 +FEME0_SX425 +FEME0_SX65 +FETB0_SI1148 +FETB0_SI1778 +FETB0_SI518 +FETB0_SX158 +FETB0_SX248 +FETB0_SX338 +FETB0_SX428 +FETB0_SX68 +FEXM0_SI1101 +FEXM0_SI1731 +FEXM0_SI482 +FEXM0_SX111 +FEXM0_SX201 +FEXM0_SX291 +FEXM0_SX366 +FEXM0_SX381 +FGCS0_SI1486 +FGCS0_SI2116 +FGCS0_SI856 +FGCS0_SX136 +FGCS0_SX226 +FGCS0_SX316 +FGCS0_SX406 +FGCS0_SX46 +FGDP0_SI1618 +FGDP0_SI2248 +FGDP0_SI988 +FGDP0_SX178 +FGDP0_SX268 +FGDP0_SX358 +FGDP0_SX448 +FGDP0_SX88 +FGMB0_SI1145 +FGMB0_SI1775 +FGMB0_SI515 +FGMB0_SX155 +FGMB0_SX245 +FGMB0_SX335 +FGMB0_SX425 +FGMB0_SX65 +FGRW0_SI1152 +FGRW0_SI1782 +FGRW0_SI1990 +FGRW0_SX162 +FGRW0_SX252 +FGRW0_SX342 +FGRW0_SX432 +FGRW0_SX72 +FHLM0_SI1560 +FHLM0_SI2190 +FHLM0_SI930 +FHLM0_SX120 +FHLM0_SX210 +FHLM0_SX300 +FHLM0_SX349 +FHLM0_SX390 +FHXS0_SI1075 +FHXS0_SI2302 +FHXS0_SI2335 +FHXS0_SX175 +FHXS0_SX265 +FHXS0_SX355 +FHXS0_SX445 +FHXS0_SX85 +FJDM2_SI1582 +FJDM2_SI1964 +FJDM2_SI2212 +FJDM2_SX142 +FJDM2_SX232 +FJDM2_SX322 +FJDM2_SX412 +FJDM2_SX52 +FJEN0_SI1047 +FJEN0_SI1677 +FJEN0_SI2307 +FJEN0_SX147 +FJEN0_SX237 +FJEN0_SX327 +FJEN0_SX417 +FJEN0_SX57 +FJHK0_SI1022 +FJHK0_SI1652 +FJHK0_SI2282 +FJHK0_SX122 +FJHK0_SX212 +FJHK0_SX302 +FJHK0_SX32 +FJHK0_SX392 +FJKL0_SI1562 +FJKL0_SI2192 +FJKL0_SI932 +FJKL0_SX122 +FJKL0_SX212 +FJKL0_SX302 +FJKL0_SX32 +FJKL0_SX392 +FJLG0_SI1506 +FJLG0_SI1889 +FJLG0_SI2306 +FJLG0_SX179 +FJLG0_SX269 +FJLG0_SX359 +FJLG0_SX449 +FJLG0_SX89 +FJLR0_SI1231 +FJLR0_SI1861 +FJLR0_SI601 +FJLR0_SX151 +FJLR0_SX241 +FJLR0_SX331 +FJLR0_SX421 +FJLR0_SX61 +FJRB0_SI1302 +FJRB0_SI1932 +FJRB0_SI672 +FJRB0_SX132 +FJRB0_SX222 +FJRB0_SX312 +FJRB0_SX402 +FJRB0_SX42 +FJRP1_SI1432 +FJRP1_SI2062 +FJRP1_SI802 +FJRP1_SX172 +FJRP1_SX262 +FJRP1_SX352 +FJRP1_SX442 +FJRP1_SX82 +FJSK0_SI1052 +FJSK0_SI1682 +FJSK0_SI2312 +FJSK0_SX152 +FJSK0_SX242 +FJSK0_SX332 +FJSK0_SX422 +FJSK0_SX62 +FJSP0_SI1434 +FJSP0_SI1763 +FJSP0_SI804 +FJSP0_SX174 +FJSP0_SX264 +FJSP0_SX354 +FJSP0_SX444 +FJSP0_SX84 +FJWB1_SI2055 +FJWB1_SI748 +FJWB1_SI795 +FJWB1_SX165 +FJWB1_SX255 +FJWB1_SX345 +FJWB1_SX435 +FJWB1_SX75 +FJXM0_SI1211 +FJXM0_SI1971 +FJXM0_SI581 +FJXM0_SX131 +FJXM0_SX221 +FJXM0_SX311 +FJXM0_SX401 +FJXM0_SX41 +FJXP0_SI1122 +FJXP0_SI1752 +FJXP0_SI492 +FJXP0_SX132 +FJXP0_SX222 +FJXP0_SX312 +FJXP0_SX402 +FJXP0_SX42 +FKAA0_SI1208 +FKAA0_SI1838 +FKAA0_SI578 +FKAA0_SX128 +FKAA0_SX218 +FKAA0_SX308 +FKAA0_SX38 +FKAA0_SX398 +FKDE0_SI1141 +FKDE0_SI1771 +FKDE0_SI2221 +FKDE0_SX151 +FKDE0_SX241 +FKDE0_SX331 +FKDE0_SX421 +FKDE0_SX61 +FKDW0_SI1207 +FKDW0_SI1891 +FKDW0_SI577 +FKDW0_SX127 +FKDW0_SX217 +FKDW0_SX307 +FKDW0_SX37 +FKDW0_SX397 +FKFB0_SI1608 +FKFB0_SI2238 +FKFB0_SI978 +FKFB0_SX168 +FKFB0_SX258 +FKFB0_SX348 +FKFB0_SX438 +FKFB0_SX78 +FKKH0_SI1290 +FKKH0_SI1920 +FKKH0_SI660 +FKKH0_SX120 +FKKH0_SX210 +FKKH0_SX30 +FKKH0_SX300 +FKKH0_SX390 +FKLC0_SI1615 +FKLC0_SI2245 +FKLC0_SI985 +FKLC0_SX175 +FKLC0_SX265 +FKLC0_SX355 +FKLC0_SX445 +FKLC0_SX85 +FKLC1_SI1048 +FKLC1_SI1678 +FKLC1_SI2308 +FKLC1_SX148 +FKLC1_SX238 +FKLC1_SX328 +FKLC1_SX418 +FKLC1_SX58 +FKLH0_SI1257 +FKLH0_SI1887 +FKLH0_SI627 +FKLH0_SX177 +FKLH0_SX267 +FKLH0_SX357 +FKLH0_SX447 +FKLH0_SX87 +FKSR0_SI1117 +FKSR0_SI1747 +FKSR0_SI487 +FKSR0_SX161 +FKSR0_SX217 +FKSR0_SX366 +FKSR0_SX37 +FKSR0_SX397 +FLAC0_SI1339 +FLAC0_SI2161 +FLAC0_SI901 +FLAC0_SX181 +FLAC0_SX271 +FLAC0_SX361 +FLAC0_SX451 +FLAC0_SX91 +FLAG0_SI1464 +FLAG0_SI2094 +FLAG0_SI834 +FLAG0_SX114 +FLAG0_SX204 +FLAG0_SX24 +FLAG0_SX294 +FLAG0_SX384 +FLEH0_SI1051 +FLEH0_SI1681 +FLEH0_SI2311 +FLEH0_SX151 +FLEH0_SX241 +FLEH0_SX331 +FLEH0_SX421 +FLEH0_SX61 +FLET0_SI1137 +FLET0_SI1767 +FLET0_SI507 +FLET0_SX147 +FLET0_SX237 +FLET0_SX277 +FLET0_SX417 +FLET0_SX57 +FLHD0_SI1344 +FLHD0_SI1827 +FLHD0_SI1974 +FLHD0_SX174 +FLHD0_SX264 +FLHD0_SX354 +FLHD0_SX444 +FLHD0_SX84 +FLJA0_SI1078 +FLJA0_SI1708 +FLJA0_SI2338 +FLJA0_SX178 +FLJA0_SX268 +FLJA0_SX358 +FLJA0_SX448 +FLJA0_SX88 +FLJD0_SI1516 +FLJD0_SI2146 +FLJD0_SI886 +FLJD0_SX166 +FLJD0_SX256 +FLJD0_SX346 +FLJD0_SX436 +FLJD0_SX76 +FLJG0_SI1611 +FLJG0_SI2241 +FLJG0_SI981 +FLJG0_SX171 +FLJG0_SX261 +FLJG0_SX351 +FLJG0_SX441 +FLJG0_SX81 +FLKM0_SI1880 +FLKM0_SI620 +FLKM0_SI686 +FLKM0_SX116 +FLKM0_SX260 +FLKM0_SX350 +FLKM0_SX440 +FLKM0_SX80 +FLMA0_SI1243 +FLMA0_SI1873 +FLMA0_SI613 +FLMA0_SX163 +FLMA0_SX253 +FLMA0_SX343 +FLMA0_SX433 +FLMA0_SX73 +FLMC0_SI1372 +FLMC0_SI2002 +FLMC0_SI742 +FLMC0_SX112 +FLMC0_SX22 +FLMC0_SX292 +FLMC0_SX336 +FLMC0_SX382 +FLMK0_SI1035 +FLMK0_SI1229 +FLMK0_SI2295 +FLMK0_SX135 +FLMK0_SX225 +FLMK0_SX315 +FLMK0_SX405 +FLMK0_SX45 +FLOD0_SI1287 +FLOD0_SI1917 +FLOD0_SI657 +FLOD0_SX117 +FLOD0_SX171 +FLOD0_SX207 +FLOD0_SX297 +FLOD0_SX387 +FLTM0_SI1070 +FLTM0_SI1700 +FLTM0_SI2330 +FLTM0_SX170 +FLTM0_SX260 +FLTM0_SX350 +FLTM0_SX440 +FLTM0_SX80 +FMAH1_SI1509 +FMAH1_SI2139 +FMAH1_SI879 +FMAH1_SX159 +FMAH1_SX249 +FMAH1_SX339 +FMAH1_SX429 +FMAH1_SX69 +FMBG0_SI1160 +FMBG0_SI1790 +FMBG0_SI2264 +FMBG0_SX260 +FMBG0_SX3 +FMBG0_SX350 +FMBG0_SX440 +FMBG0_SX80 +FMEM0_SI1377 +FMEM0_SI2007 +FMEM0_SI747 +FMEM0_SX117 +FMEM0_SX207 +FMEM0_SX297 +FMEM0_SX333 +FMEM0_SX387 +FMJB0_SI1177 +FMJB0_SI1807 +FMJB0_SI547 +FMJB0_SX187 +FMJB0_SX277 +FMJB0_SX367 +FMJB0_SX7 +FMJB0_SX97 +FMJF0_SI1254 +FMJF0_SI1884 +FMJF0_SI624 +FMJF0_SX174 +FMJF0_SX264 +FMJF0_SX354 +FMJF0_SX444 +FMJF0_SX84 +FMJU0_SI1389 +FMJU0_SI2019 +FMJU0_SI759 +FMJU0_SX129 +FMJU0_SX219 +FMJU0_SX309 +FMJU0_SX39 +FMJU0_SX399 +FMKC0_SI1041 +FMKC0_SI1072 +FMKC0_SI1702 +FMKC0_SX172 +FMKC0_SX262 +FMKC0_SX352 +FMKC0_SX442 +FMKC0_SX82 +FMKF0_SI1018 +FMKF0_SI1536 +FMKF0_SI906 +FMKF0_SX186 +FMKF0_SX276 +FMKF0_SX366 +FMKF0_SX6 +FMKF0_SX96 +FMMH0_SI1537 +FMMH0_SI2167 +FMMH0_SI907 +FMMH0_SX187 +FMMH0_SX367 +FMMH0_SX420 +FMMH0_SX7 +FMMH0_SX97 +FMPG0_SI1602 +FMPG0_SI2232 +FMPG0_SI972 +FMPG0_SX162 +FMPG0_SX252 +FMPG0_SX342 +FMPG0_SX432 +FMPG0_SX72 +FNKL0_SI1522 +FNKL0_SI2152 +FNKL0_SI892 +FNKL0_SX172 +FNKL0_SX196 +FNKL0_SX262 +FNKL0_SX442 +FNKL0_SX82 +FNTB0_SI1203 +FNTB0_SI573 +FNTB0_SI679 +FNTB0_SX123 +FNTB0_SX213 +FNTB0_SX303 +FNTB0_SX33 +FNTB0_SX393 +FPAB1_SI1471 +FPAB1_SI2101 +FPAB1_SI841 +FPAB1_SX121 +FPAB1_SX211 +FPAB1_SX301 +FPAB1_SX31 +FPAB1_SX391 +FPAC0_SI1921 +FPAC0_SI2011 +FPAC0_SI661 +FPAC0_SX121 +FPAC0_SX211 +FPAC0_SX301 +FPAC0_SX31 +FPAC0_SX391 +FPAD0_SI1346 +FPAD0_SI1976 +FPAD0_SI716 +FPAD0_SX176 +FPAD0_SX266 +FPAD0_SX356 +FPAD0_SX446 +FPAD0_SX86 +FPAF0_SI1054 +FPAF0_SI1684 +FPAF0_SI2314 +FPAF0_SX154 +FPAF0_SX244 +FPAF0_SX334 +FPAF0_SX424 +FPAF0_SX64 +FPAZ0_SI1593 +FPAZ0_SI2223 +FPAZ0_SI963 +FPAZ0_SX153 +FPAZ0_SX243 +FPAZ0_SX27 +FPAZ0_SX423 +FPAZ0_SX63 +FPJF0_SI1046 +FPJF0_SI1259 +FPJF0_SI1676 +FPJF0_SX146 +FPJF0_SX236 +FPJF0_SX326 +FPJF0_SX352 +FPJF0_SX56 +FPLS0_SI1590 +FPLS0_SI2220 +FPLS0_SI960 +FPLS0_SX150 +FPLS0_SX240 +FPLS0_SX3 +FPLS0_SX330 +FPLS0_SX60 +FPMY0_SI1153 +FPMY0_SI1783 +FPMY0_SI523 +FPMY0_SX163 +FPMY0_SX196 +FPMY0_SX253 +FPMY0_SX343 +FPMY0_SX73 +FREH0_SI1315 +FREH0_SI1945 +FREH0_SI685 +FREH0_SX145 +FREH0_SX235 +FREH0_SX325 +FREH0_SX415 +FREH0_SX55 +FRJB0_SI1427 +FRJB0_SI1470 +FRJB0_SI1794 +FRJB0_SX167 +FRJB0_SX257 +FRJB0_SX347 +FRJB0_SX437 +FRJB0_SX77 +FRLL0_SI1514 +FRLL0_SI805 +FRLL0_SI884 +FRLL0_SX164 +FRLL0_SX254 +FRLL0_SX344 +FRLL0_SX434 +FRLL0_SX74 +FSAG0_SI1323 +FSAG0_SI1953 +FSAG0_SI693 +FSAG0_SX153 +FSAG0_SX243 +FSAG0_SX333 +FSAG0_SX423 +FSAG0_SX63 +FSAH0_SI1244 +FSAH0_SI1874 +FSAH0_SI614 +FSAH0_SX164 +FSAH0_SX327 +FSAH0_SX344 +FSAH0_SX434 +FSAH0_SX74 +FSAK0_SI1300 +FSAK0_SI1930 +FSAK0_SI670 +FSAK0_SX130 +FSAK0_SX220 +FSAK0_SX310 +FSAK0_SX40 +FSAK0_SX400 +FSBK0_SI1069 +FSBK0_SI1699 +FSBK0_SI2329 +FSBK0_SX169 +FSBK0_SX259 +FSBK0_SX349 +FSBK0_SX439 +FSBK0_SX79 +FSCN0_SI1886 +FSCN0_SI626 +FSCN0_SI705 +FSCN0_SX176 +FSCN0_SX266 +FSCN0_SX356 +FSCN0_SX446 +FSCN0_SX86 +FSDC0_SI1312 +FSDC0_SI1942 +FSDC0_SI2234 +FSDC0_SX142 +FSDC0_SX232 +FSDC0_SX322 +FSDC0_SX412 +FSDC0_SX52 +FSDJ0_SI1115 +FSDJ0_SI1745 +FSDJ0_SI485 +FSDJ0_SX125 +FSDJ0_SX215 +FSDJ0_SX305 +FSDJ0_SX35 +FSDJ0_SX395 +FSGF0_SI1557 +FSGF0_SI2187 +FSGF0_SI927 +FSGF0_SX117 +FSGF0_SX207 +FSGF0_SX27 +FSGF0_SX297 +FSGF0_SX387 +FSJG0_SI1570 +FSJG0_SI2200 +FSJG0_SI940 +FSJG0_SX130 +FSJG0_SX220 +FSJG0_SX310 +FSJG0_SX40 +FSJG0_SX400 +FSJK1_SI1025 +FSJK1_SI2285 +FSJK1_SI696 +FSJK1_SX125 +FSJK1_SX215 +FSJK1_SX305 +FSJK1_SX35 +FSJK1_SX395 +FSJS0_SI1171 +FSJS0_SI1801 +FSJS0_SI541 +FSJS0_SX181 +FSJS0_SX271 +FSJS0_SX361 +FSJS0_SX451 +FSJS0_SX91 +FSJW0_SI1333 +FSJW0_SI1963 +FSJW0_SI703 +FSJW0_SX163 +FSJW0_SX253 +FSJW0_SX343 +FSJW0_SX433 +FSJW0_SX73 +FSKC0_SI1416 +FSKC0_SI2046 +FSKC0_SI786 +FSKC0_SX156 +FSKC0_SX246 +FSKC0_SX336 +FSKC0_SX426 +FSKC0_SX66 +FSKL0_SI1529 +FSKL0_SI2159 +FSKL0_SI899 +FSKL0_SX179 +FSKL0_SX269 +FSKL0_SX359 +FSKL0_SX449 +FSKL0_SX89 +FSKP0_SI1098 +FSKP0_SI1728 +FSKP0_SI468 +FSKP0_SX108 +FSKP0_SX18 +FSKP0_SX198 +FSKP0_SX288 +FSKP0_SX378 +FSLS0_SI1056 +FSLS0_SI1686 +FSLS0_SI2316 +FSLS0_SX156 +FSLS0_SX202 +FSLS0_SX246 +FSLS0_SX426 +FSLS0_SX66 +FSMA0_SI1621 +FSMA0_SI2251 +FSMA0_SI991 +FSMA0_SX181 +FSMA0_SX271 +FSMA0_SX361 +FSMA0_SX451 +FSMA0_SX91 +FSMM0_SI1314 +FSMM0_SI1944 +FSMM0_SI684 +FSMM0_SX144 +FSMM0_SX234 +FSMM0_SX324 +FSMM0_SX414 +FSMM0_SX54 +FSMS1_SI1504 +FSMS1_SI2134 +FSMS1_SI874 +FSMS1_SX154 +FSMS1_SX244 +FSMS1_SX334 +FSMS1_SX347 +FSMS1_SX64 +FSPM0_SI1241 +FSPM0_SI1871 +FSPM0_SI611 +FSPM0_SX161 +FSPM0_SX251 +FSPM0_SX341 +FSPM0_SX431 +FSPM0_SX71 +FSRH0_SI1719 +FSRH0_SI1931 +FSRH0_SI671 +FSRH0_SX131 +FSRH0_SX221 +FSRH0_SX311 +FSRH0_SX401 +FSRH0_SX41 +FSSB0_SI1082 +FSSB0_SI1712 +FSSB0_SI2342 +FSSB0_SX182 +FSSB0_SX272 +FSSB0_SX362 +FSSB0_SX452 +FSSB0_SX92 +FTAJ0_SI1329 +FTAJ0_SI474 +FTAJ0_SI699 +FTAJ0_SX159 +FTAJ0_SX249 +FTAJ0_SX339 +FTAJ0_SX429 +FTAJ0_SX69 +FTBR0_SI1402 +FTBR0_SI2181 +FTBR0_SI921 +FTBR0_SX111 +FTBR0_SX201 +FTBR0_SX21 +FTBR0_SX291 +FTBR0_SX381 +FTBW0_SI1345 +FTBW0_SI1975 +FTBW0_SI715 +FTBW0_SX175 +FTBW0_SX265 +FTBW0_SX355 +FTBW0_SX445 +FTBW0_SX85 +FTLG0_SI1743 +FTLG0_SI483 +FTLG0_SI840 +FTLG0_SX123 +FTLG0_SX213 +FTLG0_SX303 +FTLG0_SX33 +FTLG0_SX393 +FTMG0_SI1532 +FTMG0_SI2162 +FTMG0_SI902 +FTMG0_SX182 +FTMG0_SX272 +FTMG0_SX362 +FTMG0_SX452 +FTMG0_SX92 +FVFB0_SI1032 +FVFB0_SI1510 +FVFB0_SI2292 +FVFB0_SX132 +FVFB0_SX222 +FVFB0_SX312 +FVFB0_SX402 +FVFB0_SX42 +FVKB0_SI1159 +FVKB0_SI1789 +FVKB0_SI529 +FVKB0_SX169 +FVKB0_SX259 +FVKB0_SX349 +FVKB0_SX439 +FVKB0_SX79 +FVMH0_SI1466 +FVMH0_SI2096 +FVMH0_SI836 +FVMH0_SX116 +FVMH0_SX206 +FVMH0_SX26 +FVMH0_SX296 +FVMH0_SX386 +MABC0_SI1620 +MABC0_SI2041 +MABC0_SI781 +MABC0_SX151 +MABC0_SX241 +MABC0_SX331 +MABC0_SX421 +MABC0_SX61 +MADC0_SI1367 +MADC0_SI1997 +MADC0_SI737 +MADC0_SX107 +MADC0_SX17 +MADC0_SX197 +MADC0_SX287 +MADC0_SX377 +MADD0_SI1295 +MADD0_SI1798 +MADD0_SI538 +MADD0_SX178 +MADD0_SX268 +MADD0_SX358 +MADD0_SX448 +MADD0_SX88 +MAEB0_SI1411 +MAEB0_SI2250 +MAEB0_SI990 +MAEB0_SX180 +MAEB0_SX270 +MAEB0_SX360 +MAEB0_SX450 +MAEB0_SX90 +MAEO0_SI1326 +MAEO0_SI1655 +MAEO0_SI1956 +MAEO0_SX156 +MAEO0_SX246 +MAEO0_SX336 +MAEO0_SX426 +MAEO0_SX66 +MAFM0_SI1569 +MAFM0_SI2199 +MAFM0_SI939 +MAFM0_SX129 +MAFM0_SX219 +MAFM0_SX309 +MAFM0_SX39 +MAFM0_SX399 +MAJP0_SI1074 +MAJP0_SI1704 +MAJP0_SI2334 +MAJP0_SX174 +MAJP0_SX264 +MAJP0_SX354 +MAJP0_SX444 +MAJP0_SX84 +MAKB0_SI1016 +MAKB0_SI1646 +MAKB0_SI2276 +MAKB0_SX116 +MAKB0_SX206 +MAKB0_SX26 +MAKB0_SX296 +MAKB0_SX386 +MAKR0_SI1352 +MAKR0_SI1982 +MAKR0_SI722 +MAKR0_SX182 +MAKR0_SX272 +MAKR0_SX362 +MAKR0_SX452 +MAKR0_SX92 +MAPV0_SI1293 +MAPV0_SI1923 +MAPV0_SI663 +MAPV0_SX123 +MAPV0_SX213 +MAPV0_SX303 +MAPV0_SX33 +MAPV0_SX393 +MARC0_SI1188 +MARC0_SI1818 +MARC0_SI558 +MARC0_SX108 +MARC0_SX18 +MARC0_SX198 +MARC0_SX288 +MARC0_SX378 +MARW0_SI1276 +MARW0_SI1906 +MARW0_SI646 +MARW0_SX106 +MARW0_SX16 +MARW0_SX286 +MARW0_SX349 +MARW0_SX376 +MBAR0_SI1319 +MBAR0_SI1949 +MBAR0_SI689 +MBAR0_SX149 +MBAR0_SX239 +MBAR0_SX329 +MBAR0_SX419 +MBAR0_SX59 +MBBR0_SI1055 +MBBR0_SI1685 +MBBR0_SI2315 +MBBR0_SX155 +MBBR0_SX245 +MBBR0_SX335 +MBBR0_SX425 +MBBR0_SX65 +MBCG0_SI2217 +MBCG0_SI486 +MBCG0_SI957 +MBCG0_SX147 +MBCG0_SX237 +MBCG0_SX327 +MBCG0_SX417 +MBCG0_SX57 +MBEF0_SI1281 +MBEF0_SI1911 +MBEF0_SI651 +MBEF0_SX111 +MBEF0_SX201 +MBEF0_SX21 +MBEF0_SX291 +MBEF0_SX381 +MBGT0_SI1341 +MBGT0_SI1841 +MBGT0_SI711 +MBGT0_SX171 +MBGT0_SX261 +MBGT0_SX351 +MBGT0_SX441 +MBGT0_SX81 +MBJV0_SI1247 +MBJV0_SI1877 +MBJV0_SI617 +MBJV0_SX167 +MBJV0_SX257 +MBJV0_SX347 +MBJV0_SX437 +MBJV0_SX77 +MBMA0_SI1222 +MBMA0_SI1852 +MBMA0_SI592 +MBMA0_SX142 +MBMA0_SX232 +MBMA0_SX322 +MBMA0_SX412 +MBMA0_SX52 +MBMA1_SI2207 +MBMA1_SI2214 +MBMA1_SI954 +MBMA1_SX144 +MBMA1_SX234 +MBMA1_SX324 +MBMA1_SX414 +MBMA1_SX54 +MBML0_SI1169 +MBML0_SI1799 +MBML0_SI539 +MBML0_SX179 +MBML0_SX269 +MBML0_SX359 +MBML0_SX449 +MBML0_SX89 +MBOM0_SI1014 +MBOM0_SI1644 +MBOM0_SI2274 +MBOM0_SX114 +MBOM0_SX204 +MBOM0_SX294 +MBOM0_SX311 +MBOM0_SX384 +MBSB0_SI1353 +MBSB0_SI1983 +MBSB0_SI723 +MBSB0_SX183 +MBSB0_SX273 +MBSB0_SX3 +MBSB0_SX363 +MBSB0_SX93 +MBTH0_SI2102 +MBTH0_SI505 +MBTH0_SI757 +MBTH0_SX122 +MBTH0_SX212 +MBTH0_SX302 +MBTH0_SX32 +MBTH0_SX392 +MBWP0_SI1531 +MBWP0_SI1969 +MBWP0_SI709 +MBWP0_SX169 +MBWP0_SX259 +MBWP0_SX349 +MBWP0_SX439 +MBWP0_SX79 +MCAE0_SI1447 +MCAE0_SI2077 +MCAE0_SI817 +MCAE0_SX187 +MCAE0_SX277 +MCAE0_SX367 +MCAE0_SX7 +MCAE0_SX97 +MCAL0_SI1138 +MCAL0_SI1768 +MCAL0_SI508 +MCAL0_SX148 +MCAL0_SX238 +MCAL0_SX328 +MCAL0_SX418 +MCAL0_SX58 +MCDC0_SI1292 +MCDC0_SI1922 +MCDC0_SI662 +MCDC0_SX122 +MCDC0_SX212 +MCDC0_SX302 +MCDC0_SX32 +MCDC0_SX392 +MCDD0_SI1513 +MCDD0_SI2143 +MCDD0_SI883 +MCDD0_SX163 +MCDD0_SX253 +MCDD0_SX343 +MCDD0_SX433 +MCDD0_SX73 +MCDR0_SI1154 +MCDR0_SI1784 +MCDR0_SI524 +MCDR0_SX164 +MCDR0_SX254 +MCDR0_SX344 +MCDR0_SX434 +MCDR0_SX74 +MCEF0_SI1135 +MCEF0_SI1765 +MCEF0_SI842 +MCEF0_SX145 +MCEF0_SX235 +MCEF0_SX325 +MCEF0_SX415 +MCEF0_SX55 +MCEW0_SI1442 +MCEW0_SI2072 +MCEW0_SI812 +MCEW0_SX182 +MCEW0_SX272 +MCEW0_SX362 +MCEW0_SX452 +MCEW0_SX92 +MCHL0_SI1347 +MCHL0_SI1404 +MCHL0_SI1977 +MCHL0_SX177 +MCHL0_SX267 +MCHL0_SX357 +MCHL0_SX447 +MCHL0_SX87 +MCLK0_SI1660 +MCLK0_SI2290 +MCLK0_SI650 +MCLK0_SX130 +MCLK0_SX220 +MCLK0_SX310 +MCLK0_SX40 +MCLK0_SX400 +MCLM0_SI1456 +MCLM0_SI2086 +MCLM0_SI826 +MCLM0_SX106 +MCLM0_SX16 +MCLM0_SX196 +MCLM0_SX286 +MCLM0_SX376 +MCPM0_SI1194 +MCPM0_SI1824 +MCPM0_SI564 +MCPM0_SX114 +MCPM0_SX204 +MCPM0_SX24 +MCPM0_SX294 +MCPM0_SX384 +MCRE0_SI1121 +MCRE0_SI1725 +MCRE0_SI1751 +MCRE0_SX131 +MCRE0_SX221 +MCRE0_SX24 +MCRE0_SX401 +MCRE0_SX41 +MCSS0_SI1380 +MCSS0_SI688 +MCSS0_SI750 +MCSS0_SX120 +MCSS0_SX210 +MCSS0_SX30 +MCSS0_SX300 +MCSS0_SX390 +MCTH0_SI1209 +MCTH0_SI1839 +MCTH0_SI579 +MCTH0_SX129 +MCTH0_SX219 +MCTH0_SX309 +MCTH0_SX39 +MCTH0_SX399 +MCTM0_SI1350 +MCTM0_SI1980 +MCTM0_SI720 +MCTM0_SX180 +MCTM0_SX270 +MCTM0_SX360 +MCTM0_SX450 +MCTM0_SX90 +MCXM0_SI1351 +MCXM0_SI1981 +MCXM0_SI721 +MCXM0_SX181 +MCXM0_SX271 +MCXM0_SX361 +MCXM0_SX451 +MCXM0_SX91 +MDAC0_SI1261 +MDAC0_SI1837 +MDAC0_SI631 +MDAC0_SX181 +MDAC0_SX271 +MDAC0_SX361 +MDAC0_SX451 +MDAC0_SX91 +MDAS0_SI1266 +MDAS0_SI1896 +MDAS0_SI636 +MDAS0_SX186 +MDAS0_SX21 +MDAS0_SX276 +MDAS0_SX6 +MDAS0_SX96 +MDBB1_SI1006 +MDBB1_SI1636 +MDBB1_SI2056 +MDBB1_SX106 +MDBB1_SX16 +MDBB1_SX196 +MDBB1_SX286 +MDBB1_SX376 +MDBP0_SI1158 +MDBP0_SI1788 +MDBP0_SI528 +MDBP0_SX168 +MDBP0_SX258 +MDBP0_SX348 +MDBP0_SX438 +MDBP0_SX78 +MDCD0_SI1415 +MDCD0_SI2045 +MDCD0_SI785 +MDCD0_SX155 +MDCD0_SX245 +MDCD0_SX335 +MDCD0_SX425 +MDCD0_SX65 +MDCM0_SI1480 +MDCM0_SI2110 +MDCM0_SI850 +MDCM0_SX130 +MDCM0_SX220 +MDCM0_SX310 +MDCM0_SX40 +MDCM0_SX400 +MDDC0_SI1419 +MDDC0_SI2049 +MDDC0_SI789 +MDDC0_SX159 +MDDC0_SX249 +MDDC0_SX339 +MDDC0_SX429 +MDDC0_SX69 +MDED0_SI1170 +MDED0_SI1800 +MDED0_SI540 +MDED0_SX180 +MDED0_SX270 +MDED0_SX360 +MDED0_SX450 +MDED0_SX90 +MDEF0_SI1123 +MDEF0_SI1563 +MDEF0_SI2193 +MDEF0_SX123 +MDEF0_SX213 +MDEF0_SX303 +MDEF0_SX33 +MDEF0_SX393 +MDEM0_SI1868 +MDEM0_SI608 +MDEM0_SI800 +MDEM0_SX158 +MDEM0_SX248 +MDEM0_SX338 +MDEM0_SX428 +MDEM0_SX68 +MDHL0_SI1439 +MDHL0_SI2069 +MDHL0_SI809 +MDHL0_SX179 +MDHL0_SX269 +MDHL0_SX359 +MDHL0_SX449 +MDHL0_SX89 +MDHS0_SI1530 +MDHS0_SI2160 +MDHS0_SI900 +MDHS0_SX180 +MDHS0_SX270 +MDHS0_SX360 +MDHS0_SX450 +MDHS0_SX90 +MDJM0_SI1455 +MDJM0_SI2085 +MDJM0_SI825 +MDJM0_SX105 +MDJM0_SX15 +MDJM0_SX195 +MDJM0_SX285 +MDJM0_SX375 +MDKS0_SI1066 +MDKS0_SI1696 +MDKS0_SI2326 +MDKS0_SX166 +MDKS0_SX256 +MDKS0_SX346 +MDKS0_SX436 +MDKS0_SX76 +MDLB0_SI1306 +MDLB0_SI1936 +MDLB0_SI676 +MDLB0_SX136 +MDLB0_SX226 +MDLB0_SX316 +MDLB0_SX406 +MDLB0_SX46 +MDLC0_SI1395 +MDLC0_SI2025 +MDLC0_SI765 +MDLC0_SX135 +MDLC0_SX225 +MDLC0_SX315 +MDLC0_SX405 +MDLC0_SX45 +MDLC1_SI1435 +MDLC1_SI2065 +MDLC1_SI2144 +MDLC1_SX175 +MDLC1_SX265 +MDLC1_SX355 +MDLC1_SX445 +MDLC1_SX85 +MDLC2_SI1614 +MDLC2_SI2244 +MDLC2_SI984 +MDLC2_SX174 +MDLC2_SX264 +MDLC2_SX354 +MDLC2_SX444 +MDLC2_SX84 +MDLH0_SI1960 +MDLH0_SI574 +MDLH0_SI700 +MDLH0_SX160 +MDLH0_SX250 +MDLH0_SX340 +MDLH0_SX430 +MDLH0_SX70 +MDLM0_SI1234 +MDLM0_SI1864 +MDLM0_SI604 +MDLM0_SX154 +MDLM0_SX244 +MDLM0_SX334 +MDLM0_SX424 +MDLM0_SX64 +MDLR0_SI1233 +MDLR0_SI1863 +MDLR0_SI603 +MDLR0_SX153 +MDLR0_SX243 +MDLR0_SX333 +MDLR0_SX423 +MDLR0_SX63 +MDLR1_SI1299 +MDLR1_SI1929 +MDLR1_SI669 +MDLR1_SX129 +MDLR1_SX219 +MDLR1_SX309 +MDLR1_SX39 +MDLR1_SX399 +MDMA0_SI1238 +MDMA0_SI1430 +MDMA0_SI2060 +MDMA0_SX170 +MDMA0_SX260 +MDMA0_SX350 +MDMA0_SX440 +MDMA0_SX80 +MDMT0_SI1832 +MDMT0_SI2341 +MDMT0_SI572 +MDMT0_SX122 +MDMT0_SX212 +MDMT0_SX302 +MDMT0_SX32 +MDMT0_SX392 +MDNS0_SI1011 +MDNS0_SI2271 +MDNS0_SI873 +MDNS0_SX111 +MDNS0_SX201 +MDNS0_SX21 +MDNS0_SX291 +MDNS0_SX381 +MDPB0_SI1760 +MDPB0_SI2126 +MDPB0_SI866 +MDPB0_SX146 +MDPB0_SX236 +MDPB0_SX326 +MDPB0_SX416 +MDPB0_SX56 +MDPK0_SI1053 +MDPK0_SI1683 +MDPK0_SI552 +MDPK0_SX153 +MDPK0_SX243 +MDPK0_SX333 +MDPK0_SX423 +MDPK0_SX63 +MDPS0_SI1651 +MDPS0_SI1979 +MDPS0_SI719 +MDPS0_SX179 +MDPS0_SX269 +MDPS0_SX359 +MDPS0_SX449 +MDPS0_SX89 +MDRD0_SI1382 +MDRD0_SI2012 +MDRD0_SI752 +MDRD0_SX122 +MDRD0_SX212 +MDRD0_SX302 +MDRD0_SX32 +MDRD0_SX392 +MDSJ0_SI1462 +MDSJ0_SI2092 +MDSJ0_SI832 +MDSJ0_SX112 +MDSJ0_SX22 +MDSJ0_SX292 +MDSJ0_SX382 +MDSJ0_SX438 +MDSS0_SI1881 +MDSS0_SI2087 +MDSS0_SI621 +MDSS0_SX171 +MDSS0_SX261 +MDSS0_SX351 +MDSS0_SX441 +MDSS0_SX81 +MDSS1_SI1327 +MDSS1_SI1713 +MDSS1_SI697 +MDSS1_SX157 +MDSS1_SX247 +MDSS1_SX337 +MDSS1_SX427 +MDSS1_SX67 +MDTB0_SI1200 +MDTB0_SI1830 +MDTB0_SI570 +MDTB0_SX120 +MDTB0_SX210 +MDTB0_SX300 +MDTB0_SX321 +MDTB0_SX390 +MDWD0_SI1260 +MDWD0_SI1890 +MDWD0_SI557 +MDWD0_SX180 +MDWD0_SX270 +MDWD0_SX360 +MDWD0_SX450 +MDWD0_SX90 +MDWH0_SI1168 +MDWH0_SI1925 +MDWH0_SI665 +MDWH0_SX125 +MDWH0_SX215 +MDWH0_SX305 +MDWH0_SX35 +MDWH0_SX395 +MDWM0_SI1546 +MDWM0_SI2176 +MDWM0_SI916 +MDWM0_SX106 +MDWM0_SX16 +MDWM0_SX286 +MDWM0_SX376 +MDWM0_SX433 +MEAL0_SI1547 +MEAL0_SI2177 +MEAL0_SI917 +MEAL0_SX107 +MEAL0_SX197 +MEAL0_SX287 +MEAL0_SX347 +MEAL0_SX377 +MEDR0_SI1374 +MEDR0_SI2004 +MEDR0_SI744 +MEDR0_SX114 +MEDR0_SX204 +MEDR0_SX24 +MEDR0_SX294 +MEDR0_SX384 +MEFG0_SI465 +MEFG0_SI491 +MEFG0_SI598 +MEFG0_SX105 +MEFG0_SX15 +MEFG0_SX195 +MEFG0_SX285 +MEFG0_SX375 +MEGJ0_SI1337 +MEGJ0_SI1967 +MEGJ0_SI707 +MEGJ0_SX167 +MEGJ0_SX257 +MEGJ0_SX3 +MEGJ0_SX437 +MEGJ0_SX77 +MEJL0_SI1592 +MEJL0_SI1654 +MEJL0_SI962 +MEJL0_SX152 +MEJL0_SX242 +MEJL0_SX332 +MEJL0_SX422 +MEJL0_SX62 +MEJS0_SI1240 +MEJS0_SI1870 +MEJS0_SI610 +MEJS0_SX160 +MEJS0_SX250 +MEJS0_SX340 +MEJS0_SX430 +MEJS0_SX70 +MESG0_SI1332 +MESG0_SI1962 +MESG0_SI702 +MESG0_SX162 +MESG0_SX252 +MESG0_SX342 +MESG0_SX432 +MESG0_SX72 +MESJ0_SI2039 +MESJ0_SI2257 +MESJ0_SI997 +MESJ0_SX187 +MESJ0_SX277 +MESJ0_SX367 +MESJ0_SX7 +MESJ0_SX97 +MEWM0_SI1348 +MEWM0_SI1978 +MEWM0_SI718 +MEWM0_SX178 +MEWM0_SX268 +MEWM0_SX358 +MEWM0_SX448 +MEWM0_SX88 +MFER0_SI1492 +MFER0_SI2122 +MFER0_SI862 +MFER0_SX142 +MFER0_SX232 +MFER0_SX322 +MFER0_SX412 +MFER0_SX52 +MFMC0_SI1132 +MFMC0_SI1762 +MFMC0_SI502 +MFMC0_SX142 +MFMC0_SX232 +MFMC0_SX322 +MFMC0_SX412 +MFMC0_SX52 +MFRM0_SI1155 +MFRM0_SI1717 +MFRM0_SI1785 +MFRM0_SX165 +MFRM0_SX255 +MFRM0_SX345 +MFRM0_SX435 +MFRM0_SX75 +MFWK0_SI1249 +MFWK0_SI1879 +MFWK0_SI619 +MFWK0_SX169 +MFWK0_SX259 +MFWK0_SX349 +MFWK0_SX439 +MFWK0_SX79 +MFXS0_SI1674 +MFXS0_SI2225 +MFXS0_SI2304 +MFXS0_SX144 +MFXS0_SX234 +MFXS0_SX324 +MFXS0_SX414 +MFXS0_SX54 +MFXV0_SI1005 +MFXV0_SI1342 +MFXV0_SI1635 +MFXV0_SX105 +MFXV0_SX15 +MFXV0_SX195 +MFXV0_SX285 +MFXV0_SX375 +MGAF0_SI1282 +MGAF0_SI1912 +MGAF0_SI652 +MGAF0_SX112 +MGAF0_SX202 +MGAF0_SX22 +MGAF0_SX292 +MGAF0_SX382 +MGAG0_SI1321 +MGAG0_SI645 +MGAG0_SI691 +MGAG0_SX151 +MGAG0_SX241 +MGAG0_SX331 +MGAG0_SX421 +MGAG0_SX61 +MGAK0_SI1036 +MGAK0_SI1666 +MGAK0_SI2296 +MGAK0_SX136 +MGAK0_SX226 +MGAK0_SX316 +MGAK0_SX406 +MGAK0_SX46 +MGAR0_SI1212 +MGAR0_SI1694 +MGAR0_SI1842 +MGAR0_SX132 +MGAR0_SX222 +MGAR0_SX312 +MGAR0_SX402 +MGAR0_SX42 +MGAW0_SI1165 +MGAW0_SI1802 +MGAW0_SI535 +MGAW0_SX175 +MGAW0_SX265 +MGAW0_SX355 +MGAW0_SX445 +MGAW0_SX85 +MGES0_SI1481 +MGES0_SI2111 +MGES0_SI851 +MGES0_SX131 +MGES0_SX221 +MGES0_SX311 +MGES0_SX401 +MGES0_SX41 +MGJC0_SI1256 +MGJC0_SI1335 +MGJC0_SI1965 +MGJC0_SX165 +MGJC0_SX255 +MGJC0_SX345 +MGJC0_SX435 +MGJC0_SX75 +MGRL0_SI1497 +MGRL0_SI2127 +MGRL0_SI867 +MGRL0_SX147 +MGRL0_SX237 +MGRL0_SX327 +MGRL0_SX417 +MGRL0_SX57 +MGRP0_SI1317 +MGRP0_SI1947 +MGRP0_SI687 +MGRP0_SX147 +MGRP0_SX237 +MGRP0_SX327 +MGRP0_SX417 +MGRP0_SX57 +MGSH0_SI1176 +MGSH0_SI1806 +MGSH0_SI546 +MGSH0_SX127 +MGSH0_SX186 +MGSH0_SX276 +MGSH0_SX6 +MGSH0_SX96 +MGSL0_SI1164 +MGSL0_SI534 +MGSL0_SI797 +MGSL0_SX174 +MGSL0_SX264 +MGSL0_SX354 +MGSL0_SX444 +MGSL0_SX84 +MGXP0_SI1087 +MGXP0_SI457 +MGXP0_SI525 +MGXP0_SX187 +MGXP0_SX277 +MGXP0_SX367 +MGXP0_SX7 +MGXP0_SX97 +MHBS0_SI1575 +MHBS0_SI2205 +MHBS0_SI945 +MHBS0_SX135 +MHBS0_SX225 +MHBS0_SX315 +MHBS0_SX405 +MHBS0_SX45 +MHIT0_SI1613 +MHIT0_SI2243 +MHIT0_SI983 +MHIT0_SX173 +MHIT0_SX263 +MHIT0_SX353 +MHIT0_SX443 +MHIT0_SX83 +MHJB0_SI1017 +MHJB0_SI1647 +MHJB0_SI2277 +MHJB0_SX117 +MHJB0_SX207 +MHJB0_SX27 +MHJB0_SX297 +MHJB0_SX387 +MHMG0_SI1365 +MHMG0_SI1995 +MHMG0_SI735 +MHMG0_SX105 +MHMG0_SX15 +MHMG0_SX195 +MHMG0_SX285 +MHMG0_SX375 +MHMR0_SI1119 +MHMR0_SI1692 +MHMR0_SI489 +MHMR0_SX129 +MHMR0_SX219 +MHMR0_SX309 +MHMR0_SX39 +MHMR0_SX399 +MHRM0_SI1475 +MHRM0_SI2218 +MHRM0_SI958 +MHRM0_SX148 +MHRM0_SX238 +MHRM0_SX328 +MHRM0_SX418 +MHRM0_SX58 +MHXL0_SI1772 +MHXL0_SI512 +MHXL0_SI612 +MHXL0_SX152 +MHXL0_SX242 +MHXL0_SX332 +MHXL0_SX422 +MHXL0_SX62 +MILB0_SI2163 +MILB0_SI807 +MILB0_SI903 +MILB0_SX183 +MILB0_SX273 +MILB0_SX3 +MILB0_SX363 +MILB0_SX93 +MJAC0_SI1331 +MJAC0_SI2148 +MJAC0_SI701 +MJAC0_SX251 +MJAC0_SX307 +MJAC0_SX341 +MJAC0_SX431 +MJAC0_SX71 +MJAE0_SI1524 +MJAE0_SI1999 +MJAE0_SI2154 +MJAE0_SX174 +MJAE0_SX264 +MJAE0_SX354 +MJAE0_SX444 +MJAE0_SX84 +MJAI0_SI1604 +MJAI0_SI682 +MJAI0_SI710 +MJAI0_SX164 +MJAI0_SX254 +MJAI0_SX344 +MJAI0_SX434 +MJAI0_SX74 +MJBG0_SI1232 +MJBG0_SI1724 +MJBG0_SI1862 +MJBG0_SX152 +MJBG0_SX242 +MJBG0_SX332 +MJBG0_SX422 +MJBG0_SX62 +MJDA0_SI1031 +MJDA0_SI1661 +MJDA0_SI2291 +MJDA0_SX131 +MJDA0_SX221 +MJDA0_SX311 +MJDA0_SX401 +MJDA0_SX41 +MJDC0_SI1161 +MJDC0_SI2165 +MJDC0_SI531 +MJDC0_SX171 +MJDC0_SX261 +MJDC0_SX351 +MJDC0_SX441 +MJDC0_SX81 +MJDE0_SI1120 +MJDE0_SI463 +MJDE0_SI490 +MJDE0_SX130 +MJDE0_SX220 +MJDE0_SX310 +MJDE0_SX40 +MJDE0_SX400 +MJDG0_SI1042 +MJDG0_SI1672 +MJDG0_SI1705 +MJDG0_SX142 +MJDG0_SX232 +MJDG0_SX322 +MJDG0_SX412 +MJDG0_SX52 +MJDM0_SI1340 +MJDM0_SI1937 +MJDM0_SI974 +MJDM0_SX170 +MJDM0_SX260 +MJDM0_SX350 +MJDM0_SX440 +MJDM0_SX80 +MJEB0_SI1286 +MJEB0_SI1916 +MJEB0_SI656 +MJEB0_SX170 +MJEB0_SX206 +MJEB0_SX26 +MJEB0_SX296 +MJEB0_SX386 +MJEB1_SI1467 +MJEB1_SI2097 +MJEB1_SI837 +MJEB1_SX117 +MJEB1_SX207 +MJEB1_SX27 +MJEB1_SX297 +MJEB1_SX387 +MJEE0_SI1237 +MJEE0_SI1867 +MJEE0_SI607 +MJEE0_SX157 +MJEE0_SX247 +MJEE0_SX337 +MJEE0_SX427 +MJEE0_SX67 +MJFH0_SI1107 +MJFH0_SI1737 +MJFH0_SI477 +MJFH0_SX117 +MJFH0_SX207 +MJFH0_SX27 +MJFH0_SX297 +MJFH0_SX387 +MJFR0_SI1605 +MJFR0_SI2235 +MJFR0_SI975 +MJFR0_SX165 +MJFR0_SX255 +MJFR0_SX345 +MJFR0_SX435 +MJFR0_SX75 +MJHI0_SI1328 +MJHI0_SI555 +MJHI0_SI698 +MJHI0_SX158 +MJHI0_SX248 +MJHI0_SX338 +MJHI0_SX428 +MJHI0_SX68 +MJJB0_SI1139 +MJJB0_SI1277 +MJJB0_SI1769 +MJJB0_SX149 +MJJB0_SX239 +MJJB0_SX329 +MJJB0_SX419 +MJJB0_SX59 +MJJJ0_SI1163 +MJJJ0_SI1793 +MJJJ0_SI533 +MJJJ0_SX173 +MJJJ0_SX263 +MJJJ0_SX353 +MJJJ0_SX443 +MJJJ0_SX83 +MJJM0_SI1251 +MJJM0_SI1457 +MJJM0_SI827 +MJJM0_SX107 +MJJM0_SX17 +MJJM0_SX197 +MJJM0_SX287 +MJJM0_SX377 +MJKR0_SI1201 +MJKR0_SI1831 +MJKR0_SI571 +MJKR0_SX121 +MJKR0_SX211 +MJKR0_SX301 +MJKR0_SX31 +MJKR0_SX391 +MJLB0_SI1616 +MJLB0_SI2246 +MJLB0_SI986 +MJLB0_SX176 +MJLB0_SX266 +MJLB0_SX356 +MJLB0_SX446 +MJLB0_SX86 +MJLG1_SI1012 +MJLG1_SI1642 +MJLG1_SI2272 +MJLG1_SX112 +MJLG1_SX202 +MJLG1_SX22 +MJLG1_SX292 +MJLG1_SX382 +MJLS0_SI1096 +MJLS0_SI1726 +MJLS0_SI466 +MJLS0_SX106 +MJLS0_SX16 +MJLS0_SX196 +MJLS0_SX286 +MJLS0_SX376 +MJMA0_SI1495 +MJMA0_SI2125 +MJMA0_SI865 +MJMA0_SX145 +MJMA0_SX235 +MJMA0_SX325 +MJMA0_SX415 +MJMA0_SX55 +MJMD0_SI1028 +MJMD0_SI1658 +MJMD0_SI2288 +MJMD0_SX128 +MJMD0_SX218 +MJMD0_SX308 +MJMD0_SX38 +MJMD0_SX398 +MJMM0_SI1255 +MJMM0_SI1885 +MJMM0_SI625 +MJMM0_SX175 +MJMM0_SX265 +MJMM0_SX355 +MJMM0_SX445 +MJMM0_SX85 +MJPG0_SI1191 +MJPG0_SI1821 +MJPG0_SI561 +MJPG0_SX111 +MJPG0_SX201 +MJPG0_SX21 +MJPG0_SX291 +MJPG0_SX381 +MJPM0_SI1368 +MJPM0_SI1998 +MJPM0_SI738 +MJPM0_SX108 +MJPM0_SX18 +MJPM0_SX198 +MJPM0_SX288 +MJPM0_SX378 +MJPM1_SI1897 +MJPM1_SI2280 +MJPM1_SI761 +MJPM1_SX131 +MJPM1_SX221 +MJPM1_SX311 +MJPM1_SX401 +MJPM1_SX41 +MJRA0_SI1236 +MJRA0_SI1866 +MJRA0_SI606 +MJRA0_SX156 +MJRA0_SX246 +MJRA0_SX336 +MJRA0_SX426 +MJRA0_SX66 +MJRG0_SI1366 +MJRG0_SI1996 +MJRG0_SI736 +MJRG0_SX106 +MJRG0_SX16 +MJRG0_SX286 +MJRG0_SX352 +MJRG0_SX376 +MJRH0_SI1125 +MJRH0_SI1755 +MJRH0_SI1840 +MJRH0_SX135 +MJRH0_SX225 +MJRH0_SX315 +MJRH0_SX405 +MJRH0_SX45 +MJRH1_SI1558 +MJRH1_SI1774 +MJRH1_SI514 +MJRH1_SX154 +MJRH1_SX244 +MJRH1_SX334 +MJRH1_SX424 +MJRH1_SX64 +MJRK0_SI1662 +MJRK0_SI2103 +MJRK0_SI880 +MJRK0_SX160 +MJRK0_SX250 +MJRK0_SX340 +MJRK0_SX430 +MJRK0_SX70 +MJRP0_SI1835 +MJRP0_SI1845 +MJRP0_SI585 +MJRP0_SX135 +MJRP0_SX225 +MJRP0_SX315 +MJRP0_SX405 +MJRP0_SX45 +MJSR0_SI1424 +MJSR0_SI2054 +MJSR0_SI794 +MJSR0_SX164 +MJSR0_SX254 +MJSR0_SX344 +MJSR0_SX434 +MJSR0_SX74 +MJWG0_SI2155 +MJWG0_SI813 +MJWG0_SI895 +MJWG0_SX175 +MJWG0_SX265 +MJWG0_SX355 +MJWG0_SX445 +MJWG0_SX85 +MJWS0_SI1143 +MJWS0_SI1773 +MJWS0_SI513 +MJWS0_SX153 +MJWS0_SX243 +MJWS0_SX333 +MJWS0_SX423 +MJWS0_SX63 +MJWT0_SI1291 +MJWT0_SI1381 +MJWT0_SI751 +MJWT0_SX121 +MJWT0_SX211 +MJWT0_SX301 +MJWT0_SX31 +MJWT0_SX391 +MJXA0_SI1507 +MJXA0_SI2137 +MJXA0_SI877 +MJXA0_SX157 +MJXA0_SX247 +MJXA0_SX337 +MJXA0_SX427 +MJXA0_SX67 +MJXL0_SI1172 +MJXL0_SI1795 +MJXL0_SI542 +MJXL0_SX182 +MJXL0_SX272 +MJXL0_SX362 +MJXL0_SX452 +MJXL0_SX92 +MKAG0_SI1609 +MKAG0_SI2239 +MKAG0_SI979 +MKAG0_SX169 +MKAG0_SX259 +MKAG0_SX30 +MKAG0_SX439 +MKAG0_SX79 +MKAH0_SI1528 +MKAH0_SI2158 +MKAH0_SI898 +MKAH0_SX178 +MKAH0_SX268 +MKAH0_SX358 +MKAH0_SX448 +MKAH0_SX88 +MKAJ0_SI1414 +MKAJ0_SI2044 +MKAJ0_SI784 +MKAJ0_SX154 +MKAJ0_SX244 +MKAJ0_SX334 +MKAJ0_SX424 +MKAJ0_SX64 +MKAM0_SI1250 +MKAM0_SI1316 +MKAM0_SI1465 +MKAM0_SX146 +MKAM0_SX236 +MKAM0_SX326 +MKAM0_SX416 +MKAM0_SX56 +MKDB0_SI2132 +MKDB0_SI588 +MKDB0_SI872 +MKDB0_SX152 +MKDB0_SX242 +MKDB0_SX332 +MKDB0_SX422 +MKDB0_SX62 +MKDD0_SI1567 +MKDD0_SI2197 +MKDD0_SI937 +MKDD0_SX127 +MKDD0_SX217 +MKDD0_SX307 +MKDD0_SX37 +MKDD0_SX397 +MKDT0_SI2153 +MKDT0_SI814 +MKDT0_SI893 +MKDT0_SX173 +MKDT0_SX263 +MKDT0_SX353 +MKDT0_SX443 +MKDT0_SX83 +MKES0_SI1253 +MKES0_SI1883 +MKES0_SI623 +MKES0_SX173 +MKES0_SX263 +MKES0_SX353 +MKES0_SX443 +MKES0_SX83 +MKJO0_SI1517 +MKJO0_SI2147 +MKJO0_SI887 +MKJO0_SX167 +MKJO0_SX257 +MKJO0_SX424 +MKJO0_SX437 +MKJO0_SX77 +MKLN0_SI1598 +MKLN0_SI2228 +MKLN0_SI968 +MKLN0_SX158 +MKLN0_SX248 +MKLN0_SX338 +MKLN0_SX428 +MKLN0_SX68 +MKLR0_SI1059 +MKLR0_SI1689 +MKLR0_SI2319 +MKLR0_SX159 +MKLR0_SX249 +MKLR0_SX339 +MKLR0_SX429 +MKLR0_SX69 +MKLS0_SI1437 +MKLS0_SI1533 +MKLS0_SI2067 +MKLS0_SX177 +MKLS0_SX267 +MKLS0_SX357 +MKLS0_SX447 +MKLS0_SX87 +MKLS1_SI1545 +MKLS1_SI2175 +MKLS1_SI915 +MKLS1_SX105 +MKLS1_SX15 +MKLS1_SX195 +MKLS1_SX285 +MKLS1_SX375 +MKLW0_SI1571 +MKLW0_SI1844 +MKLW0_SI2201 +MKLW0_SX131 +MKLW0_SX221 +MKLW0_SX311 +MKLW0_SX401 +MKLW0_SX41 +MKRG0_SI1491 +MKRG0_SI2121 +MKRG0_SI861 +MKRG0_SX141 +MKRG0_SX231 +MKRG0_SX31 +MKRG0_SX411 +MKRG0_SX51 +MKXL0_SI1185 +MKXL0_SI1815 +MKXL0_SI1958 +MKXL0_SX105 +MKXL0_SX15 +MKXL0_SX195 +MKXL0_SX285 +MKXL0_SX375 +MLBC0_SI1239 +MLBC0_SI1869 +MLBC0_SI609 +MLBC0_SX159 +MLBC0_SX249 +MLBC0_SX339 +MLBC0_SX429 +MLBC0_SX69 +MLEL0_SI1246 +MLEL0_SI1876 +MLEL0_SI616 +MLEL0_SX166 +MLEL0_SX256 +MLEL0_SX346 +MLEL0_SX436 +MLEL0_SX76 +MLJC0_SI1225 +MLJC0_SI1855 +MLJC0_SI595 +MLJC0_SX145 +MLJC0_SX235 +MLJC0_SX325 +MLJC0_SX415 +MLJC0_SX55 +MLJH0_SI1324 +MLJH0_SI1422 +MLJH0_SI694 +MLJH0_SX154 +MLJH0_SX244 +MLJH0_SX334 +MLJH0_SX424 +MLJH0_SX64 +MLNS0_SI1407 +MLNS0_SI2037 +MLNS0_SI777 +MLNS0_SX147 +MLNS0_SX237 +MLNS0_SX327 +MLNS0_SX417 +MLNS0_SX57 +MLSH0_SI1417 +MLSH0_SI2047 +MLSH0_SI787 +MLSH0_SX157 +MLSH0_SX247 +MLSH0_SX337 +MLSH0_SX427 +MLSH0_SX67 +MMAA0_SI1588 +MMAA0_SI2105 +MMAA0_SI845 +MMAA0_SX125 +MMAA0_SX215 +MMAA0_SX305 +MMAA0_SX35 +MMAA0_SX395 +MMAB1_SI1494 +MMAB1_SI2124 +MMAB1_SI864 +MMAB1_SX144 +MMAB1_SX234 +MMAB1_SX324 +MMAB1_SX414 +MMAB1_SX54 +MMAG0_SI1126 +MMAG0_SI1756 +MMAG0_SI496 +MMAG0_SX136 +MMAG0_SX226 +MMAG0_SX316 +MMAG0_SX406 +MMAG0_SX46 +MMAM0_SI1597 +MMAM0_SI1668 +MMAM0_SI2227 +MMAM0_SX157 +MMAM0_SX247 +MMAM0_SX337 +MMAM0_SX427 +MMAM0_SX67 +MMAR0_SI1336 +MMAR0_SI1966 +MMAR0_SI706 +MMAR0_SX166 +MMAR0_SX256 +MMAR0_SX346 +MMAR0_SX436 +MMAR0_SX76 +MMBS0_SI1151 +MMBS0_SI1781 +MMBS0_SI521 +MMBS0_SX161 +MMBS0_SX251 +MMBS0_SX341 +MMBS0_SX431 +MMBS0_SX71 +MMCC0_SI1338 +MMCC0_SI1968 +MMCC0_SI708 +MMCC0_SX168 +MMCC0_SX258 +MMCC0_SX348 +MMCC0_SX438 +MMCC0_SX78 +MMDB0_SI1358 +MMDB0_SI1617 +MMDB0_SI987 +MMDB0_SX177 +MMDB0_SX267 +MMDB0_SX357 +MMDB0_SX447 +MMDB0_SX87 +MMDG0_SI1780 +MMDG0_SI2035 +MMDG0_SI520 +MMDG0_SX160 +MMDG0_SX250 +MMDG0_SX340 +MMDG0_SX430 +MMDG0_SX70 +MMDM0_SI1311 +MMDM0_SI1941 +MMDM0_SI681 +MMDM0_SX141 +MMDM0_SX231 +MMDM0_SX321 +MMDM0_SX411 +MMDM0_SX51 +MMDM1_SI1650 +MMDM1_SI2043 +MMDM1_SI783 +MMDM1_SX153 +MMDM1_SX243 +MMDM1_SX333 +MMDM1_SX423 +MMDM1_SX63 +MMDS0_SI1343 +MMDS0_SI1973 +MMDS0_SI713 +MMDS0_SX173 +MMDS0_SX263 +MMDS0_SX353 +MMDS0_SX443 +MMDS0_SX83 +MMEA0_SI1388 +MMEA0_SI2018 +MMEA0_SI758 +MMEA0_SX128 +MMEA0_SX218 +MMEA0_SX308 +MMEA0_SX38 +MMEA0_SX398 +MMEB0_SI1357 +MMEB0_SI1987 +MMEB0_SI727 +MMEB0_SX187 +MMEB0_SX327 +MMEB0_SX367 +MMEB0_SX7 +MMEB0_SX97 +MMGC0_SI1305 +MMGC0_SI1935 +MMGC0_SI2184 +MMGC0_SX135 +MMGC0_SX225 +MMGC0_SX315 +MMGC0_SX405 +MMGC0_SX45 +MMGG0_SI1079 +MMGG0_SI1709 +MMGG0_SI2339 +MMGG0_SX179 +MMGG0_SX269 +MMGG0_SX359 +MMGG0_SX449 +MMGG0_SX89 +MMGK0_SI1322 +MMGK0_SI1952 +MMGK0_SI692 +MMGK0_SX152 +MMGK0_SX242 +MMGK0_SX332 +MMGK0_SX422 +MMGK0_SX62 +MMJB1_SI1408 +MMJB1_SI2038 +MMJB1_SI778 +MMJB1_SX148 +MMJB1_SX238 +MMJB1_SX328 +MMJB1_SX418 +MMJB1_SX58 +MMLM0_SI1527 +MMLM0_SI2150 +MMLM0_SI897 +MMLM0_SX177 +MMLM0_SX267 +MMLM0_SX357 +MMLM0_SX447 +MMLM0_SX87 +MMPM0_SI1061 +MMPM0_SI1691 +MMPM0_SI2321 +MMPM0_SX161 +MMPM0_SX251 +MMPM0_SX341 +MMPM0_SX431 +MMPM0_SX71 +MMRP0_SI2034 +MMRP0_SI717 +MMRP0_SI774 +MMRP0_SX144 +MMRP0_SX234 +MMRP0_SX324 +MMRP0_SX414 +MMRP0_SX54 +MMSM0_SI1106 +MMSM0_SI1736 +MMSM0_SI476 +MMSM0_SX116 +MMSM0_SX206 +MMSM0_SX26 +MMSM0_SX296 +MMSM0_SX386 +MMVP0_SI1284 +MMVP0_SI1914 +MMVP0_SI654 +MMVP0_SX114 +MMVP0_SX204 +MMVP0_SX294 +MMVP0_SX347 +MMVP0_SX384 +MMWB0_SI1619 +MMWB0_SI2249 +MMWB0_SI989 +MMWB0_SX179 +MMWB0_SX269 +MMWB0_SX359 +MMWB0_SX449 +MMWB0_SX89 +MMWS0_SI1518 +MMWS0_SI559 +MMWS0_SI888 +MMWS0_SX168 +MMWS0_SX258 +MMWS0_SX348 +MMWS0_SX438 +MMWS0_SX78 +MMWS1_SI1071 +MMWS1_SI1701 +MMWS1_SI2331 +MMWS1_SX261 +MMWS1_SX27 +MMWS1_SX351 +MMWS1_SX441 +MMWS1_SX81 +MMXS0_SI2136 +MMXS0_SI629 +MMXS0_SI876 +MMXS0_SX156 +MMXS0_SX246 +MMXS0_SX336 +MMXS0_SX426 +MMXS0_SX66 +MNET0_SI1446 +MNET0_SI2076 +MNET0_SI816 +MNET0_SX186 +MNET0_SX276 +MNET0_SX366 +MNET0_SX6 +MNET0_SX96 +MNTW0_SI1068 +MNTW0_SI1698 +MNTW0_SI2328 +MNTW0_SX168 +MNTW0_SX202 +MNTW0_SX258 +MNTW0_SX348 +MNTW0_SX78 +MPAR0_SI1576 +MPAR0_SI2206 +MPAR0_SI946 +MPAR0_SX136 +MPAR0_SX226 +MPAR0_SX316 +MPAR0_SX406 +MPAR0_SX46 +MPEB0_SI1034 +MPEB0_SI1860 +MPEB0_SI600 +MPEB0_SX150 +MPEB0_SX240 +MPEB0_SX330 +MPEB0_SX420 +MPEB0_SX60 +MPFU0_SI1258 +MPFU0_SI1888 +MPFU0_SI628 +MPFU0_SX178 +MPFU0_SX268 +MPFU0_SX358 +MPFU0_SX448 +MPFU0_SX88 +MPGH0_SI1554 +MPGH0_SI675 +MPGH0_SI924 +MPGH0_SX114 +MPGH0_SX204 +MPGH0_SX24 +MPGH0_SX294 +MPGH0_SX384 +MPGR0_SI1410 +MPGR0_SI2040 +MPGR0_SI780 +MPGR0_SX150 +MPGR0_SX240 +MPGR0_SX330 +MPGR0_SX420 +MPGR0_SX60 +MPGR1_SI1269 +MPGR1_SI1499 +MPGR1_SI2129 +MPGR1_SX149 +MPGR1_SX239 +MPGR1_SX329 +MPGR1_SX419 +MPGR1_SX59 +MPMB0_SI1501 +MPMB0_SI2131 +MPMB0_SI871 +MPMB0_SX151 +MPMB0_SX241 +MPMB0_SX331 +MPMB0_SX421 +MPMB0_SX61 +MPPC0_SI1412 +MPPC0_SI2042 +MPPC0_SI782 +MPPC0_SX152 +MPPC0_SX242 +MPPC0_SX332 +MPPC0_SX422 +MPPC0_SX62 +MPRB0_SI1205 +MPRB0_SI1215 +MPRB0_SI575 +MPRB0_SX125 +MPRB0_SX215 +MPRB0_SX305 +MPRB0_SX35 +MPRB0_SX395 +MPRD0_SI1431 +MPRD0_SI2061 +MPRD0_SI801 +MPRD0_SX171 +MPRD0_SX261 +MPRD0_SX351 +MPRD0_SX441 +MPRD0_SX81 +MPRK0_SI1097 +MPRK0_SI1727 +MPRK0_SI467 +MPRK0_SX107 +MPRK0_SX17 +MPRK0_SX197 +MPRK0_SX287 +MPRK0_SX377 +MPRT0_SI1210 +MPRT0_SI495 +MPRT0_SI580 +MPRT0_SX130 +MPRT0_SX220 +MPRT0_SX310 +MPRT0_SX40 +MPRT0_SX400 +MPSW0_SI1067 +MPSW0_SI1697 +MPSW0_SI2327 +MPSW0_SX167 +MPSW0_SX24 +MPSW0_SX257 +MPSW0_SX437 +MPSW0_SX77 +MRAB0_SI1224 +MRAB0_SI1854 +MRAB0_SI594 +MRAB0_SX144 +MRAB0_SX234 +MRAB0_SX324 +MRAB0_SX414 +MRAB0_SX54 +MRAB1_SI1478 +MRAB1_SI2108 +MRAB1_SI848 +MRAB1_SX128 +MRAB1_SX218 +MRAB1_SX308 +MRAB1_SX38 +MRAB1_SX398 +MRAI0_SI1954 +MRAI0_SI2052 +MRAI0_SI792 +MRAI0_SX162 +MRAI0_SX252 +MRAI0_SX342 +MRAI0_SX432 +MRAI0_SX72 +MRAM0_SI1275 +MRAM0_SI1905 +MRAM0_SI1951 +MRAM0_SX105 +MRAM0_SX15 +MRAM0_SX195 +MRAM0_SX285 +MRAM0_SX375 +MRAV0_SI1008 +MRAV0_SI1638 +MRAV0_SI2268 +MRAV0_SX108 +MRAV0_SX18 +MRAV0_SX198 +MRAV0_SX288 +MRAV0_SX378 +MRBC0_SI1665 +MRBC0_SI1859 +MRBC0_SI599 +MRBC0_SX149 +MRBC0_SX239 +MRBC0_SX329 +MRBC0_SX419 +MRBC0_SX59 +MRCG0_SI1428 +MRCG0_SI2058 +MRCG0_SI798 +MRCG0_SX168 +MRCG0_SX258 +MRCG0_SX348 +MRCG0_SX438 +MRCG0_SX78 +MRCW0_SI1371 +MRCW0_SI2001 +MRCW0_SI741 +MRCW0_SX111 +MRCW0_SX201 +MRCW0_SX21 +MRCW0_SX291 +MRCW0_SX381 +MRDD0_SI1050 +MRDD0_SI1680 +MRDD0_SI2310 +MRDD0_SX150 +MRDD0_SX240 +MRDD0_SX277 +MRDD0_SX330 +MRDD0_SX60 +MRDM0_SI1044 +MRDM0_SI1595 +MRDM0_SI965 +MRDM0_SX155 +MRDM0_SX245 +MRDM0_SX335 +MRDM0_SX425 +MRDM0_SX65 +MRDS0_SI1167 +MRDS0_SI1797 +MRDS0_SI537 +MRDS0_SX177 +MRDS0_SX267 +MRDS0_SX357 +MRDS0_SX447 +MRDS0_SX87 +MREE0_SI1104 +MREE0_SI1734 +MREE0_SI1959 +MREE0_SX114 +MREE0_SX204 +MREE0_SX24 +MREE0_SX294 +MREE0_SX384 +MREH1_SI1599 +MREH1_SI2229 +MREH1_SI969 +MREH1_SX159 +MREH1_SX249 +MREH1_SX339 +MREH1_SX429 +MREH1_SX69 +MREM0_SI1591 +MREM0_SI511 +MREM0_SI961 +MREM0_SX151 +MREM0_SX241 +MREM0_SX331 +MREM0_SX421 +MREM0_SX61 +MREW1_SI1500 +MREW1_SI2130 +MREW1_SI870 +MREW1_SX150 +MREW1_SX240 +MREW1_SX330 +MREW1_SX420 +MREW1_SX60 +MRFK0_SI1076 +MRFK0_SI1706 +MRFK0_SI2336 +MRFK0_SX176 +MRFK0_SX266 +MRFK0_SX356 +MRFK0_SX446 +MRFK0_SX86 +MRFL0_SI1156 +MRFL0_SI1786 +MRFL0_SI526 +MRFL0_SX166 +MRFL0_SX256 +MRFL0_SX346 +MRFL0_SX436 +MRFL0_SX76 +MRGM0_SI1162 +MRGM0_SI1792 +MRGM0_SI532 +MRGM0_SX172 +MRGM0_SX262 +MRGM0_SX416 +MRGM0_SX442 +MRGM0_SX82 +MRGS0_SI1356 +MRGS0_SI1986 +MRGS0_SI726 +MRGS0_SX186 +MRGS0_SX276 +MRGS0_SX366 +MRGS0_SX6 +MRGS0_SX96 +MRHL0_SI1515 +MRHL0_SI2145 +MRHL0_SI885 +MRHL0_SX165 +MRHL0_SX255 +MRHL0_SX345 +MRHL0_SX435 +MRHL0_SX75 +MRJB1_SI1020 +MRJB1_SI1413 +MRJB1_SI2021 +MRJB1_SX120 +MRJB1_SX210 +MRJB1_SX30 +MRJB1_SX300 +MRJB1_SX390 +MRJH0_SI1519 +MRJH0_SI889 +MRJH0_SI914 +MRJH0_SX169 +MRJH0_SX259 +MRJH0_SX307 +MRJH0_SX439 +MRJH0_SX79 +MRJM0_SI1095 +MRJM0_SI1228 +MRJM0_SI1858 +MRJM0_SX148 +MRJM0_SX238 +MRJM0_SX328 +MRJM0_SX418 +MRJM0_SX58 +MRJM1_SI1298 +MRJM1_SI1928 +MRJM1_SI668 +MRJM1_SX128 +MRJM1_SX218 +MRJM1_SX308 +MRJM1_SX38 +MRJM1_SX398 +MRJT0_SI1498 +MRJT0_SI1805 +MRJT0_SI868 +MRJT0_SX148 +MRJT0_SX238 +MRJT0_SX328 +MRJT0_SX418 +MRJT0_SX58 +MRKM0_SI1267 +MRKM0_SI1391 +MRKM0_SI637 +MRKM0_SX187 +MRKM0_SX277 +MRKM0_SX367 +MRKM0_SX7 +MRKM0_SX97 +MRLD0_SI1594 +MRLD0_SI2224 +MRLD0_SI964 +MRLD0_SX154 +MRLD0_SX244 +MRLD0_SX334 +MRLD0_SX424 +MRLD0_SX64 +MRLJ0_SI1420 +MRLJ0_SI2050 +MRLJ0_SI790 +MRLJ0_SX160 +MRLJ0_SX250 +MRLJ0_SX340 +MRLJ0_SX430 +MRLJ0_SX70 +MRLJ1_SI1671 +MRLJ1_SI2301 +MRLJ1_SI2332 +MRLJ1_SX141 +MRLJ1_SX231 +MRLJ1_SX321 +MRLJ1_SX411 +MRLJ1_SX51 +MRLK0_SI1468 +MRLK0_SI2140 +MRLK0_SI843 +MRLK0_SX123 +MRLK0_SX213 +MRLK0_SX303 +MRLK0_SX33 +MRLK0_SX393 +MRLR0_SI1196 +MRLR0_SI1826 +MRLR0_SI566 +MRLR0_SX116 +MRLR0_SX206 +MRLR0_SX26 +MRLR0_SX296 +MRLR0_SX386 +MRMB0_SI1581 +MRMB0_SI2211 +MRMB0_SI951 +MRMB0_SX141 +MRMB0_SX231 +MRMB0_SX321 +MRMB0_SX411 +MRMB0_SX51 +MRMG0_SI1080 +MRMG0_SI1710 +MRMG0_SI2340 +MRMG0_SX180 +MRMG0_SX270 +MRMG0_SX360 +MRMG0_SX450 +MRMG0_SX90 +MRMH0_SI1021 +MRMH0_SI1349 +MRMH0_SI2281 +MRMH0_SX121 +MRMH0_SX211 +MRMH0_SX301 +MRMH0_SX31 +MRMH0_SX391 +MRML0_SI1421 +MRML0_SI2051 +MRML0_SI791 +MRML0_SX161 +MRML0_SX251 +MRML0_SX341 +MRML0_SX431 +MRML0_SX71 +MRMS0_SI1113 +MRMS0_SI2057 +MRMS0_SI2100 +MRMS0_SX120 +MRMS0_SX210 +MRMS0_SX30 +MRMS0_SX300 +MRMS0_SX390 +MRPC1_SI1482 +MRPC1_SI2026 +MRPC1_SI2112 +MRPC1_SX132 +MRPC1_SX222 +MRPC1_SX312 +MRPC1_SX402 +MRPC1_SX42 +MRRE0_SI1334 +MRRE0_SI704 +MRRE0_SI952 +MRRE0_SX164 +MRRE0_SX254 +MRRE0_SX344 +MRRE0_SX434 +MRRE0_SX74 +MRSO0_SI1206 +MRSO0_SI1659 +MRSO0_SI2289 +MRSO0_SX129 +MRSO0_SX219 +MRSO0_SX309 +MRSO0_SX39 +MRSO0_SX399 +MRSP0_SI1429 +MRSP0_SI2059 +MRSP0_SI799 +MRSP0_SX169 +MRSP0_SX196 +MRSP0_SX259 +MRSP0_SX439 +MRSP0_SX79 +MRTC0_SI1458 +MRTC0_SI2088 +MRTC0_SI828 +MRTC0_SX108 +MRTC0_SX18 +MRTC0_SX198 +MRTC0_SX288 +MRTC0_SX378 +MRTJ0_SI1551 +MRTJ0_SI2032 +MRTJ0_SI772 +MRTJ0_SX142 +MRTJ0_SX232 +MRTJ0_SX322 +MRTJ0_SX412 +MRTJ0_SX52 +MRVG0_SI1140 +MRVG0_SI1770 +MRVG0_SI510 +MRVG0_SX150 +MRVG0_SX240 +MRVG0_SX330 +MRVG0_SX420 +MRVG0_SX60 +MRWA0_SI1603 +MRWA0_SI2233 +MRWA0_SI973 +MRWA0_SX163 +MRWA0_SX253 +MRWA0_SX343 +MRWA0_SX433 +MRWA0_SX73 +MRWS0_SI1102 +MRWS0_SI1732 +MRWS0_SI472 +MRWS0_SX112 +MRWS0_SX202 +MRWS0_SX22 +MRWS0_SX292 +MRWS0_SX382 +MRXB0_SI1585 +MRXB0_SI2215 +MRXB0_SI955 +MRXB0_SX145 +MRXB0_SX235 +MRXB0_SX325 +MRXB0_SX415 +MRXB0_SX55 +MSAH1_SI1049 +MSAH1_SI1679 +MSAH1_SI2309 +MSAH1_SX149 +MSAH1_SX239 +MSAH1_SX329 +MSAH1_SX419 +MSAH1_SX59 +MSAS0_SI1376 +MSAS0_SI2006 +MSAS0_SI746 +MSAS0_SX116 +MSAS0_SX206 +MSAS0_SX26 +MSAS0_SX296 +MSAS0_SX386 +MSAT0_SI1526 +MSAT0_SI2156 +MSAT0_SI896 +MSAT0_SX176 +MSAT0_SX266 +MSAT0_SX356 +MSAT0_SX446 +MSAT0_SX86 +MSAT1_SI1073 +MSAT1_SI1703 +MSAT1_SI2333 +MSAT1_SX173 +MSAT1_SX263 +MSAT1_SX353 +MSAT1_SX443 +MSAT1_SX83 +MSDB0_SI1007 +MSDB0_SI1637 +MSDB0_SI2267 +MSDB0_SX107 +MSDB0_SX17 +MSDB0_SX197 +MSDB0_SX287 +MSDB0_SX377 +MSDH0_SI2113 +MSDH0_SI2240 +MSDH0_SI980 +MSDH0_SX170 +MSDH0_SX260 +MSDH0_SX350 +MSDH0_SX440 +MSDH0_SX80 +MSDS0_SI1077 +MSDS0_SI1707 +MSDS0_SI2337 +MSDS0_SX177 +MSDS0_SX267 +MSDS0_SX357 +MSDS0_SX447 +MSDS0_SX87 +MSEM1_SI1440 +MSEM1_SI2070 +MSEM1_SI810 +MSEM1_SX180 +MSEM1_SX270 +MSEM1_SX360 +MSEM1_SX450 +MSEM1_SX90 +MSES0_SI1589 +MSES0_SI2216 +MSES0_SI2219 +MSES0_SX149 +MSES0_SX239 +MSES0_SX329 +MSES0_SX419 +MSES0_SX59 +MSFH0_SI1216 +MSFH0_SI1738 +MSFH0_SI586 +MSFH0_SX136 +MSFH0_SX226 +MSFH0_SX316 +MSFH0_SX406 +MSFH0_SX46 +MSFV0_SI1262 +MSFV0_SI1892 +MSFV0_SI632 +MSFV0_SX182 +MSFV0_SX272 +MSFV0_SX362 +MSFV0_SX452 +MSFV0_SX92 +MSJK0_SI1596 +MSJK0_SI2226 +MSJK0_SI966 +MSJK0_SX156 +MSJK0_SX246 +MSJK0_SX336 +MSJK0_SX426 +MSJK0_SX66 +MSMC0_SI1907 +MSMC0_SI509 +MSMC0_SI647 +MSMC0_SX107 +MSMC0_SX17 +MSMC0_SX197 +MSMC0_SX287 +MSMC0_SX377 +MSMR0_SI1150 +MSMR0_SI1405 +MSMR0_SI775 +MSMR0_SX145 +MSMR0_SX235 +MSMR0_SX325 +MSMR0_SX415 +MSMR0_SX55 +MSMS0_SI1433 +MSMS0_SI2063 +MSMS0_SI803 +MSMS0_SX173 +MSMS0_SX263 +MSMS0_SX353 +MSMS0_SX443 +MSMS0_SX83 +MSRG0_SI1221 +MSRG0_SI1851 +MSRG0_SI591 +MSRG0_SX141 +MSRG0_SX231 +MSRG0_SX321 +MSRG0_SX411 +MSRG0_SX51 +MSRR0_SI1131 +MSRR0_SI1761 +MSRR0_SI501 +MSRR0_SX141 +MSRR0_SX231 +MSRR0_SX30 +MSRR0_SX411 +MSRR0_SX51 +MSTF0_SI1396 +MSTF0_SI766 +MSTF0_SI852 +MSTF0_SX136 +MSTF0_SX226 +MSTF0_SX316 +MSTF0_SX406 +MSTF0_SX46 +MSVS0_SI1568 +MSVS0_SI2198 +MSVS0_SI938 +MSVS0_SX128 +MSVS0_SX218 +MSVS0_SX308 +MSVS0_SX38 +MSVS0_SX398 +MTAB0_SI1572 +MTAB0_SI2202 +MTAB0_SI942 +MTAB0_SX132 +MTAB0_SX222 +MTAB0_SX312 +MTAB0_SX402 +MTAB0_SX42 +MTAS0_SI1385 +MTAS0_SI2015 +MTAS0_SI755 +MTAS0_SX125 +MTAS0_SX215 +MTAS0_SX305 +MTAS0_SX35 +MTAS0_SX395 +MTAT0_SI1110 +MTAT0_SI1740 +MTAT0_SI811 +MTAT0_SX120 +MTAT0_SX210 +MTAT0_SX30 +MTAT0_SX300 +MTAT0_SX390 +MTAT1_SI1409 +MTAT1_SI1627 +MTAT1_SI779 +MTAT1_SX149 +MTAT1_SX239 +MTAT1_SX329 +MTAT1_SX419 +MTAT1_SX59 +MTBC0_SI1173 +MTBC0_SI1803 +MTBC0_SI543 +MTBC0_SX183 +MTBC0_SX273 +MTBC0_SX347 +MTBC0_SX363 +MTBC0_SX93 +MTCS0_SI1972 +MTCS0_SI2265 +MTCS0_SI712 +MTCS0_SX172 +MTCS0_SX262 +MTCS0_SX352 +MTCS0_SX442 +MTCS0_SX82 +MTDB0_SI1401 +MTDB0_SI2031 +MTDB0_SI771 +MTDB0_SX141 +MTDB0_SX231 +MTDB0_SX321 +MTDB0_SX411 +MTDB0_SX51 +MTDP0_SI1274 +MTDP0_SI1521 +MTDP0_SI2151 +MTDP0_SX171 +MTDP0_SX261 +MTDP0_SX351 +MTDP0_SX441 +MTDP0_SX81 +MTER0_SI1157 +MTER0_SI1787 +MTER0_SI527 +MTER0_SX167 +MTER0_SX17 +MTER0_SX257 +MTER0_SX437 +MTER0_SX77 +MTJG0_SI1520 +MTJG0_SI2157 +MTJG0_SI890 +MTJG0_SX170 +MTJG0_SX260 +MTJG0_SX350 +MTJG0_SX440 +MTJG0_SX80 +MTJM0_SI1226 +MTJM0_SI1856 +MTJM0_SI655 +MTJM0_SX146 +MTJM0_SX236 +MTJM0_SX326 +MTJM0_SX416 +MTJM0_SX56 +MTJS0_SI1192 +MTJS0_SI1822 +MTJS0_SI562 +MTJS0_SX112 +MTJS0_SX202 +MTJS0_SX22 +MTJS0_SX292 +MTJS0_SX382 +MTJU0_SI2020 +MTJU0_SI2269 +MTJU0_SI760 +MTJU0_SX130 +MTJU0_SX220 +MTJU0_SX310 +MTJU0_SX40 +MTJU0_SX400 +MTKD0_SI1187 +MTKD0_SI1817 +MTKD0_SI630 +MTKD0_SX107 +MTKD0_SX17 +MTKD0_SX197 +MTKD0_SX287 +MTKD0_SX377 +MTKP0_SI1023 +MTKP0_SI2283 +MTKP0_SI454 +MTKP0_SX123 +MTKP0_SX213 +MTKP0_SX303 +MTKP0_SX33 +MTKP0_SX393 +MTLB0_SI1134 +MTLB0_SI1764 +MTLB0_SI504 +MTLB0_SX144 +MTLB0_SX234 +MTLB0_SX324 +MTLB0_SX414 +MTLB0_SX54 +MTLC0_SI1313 +MTLC0_SI1477 +MTLC0_SI847 +MTLC0_SX127 +MTLC0_SX217 +MTLC0_SX307 +MTLC0_SX37 +MTLC0_SX397 +MTML0_SI1065 +MTML0_SI1695 +MTML0_SI2325 +MTML0_SX165 +MTML0_SX255 +MTML0_SX345 +MTML0_SX435 +MTML0_SX75 +MTMN0_SI1064 +MTMN0_SI2324 +MTMN0_SI582 +MTMN0_SX164 +MTMN0_SX254 +MTMN0_SX344 +MTMN0_SX434 +MTMN0_SX74 +MTMT0_SI1118 +MTMT0_SI1748 +MTMT0_SI488 +MTMT0_SX128 +MTMT0_SX218 +MTMT0_SX308 +MTMT0_SX38 +MTMT0_SX398 +MTPF0_SI1235 +MTPF0_SI1865 +MTPF0_SI605 +MTPF0_SX155 +MTPF0_SX245 +MTPF0_SX335 +MTPF0_SX425 +MTPF0_SX65 +MTPG0_SI1383 +MTPG0_SI2013 +MTPG0_SI753 +MTPG0_SX123 +MTPG0_SX213 +MTPG0_SX303 +MTPG0_SX33 +MTPG0_SX393 +MTPP0_SI1508 +MTPP0_SI2138 +MTPP0_SI878 +MTPP0_SX158 +MTPP0_SX248 +MTPP0_SX338 +MTPP0_SX428 +MTPP0_SX68 +MTPR0_SI1600 +MTPR0_SI2230 +MTPR0_SI506 +MTPR0_SX160 +MTPR0_SX250 +MTPR0_SX340 +MTPR0_SX430 +MTPR0_SX70 +MTQC0_SI1441 +MTQC0_SI2071 +MTQC0_SI480 +MTQC0_SX181 +MTQC0_SX271 +MTQC0_SX361 +MTQC0_SX451 +MTQC0_SX91 +MTRC0_SI1623 +MTRC0_SI589 +MTRC0_SI993 +MTRC0_SX170 +MTRC0_SX183 +MTRC0_SX273 +MTRC0_SX363 +MTRC0_SX93 +MTRR0_SI1548 +MTRR0_SI2178 +MTRR0_SI918 +MTRR0_SX108 +MTRR0_SX18 +MTRR0_SX198 +MTRR0_SX288 +MTRR0_SX378 +MTRT0_SI1227 +MTRT0_SI1857 +MTRT0_SI597 +MTRT0_SX147 +MTRT0_SX237 +MTRT0_SX254 +MTRT0_SX417 +MTRT0_SX57 +MTWH1_SI1512 +MTWH1_SI2142 +MTWH1_SI882 +MTWH1_SX162 +MTWH1_SX252 +MTWH1_SX342 +MTWH1_SX432 +MTWH1_SX72 +MTXS0_SI1060 +MTXS0_SI1690 +MTXS0_SI2320 +MTXS0_SX160 +MTXS0_SX250 +MTXS0_SX340 +MTXS0_SX430 +MTXS0_SX70 +MVJH0_SI1556 +MVJH0_SI2186 +MVJH0_SI926 +MVJH0_SX116 +MVJH0_SX206 +MVJH0_SX26 +MVJH0_SX296 +MVJH0_SX386 +MVLO0_SI1147 +MVLO0_SI1777 +MVLO0_SI517 +MVLO0_SX157 +MVLO0_SX247 +MVLO0_SX337 +MVLO0_SX427 +MVLO0_SX67 +MVRW0_SI1485 +MVRW0_SI2115 +MVRW0_SI855 +MVRW0_SX135 +MVRW0_SX225 +MVRW0_SX315 +MVRW0_SX405 +MVRW0_SX45 +MWAC0_SI1601 +MWAC0_SI2231 +MWAC0_SI971 +MWAC0_SX161 +MWAC0_SX251 +MWAC0_SX341 +MWAC0_SX431 +MWAC0_SX71 +MWAD0_SI1062 +MWAD0_SI1749 +MWAD0_SI2322 +MWAD0_SX162 +MWAD0_SX252 +MWAD0_SX342 +MWAD0_SX432 +MWAD0_SX72 +MWAR0_SI1045 +MWAR0_SI1675 +MWAR0_SI2305 +MWAR0_SX145 +MWAR0_SX235 +MWAR0_SX325 +MWAR0_SX415 +MWAR0_SX55 +MWCH0_SI1622 +MWCH0_SI1895 +MWCH0_SI2252 +MWCH0_SX182 +MWCH0_SX272 +MWCH0_SX362 +MWCH0_SX452 +MWCH0_SX92 +MWDK0_SI1436 +MWDK0_SI2017 +MWDK0_SI806 +MWDK0_SX176 +MWDK0_SX266 +MWDK0_SX356 +MWDK0_SX446 +MWDK0_SX86 +MWEM0_SI1320 +MWEM0_SI1393 +MWEM0_SI1950 +MWEM0_SX150 +MWEM0_SX240 +MWEM0_SX330 +MWEM0_SX420 +MWEM0_SX60 +MWGR0_SI1606 +MWGR0_SI2236 +MWGR0_SI976 +MWGR0_SX166 +MWGR0_SX256 +MWGR0_SX346 +MWGR0_SX436 +MWGR0_SX76 +MWRE0_SI1057 +MWRE0_SI1687 +MWRE0_SI2317 +MWRE0_SX157 +MWRE0_SX247 +MWRE0_SX337 +MWRE0_SX427 +MWRE0_SX67 +MWRP0_SI1443 +MWRP0_SI1525 +MWRP0_SI2073 +MWRP0_SX183 +MWRP0_SX273 +MWRP0_SX3 +MWRP0_SX363 +MWRP0_SX93 +MWSB0_SI1626 +MWSB0_SI2256 +MWSB0_SI996 +MWSB0_SX186 +MWSB0_SX276 +MWSB0_SX366 +MWSB0_SX6 +MWSB0_SX96 +MWSH0_SI1426 +MWSH0_SI2266 +MWSH0_SI796 +MWSH0_SX166 +MWSH0_SX256 +MWSH0_SX346 +MWSH0_SX436 +MWSH0_SX76 +MZMB0_SI1166 +MZMB0_SI1796 +MZMB0_SI536 +MZMB0_SX176 +MZMB0_SX266 +MZMB0_SX356 +MZMB0_SX446 +MZMB0_SX86 diff --git a/examples/wav2vec/unsupervised/config/timit_matched/valid.uid b/examples/wav2vec/unsupervised/config/timit_matched/valid.uid new file mode 100644 index 0000000000..ab5ef381ab --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_matched/valid.uid @@ -0,0 +1,400 @@ +FADG0_SI1279 +FADG0_SI1909 +FADG0_SI649 +FADG0_SX109 +FADG0_SX19 +FADG0_SX199 +FADG0_SX289 +FADG0_SX379 +FAKS0_SI1573 +FAKS0_SI2203 +FAKS0_SI943 +FAKS0_SX133 +FAKS0_SX223 +FAKS0_SX313 +FAKS0_SX403 +FAKS0_SX43 +FCAL1_SI1403 +FCAL1_SI2033 +FCAL1_SI773 +FCAL1_SX143 +FCAL1_SX233 +FCAL1_SX323 +FCAL1_SX413 +FCAL1_SX53 +FCMH0_SI1454 +FCMH0_SI2084 +FCMH0_SI824 +FCMH0_SX104 +FCMH0_SX14 +FCMH0_SX194 +FCMH0_SX284 +FCMH0_SX374 +FDAC1_SI1474 +FDAC1_SI2104 +FDAC1_SI844 +FDAC1_SX124 +FDAC1_SX214 +FDAC1_SX304 +FDAC1_SX34 +FDAC1_SX394 +FDMS0_SI1218 +FDMS0_SI1502 +FDMS0_SI1848 +FDMS0_SX138 +FDMS0_SX228 +FDMS0_SX318 +FDMS0_SX408 +FDMS0_SX48 +FDRW0_SI1283 +FDRW0_SI1423 +FDRW0_SI653 +FDRW0_SX113 +FDRW0_SX203 +FDRW0_SX23 +FDRW0_SX293 +FDRW0_SX383 +FEDW0_SI1084 +FEDW0_SI1653 +FEDW0_SI1714 +FEDW0_SX184 +FEDW0_SX274 +FEDW0_SX364 +FEDW0_SX4 +FEDW0_SX94 +FGJD0_SI1179 +FGJD0_SI549 +FGJD0_SI818 +FGJD0_SX189 +FGJD0_SX279 +FGJD0_SX369 +FGJD0_SX9 +FGJD0_SX99 +FJEM0_SI1264 +FJEM0_SI1894 +FJEM0_SI634 +FJEM0_SX184 +FJEM0_SX274 +FJEM0_SX364 +FJEM0_SX4 +FJEM0_SX94 +FJMG0_SI1181 +FJMG0_SI1811 +FJMG0_SI551 +FJMG0_SX101 +FJMG0_SX11 +FJMG0_SX191 +FJMG0_SX281 +FJMG0_SX371 +FJSJ0_SI1484 +FJSJ0_SI2114 +FJSJ0_SI854 +FJSJ0_SX134 +FJSJ0_SX224 +FJSJ0_SX314 +FJSJ0_SX404 +FJSJ0_SX44 +FKMS0_SI1490 +FKMS0_SI2120 +FKMS0_SI860 +FKMS0_SX140 +FKMS0_SX230 +FKMS0_SX320 +FKMS0_SX410 +FKMS0_SX50 +FMAH0_SI1289 +FMAH0_SI1919 +FMAH0_SI659 +FMAH0_SX119 +FMAH0_SX209 +FMAH0_SX29 +FMAH0_SX299 +FMAH0_SX389 +FMML0_SI1040 +FMML0_SI1670 +FMML0_SI2300 +FMML0_SX140 +FMML0_SX230 +FMML0_SX320 +FMML0_SX410 +FMML0_SX50 +FNMR0_SI1399 +FNMR0_SI2029 +FNMR0_SI769 +FNMR0_SX139 +FNMR0_SX229 +FNMR0_SX319 +FNMR0_SX409 +FNMR0_SX49 +FREW0_SI1030 +FREW0_SI1280 +FREW0_SI1910 +FREW0_SX110 +FREW0_SX20 +FREW0_SX200 +FREW0_SX290 +FREW0_SX380 +FSEM0_SI1198 +FSEM0_SI1828 +FSEM0_SI568 +FSEM0_SX118 +FSEM0_SX208 +FSEM0_SX28 +FSEM0_SX298 +FSEM0_SX388 +MAJC0_SI1946 +MAJC0_SI2095 +MAJC0_SI835 +MAJC0_SX115 +MAJC0_SX205 +MAJC0_SX25 +MAJC0_SX295 +MAJC0_SX385 +MBDG0_SI1463 +MBDG0_SI2093 +MBDG0_SI833 +MBDG0_SX113 +MBDG0_SX203 +MBDG0_SX23 +MBDG0_SX293 +MBDG0_SX383 +MBNS0_SI1220 +MBNS0_SI1850 +MBNS0_SI590 +MBNS0_SX140 +MBNS0_SX230 +MBNS0_SX320 +MBNS0_SX410 +MBNS0_SX50 +MBWM0_SI1304 +MBWM0_SI1934 +MBWM0_SI674 +MBWM0_SX134 +MBWM0_SX224 +MBWM0_SX314 +MBWM0_SX404 +MBWM0_SX44 +MCSH0_SI1549 +MCSH0_SI2179 +MCSH0_SI919 +MCSH0_SX109 +MCSH0_SX19 +MCSH0_SX199 +MCSH0_SX289 +MCSH0_SX379 +MDLF0_SI1583 +MDLF0_SI2213 +MDLF0_SI953 +MDLF0_SX143 +MDLF0_SX233 +MDLF0_SX323 +MDLF0_SX413 +MDLF0_SX53 +MDLS0_SI1628 +MDLS0_SI2258 +MDLS0_SI998 +MDLS0_SX188 +MDLS0_SX278 +MDLS0_SX368 +MDLS0_SX8 +MDLS0_SX98 +MDVC0_SI2174 +MDVC0_SI2196 +MDVC0_SI936 +MDVC0_SX126 +MDVC0_SX216 +MDVC0_SX306 +MDVC0_SX36 +MDVC0_SX396 +MERS0_SI1019 +MERS0_SI1649 +MERS0_SI497 +MERS0_SX119 +MERS0_SX209 +MERS0_SX29 +MERS0_SX299 +MERS0_SX389 +MGJF0_SI1901 +MGJF0_SI641 +MGJF0_SI776 +MGJF0_SX101 +MGJF0_SX11 +MGJF0_SX191 +MGJF0_SX281 +MGJF0_SX371 +MGLB0_SI1534 +MGLB0_SI2164 +MGLB0_SI904 +MGLB0_SX184 +MGLB0_SX274 +MGLB0_SX364 +MGLB0_SX4 +MGLB0_SX94 +MGWT0_SI1539 +MGWT0_SI2169 +MGWT0_SI909 +MGWT0_SX189 +MGWT0_SX279 +MGWT0_SX369 +MGWT0_SX9 +MGWT0_SX99 +MJAR0_SI1988 +MJAR0_SI2247 +MJAR0_SI728 +MJAR0_SX188 +MJAR0_SX278 +MJAR0_SX368 +MJAR0_SX8 +MJAR0_SX98 +MJFC0_SI1033 +MJFC0_SI1663 +MJFC0_SI2293 +MJFC0_SX133 +MJFC0_SX223 +MJFC0_SX313 +MJFC0_SX403 +MJFC0_SX43 +MJSW0_SI1010 +MJSW0_SI1640 +MJSW0_SI2270 +MJSW0_SX110 +MJSW0_SX20 +MJSW0_SX200 +MJSW0_SX290 +MJSW0_SX380 +MMDB1_SI1625 +MMDB1_SI2255 +MMDB1_SI995 +MMDB1_SX185 +MMDB1_SX275 +MMDB1_SX365 +MMDB1_SX5 +MMDB1_SX95 +MMDM2_SI1452 +MMDM2_SI1555 +MMDM2_SI2082 +MMDM2_SX102 +MMDM2_SX12 +MMDM2_SX192 +MMDM2_SX282 +MMDM2_SX372 +MMJR0_SI1648 +MMJR0_SI2166 +MMJR0_SI2278 +MMJR0_SX118 +MMJR0_SX208 +MMJR0_SX28 +MMJR0_SX298 +MMJR0_SX388 +MMWH0_SI1089 +MMWH0_SI1301 +MMWH0_SI459 +MMWH0_SX189 +MMWH0_SX279 +MMWH0_SX369 +MMWH0_SX9 +MMWH0_SX99 +MPDF0_SI1542 +MPDF0_SI2172 +MPDF0_SI912 +MPDF0_SX102 +MPDF0_SX12 +MPDF0_SX192 +MPDF0_SX282 +MPDF0_SX372 +MRCS0_SI1223 +MRCS0_SI1853 +MRCS0_SI593 +MRCS0_SX143 +MRCS0_SX233 +MRCS0_SX323 +MRCS0_SX413 +MRCS0_SX53 +MREB0_SI1375 +MREB0_SI2005 +MREB0_SI745 +MREB0_SX115 +MREB0_SX205 +MREB0_SX25 +MREB0_SX295 +MREB0_SX385 +MRJM4_SI1489 +MRJM4_SI2119 +MRJM4_SI859 +MRJM4_SX139 +MRJM4_SX229 +MRJM4_SX319 +MRJM4_SX409 +MRJM4_SX49 +MRJR0_SI1182 +MRJR0_SI1812 +MRJR0_SI2313 +MRJR0_SX102 +MRJR0_SX12 +MRJR0_SX192 +MRJR0_SX282 +MRJR0_SX372 +MROA0_SI1307 +MROA0_SI1970 +MROA0_SI677 +MROA0_SX137 +MROA0_SX227 +MROA0_SX317 +MROA0_SX407 +MROA0_SX47 +MRTK0_SI1093 +MRTK0_SI1723 +MRTK0_SI1750 +MRTK0_SX103 +MRTK0_SX13 +MRTK0_SX193 +MRTK0_SX283 +MRTK0_SX373 +MRWS1_SI1130 +MRWS1_SI1496 +MRWS1_SI500 +MRWS1_SX140 +MRWS1_SX230 +MRWS1_SX320 +MRWS1_SX410 +MRWS1_SX50 +MTAA0_SI1285 +MTAA0_SI1915 +MTAA0_SI596 +MTAA0_SX115 +MTAA0_SX205 +MTAA0_SX25 +MTAA0_SX295 +MTAA0_SX385 +MTDT0_SI1994 +MTDT0_SI2254 +MTDT0_SI994 +MTDT0_SX184 +MTDT0_SX274 +MTDT0_SX364 +MTDT0_SX4 +MTDT0_SX94 +MTEB0_SI1133 +MTEB0_SI2064 +MTEB0_SI503 +MTEB0_SX143 +MTEB0_SX233 +MTEB0_SX323 +MTEB0_SX413 +MTEB0_SX53 +MTHC0_SI1015 +MTHC0_SI1645 +MTHC0_SI2275 +MTHC0_SX115 +MTHC0_SX205 +MTHC0_SX25 +MTHC0_SX295 +MTHC0_SX385 +MWJG0_SI1124 +MWJG0_SI1754 +MWJG0_SI494 +MWJG0_SX134 +MWJG0_SX224 +MWJG0_SX314 +MWJG0_SX404 +MWJG0_SX44 diff --git a/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid b/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid new file mode 100644 index 0000000000..e3967e4242 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid @@ -0,0 +1,1680 @@ +FADG0_SA1 +FADG0_SA2 +FADG0_SI1279 +FADG0_SI1909 +FADG0_SI649 +FADG0_SX109 +FADG0_SX19 +FADG0_SX199 +FADG0_SX289 +FADG0_SX379 +FAKS0_SA1 +FAKS0_SA2 +FAKS0_SI1573 +FAKS0_SI2203 +FAKS0_SI943 +FAKS0_SX133 +FAKS0_SX223 +FAKS0_SX313 +FAKS0_SX403 +FAKS0_SX43 +FASW0_SA1 +FASW0_SA2 +FASW0_SI1550 +FASW0_SI2180 +FASW0_SI920 +FASW0_SX110 +FASW0_SX20 +FASW0_SX200 +FASW0_SX290 +FASW0_SX380 +FAWF0_SA1 +FAWF0_SA2 +FAWF0_SI1000 +FAWF0_SI1630 +FAWF0_SI2260 +FAWF0_SX10 +FAWF0_SX100 +FAWF0_SX190 +FAWF0_SX280 +FAWF0_SX370 +FCAL1_SA1 +FCAL1_SA2 +FCAL1_SI1403 +FCAL1_SI2033 +FCAL1_SI773 +FCAL1_SX143 +FCAL1_SX233 +FCAL1_SX323 +FCAL1_SX413 +FCAL1_SX53 +FCAU0_SA1 +FCAU0_SA2 +FCAU0_SI1037 +FCAU0_SI1667 +FCAU0_SI2297 +FCAU0_SX137 +FCAU0_SX227 +FCAU0_SX317 +FCAU0_SX407 +FCAU0_SX47 +FCFT0_SA1 +FCFT0_SA2 +FCFT0_SI1178 +FCFT0_SI1808 +FCFT0_SI548 +FCFT0_SX188 +FCFT0_SX278 +FCFT0_SX368 +FCFT0_SX8 +FCFT0_SX98 +FCMH0_SA1 +FCMH0_SA2 +FCMH0_SI1454 +FCMH0_SI2084 +FCMH0_SI824 +FCMH0_SX104 +FCMH0_SX14 +FCMH0_SX194 +FCMH0_SX284 +FCMH0_SX374 +FCMH1_SA1 +FCMH1_SA2 +FCMH1_SI1493 +FCMH1_SI2123 +FCMH1_SI863 +FCMH1_SX143 +FCMH1_SX233 +FCMH1_SX323 +FCMH1_SX413 +FCMH1_SX53 +FCMR0_SA1 +FCMR0_SA2 +FCMR0_SI1105 +FCMR0_SI1735 +FCMR0_SI475 +FCMR0_SX115 +FCMR0_SX205 +FCMR0_SX25 +FCMR0_SX295 +FCMR0_SX385 +FCRH0_SA1 +FCRH0_SA2 +FCRH0_SI1088 +FCRH0_SI1718 +FCRH0_SI458 +FCRH0_SX188 +FCRH0_SX278 +FCRH0_SX368 +FCRH0_SX8 +FCRH0_SX98 +FDAC1_SA1 +FDAC1_SA2 +FDAC1_SI1474 +FDAC1_SI2104 +FDAC1_SI844 +FDAC1_SX124 +FDAC1_SX214 +FDAC1_SX304 +FDAC1_SX34 +FDAC1_SX394 +FDHC0_SA1 +FDHC0_SA2 +FDHC0_SI1559 +FDHC0_SI2189 +FDHC0_SI929 +FDHC0_SX119 +FDHC0_SX209 +FDHC0_SX29 +FDHC0_SX299 +FDHC0_SX389 +FDMS0_SA1 +FDMS0_SA2 +FDMS0_SI1218 +FDMS0_SI1502 +FDMS0_SI1848 +FDMS0_SX138 +FDMS0_SX228 +FDMS0_SX318 +FDMS0_SX408 +FDMS0_SX48 +FDRD1_SA1 +FDRD1_SA2 +FDRD1_SI1544 +FDRD1_SI1566 +FDRD1_SI2149 +FDRD1_SX104 +FDRD1_SX14 +FDRD1_SX194 +FDRD1_SX284 +FDRD1_SX374 +FDRW0_SA1 +FDRW0_SA2 +FDRW0_SI1283 +FDRW0_SI1423 +FDRW0_SI653 +FDRW0_SX113 +FDRW0_SX203 +FDRW0_SX23 +FDRW0_SX293 +FDRW0_SX383 +FEDW0_SA1 +FEDW0_SA2 +FEDW0_SI1084 +FEDW0_SI1653 +FEDW0_SI1714 +FEDW0_SX184 +FEDW0_SX274 +FEDW0_SX364 +FEDW0_SX4 +FEDW0_SX94 +FELC0_SA1 +FELC0_SA2 +FELC0_SI1386 +FELC0_SI2016 +FELC0_SI756 +FELC0_SX126 +FELC0_SX216 +FELC0_SX306 +FELC0_SX36 +FELC0_SX396 +FGJD0_SA1 +FGJD0_SA2 +FGJD0_SI1179 +FGJD0_SI549 +FGJD0_SI818 +FGJD0_SX189 +FGJD0_SX279 +FGJD0_SX369 +FGJD0_SX9 +FGJD0_SX99 +FGMD0_SA1 +FGMD0_SA2 +FGMD0_SI1943 +FGMD0_SI2107 +FGMD0_SI683 +FGMD0_SX143 +FGMD0_SX233 +FGMD0_SX323 +FGMD0_SX413 +FGMD0_SX53 +FGWR0_SA1 +FGWR0_SA2 +FGWR0_SI1578 +FGWR0_SI2208 +FGWR0_SI948 +FGWR0_SX138 +FGWR0_SX228 +FGWR0_SX318 +FGWR0_SX408 +FGWR0_SX48 +FHES0_SA1 +FHES0_SA2 +FHES0_SI1109 +FHES0_SI1739 +FHES0_SI479 +FHES0_SX119 +FHES0_SX209 +FHES0_SX29 +FHES0_SX299 +FHES0_SX389 +FHEW0_SA1 +FHEW0_SA2 +FHEW0_SI2023 +FHEW0_SI690 +FHEW0_SI763 +FHEW0_SX133 +FHEW0_SX223 +FHEW0_SX313 +FHEW0_SX403 +FHEW0_SX43 +FISB0_SA1 +FISB0_SA2 +FISB0_SI1579 +FISB0_SI2209 +FISB0_SI949 +FISB0_SX139 +FISB0_SX229 +FISB0_SX319 +FISB0_SX409 +FISB0_SX49 +FJAS0_SA1 +FJAS0_SA2 +FJAS0_SI1400 +FJAS0_SI2030 +FJAS0_SI770 +FJAS0_SX140 +FJAS0_SX230 +FJAS0_SX320 +FJAS0_SX410 +FJAS0_SX50 +FJCS0_SA1 +FJCS0_SA2 +FJCS0_SI1309 +FJCS0_SI1833 +FJCS0_SI1939 +FJCS0_SX139 +FJCS0_SX229 +FJCS0_SX319 +FJCS0_SX409 +FJCS0_SX49 +FJEM0_SA1 +FJEM0_SA2 +FJEM0_SI1264 +FJEM0_SI1894 +FJEM0_SI634 +FJEM0_SX184 +FJEM0_SX274 +FJEM0_SX364 +FJEM0_SX4 +FJEM0_SX94 +FJLM0_SA1 +FJLM0_SA2 +FJLM0_SI1043 +FJLM0_SI1673 +FJLM0_SI2303 +FJLM0_SX143 +FJLM0_SX233 +FJLM0_SX323 +FJLM0_SX413 +FJLM0_SX53 +FJMG0_SA1 +FJMG0_SA2 +FJMG0_SI1181 +FJMG0_SI1811 +FJMG0_SI551 +FJMG0_SX101 +FJMG0_SX11 +FJMG0_SX191 +FJMG0_SX281 +FJMG0_SX371 +FJRE0_SA1 +FJRE0_SA2 +FJRE0_SI1116 +FJRE0_SI1587 +FJRE0_SI1746 +FJRE0_SX126 +FJRE0_SX216 +FJRE0_SX306 +FJRE0_SX36 +FJRE0_SX396 +FJSA0_SA1 +FJSA0_SA2 +FJSA0_SI1379 +FJSA0_SI2009 +FJSA0_SI749 +FJSA0_SX119 +FJSA0_SX209 +FJSA0_SX29 +FJSA0_SX299 +FJSA0_SX389 +FJSJ0_SA1 +FJSJ0_SA2 +FJSJ0_SI1484 +FJSJ0_SI2114 +FJSJ0_SI854 +FJSJ0_SX134 +FJSJ0_SX224 +FJSJ0_SX314 +FJSJ0_SX404 +FJSJ0_SX44 +FJWB0_SA1 +FJWB0_SA2 +FJWB0_SI1265 +FJWB0_SI635 +FJWB0_SI992 +FJWB0_SX185 +FJWB0_SX275 +FJWB0_SX365 +FJWB0_SX5 +FJWB0_SX95 +FKMS0_SA1 +FKMS0_SA2 +FKMS0_SI1490 +FKMS0_SI2120 +FKMS0_SI860 +FKMS0_SX140 +FKMS0_SX230 +FKMS0_SX320 +FKMS0_SX410 +FKMS0_SX50 +FLAS0_SA1 +FLAS0_SA2 +FLAS0_SI1026 +FLAS0_SI1488 +FLAS0_SI858 +FLAS0_SX138 +FLAS0_SX228 +FLAS0_SX318 +FLAS0_SX408 +FLAS0_SX48 +FLBW0_SA1 +FLBW0_SA2 +FLBW0_SI1219 +FLBW0_SI1849 +FLBW0_SI2253 +FLBW0_SX139 +FLBW0_SX229 +FLBW0_SX319 +FLBW0_SX409 +FLBW0_SX49 +FLKD0_SA1 +FLKD0_SA2 +FLKD0_SI1369 +FLKD0_SI739 +FLKD0_SI894 +FLKD0_SX109 +FLKD0_SX19 +FLKD0_SX199 +FLKD0_SX289 +FLKD0_SX379 +FLNH0_SA1 +FLNH0_SA2 +FLNH0_SI1214 +FLNH0_SI584 +FLNH0_SI941 +FLNH0_SX134 +FLNH0_SX224 +FLNH0_SX314 +FLNH0_SX404 +FLNH0_SX44 +FMAF0_SA1 +FMAF0_SA2 +FMAF0_SI1459 +FMAF0_SI2089 +FMAF0_SI829 +FMAF0_SX109 +FMAF0_SX19 +FMAF0_SX199 +FMAF0_SX289 +FMAF0_SX379 +FMAH0_SA1 +FMAH0_SA2 +FMAH0_SI1289 +FMAH0_SI1919 +FMAH0_SI659 +FMAH0_SX119 +FMAH0_SX209 +FMAH0_SX29 +FMAH0_SX299 +FMAH0_SX389 +FMCM0_SA1 +FMCM0_SA2 +FMCM0_SI1180 +FMCM0_SI1810 +FMCM0_SI550 +FMCM0_SX10 +FMCM0_SX100 +FMCM0_SX190 +FMCM0_SX280 +FMCM0_SX370 +FMGD0_SA1 +FMGD0_SA2 +FMGD0_SI1564 +FMGD0_SI2194 +FMGD0_SI934 +FMGD0_SX124 +FMGD0_SX214 +FMGD0_SX304 +FMGD0_SX34 +FMGD0_SX394 +FMLD0_SA1 +FMLD0_SA2 +FMLD0_SI2185 +FMLD0_SI822 +FMLD0_SI925 +FMLD0_SX115 +FMLD0_SX205 +FMLD0_SX25 +FMLD0_SX295 +FMLD0_SX385 +FMML0_SA1 +FMML0_SA2 +FMML0_SI1040 +FMML0_SI1670 +FMML0_SI2300 +FMML0_SX140 +FMML0_SX230 +FMML0_SX320 +FMML0_SX410 +FMML0_SX50 +FNLP0_SA1 +FNLP0_SA2 +FNLP0_SI1308 +FNLP0_SI1938 +FNLP0_SI678 +FNLP0_SX138 +FNLP0_SX228 +FNLP0_SX318 +FNLP0_SX408 +FNLP0_SX48 +FNMR0_SA1 +FNMR0_SA2 +FNMR0_SI1399 +FNMR0_SI2029 +FNMR0_SI769 +FNMR0_SX139 +FNMR0_SX229 +FNMR0_SX319 +FNMR0_SX409 +FNMR0_SX49 +FPAS0_SA1 +FPAS0_SA2 +FPAS0_SI1272 +FPAS0_SI2204 +FPAS0_SI944 +FPAS0_SX134 +FPAS0_SX224 +FPAS0_SX314 +FPAS0_SX404 +FPAS0_SX44 +FPKT0_SA1 +FPKT0_SA2 +FPKT0_SI1538 +FPKT0_SI2168 +FPKT0_SI908 +FPKT0_SX188 +FPKT0_SX278 +FPKT0_SX368 +FPKT0_SX8 +FPKT0_SX98 +FRAM1_SA1 +FRAM1_SA2 +FRAM1_SI1360 +FRAM1_SI522 +FRAM1_SI730 +FRAM1_SX10 +FRAM1_SX100 +FRAM1_SX190 +FRAM1_SX280 +FRAM1_SX370 +FREW0_SA1 +FREW0_SA2 +FREW0_SI1030 +FREW0_SI1280 +FREW0_SI1910 +FREW0_SX110 +FREW0_SX20 +FREW0_SX200 +FREW0_SX290 +FREW0_SX380 +FRNG0_SA1 +FRNG0_SA2 +FRNG0_SI1355 +FRNG0_SI1985 +FRNG0_SI725 +FRNG0_SX185 +FRNG0_SX275 +FRNG0_SX365 +FRNG0_SX5 +FRNG0_SX95 +FSEM0_SA1 +FSEM0_SA2 +FSEM0_SI1198 +FSEM0_SI1828 +FSEM0_SI568 +FSEM0_SX118 +FSEM0_SX208 +FSEM0_SX28 +FSEM0_SX298 +FSEM0_SX388 +FSLB1_SA1 +FSLB1_SA2 +FSLB1_SI1904 +FSLB1_SI644 +FSLB1_SI891 +FSLB1_SX104 +FSLB1_SX14 +FSLB1_SX194 +FSLB1_SX284 +FSLB1_SX374 +FSXA0_SA1 +FSXA0_SA2 +FSXA0_SI1108 +FSXA0_SI1846 +FSXA0_SI478 +FSXA0_SX118 +FSXA0_SX208 +FSXA0_SX28 +FSXA0_SX298 +FSXA0_SX388 +FTLH0_SA1 +FTLH0_SA2 +FTLH0_SI1009 +FTLH0_SI1390 +FTLH0_SI1639 +FTLH0_SX109 +FTLH0_SX19 +FTLH0_SX199 +FTLH0_SX289 +FTLH0_SX379 +FUTB0_SA1 +FUTB0_SA2 +FUTB0_SI1204 +FUTB0_SI1330 +FUTB0_SI1834 +FUTB0_SX124 +FUTB0_SX214 +FUTB0_SX304 +FUTB0_SX34 +FUTB0_SX394 +MABW0_SA1 +MABW0_SA2 +MABW0_SI1230 +MABW0_SI1664 +MABW0_SI2294 +MABW0_SX134 +MABW0_SX224 +MABW0_SX314 +MABW0_SX404 +MABW0_SX44 +MAHH0_SA1 +MAHH0_SA2 +MAHH0_SI1294 +MAHH0_SI1924 +MAHH0_SI664 +MAHH0_SX124 +MAHH0_SX214 +MAHH0_SX304 +MAHH0_SX34 +MAHH0_SX394 +MAJC0_SA1 +MAJC0_SA2 +MAJC0_SI1946 +MAJC0_SI2095 +MAJC0_SI835 +MAJC0_SX115 +MAJC0_SX205 +MAJC0_SX25 +MAJC0_SX295 +MAJC0_SX385 +MBDG0_SA1 +MBDG0_SA2 +MBDG0_SI1463 +MBDG0_SI2093 +MBDG0_SI833 +MBDG0_SX113 +MBDG0_SX203 +MBDG0_SX23 +MBDG0_SX293 +MBDG0_SX383 +MBJK0_SA1 +MBJK0_SA2 +MBJK0_SI1175 +MBJK0_SI2128 +MBJK0_SI545 +MBJK0_SX185 +MBJK0_SX275 +MBJK0_SX365 +MBJK0_SX5 +MBJK0_SX95 +MBNS0_SA1 +MBNS0_SA2 +MBNS0_SI1220 +MBNS0_SI1850 +MBNS0_SI590 +MBNS0_SX140 +MBNS0_SX230 +MBNS0_SX320 +MBNS0_SX410 +MBNS0_SX50 +MBPM0_SA1 +MBPM0_SA2 +MBPM0_SI1577 +MBPM0_SI1584 +MBPM0_SI947 +MBPM0_SX137 +MBPM0_SX227 +MBPM0_SX317 +MBPM0_SX407 +MBPM0_SX47 +MBWM0_SA1 +MBWM0_SA2 +MBWM0_SI1304 +MBWM0_SI1934 +MBWM0_SI674 +MBWM0_SX134 +MBWM0_SX224 +MBWM0_SX314 +MBWM0_SX404 +MBWM0_SX44 +MCCS0_SA1 +MCCS0_SA2 +MCCS0_SI1469 +MCCS0_SI2099 +MCCS0_SI839 +MCCS0_SX119 +MCCS0_SX209 +MCCS0_SX29 +MCCS0_SX299 +MCCS0_SX389 +MCEM0_SA1 +MCEM0_SA2 +MCEM0_SI1398 +MCEM0_SI2028 +MCEM0_SI768 +MCEM0_SX138 +MCEM0_SX228 +MCEM0_SX318 +MCEM0_SX408 +MCEM0_SX48 +MCHH0_SA1 +MCHH0_SA2 +MCHH0_SI1004 +MCHH0_SI1634 +MCHH0_SI530 +MCHH0_SX104 +MCHH0_SX14 +MCHH0_SX194 +MCHH0_SX284 +MCHH0_SX374 +MCMB0_SA1 +MCMB0_SA2 +MCMB0_SI1268 +MCMB0_SI1898 +MCMB0_SI638 +MCMB0_SX188 +MCMB0_SX278 +MCMB0_SX368 +MCMB0_SX8 +MCMB0_SX98 +MCMJ0_SA1 +MCMJ0_SA2 +MCMJ0_SI1094 +MCMJ0_SI464 +MCMJ0_SI602 +MCMJ0_SX104 +MCMJ0_SX14 +MCMJ0_SX194 +MCMJ0_SX284 +MCMJ0_SX374 +MCRC0_SA1 +MCRC0_SA2 +MCRC0_SI1092 +MCRC0_SI1722 +MCRC0_SI462 +MCRC0_SX102 +MCRC0_SX12 +MCRC0_SX192 +MCRC0_SX282 +MCRC0_SX372 +MCSH0_SA1 +MCSH0_SA2 +MCSH0_SI1549 +MCSH0_SI2179 +MCSH0_SI919 +MCSH0_SX109 +MCSH0_SX19 +MCSH0_SX199 +MCSH0_SX289 +MCSH0_SX379 +MCTT0_SA1 +MCTT0_SA2 +MCTT0_SI1144 +MCTT0_SI2188 +MCTT0_SI928 +MCTT0_SX118 +MCTT0_SX208 +MCTT0_SX28 +MCTT0_SX298 +MCTT0_SX388 +MCTW0_SA1 +MCTW0_SA2 +MCTW0_SI1373 +MCTW0_SI2003 +MCTW0_SI743 +MCTW0_SX113 +MCTW0_SX203 +MCTW0_SX23 +MCTW0_SX293 +MCTW0_SX383 +MDAB0_SA1 +MDAB0_SA2 +MDAB0_SI1039 +MDAB0_SI1669 +MDAB0_SI2299 +MDAB0_SX139 +MDAB0_SX229 +MDAB0_SX319 +MDAB0_SX409 +MDAB0_SX49 +MDAC2_SA1 +MDAC2_SA2 +MDAC2_SI2259 +MDAC2_SI560 +MDAC2_SI999 +MDAC2_SX189 +MDAC2_SX279 +MDAC2_SX369 +MDAC2_SX9 +MDAC2_SX99 +MDAW1_SA1 +MDAW1_SA2 +MDAW1_SI1453 +MDAW1_SI2083 +MDAW1_SI823 +MDAW1_SX103 +MDAW1_SX13 +MDAW1_SX193 +MDAW1_SX283 +MDAW1_SX373 +MDBB0_SA1 +MDBB0_SA2 +MDBB0_SI1195 +MDBB0_SI1825 +MDBB0_SI565 +MDBB0_SX115 +MDBB0_SX205 +MDBB0_SX25 +MDBB0_SX295 +MDBB0_SX385 +MDLD0_SA1 +MDLD0_SA2 +MDLD0_SI1543 +MDLD0_SI2173 +MDLD0_SI913 +MDLD0_SX103 +MDLD0_SX13 +MDLD0_SX193 +MDLD0_SX283 +MDLD0_SX373 +MDLF0_SA1 +MDLF0_SA2 +MDLF0_SI1583 +MDLF0_SI2213 +MDLF0_SI953 +MDLF0_SX143 +MDLF0_SX233 +MDLF0_SX323 +MDLF0_SX413 +MDLF0_SX53 +MDLS0_SA1 +MDLS0_SA2 +MDLS0_SI1628 +MDLS0_SI2258 +MDLS0_SI998 +MDLS0_SX188 +MDLS0_SX278 +MDLS0_SX368 +MDLS0_SX8 +MDLS0_SX98 +MDRB0_SA1 +MDRB0_SA2 +MDRB0_SI1174 +MDRB0_SI2109 +MDRB0_SI544 +MDRB0_SX184 +MDRB0_SX274 +MDRB0_SX364 +MDRB0_SX4 +MDRB0_SX94 +MDRM0_SA1 +MDRM0_SA2 +MDRM0_SI1013 +MDRM0_SI1643 +MDRM0_SI2273 +MDRM0_SX113 +MDRM0_SX203 +MDRM0_SX23 +MDRM0_SX293 +MDRM0_SX383 +MDSC0_SA1 +MDSC0_SA2 +MDSC0_SI1038 +MDSC0_SI2298 +MDSC0_SI967 +MDSC0_SX138 +MDSC0_SX228 +MDSC0_SX318 +MDSC0_SX408 +MDSC0_SX48 +MDVC0_SA1 +MDVC0_SA2 +MDVC0_SI2174 +MDVC0_SI2196 +MDVC0_SI936 +MDVC0_SX126 +MDVC0_SX216 +MDVC0_SX306 +MDVC0_SX36 +MDVC0_SX396 +MDWA0_SA1 +MDWA0_SA2 +MDWA0_SI1146 +MDWA0_SI1445 +MDWA0_SI519 +MDWA0_SX185 +MDWA0_SX275 +MDWA0_SX365 +MDWA0_SX5 +MDWA0_SX95 +MDWK0_SA1 +MDWK0_SA2 +MDWK0_SI1540 +MDWK0_SI2170 +MDWK0_SI910 +MDWK0_SX10 +MDWK0_SX100 +MDWK0_SX190 +MDWK0_SX280 +MDWK0_SX370 +MERS0_SA1 +MERS0_SA2 +MERS0_SI1019 +MERS0_SI1649 +MERS0_SI497 +MERS0_SX119 +MERS0_SX209 +MERS0_SX29 +MERS0_SX299 +MERS0_SX389 +MESD0_SA1 +MESD0_SA2 +MESD0_SI1002 +MESD0_SI1632 +MESD0_SI2262 +MESD0_SX102 +MESD0_SX12 +MESD0_SX192 +MESD0_SX282 +MESD0_SX372 +MFGK0_SA1 +MFGK0_SA2 +MFGK0_SI1451 +MFGK0_SI1744 +MFGK0_SI484 +MFGK0_SX124 +MFGK0_SX214 +MFGK0_SX304 +MFGK0_SX34 +MFGK0_SX394 +MGJF0_SA1 +MGJF0_SA2 +MGJF0_SI1901 +MGJF0_SI641 +MGJF0_SI776 +MGJF0_SX101 +MGJF0_SX11 +MGJF0_SX191 +MGJF0_SX281 +MGJF0_SX371 +MGLB0_SA1 +MGLB0_SA2 +MGLB0_SI1534 +MGLB0_SI2164 +MGLB0_SI904 +MGLB0_SX184 +MGLB0_SX274 +MGLB0_SX364 +MGLB0_SX4 +MGLB0_SX94 +MGMM0_SA1 +MGMM0_SA2 +MGMM0_SI1129 +MGMM0_SI1759 +MGMM0_SI499 +MGMM0_SX139 +MGMM0_SX229 +MGMM0_SX319 +MGMM0_SX409 +MGMM0_SX49 +MGRT0_SA1 +MGRT0_SA2 +MGRT0_SI1450 +MGRT0_SI2080 +MGRT0_SI820 +MGRT0_SX10 +MGRT0_SX100 +MGRT0_SX190 +MGRT0_SX280 +MGRT0_SX370 +MGWT0_SA1 +MGWT0_SA2 +MGWT0_SI1539 +MGWT0_SI2169 +MGWT0_SI909 +MGWT0_SX189 +MGWT0_SX279 +MGWT0_SX369 +MGWT0_SX9 +MGWT0_SX99 +MHPG0_SA1 +MHPG0_SA2 +MHPG0_SI1090 +MHPG0_SI1720 +MHPG0_SI460 +MHPG0_SX10 +MHPG0_SX100 +MHPG0_SX190 +MHPG0_SX280 +MHPG0_SX370 +MJAR0_SA1 +MJAR0_SA2 +MJAR0_SI1988 +MJAR0_SI2247 +MJAR0_SI728 +MJAR0_SX188 +MJAR0_SX278 +MJAR0_SX368 +MJAR0_SX8 +MJAR0_SX98 +MJBR0_SA1 +MJBR0_SA2 +MJBR0_SI1001 +MJBR0_SI1631 +MJBR0_SI2261 +MJBR0_SX101 +MJBR0_SX11 +MJBR0_SX191 +MJBR0_SX281 +MJBR0_SX371 +MJDH0_SA1 +MJDH0_SA2 +MJDH0_SI1354 +MJDH0_SI1984 +MJDH0_SI724 +MJDH0_SX184 +MJDH0_SX274 +MJDH0_SX364 +MJDH0_SX4 +MJDH0_SX94 +MJDM1_SA1 +MJDM1_SA2 +MJDM1_SI1085 +MJDM1_SI1715 +MJDM1_SI455 +MJDM1_SX185 +MJDM1_SX275 +MJDM1_SX365 +MJDM1_SX5 +MJDM1_SX95 +MJES0_SA1 +MJES0_SA2 +MJES0_SI1384 +MJES0_SI2014 +MJES0_SI754 +MJES0_SX124 +MJES0_SX214 +MJES0_SX304 +MJES0_SX34 +MJES0_SX394 +MJFC0_SA1 +MJFC0_SA2 +MJFC0_SI1033 +MJFC0_SI1663 +MJFC0_SI2293 +MJFC0_SX133 +MJFC0_SX223 +MJFC0_SX313 +MJFC0_SX403 +MJFC0_SX43 +MJJG0_SA1 +MJJG0_SA2 +MJJG0_SI1003 +MJJG0_SI1633 +MJJG0_SI2263 +MJJG0_SX103 +MJJG0_SX13 +MJJG0_SX193 +MJJG0_SX283 +MJJG0_SX373 +MJLN0_SA1 +MJLN0_SA2 +MJLN0_SI1449 +MJLN0_SI2079 +MJLN0_SI819 +MJLN0_SX189 +MJLN0_SX279 +MJLN0_SX369 +MJLN0_SX9 +MJLN0_SX99 +MJMP0_SA1 +MJMP0_SA2 +MJMP0_SI1535 +MJMP0_SI1791 +MJMP0_SI905 +MJMP0_SX185 +MJMP0_SX275 +MJMP0_SX365 +MJMP0_SX5 +MJMP0_SX95 +MJRF0_SA1 +MJRF0_SA2 +MJRF0_SI1114 +MJRF0_SI2081 +MJRF0_SI821 +MJRF0_SX101 +MJRF0_SX11 +MJRF0_SX191 +MJRF0_SX281 +MJRF0_SX371 +MJSW0_SA1 +MJSW0_SA2 +MJSW0_SI1010 +MJSW0_SI1640 +MJSW0_SI2270 +MJSW0_SX110 +MJSW0_SX20 +MJSW0_SX200 +MJSW0_SX290 +MJSW0_SX380 +MJTC0_SA1 +MJTC0_SA2 +MJTC0_SI1460 +MJTC0_SI2090 +MJTC0_SI830 +MJTC0_SX110 +MJTC0_SX20 +MJTC0_SX200 +MJTC0_SX290 +MJTC0_SX380 +MJTH0_SA1 +MJTH0_SA2 +MJTH0_SI1296 +MJTH0_SI1926 +MJTH0_SI666 +MJTH0_SX126 +MJTH0_SX216 +MJTH0_SX306 +MJTH0_SX36 +MJTH0_SX396 +MJVW0_SA1 +MJVW0_SA2 +MJVW0_SI1733 +MJVW0_SI1758 +MJVW0_SI473 +MJVW0_SX113 +MJVW0_SX203 +MJVW0_SX23 +MJVW0_SX293 +MJVW0_SX383 +MKCH0_SA1 +MKCH0_SA2 +MKCH0_SI1378 +MKCH0_SI1425 +MKCH0_SI2008 +MKCH0_SX118 +MKCH0_SX208 +MKCH0_SX28 +MKCH0_SX298 +MKCH0_SX388 +MKCL0_SA1 +MKCL0_SA2 +MKCL0_SI1091 +MKCL0_SI1721 +MKCL0_SI461 +MKCL0_SX101 +MKCL0_SX11 +MKCL0_SX191 +MKCL0_SX281 +MKCL0_SX371 +MKDR0_SA1 +MKDR0_SA2 +MKDR0_SI1273 +MKDR0_SI1903 +MKDR0_SI643 +MKDR0_SX103 +MKDR0_SX13 +MKDR0_SX193 +MKDR0_SX283 +MKDR0_SX373 +MKJL0_SA1 +MKJL0_SA2 +MKJL0_SI1100 +MKJL0_SI1730 +MKJL0_SI470 +MKJL0_SX110 +MKJL0_SX20 +MKJL0_SX200 +MKJL0_SX290 +MKJL0_SX380 +MKLT0_SA1 +MKLT0_SA2 +MKLT0_SI1213 +MKLT0_SI1843 +MKLT0_SI583 +MKLT0_SX133 +MKLT0_SX223 +MKLT0_SX313 +MKLT0_SX403 +MKLT0_SX43 +MLIH0_SA1 +MLIH0_SA2 +MLIH0_SI1183 +MLIH0_SI1813 +MLIH0_SI553 +MLIH0_SX103 +MLIH0_SX13 +MLIH0_SX193 +MLIH0_SX283 +MLIH0_SX373 +MLJB0_SA1 +MLJB0_SA2 +MLJB0_SI1310 +MLJB0_SI1940 +MLJB0_SI680 +MLJB0_SX140 +MLJB0_SX230 +MLJB0_SX320 +MLJB0_SX410 +MLJB0_SX50 +MLLL0_SA1 +MLLL0_SA2 +MLLL0_SI1363 +MLLL0_SI1993 +MLLL0_SI733 +MLLL0_SX103 +MLLL0_SX13 +MLLL0_SX193 +MLLL0_SX283 +MLLL0_SX373 +MLNT0_SA1 +MLNT0_SA2 +MLNT0_SI1574 +MLNT0_SI1902 +MLNT0_SI642 +MLNT0_SX102 +MLNT0_SX12 +MLNT0_SX192 +MLNT0_SX282 +MLNT0_SX372 +MMAB0_SA1 +MMAB0_SA2 +MMAB0_SI1362 +MMAB0_SI1992 +MMAB0_SI732 +MMAB0_SX102 +MMAB0_SX12 +MMAB0_SX192 +MMAB0_SX282 +MMAB0_SX372 +MMDB1_SA1 +MMDB1_SA2 +MMDB1_SI1625 +MMDB1_SI2255 +MMDB1_SI995 +MMDB1_SX185 +MMDB1_SX275 +MMDB1_SX365 +MMDB1_SX5 +MMDB1_SX95 +MMDH0_SA1 +MMDH0_SA2 +MMDH0_SI1656 +MMDH0_SI2118 +MMDH0_SI2286 +MMDH0_SX126 +MMDH0_SX216 +MMDH0_SX306 +MMDH0_SX36 +MMDH0_SX396 +MMDM2_SA1 +MMDM2_SA2 +MMDM2_SI1452 +MMDM2_SI1555 +MMDM2_SI2082 +MMDM2_SX102 +MMDM2_SX12 +MMDM2_SX192 +MMDM2_SX282 +MMDM2_SX372 +MMJR0_SA1 +MMJR0_SA2 +MMJR0_SI1648 +MMJR0_SI2166 +MMJR0_SI2278 +MMJR0_SX118 +MMJR0_SX208 +MMJR0_SX28 +MMJR0_SX298 +MMJR0_SX388 +MMWH0_SA1 +MMWH0_SA2 +MMWH0_SI1089 +MMWH0_SI1301 +MMWH0_SI459 +MMWH0_SX189 +MMWH0_SX279 +MMWH0_SX369 +MMWH0_SX9 +MMWH0_SX99 +MNJM0_SA1 +MNJM0_SA2 +MNJM0_SI1580 +MNJM0_SI2210 +MNJM0_SI950 +MNJM0_SX140 +MNJM0_SX230 +MNJM0_SX320 +MNJM0_SX410 +MNJM0_SX50 +MNLS0_SA1 +MNLS0_SA2 +MNLS0_SI1483 +MNLS0_SI1610 +MNLS0_SI853 +MNLS0_SX133 +MNLS0_SX223 +MNLS0_SX313 +MNLS0_SX403 +MNLS0_SX43 +MPAB0_SA1 +MPAB0_SA2 +MPAB0_SI1103 +MPAB0_SI1128 +MPAB0_SI498 +MPAB0_SX138 +MPAB0_SX228 +MPAB0_SX318 +MPAB0_SX408 +MPAB0_SX48 +MPAM0_SA1 +MPAM0_SA2 +MPAM0_SI1189 +MPAM0_SI1819 +MPAM0_SI1961 +MPAM0_SX109 +MPAM0_SX19 +MPAM0_SX199 +MPAM0_SX289 +MPAM0_SX379 +MPAM1_SA1 +MPAM1_SA2 +MPAM1_SI1029 +MPAM1_SI1836 +MPAM1_SI576 +MPAM1_SX126 +MPAM1_SX216 +MPAM1_SX306 +MPAM1_SX36 +MPAM1_SX396 +MPCS0_SA1 +MPCS0_SA2 +MPCS0_SI1359 +MPCS0_SI1989 +MPCS0_SI729 +MPCS0_SX189 +MPCS0_SX279 +MPCS0_SX369 +MPCS0_SX9 +MPCS0_SX99 +MPDF0_SA1 +MPDF0_SA2 +MPDF0_SI1542 +MPDF0_SI2172 +MPDF0_SI912 +MPDF0_SX102 +MPDF0_SX12 +MPDF0_SX192 +MPDF0_SX282 +MPDF0_SX372 +MPGL0_SA1 +MPGL0_SA2 +MPGL0_SI1099 +MPGL0_SI1729 +MPGL0_SI469 +MPGL0_SX109 +MPGL0_SX19 +MPGL0_SX199 +MPGL0_SX289 +MPGL0_SX379 +MPLB0_SA1 +MPLB0_SA2 +MPLB0_SI1394 +MPLB0_SI2024 +MPLB0_SI764 +MPLB0_SX134 +MPLB0_SX224 +MPLB0_SX314 +MPLB0_SX404 +MPLB0_SX44 +MPWM0_SA1 +MPWM0_SA2 +MPWM0_SI1127 +MPWM0_SI1757 +MPWM0_SI2279 +MPWM0_SX137 +MPWM0_SX227 +MPWM0_SX317 +MPWM0_SX407 +MPWM0_SX47 +MRCS0_SA1 +MRCS0_SA2 +MRCS0_SI1223 +MRCS0_SI1853 +MRCS0_SI593 +MRCS0_SX143 +MRCS0_SX233 +MRCS0_SX323 +MRCS0_SX413 +MRCS0_SX53 +MRCZ0_SA1 +MRCZ0_SA2 +MRCZ0_SI1541 +MRCZ0_SI2171 +MRCZ0_SI911 +MRCZ0_SX101 +MRCZ0_SX11 +MRCZ0_SX191 +MRCZ0_SX281 +MRCZ0_SX371 +MREB0_SA1 +MREB0_SA2 +MREB0_SI1375 +MREB0_SI2005 +MREB0_SI745 +MREB0_SX115 +MREB0_SX205 +MREB0_SX25 +MREB0_SX295 +MREB0_SX385 +MRES0_SA1 +MRES0_SA2 +MRES0_SI1217 +MRES0_SI1847 +MRES0_SI587 +MRES0_SX137 +MRES0_SX227 +MRES0_SX317 +MRES0_SX407 +MRES0_SX47 +MRGG0_SA1 +MRGG0_SA2 +MRGG0_SI1199 +MRGG0_SI1829 +MRGG0_SI569 +MRGG0_SX119 +MRGG0_SX209 +MRGG0_SX29 +MRGG0_SX299 +MRGG0_SX389 +MRJM3_SA1 +MRJM3_SA2 +MRJM3_SI1448 +MRJM3_SI1809 +MRJM3_SI2078 +MRJM3_SX188 +MRJM3_SX278 +MRJM3_SX368 +MRJM3_SX8 +MRJM3_SX98 +MRJM4_SA1 +MRJM4_SA2 +MRJM4_SI1489 +MRJM4_SI2119 +MRJM4_SI859 +MRJM4_SX139 +MRJM4_SX229 +MRJM4_SX319 +MRJM4_SX409 +MRJM4_SX49 +MRJO0_SA1 +MRJO0_SA2 +MRJO0_SI1364 +MRJO0_SI1624 +MRJO0_SI734 +MRJO0_SX104 +MRJO0_SX14 +MRJO0_SX194 +MRJO0_SX284 +MRJO0_SX374 +MRJR0_SA1 +MRJR0_SA2 +MRJR0_SI1182 +MRJR0_SI1812 +MRJR0_SI2313 +MRJR0_SX102 +MRJR0_SX12 +MRJR0_SX192 +MRJR0_SX282 +MRJR0_SX372 +MRJS0_SA1 +MRJS0_SA2 +MRJS0_SI1444 +MRJS0_SI1523 +MRJS0_SI2074 +MRJS0_SX184 +MRJS0_SX274 +MRJS0_SX364 +MRJS0_SX4 +MRJS0_SX94 +MRKO0_SA1 +MRKO0_SA2 +MRKO0_SI1397 +MRKO0_SI2027 +MRKO0_SI767 +MRKO0_SX137 +MRKO0_SX227 +MRKO0_SX317 +MRKO0_SX407 +MRKO0_SX47 +MRMS1_SA1 +MRMS1_SA2 +MRMS1_SI1487 +MRMS1_SI2117 +MRMS1_SI857 +MRMS1_SX137 +MRMS1_SX227 +MRMS1_SX317 +MRMS1_SX407 +MRMS1_SX47 +MROA0_SA1 +MROA0_SA2 +MROA0_SI1307 +MROA0_SI1970 +MROA0_SI677 +MROA0_SX137 +MROA0_SX227 +MROA0_SX317 +MROA0_SX407 +MROA0_SX47 +MRPC0_SA1 +MRPC0_SA2 +MRPC0_SI1753 +MRPC0_SI493 +MRPC0_SI933 +MRPC0_SX133 +MRPC0_SX223 +MRPC0_SX313 +MRPC0_SX403 +MRPC0_SX43 +MRPP0_SA1 +MRPP0_SA2 +MRPP0_SI1184 +MRPP0_SI1814 +MRPP0_SI554 +MRPP0_SX104 +MRPP0_SX14 +MRPP0_SX194 +MRPP0_SX284 +MRPP0_SX374 +MRRK0_SA1 +MRRK0_SA2 +MRRK0_SI1288 +MRRK0_SI1716 +MRRK0_SI1918 +MRRK0_SX118 +MRRK0_SX208 +MRRK0_SX28 +MRRK0_SX298 +MRRK0_SX388 +MRTK0_SA1 +MRTK0_SA2 +MRTK0_SI1093 +MRTK0_SI1723 +MRTK0_SI1750 +MRTK0_SX103 +MRTK0_SX13 +MRTK0_SX193 +MRTK0_SX283 +MRTK0_SX373 +MRWS1_SA1 +MRWS1_SA2 +MRWS1_SI1130 +MRWS1_SI1496 +MRWS1_SI500 +MRWS1_SX140 +MRWS1_SX230 +MRWS1_SX320 +MRWS1_SX410 +MRWS1_SX50 +MSFH1_SA1 +MSFH1_SA2 +MSFH1_SI1270 +MSFH1_SI1900 +MSFH1_SI640 +MSFH1_SX10 +MSFH1_SX100 +MSFH1_SX190 +MSFH1_SX280 +MSFH1_SX370 +MSJS1_SA1 +MSJS1_SA2 +MSJS1_SI1899 +MSJS1_SI639 +MSJS1_SI869 +MSJS1_SX189 +MSJS1_SX279 +MSJS1_SX369 +MSJS1_SX9 +MSJS1_SX99 +MSLB0_SA1 +MSLB0_SA2 +MSLB0_SI1193 +MSLB0_SI1823 +MSLB0_SI563 +MSLB0_SX113 +MSLB0_SX203 +MSLB0_SX23 +MSLB0_SX293 +MSLB0_SX383 +MSTK0_SA1 +MSTK0_SA2 +MSTK0_SI1024 +MSTK0_SI2222 +MSTK0_SI2284 +MSTK0_SX124 +MSTK0_SX214 +MSTK0_SX304 +MSTK0_SX34 +MSTK0_SX394 +MTAA0_SA1 +MTAA0_SA2 +MTAA0_SI1285 +MTAA0_SI1915 +MTAA0_SI596 +MTAA0_SX115 +MTAA0_SX205 +MTAA0_SX25 +MTAA0_SX295 +MTAA0_SX385 +MTAS1_SA1 +MTAS1_SA2 +MTAS1_SI1473 +MTAS1_SI2098 +MTAS1_SI838 +MTAS1_SX118 +MTAS1_SX208 +MTAS1_SX28 +MTAS1_SX298 +MTAS1_SX388 +MTDT0_SA1 +MTDT0_SA2 +MTDT0_SI1994 +MTDT0_SI2254 +MTDT0_SI994 +MTDT0_SX184 +MTDT0_SX274 +MTDT0_SX364 +MTDT0_SX4 +MTDT0_SX94 +MTEB0_SA1 +MTEB0_SA2 +MTEB0_SI1133 +MTEB0_SI2064 +MTEB0_SI503 +MTEB0_SX143 +MTEB0_SX233 +MTEB0_SX323 +MTEB0_SX413 +MTEB0_SX53 +MTHC0_SA1 +MTHC0_SA2 +MTHC0_SI1015 +MTHC0_SI1645 +MTHC0_SI2275 +MTHC0_SX115 +MTHC0_SX205 +MTHC0_SX25 +MTHC0_SX295 +MTHC0_SX385 +MTLS0_SA1 +MTLS0_SA2 +MTLS0_SI1370 +MTLS0_SI2000 +MTLS0_SI740 +MTLS0_SX110 +MTLS0_SX20 +MTLS0_SX200 +MTLS0_SX290 +MTLS0_SX380 +MTMR0_SA1 +MTMR0_SA2 +MTMR0_SI1303 +MTMR0_SI1933 +MTMR0_SI673 +MTMR0_SX133 +MTMR0_SX223 +MTMR0_SX313 +MTMR0_SX403 +MTMR0_SX43 +MTWH0_SA1 +MTWH0_SA2 +MTWH0_SI1190 +MTWH0_SI1629 +MTWH0_SI1820 +MTWH0_SX110 +MTWH0_SX20 +MTWH0_SX200 +MTWH0_SX290 +MTWH0_SX380 +MWBT0_SA1 +MWBT0_SA2 +MWBT0_SI1553 +MWBT0_SI2183 +MWBT0_SI923 +MWBT0_SX113 +MWBT0_SX203 +MWBT0_SX23 +MWBT0_SX293 +MWBT0_SX383 +MWEW0_SA1 +MWEW0_SA2 +MWEW0_SI1361 +MWEW0_SI1991 +MWEW0_SI731 +MWEW0_SX101 +MWEW0_SX11 +MWEW0_SX191 +MWEW0_SX281 +MWEW0_SX371 +MWJG0_SA1 +MWJG0_SA2 +MWJG0_SI1124 +MWJG0_SI1754 +MWJG0_SI494 +MWJG0_SX134 +MWJG0_SX224 +MWJG0_SX314 +MWJG0_SX404 +MWJG0_SX44 +MWVW0_SA1 +MWVW0_SA2 +MWVW0_SI1476 +MWVW0_SI2106 +MWVW0_SI846 +MWVW0_SX126 +MWVW0_SX216 +MWVW0_SX306 +MWVW0_SX36 +MWVW0_SX396 diff --git a/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid b/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid new file mode 100644 index 0000000000..35b02e7f82 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid @@ -0,0 +1,3000 @@ +FAEM0_SA1 +FAEM0_SA2 +FAEM0_SI2022 +FAEM0_SX132 +FAEM0_SX222 +FAEM0_SX312 +FAEM0_SX402 +FAJW0_SA2 +FAJW0_SI1893 +FAJW0_SX183 +FAJW0_SX273 +FAJW0_SX363 +FALK0_SA1 +FALK0_SA2 +FALK0_SI1086 +FALK0_SI456 +FALK0_SX276 +FALK0_SX366 +FALK0_SX96 +FALR0_SA1 +FALR0_SA2 +FALR0_SI1955 +FALR0_SI695 +FALR0_SX155 +FALR0_SX245 +FALR0_SX425 +FALR0_SX65 +FAPB0_SA1 +FAPB0_SA2 +FAPB0_SI1693 +FAPB0_SX163 +FAPB0_SX253 +FAPB0_SX343 +FAPB0_SX73 +FBAS0_SA2 +FBAS0_SI1387 +FBAS0_SX127 +FBAS0_SX307 +FBAS0_SX37 +FBAS0_SX397 +FBCG1_SA2 +FBCG1_SI1612 +FBCG1_SI2242 +FBCG1_SI982 +FBCG1_SX262 +FBCG1_SX82 +FBCH0_SA1 +FBCH0_SA2 +FBCH0_SI1586 +FBCH0_SI956 +FBCH0_SX146 +FBCH0_SX326 +FBCH0_SX56 +FBJL0_SA1 +FBJL0_SA2 +FBJL0_SI1552 +FBJL0_SI2182 +FBJL0_SX112 +FBJL0_SX202 +FBJL0_SX22 +FBJL0_SX292 +FBJL0_SX382 +FBLV0_SA2 +FBLV0_SI2318 +FBLV0_SX158 +FBLV0_SX248 +FBLV0_SX428 +FBMH0_SA2 +FBMH0_SI1766 +FBMH0_SX146 +FBMH0_SX236 +FBMH0_SX326 +FBMH0_SX416 +FBMH0_SX56 +FBMJ0_SA2 +FBMJ0_SX156 +FBMJ0_SX246 +FBMJ0_SX426 +FBMJ0_SX66 +FCAG0_SA2 +FCAG0_SI1503 +FCAG0_SI1641 +FCAG0_SI2133 +FCAG0_SX333 +FCAG0_SX423 +FCAG0_SX63 +FCAJ0_SA1 +FCAJ0_SA2 +FCAJ0_SI1804 +FCAJ0_SI849 +FCAJ0_SX129 +FCAJ0_SX219 +FCAJ0_SX39 +FCAJ0_SX399 +FCDR1_SA1 +FCDR1_SA2 +FCDR1_SX16 +FCDR1_SX376 +FCEG0_SA1 +FCEG0_SI1248 +FCEG0_SI1878 +FCEG0_SI618 +FCEG0_SX168 +FCEG0_SX258 +FCEG0_SX348 +FCEG0_SX438 +FCEG0_SX78 +FCJF0_SA2 +FCJF0_SI1027 +FCJF0_SI1657 +FCJF0_SI648 +FCJF0_SX217 +FCJF0_SX307 +FCJF0_SX37 +FCJF0_SX397 +FCJS0_SA1 +FCJS0_SA2 +FCJS0_SI977 +FCJS0_SX167 +FCJS0_SX347 +FCJS0_SX437 +FCJS0_SX77 +FCKE0_SA1 +FCKE0_SI1111 +FCKE0_SX211 +FCKE0_SX301 +FCKE0_SX31 +FCKE0_SX391 +FCLT0_SA1 +FCLT0_SA2 +FCLT0_SI1438 +FCLT0_SX178 +FCLT0_SX268 +FCLT0_SX358 +FCMG0_SA1 +FCMG0_SI1242 +FCMG0_SX162 +FCMG0_SX252 +FCMG0_SX342 +FCMM0_SI1083 +FCMM0_SI453 +FCMM0_SX273 +FCMM0_SX363 +FCMM0_SX93 +FCRZ0_SA1 +FCRZ0_SA2 +FCRZ0_SI1913 +FCRZ0_SI793 +FCRZ0_SX163 +FCRZ0_SX253 +FCRZ0_SX343 +FCRZ0_SX73 +FCYL0_SA2 +FCYL0_SI1297 +FCYL0_SI1927 +FCYL0_SX127 +FCYL0_SX217 +FCYL0_SX397 +FDAS1_SA1 +FDAS1_SA2 +FDAS1_SX111 +FDAS1_SX21 +FDAS1_SX291 +FDAW0_SA1 +FDAW0_SA2 +FDAW0_SX146 +FDAW0_SX236 +FDAW0_SX326 +FDAW0_SX416 +FDAW0_SX56 +FDFB0_SI1318 +FDFB0_SI1948 +FDFB0_SX148 +FDFB0_SX238 +FDFB0_SX328 +FDFB0_SX418 +FDJH0_SA1 +FDJH0_SA2 +FDJH0_SI1565 +FDJH0_SI2195 +FDJH0_SX125 +FDJH0_SX215 +FDJH0_SX35 +FDJH0_SX395 +FDKN0_SA1 +FDKN0_SA2 +FDKN0_SI1081 +FDKN0_SI1711 +FDKN0_SX271 +FDKN0_SX361 +FDKN0_SX91 +FDML0_SA1 +FDML0_SI1149 +FDML0_SI1779 +FDML0_SI2075 +FDML0_SX339 +FDML0_SX69 +FDMY0_SI1197 +FDMY0_SX117 +FDMY0_SX207 +FDMY0_SX297 +FDNC0_SA1 +FDNC0_SA2 +FDNC0_SI2287 +FDNC0_SX108 +FDNC0_SX18 +FDNC0_SX378 +FDTD0_SA2 +FDTD0_SI1561 +FDTD0_SI2191 +FDTD0_SI931 +FDTD0_SX121 +FDTD0_SX301 +FDTD0_SX391 +FDXW0_SA2 +FDXW0_SI1511 +FDXW0_SI2141 +FDXW0_SI881 +FDXW0_SX161 +FDXW0_SX431 +FEAC0_SA1 +FEAC0_SA2 +FEAC0_SI1245 +FEAC0_SI1875 +FEAC0_SX255 +FEAC0_SX345 +FEAC0_SX435 +FEAR0_SA1 +FEAR0_SA2 +FEAR0_SI1252 +FEAR0_SI1882 +FEAR0_SX172 +FEAR0_SX262 +FEAR0_SX442 +FEAR0_SX82 +FECD0_SA2 +FECD0_SI2048 +FECD0_SX158 +FECD0_SX248 +FECD0_SX338 +FECD0_SX428 +FEEH0_SA2 +FEEH0_SI1112 +FEEH0_SX212 +FEEH0_SX302 +FEEH0_SX32 +FEEH0_SX392 +FEME0_SA2 +FEME0_SI1505 +FEME0_SI2135 +FEME0_SX245 +FEME0_SX425 +FETB0_SA2 +FETB0_SI1778 +FETB0_SI518 +FETB0_SX248 +FETB0_SX338 +FETB0_SX428 +FETB0_SX68 +FEXM0_SA2 +FEXM0_SI1731 +FEXM0_SX111 +FEXM0_SX201 +FEXM0_SX291 +FEXM0_SX381 +FGCS0_SA1 +FGCS0_SA2 +FGCS0_SI1486 +FGCS0_SI2116 +FGCS0_SI856 +FGCS0_SX46 +FGDP0_SA2 +FGDP0_SI1618 +FGDP0_SI2248 +FGDP0_SX178 +FGDP0_SX268 +FGDP0_SX358 +FGDP0_SX448 +FGMB0_SA1 +FGMB0_SA2 +FGMB0_SI515 +FGMB0_SX155 +FGMB0_SX425 +FGMB0_SX65 +FGRW0_SA2 +FGRW0_SI1782 +FGRW0_SI1990 +FGRW0_SX252 +FGRW0_SX342 +FGRW0_SX72 +FHLM0_SA1 +FHLM0_SA2 +FHLM0_SI1560 +FHLM0_SI2190 +FHLM0_SI930 +FHLM0_SX210 +FHLM0_SX300 +FHXS0_SI2335 +FHXS0_SX265 +FHXS0_SX355 +FHXS0_SX85 +FJDM2_SI1582 +FJDM2_SI1964 +FJDM2_SI2212 +FJDM2_SX322 +FJDM2_SX412 +FJEN0_SA2 +FJEN0_SI1047 +FJEN0_SI1677 +FJEN0_SI2307 +FJEN0_SX147 +FJEN0_SX237 +FJEN0_SX57 +FJHK0_SA1 +FJHK0_SA2 +FJHK0_SI1022 +FJHK0_SI1652 +FJHK0_SX122 +FJHK0_SX212 +FJHK0_SX32 +FJHK0_SX392 +FJKL0_SA1 +FJKL0_SA2 +FJKL0_SI1562 +FJKL0_SI2192 +FJKL0_SX122 +FJKL0_SX302 +FJKL0_SX32 +FJLG0_SA1 +FJLG0_SA2 +FJLG0_SI1506 +FJLG0_SX179 +FJLG0_SX269 +FJLG0_SX359 +FJLG0_SX449 +FJLG0_SX89 +FJLR0_SA2 +FJLR0_SI1861 +FJLR0_SI601 +FJLR0_SX151 +FJLR0_SX241 +FJLR0_SX331 +FJLR0_SX421 +FJLR0_SX61 +FJRB0_SA1 +FJRB0_SA2 +FJRB0_SI1302 +FJRB0_SI1932 +FJRB0_SI672 +FJRB0_SX132 +FJRB0_SX222 +FJRB0_SX312 +FJRB0_SX42 +FJRP1_SA2 +FJRP1_SI802 +FJRP1_SX172 +FJRP1_SX442 +FJSK0_SA2 +FJSK0_SI1682 +FJSK0_SI2312 +FJSK0_SX152 +FJSK0_SX242 +FJSK0_SX332 +FJSK0_SX422 +FJSK0_SX62 +FJSP0_SA1 +FJSP0_SA2 +FJSP0_SI1763 +FJSP0_SI804 +FJSP0_SX174 +FJSP0_SX84 +FJWB1_SA2 +FJWB1_SI2055 +FJWB1_SI795 +FJWB1_SX165 +FJWB1_SX255 +FJWB1_SX75 +FJXM0_SA2 +FJXM0_SI1211 +FJXM0_SI1971 +FJXM0_SX131 +FJXM0_SX221 +FJXP0_SA2 +FJXP0_SI492 +FJXP0_SX222 +FJXP0_SX312 +FJXP0_SX402 +FJXP0_SX42 +FKAA0_SA2 +FKAA0_SI1208 +FKAA0_SI1838 +FKAA0_SI578 +FKAA0_SX218 +FKAA0_SX308 +FKAA0_SX38 +FKDE0_SA2 +FKDE0_SI2221 +FKDE0_SX331 +FKDW0_SA1 +FKDW0_SA2 +FKDW0_SI577 +FKDW0_SX127 +FKDW0_SX217 +FKDW0_SX307 +FKDW0_SX37 +FKFB0_SA1 +FKFB0_SI2238 +FKFB0_SI978 +FKFB0_SX168 +FKFB0_SX258 +FKKH0_SI660 +FKKH0_SX210 +FKKH0_SX30 +FKKH0_SX300 +FKLC0_SA1 +FKLC0_SA2 +FKLC0_SI1615 +FKLC0_SI2245 +FKLC0_SX265 +FKLC0_SX445 +FKLC0_SX85 +FKLC1_SA1 +FKLC1_SA2 +FKLC1_SI1678 +FKLC1_SX148 +FKLC1_SX58 +FKLH0_SA1 +FKLH0_SI1887 +FKLH0_SI627 +FKLH0_SX267 +FKLH0_SX357 +FKLH0_SX447 +FKLH0_SX87 +FKSR0_SI1117 +FKSR0_SX161 +FKSR0_SX37 +FKSR0_SX397 +FLAC0_SA1 +FLAC0_SA2 +FLAC0_SI2161 +FLAC0_SI901 +FLAC0_SX181 +FLAC0_SX271 +FLAC0_SX361 +FLAC0_SX91 +FLAG0_SA1 +FLAG0_SI2094 +FLAG0_SX294 +FLEH0_SA1 +FLEH0_SA2 +FLEH0_SX151 +FLEH0_SX241 +FLEH0_SX421 +FLEH0_SX61 +FLET0_SA2 +FLET0_SI1137 +FLET0_SI1767 +FLET0_SX147 +FLET0_SX237 +FLET0_SX277 +FLET0_SX417 +FLET0_SX57 +FLHD0_SA1 +FLHD0_SA2 +FLHD0_SI1344 +FLHD0_SI1974 +FLHD0_SX174 +FLHD0_SX264 +FLHD0_SX444 +FLHD0_SX84 +FLJA0_SA2 +FLJA0_SI1708 +FLJA0_SX268 +FLJA0_SX358 +FLJA0_SX448 +FLJA0_SX88 +FLJD0_SA1 +FLJD0_SA2 +FLJD0_SI2146 +FLJD0_SX166 +FLJD0_SX256 +FLJD0_SX346 +FLJD0_SX436 +FLJG0_SA1 +FLJG0_SI1611 +FLJG0_SI2241 +FLJG0_SX261 +FLJG0_SX441 +FLJG0_SX81 +FLKM0_SI1880 +FLKM0_SX116 +FLMA0_SA2 +FLMA0_SI1243 +FLMA0_SI1873 +FLMA0_SX163 +FLMA0_SX253 +FLMA0_SX343 +FLMC0_SA1 +FLMC0_SA2 +FLMC0_SI2002 +FLMC0_SI742 +FLMC0_SX112 +FLMC0_SX292 +FLMC0_SX336 +FLMC0_SX382 +FLMK0_SA2 +FLMK0_SI2295 +FLMK0_SX135 +FLMK0_SX225 +FLMK0_SX45 +FLOD0_SA1 +FLOD0_SA2 +FLOD0_SI1287 +FLOD0_SI657 +FLOD0_SX207 +FLOD0_SX387 +FLTM0_SA2 +FLTM0_SI1700 +FLTM0_SX260 +FLTM0_SX80 +FMAH1_SA1 +FMAH1_SI1509 +FMAH1_SI2139 +FMAH1_SX249 +FMAH1_SX339 +FMAH1_SX429 +FMAH1_SX69 +FMBG0_SA1 +FMBG0_SI1790 +FMBG0_SX260 +FMBG0_SX3 +FMBG0_SX350 +FMBG0_SX440 +FMBG0_SX80 +FMEM0_SA2 +FMEM0_SI1377 +FMEM0_SI2007 +FMEM0_SX117 +FMEM0_SX207 +FMEM0_SX297 +FMJB0_SA1 +FMJB0_SA2 +FMJB0_SI1807 +FMJB0_SX187 +FMJB0_SX277 +FMJB0_SX367 +FMJB0_SX7 +FMJF0_SA1 +FMJF0_SI1254 +FMJF0_SI1884 +FMJF0_SX264 +FMJF0_SX354 +FMJF0_SX444 +FMJU0_SA1 +FMJU0_SA2 +FMJU0_SI2019 +FMJU0_SI759 +FMJU0_SX129 +FMJU0_SX219 +FMJU0_SX39 +FMKC0_SA1 +FMKC0_SA2 +FMKC0_SI1072 +FMKC0_SX172 +FMKC0_SX262 +FMKC0_SX352 +FMKF0_SA1 +FMKF0_SA2 +FMKF0_SI1536 +FMKF0_SI906 +FMKF0_SX276 +FMKF0_SX366 +FMKF0_SX6 +FMKF0_SX96 +FMMH0_SA1 +FMMH0_SA2 +FMMH0_SI1537 +FMMH0_SI2167 +FMMH0_SI907 +FMMH0_SX187 +FMMH0_SX367 +FMMH0_SX420 +FMMH0_SX7 +FMMH0_SX97 +FMPG0_SI1602 +FMPG0_SI2232 +FMPG0_SX252 +FMPG0_SX72 +FNKL0_SA1 +FNKL0_SA2 +FNKL0_SI2152 +FNKL0_SX172 +FNKL0_SX196 +FNKL0_SX262 +FNKL0_SX442 +FNKL0_SX82 +FNTB0_SA1 +FNTB0_SA2 +FNTB0_SX123 +FNTB0_SX213 +FNTB0_SX33 +FNTB0_SX393 +FPAB1_SA2 +FPAB1_SX121 +FPAB1_SX301 +FPAB1_SX31 +FPAB1_SX391 +FPAC0_SA1 +FPAC0_SI2011 +FPAC0_SX121 +FPAC0_SX211 +FPAC0_SX301 +FPAC0_SX31 +FPAC0_SX391 +FPAD0_SA1 +FPAD0_SI1346 +FPAD0_SI1976 +FPAD0_SX266 +FPAD0_SX446 +FPAF0_SI1684 +FPAF0_SI2314 +FPAF0_SX244 +FPAF0_SX334 +FPAF0_SX424 +FPAF0_SX64 +FPAZ0_SI1593 +FPAZ0_SX153 +FPAZ0_SX27 +FPAZ0_SX423 +FPAZ0_SX63 +FPJF0_SA2 +FPJF0_SI1046 +FPJF0_SI1676 +FPJF0_SX236 +FPJF0_SX326 +FPLS0_SA1 +FPLS0_SA2 +FPLS0_SI2220 +FPLS0_SX150 +FPLS0_SX240 +FPLS0_SX3 +FPLS0_SX60 +FPMY0_SA2 +FPMY0_SI1783 +FPMY0_SX163 +FPMY0_SX196 +FPMY0_SX253 +FPMY0_SX73 +FREH0_SI1315 +FREH0_SI685 +FREH0_SX145 +FREH0_SX235 +FREH0_SX325 +FREH0_SX55 +FRJB0_SA1 +FRJB0_SA2 +FRJB0_SI1427 +FRJB0_SI1470 +FRJB0_SI1794 +FRJB0_SX167 +FRJB0_SX257 +FRJB0_SX437 +FRJB0_SX77 +FRLL0_SA1 +FRLL0_SA2 +FRLL0_SI1514 +FRLL0_SI884 +FRLL0_SX164 +FRLL0_SX254 +FRLL0_SX344 +FRLL0_SX74 +FSAG0_SA2 +FSAG0_SI1953 +FSAG0_SI693 +FSAG0_SX63 +FSAH0_SI1244 +FSAH0_SI1874 +FSAH0_SX344 +FSAH0_SX74 +FSAK0_SA1 +FSAK0_SA2 +FSAK0_SI1930 +FSAK0_SI670 +FSAK0_SX130 +FSAK0_SX220 +FSAK0_SX310 +FSAK0_SX40 +FSAK0_SX400 +FSBK0_SA1 +FSBK0_SI1699 +FSBK0_SI2329 +FSBK0_SX259 +FSBK0_SX439 +FSBK0_SX79 +FSCN0_SI1886 +FSCN0_SX356 +FSDC0_SA1 +FSDC0_SI1942 +FSDC0_SI2234 +FSDC0_SX232 +FSDC0_SX412 +FSDJ0_SA1 +FSDJ0_SA2 +FSDJ0_SI1745 +FSDJ0_SX125 +FSDJ0_SX35 +FSGF0_SA1 +FSGF0_SA2 +FSGF0_SI1557 +FSGF0_SX207 +FSGF0_SX27 +FSGF0_SX297 +FSGF0_SX387 +FSJG0_SI1570 +FSJG0_SI2200 +FSJG0_SX310 +FSJK1_SA1 +FSJK1_SI1025 +FSJK1_SI2285 +FSJK1_SI696 +FSJK1_SX215 +FSJK1_SX305 +FSJK1_SX395 +FSJS0_SA2 +FSJS0_SI1171 +FSJS0_SI1801 +FSJS0_SI541 +FSJS0_SX271 +FSJS0_SX361 +FSJS0_SX91 +FSJW0_SA1 +FSJW0_SA2 +FSJW0_SI703 +FSJW0_SX163 +FSJW0_SX253 +FSJW0_SX343 +FSJW0_SX73 +FSKC0_SA1 +FSKC0_SA2 +FSKC0_SI2046 +FSKC0_SX156 +FSKC0_SX336 +FSKC0_SX426 +FSKC0_SX66 +FSKL0_SA1 +FSKL0_SA2 +FSKL0_SI2159 +FSKL0_SI899 +FSKL0_SX179 +FSKL0_SX269 +FSKL0_SX359 +FSKL0_SX89 +FSKP0_SA1 +FSKP0_SI1728 +FSKP0_SI468 +FSKP0_SX108 +FSKP0_SX18 +FSKP0_SX198 +FSKP0_SX288 +FSKP0_SX378 +FSLS0_SA1 +FSLS0_SA2 +FSLS0_SI1056 +FSLS0_SI1686 +FSLS0_SI2316 +FSLS0_SX202 +FSLS0_SX246 +FSLS0_SX66 +FSMA0_SA1 +FSMA0_SI1621 +FSMA0_SI2251 +FSMA0_SX271 +FSMA0_SX361 +FSMA0_SX91 +FSMM0_SA1 +FSMM0_SA2 +FSMM0_SI1314 +FSMM0_SI1944 +FSMM0_SI684 +FSMM0_SX414 +FSMM0_SX54 +FSMS1_SA1 +FSMS1_SA2 +FSMS1_SI1504 +FSMS1_SI2134 +FSMS1_SI874 +FSMS1_SX154 +FSMS1_SX334 +FSMS1_SX64 +FSPM0_SA1 +FSPM0_SI1871 +FSPM0_SI611 +FSPM0_SX341 +FSPM0_SX431 +FSRH0_SA1 +FSRH0_SA2 +FSRH0_SI1719 +FSRH0_SX131 +FSRH0_SX41 +FSSB0_SA1 +FSSB0_SA2 +FSSB0_SI1082 +FSSB0_SI2342 +FSSB0_SX182 +FSSB0_SX272 +FSSB0_SX452 +FSSB0_SX92 +FTAJ0_SA1 +FTAJ0_SA2 +FTAJ0_SI1329 +FTAJ0_SI474 +FTAJ0_SX339 +FTAJ0_SX69 +FTBR0_SA1 +FTBR0_SA2 +FTBR0_SI2181 +FTBR0_SX111 +FTBR0_SX201 +FTBR0_SX291 +FTBR0_SX381 +FTBW0_SA2 +FTBW0_SI1345 +FTBW0_SI1975 +FTBW0_SX265 +FTBW0_SX355 +FTBW0_SX445 +FTBW0_SX85 +FTLG0_SA1 +FTLG0_SA2 +FTLG0_SI840 +FTLG0_SX123 +FTLG0_SX213 +FTLG0_SX303 +FTLG0_SX33 +FTLG0_SX393 +FTMG0_SA1 +FTMG0_SA2 +FTMG0_SX182 +FTMG0_SX272 +FTMG0_SX362 +FTMG0_SX92 +FVFB0_SA1 +FVFB0_SI1032 +FVFB0_SI2292 +FVFB0_SX222 +FVFB0_SX312 +FVFB0_SX402 +FVKB0_SA2 +FVKB0_SI1159 +FVKB0_SI1789 +FVKB0_SI529 +FVKB0_SX169 +FVKB0_SX259 +FVKB0_SX439 +FVKB0_SX79 +FVMH0_SA1 +FVMH0_SI2096 +FVMH0_SX206 +FVMH0_SX296 +FVMH0_SX386 +MABC0_SA1 +MABC0_SA2 +MABC0_SX151 +MABC0_SX241 +MABC0_SX331 +MABC0_SX421 +MABC0_SX61 +MADC0_SA1 +MADC0_SA2 +MADC0_SI1997 +MADC0_SX17 +MADC0_SX197 +MADC0_SX287 +MADD0_SA1 +MADD0_SI1798 +MADD0_SI538 +MADD0_SX358 +MADD0_SX448 +MAEB0_SA1 +MAEB0_SA2 +MAEB0_SI2250 +MAEB0_SI990 +MAEB0_SX180 +MAEB0_SX270 +MAEB0_SX360 +MAEB0_SX90 +MAEO0_SA2 +MAEO0_SI1655 +MAEO0_SI1956 +MAEO0_SX156 +MAEO0_SX246 +MAEO0_SX336 +MAEO0_SX426 +MAEO0_SX66 +MAFM0_SA1 +MAFM0_SA2 +MAFM0_SI1569 +MAFM0_SI2199 +MAFM0_SX219 +MAFM0_SX39 +MAFM0_SX399 +MAJP0_SA1 +MAJP0_SI1074 +MAJP0_SI2334 +MAJP0_SX264 +MAJP0_SX354 +MAJP0_SX444 +MAJP0_SX84 +MAKB0_SA1 +MAKB0_SX206 +MAKB0_SX296 +MAKR0_SA1 +MAKR0_SA2 +MAKR0_SI1352 +MAKR0_SI1982 +MAKR0_SI722 +MAKR0_SX182 +MAKR0_SX272 +MAKR0_SX452 +MAPV0_SA1 +MAPV0_SA2 +MAPV0_SI1923 +MAPV0_SX123 +MAPV0_SX303 +MAPV0_SX33 +MAPV0_SX393 +MARC0_SA1 +MARC0_SI1188 +MARC0_SI1818 +MARC0_SI558 +MARC0_SX288 +MARC0_SX378 +MARW0_SA1 +MARW0_SA2 +MARW0_SI1276 +MARW0_SI646 +MARW0_SX106 +MARW0_SX16 +MARW0_SX376 +MBAR0_SA2 +MBAR0_SI1319 +MBAR0_SI1949 +MBAR0_SI689 +MBAR0_SX149 +MBAR0_SX239 +MBAR0_SX329 +MBBR0_SA1 +MBBR0_SA2 +MBBR0_SI1685 +MBBR0_SX155 +MBBR0_SX245 +MBBR0_SX425 +MBCG0_SA2 +MBCG0_SI2217 +MBCG0_SX147 +MBCG0_SX237 +MBCG0_SX417 +MBCG0_SX57 +MBEF0_SA1 +MBEF0_SA2 +MBEF0_SX111 +MBEF0_SX201 +MBEF0_SX291 +MBGT0_SA1 +MBGT0_SI1341 +MBGT0_SI711 +MBGT0_SX81 +MBJV0_SA2 +MBJV0_SI1247 +MBJV0_SI1877 +MBJV0_SX167 +MBJV0_SX257 +MBJV0_SX437 +MBJV0_SX77 +MBMA0_SA1 +MBMA0_SA2 +MBMA0_SI1852 +MBMA0_SX142 +MBMA0_SX322 +MBMA0_SX412 +MBMA1_SA1 +MBMA1_SA2 +MBMA1_SI2207 +MBMA1_SX144 +MBMA1_SX234 +MBMA1_SX414 +MBML0_SA1 +MBML0_SI1799 +MBML0_SI539 +MBML0_SX179 +MBML0_SX269 +MBML0_SX359 +MBML0_SX449 +MBOM0_SA1 +MBOM0_SI1014 +MBOM0_SI1644 +MBOM0_SX114 +MBOM0_SX204 +MBOM0_SX311 +MBOM0_SX384 +MBSB0_SA2 +MBSB0_SI1353 +MBSB0_SI1983 +MBSB0_SI723 +MBSB0_SX183 +MBSB0_SX273 +MBSB0_SX363 +MBSB0_SX93 +MBTH0_SA1 +MBTH0_SI505 +MBTH0_SI757 +MBTH0_SX212 +MBTH0_SX302 +MBTH0_SX392 +MBWP0_SA1 +MBWP0_SA2 +MBWP0_SI1531 +MBWP0_SI1969 +MBWP0_SI709 +MBWP0_SX169 +MBWP0_SX259 +MBWP0_SX439 +MBWP0_SX79 +MCAE0_SA1 +MCAE0_SA2 +MCAE0_SX187 +MCAE0_SX367 +MCAE0_SX7 +MCAE0_SX97 +MCAL0_SA1 +MCAL0_SI508 +MCAL0_SX148 +MCAL0_SX238 +MCAL0_SX328 +MCAL0_SX418 +MCAL0_SX58 +MCDC0_SA2 +MCDC0_SI1292 +MCDC0_SI1922 +MCDC0_SI662 +MCDC0_SX122 +MCDC0_SX302 +MCDC0_SX32 +MCDC0_SX392 +MCDD0_SA1 +MCDD0_SI1513 +MCDD0_SI2143 +MCDD0_SX163 +MCDD0_SX343 +MCDD0_SX73 +MCDR0_SA1 +MCDR0_SA2 +MCDR0_SX164 +MCDR0_SX254 +MCDR0_SX344 +MCDR0_SX434 +MCDR0_SX74 +MCEF0_SA1 +MCEF0_SA2 +MCEF0_SI1135 +MCEF0_SI1765 +MCEF0_SX145 +MCEF0_SX325 +MCEF0_SX55 +MCEW0_SI1442 +MCEW0_SX182 +MCEW0_SX272 +MCEW0_SX92 +MCHL0_SA1 +MCHL0_SA2 +MCHL0_SI1977 +MCHL0_SX177 +MCHL0_SX267 +MCHL0_SX357 +MCHL0_SX447 +MCLK0_SA1 +MCLK0_SA2 +MCLK0_SI1660 +MCLK0_SX130 +MCLK0_SX220 +MCLK0_SX40 +MCLK0_SX400 +MCLM0_SA2 +MCLM0_SI1456 +MCLM0_SX106 +MCLM0_SX16 +MCLM0_SX196 +MCLM0_SX286 +MCLM0_SX376 +MCPM0_SA2 +MCPM0_SI1194 +MCPM0_SI564 +MCPM0_SX204 +MCPM0_SX24 +MCRE0_SA1 +MCRE0_SA2 +MCRE0_SI1121 +MCRE0_SI1725 +MCRE0_SI1751 +MCRE0_SX131 +MCRE0_SX221 +MCRE0_SX24 +MCRE0_SX401 +MCRE0_SX41 +MCSS0_SA1 +MCSS0_SA2 +MCSS0_SX120 +MCSS0_SX210 +MCSS0_SX30 +MCSS0_SX300 +MCSS0_SX390 +MCTH0_SA2 +MCTH0_SI1209 +MCTH0_SI1839 +MCTH0_SI579 +MCTH0_SX129 +MCTH0_SX219 +MCTH0_SX309 +MCTH0_SX399 +MCTM0_SA1 +MCTM0_SA2 +MCTM0_SI720 +MCTM0_SX180 +MCTM0_SX270 +MCTM0_SX360 +MCTM0_SX450 +MCTM0_SX90 +MCXM0_SA1 +MCXM0_SA2 +MCXM0_SI1351 +MCXM0_SI1981 +MCXM0_SI721 +MCXM0_SX181 +MCXM0_SX271 +MCXM0_SX361 +MCXM0_SX451 +MDAC0_SA2 +MDAC0_SI1261 +MDAC0_SI1837 +MDAC0_SX271 +MDAC0_SX451 +MDAC0_SX91 +MDAS0_SA1 +MDAS0_SA2 +MDAS0_SI1266 +MDAS0_SX186 +MDAS0_SX21 +MDAS0_SX276 +MDAS0_SX96 +MDBB1_SA1 +MDBB1_SA2 +MDBB1_SI1006 +MDBB1_SI1636 +MDBB1_SI2056 +MDBB1_SX196 +MDBB1_SX286 +MDBP0_SA1 +MDBP0_SA2 +MDBP0_SI1158 +MDBP0_SI1788 +MDBP0_SX258 +MDBP0_SX348 +MDBP0_SX78 +MDCD0_SA1 +MDCD0_SA2 +MDCD0_SI2045 +MDCD0_SX155 +MDCD0_SX65 +MDCM0_SA1 +MDCM0_SA2 +MDCM0_SI2110 +MDCM0_SI850 +MDCM0_SX130 +MDCM0_SX220 +MDCM0_SX310 +MDDC0_SA1 +MDDC0_SA2 +MDDC0_SX249 +MDDC0_SX339 +MDDC0_SX429 +MDED0_SI1170 +MDED0_SI1800 +MDED0_SX180 +MDED0_SX270 +MDED0_SX360 +MDED0_SX450 +MDED0_SX90 +MDEF0_SA1 +MDEF0_SA2 +MDEF0_SI1563 +MDEF0_SI2193 +MDEF0_SX213 +MDEF0_SX33 +MDEF0_SX393 +MDEM0_SA2 +MDEM0_SI1868 +MDEM0_SX158 +MDEM0_SX248 +MDEM0_SX338 +MDEM0_SX68 +MDHL0_SA1 +MDHL0_SA2 +MDHL0_SI2069 +MDHL0_SI809 +MDHL0_SX179 +MDHL0_SX359 +MDHL0_SX89 +MDHS0_SX180 +MDHS0_SX270 +MDHS0_SX360 +MDHS0_SX450 +MDHS0_SX90 +MDJM0_SA1 +MDJM0_SA2 +MDJM0_SI2085 +MDJM0_SI825 +MDJM0_SX195 +MDJM0_SX285 +MDJM0_SX375 +MDKS0_SA1 +MDKS0_SA2 +MDKS0_SI1066 +MDKS0_SI1696 +MDKS0_SI2326 +MDKS0_SX256 +MDKS0_SX76 +MDLB0_SA1 +MDLB0_SI1936 +MDLB0_SI676 +MDLB0_SX226 +MDLB0_SX316 +MDLB0_SX46 +MDLC0_SA1 +MDLC0_SA2 +MDLC0_SI765 +MDLC0_SX135 +MDLC0_SX225 +MDLC0_SX315 +MDLC0_SX45 +MDLC1_SA1 +MDLC1_SX175 +MDLC1_SX265 +MDLC1_SX355 +MDLC1_SX85 +MDLC2_SA1 +MDLC2_SA2 +MDLC2_SI1614 +MDLC2_SI984 +MDLC2_SX174 +MDLC2_SX264 +MDLC2_SX444 +MDLC2_SX84 +MDLH0_SA1 +MDLH0_SI1960 +MDLH0_SI574 +MDLH0_SI700 +MDLH0_SX250 +MDLH0_SX340 +MDLH0_SX70 +MDLM0_SA1 +MDLM0_SA2 +MDLM0_SX244 +MDLM0_SX334 +MDLM0_SX64 +MDLR0_SI1233 +MDLR0_SX243 +MDLR0_SX423 +MDLR0_SX63 +MDLR1_SI1299 +MDLR1_SI1929 +MDLR1_SX129 +MDLR1_SX219 +MDLR1_SX309 +MDLR1_SX39 +MDLR1_SX399 +MDMA0_SA1 +MDMA0_SA2 +MDMA0_SI1238 +MDMA0_SI2060 +MDMT0_SI2341 +MDMT0_SI572 +MDMT0_SX212 +MDMT0_SX302 +MDMT0_SX392 +MDNS0_SA1 +MDNS0_SX111 +MDNS0_SX291 +MDNS0_SX381 +MDPB0_SA1 +MDPB0_SA2 +MDPB0_SI2126 +MDPB0_SX146 +MDPB0_SX236 +MDPB0_SX326 +MDPB0_SX56 +MDPK0_SA1 +MDPK0_SA2 +MDPK0_SI1683 +MDPK0_SI552 +MDPK0_SX153 +MDPK0_SX243 +MDPK0_SX63 +MDPS0_SA1 +MDPS0_SA2 +MDPS0_SI1651 +MDPS0_SI1979 +MDPS0_SX179 +MDPS0_SX269 +MDPS0_SX449 +MDPS0_SX89 +MDRD0_SA2 +MDRD0_SI1382 +MDRD0_SI2012 +MDRD0_SX122 +MDRD0_SX212 +MDRD0_SX302 +MDRD0_SX392 +MDSJ0_SA1 +MDSJ0_SA2 +MDSJ0_SI832 +MDSJ0_SX112 +MDSJ0_SX22 +MDSJ0_SX292 +MDSJ0_SX382 +MDSS0_SA1 +MDSS0_SI1881 +MDSS0_SI2087 +MDSS0_SI621 +MDSS0_SX171 +MDSS0_SX261 +MDSS0_SX351 +MDSS0_SX81 +MDSS1_SA2 +MDSS1_SI1713 +MDSS1_SX247 +MDSS1_SX337 +MDSS1_SX427 +MDTB0_SA1 +MDTB0_SA2 +MDTB0_SI570 +MDTB0_SX210 +MDTB0_SX300 +MDTB0_SX321 +MDTB0_SX390 +MDWD0_SA1 +MDWD0_SI1890 +MDWD0_SI557 +MDWD0_SX180 +MDWD0_SX360 +MDWD0_SX450 +MDWH0_SA2 +MDWH0_SI1925 +MDWH0_SX125 +MDWH0_SX35 +MDWH0_SX395 +MDWM0_SI1546 +MDWM0_SI2176 +MDWM0_SX106 +MDWM0_SX376 +MDWM0_SX433 +MEAL0_SA1 +MEAL0_SI1547 +MEAL0_SI917 +MEAL0_SX197 +MEAL0_SX287 +MEAL0_SX377 +MEDR0_SI744 +MEDR0_SX114 +MEDR0_SX204 +MEDR0_SX24 +MEDR0_SX294 +MEDR0_SX384 +MEFG0_SA2 +MEFG0_SI465 +MEFG0_SX105 +MEFG0_SX15 +MEFG0_SX195 +MEFG0_SX285 +MEFG0_SX375 +MEGJ0_SI1967 +MEGJ0_SX437 +MEGJ0_SX77 +MEJL0_SA2 +MEJL0_SI1592 +MEJL0_SI1654 +MEJL0_SI962 +MEJL0_SX332 +MEJL0_SX422 +MEJL0_SX62 +MEJS0_SA1 +MEJS0_SA2 +MEJS0_SI1870 +MEJS0_SX250 +MEJS0_SX430 +MEJS0_SX70 +MESG0_SA1 +MESG0_SA2 +MESG0_SI1332 +MESG0_SI1962 +MESG0_SX162 +MESG0_SX252 +MESG0_SX342 +MESG0_SX72 +MESJ0_SA1 +MESJ0_SA2 +MESJ0_SI2257 +MESJ0_SI997 +MESJ0_SX277 +MESJ0_SX367 +MESJ0_SX7 +MEWM0_SA1 +MEWM0_SA2 +MEWM0_SI1348 +MEWM0_SI1978 +MEWM0_SX268 +MEWM0_SX358 +MEWM0_SX448 +MFER0_SA1 +MFER0_SA2 +MFER0_SI1492 +MFER0_SI2122 +MFER0_SX232 +MFER0_SX322 +MFER0_SX412 +MFER0_SX52 +MFMC0_SA1 +MFMC0_SA2 +MFMC0_SI1132 +MFMC0_SI1762 +MFMC0_SI502 +MFMC0_SX142 +MFMC0_SX232 +MFMC0_SX322 +MFMC0_SX412 +MFMC0_SX52 +MFRM0_SA1 +MFRM0_SA2 +MFRM0_SI1155 +MFRM0_SI1717 +MFRM0_SI1785 +MFRM0_SX165 +MFRM0_SX255 +MFRM0_SX75 +MFWK0_SA1 +MFWK0_SA2 +MFWK0_SI1249 +MFWK0_SI619 +MFWK0_SX259 +MFWK0_SX439 +MFWK0_SX79 +MFXS0_SA1 +MFXS0_SA2 +MFXS0_SI1674 +MFXS0_SI2225 +MFXS0_SI2304 +MFXS0_SX144 +MFXS0_SX234 +MFXS0_SX414 +MFXV0_SA1 +MFXV0_SI1635 +MFXV0_SX15 +MFXV0_SX195 +MFXV0_SX285 +MFXV0_SX375 +MGAF0_SA2 +MGAF0_SI1912 +MGAF0_SI652 +MGAF0_SX112 +MGAF0_SX202 +MGAF0_SX292 +MGAG0_SA1 +MGAG0_SI1321 +MGAG0_SI645 +MGAG0_SX151 +MGAG0_SX241 +MGAG0_SX331 +MGAG0_SX421 +MGAG0_SX61 +MGAK0_SA1 +MGAK0_SA2 +MGAK0_SI1666 +MGAK0_SI2296 +MGAK0_SX316 +MGAK0_SX406 +MGAR0_SA1 +MGAR0_SA2 +MGAR0_SI1212 +MGAR0_SI1694 +MGAR0_SI1842 +MGAR0_SX222 +MGAR0_SX402 +MGAR0_SX42 +MGAW0_SA1 +MGAW0_SA2 +MGAW0_SI1802 +MGAW0_SX265 +MGAW0_SX355 +MGAW0_SX445 +MGAW0_SX85 +MGES0_SA2 +MGES0_SI1481 +MGES0_SX131 +MGES0_SX221 +MGES0_SX401 +MGES0_SX41 +MGJC0_SA1 +MGJC0_SI1256 +MGJC0_SI1335 +MGJC0_SI1965 +MGJC0_SX165 +MGJC0_SX255 +MGJC0_SX345 +MGRL0_SA1 +MGRL0_SA2 +MGRL0_SI1497 +MGRL0_SX237 +MGRL0_SX417 +MGRL0_SX57 +MGRP0_SA1 +MGRP0_SI1947 +MGRP0_SI687 +MGRP0_SX147 +MGRP0_SX237 +MGRP0_SX417 +MGRP0_SX57 +MGSH0_SA1 +MGSH0_SX186 +MGSH0_SX96 +MGSL0_SA2 +MGSL0_SI1164 +MGSL0_SX174 +MGSL0_SX354 +MGSL0_SX444 +MGSL0_SX84 +MGXP0_SA1 +MGXP0_SA2 +MGXP0_SI457 +MGXP0_SX277 +MGXP0_SX367 +MGXP0_SX97 +MHBS0_SA1 +MHBS0_SA2 +MHBS0_SI1575 +MHBS0_SI2205 +MHBS0_SX135 +MHBS0_SX225 +MHBS0_SX405 +MHIT0_SA2 +MHIT0_SI1613 +MHIT0_SI2243 +MHIT0_SX173 +MHIT0_SX263 +MHIT0_SX353 +MHIT0_SX443 +MHIT0_SX83 +MHJB0_SA2 +MHJB0_SI1647 +MHJB0_SI2277 +MHJB0_SX117 +MHJB0_SX207 +MHJB0_SX27 +MHJB0_SX297 +MHJB0_SX387 +MHMG0_SA1 +MHMG0_SA2 +MHMG0_SI1365 +MHMG0_SI1995 +MHMG0_SX105 +MHMG0_SX15 +MHMG0_SX285 +MHMG0_SX375 +MHMR0_SA2 +MHMR0_SI1119 +MHMR0_SX129 +MHMR0_SX219 +MHMR0_SX309 +MHMR0_SX39 +MHMR0_SX399 +MHRM0_SA2 +MHRM0_SI1475 +MHRM0_SI2218 +MHRM0_SX238 +MHRM0_SX328 +MHRM0_SX418 +MHXL0_SA1 +MHXL0_SA2 +MHXL0_SI512 +MHXL0_SI612 +MHXL0_SX152 +MHXL0_SX332 +MHXL0_SX422 +MHXL0_SX62 +MILB0_SA1 +MILB0_SI2163 +MILB0_SI807 +MILB0_SX183 +MILB0_SX273 +MILB0_SX3 +MILB0_SX363 +MILB0_SX93 +MJAC0_SA1 +MJAC0_SA2 +MJAC0_SI1331 +MJAC0_SI2148 +MJAC0_SX341 +MJAC0_SX431 +MJAE0_SA1 +MJAE0_SA2 +MJAE0_SI1524 +MJAE0_SI1999 +MJAE0_SI2154 +MJAE0_SX264 +MJAE0_SX354 +MJAE0_SX444 +MJAI0_SI1604 +MJAI0_SX164 +MJAI0_SX254 +MJAI0_SX344 +MJAI0_SX434 +MJAI0_SX74 +MJBG0_SA1 +MJBG0_SA2 +MJBG0_SI1232 +MJBG0_SI1724 +MJBG0_SI1862 +MJBG0_SX152 +MJBG0_SX242 +MJBG0_SX332 +MJBG0_SX422 +MJDA0_SA1 +MJDA0_SA2 +MJDA0_SI1661 +MJDA0_SI2291 +MJDA0_SX131 +MJDA0_SX221 +MJDA0_SX401 +MJDA0_SX41 +MJDC0_SA1 +MJDC0_SA2 +MJDC0_SI1161 +MJDC0_SI2165 +MJDC0_SX171 +MJDC0_SX261 +MJDC0_SX351 +MJDC0_SX441 +MJDC0_SX81 +MJDE0_SA2 +MJDE0_SX130 +MJDE0_SX310 +MJDE0_SX40 +MJDE0_SX400 +MJDG0_SA1 +MJDG0_SI1672 +MJDG0_SX142 +MJDG0_SX232 +MJDG0_SX322 +MJDG0_SX412 +MJDG0_SX52 +MJDM0_SA2 +MJDM0_SI1937 +MJDM0_SX260 +MJDM0_SX440 +MJDM0_SX80 +MJEB0_SA1 +MJEB0_SA2 +MJEB0_SI1286 +MJEB0_SI1916 +MJEB0_SX206 +MJEB0_SX26 +MJEB0_SX386 +MJEB1_SA1 +MJEB1_SI2097 +MJEB1_SX117 +MJEB1_SX27 +MJEB1_SX297 +MJEE0_SA2 +MJEE0_SI1237 +MJEE0_SI1867 +MJEE0_SI607 +MJEE0_SX157 +MJEE0_SX427 +MJEE0_SX67 +MJFH0_SA1 +MJFH0_SI1737 +MJFH0_SI477 +MJFH0_SX117 +MJFH0_SX207 +MJFH0_SX27 +MJFH0_SX297 +MJFH0_SX387 +MJFR0_SA2 +MJFR0_SI1605 +MJFR0_SI2235 +MJFR0_SI975 +MJFR0_SX165 +MJFR0_SX255 +MJFR0_SX345 +MJHI0_SA2 +MJHI0_SI555 +MJHI0_SI698 +MJHI0_SX248 +MJHI0_SX338 +MJHI0_SX428 +MJHI0_SX68 +MJJB0_SA2 +MJJB0_SI1139 +MJJB0_SI1277 +MJJB0_SI1769 +MJJB0_SX149 +MJJB0_SX329 +MJJB0_SX419 +MJJB0_SX59 +MJJJ0_SA1 +MJJJ0_SA2 +MJJJ0_SI1793 +MJJJ0_SI533 +MJJJ0_SX173 +MJJJ0_SX263 +MJJJ0_SX353 +MJJJ0_SX83 +MJJM0_SA1 +MJJM0_SI1457 +MJJM0_SX17 +MJJM0_SX197 +MJJM0_SX287 +MJJM0_SX377 +MJKR0_SA2 +MJKR0_SI1201 +MJKR0_SI1831 +MJKR0_SX121 +MJKR0_SX211 +MJKR0_SX301 +MJKR0_SX31 +MJKR0_SX391 +MJLB0_SA1 +MJLB0_SA2 +MJLB0_SI2246 +MJLB0_SI986 +MJLB0_SX266 +MJLB0_SX356 +MJLB0_SX446 +MJLB0_SX86 +MJLG1_SA1 +MJLG1_SA2 +MJLG1_SI1012 +MJLG1_SI1642 +MJLG1_SI2272 +MJLG1_SX112 +MJLG1_SX202 +MJLG1_SX22 +MJLG1_SX382 +MJLS0_SA1 +MJLS0_SA2 +MJLS0_SI1096 +MJLS0_SI466 +MJLS0_SX16 +MJLS0_SX196 +MJLS0_SX286 +MJLS0_SX376 +MJMA0_SI1495 +MJMA0_SI865 +MJMA0_SX145 +MJMA0_SX235 +MJMA0_SX325 +MJMA0_SX415 +MJMA0_SX55 +MJMD0_SA1 +MJMD0_SI1028 +MJMD0_SI1658 +MJMD0_SX128 +MJMD0_SX218 +MJMD0_SX398 +MJMM0_SA1 +MJMM0_SA2 +MJMM0_SI1885 +MJMM0_SI625 +MJMM0_SX265 +MJMM0_SX355 +MJMM0_SX445 +MJPG0_SA1 +MJPG0_SA2 +MJPG0_SI561 +MJPG0_SX291 +MJPG0_SX381 +MJPM0_SA1 +MJPM0_SI1998 +MJPM0_SI738 +MJPM0_SX108 +MJPM0_SX18 +MJPM0_SX198 +MJPM0_SX288 +MJPM1_SA1 +MJPM1_SA2 +MJPM1_SI1897 +MJPM1_SI761 +MJPM1_SX131 +MJPM1_SX221 +MJPM1_SX41 +MJRA0_SI606 +MJRA0_SX156 +MJRA0_SX246 +MJRA0_SX66 +MJRG0_SA1 +MJRG0_SA2 +MJRG0_SX106 +MJRG0_SX16 +MJRG0_SX286 +MJRH0_SA1 +MJRH0_SA2 +MJRH0_SI1125 +MJRH0_SI1755 +MJRH0_SX135 +MJRH0_SX315 +MJRH0_SX405 +MJRH0_SX45 +MJRH1_SA2 +MJRH1_SI1774 +MJRH1_SX334 +MJRH1_SX64 +MJRK0_SI2103 +MJRK0_SX340 +MJRK0_SX70 +MJRP0_SI1835 +MJRP0_SI585 +MJRP0_SX135 +MJRP0_SX315 +MJRP0_SX405 +MJRP0_SX45 +MJSR0_SA2 +MJSR0_SX164 +MJSR0_SX254 +MJSR0_SX434 +MJSR0_SX74 +MJWG0_SA2 +MJWG0_SI2155 +MJWG0_SX355 +MJWG0_SX445 +MJWG0_SX85 +MJWS0_SA1 +MJWS0_SA2 +MJWS0_SI1143 +MJWS0_SI1773 +MJWS0_SX243 +MJWS0_SX423 +MJWT0_SA2 +MJWT0_SI751 +MJXA0_SA1 +MJXA0_SA2 +MJXA0_SI1507 +MJXA0_SI2137 +MJXA0_SI877 +MJXA0_SX157 +MJXA0_SX247 +MJXA0_SX337 +MJXA0_SX67 +MJXL0_SA1 +MJXL0_SA2 +MJXL0_SI1795 +MJXL0_SX182 +MJXL0_SX272 +MJXL0_SX362 +MJXL0_SX452 +MJXL0_SX92 +MKAG0_SA2 +MKAG0_SI1609 +MKAG0_SI2239 +MKAG0_SX169 +MKAG0_SX30 +MKAG0_SX439 +MKAG0_SX79 +MKAH0_SA1 +MKAH0_SA2 +MKAH0_SI1528 +MKAH0_SI2158 +MKAH0_SI898 +MKAH0_SX268 +MKAH0_SX358 +MKAH0_SX448 +MKAH0_SX88 +MKAJ0_SA1 +MKAJ0_SI1414 +MKAJ0_SI2044 +MKAJ0_SI784 +MKAJ0_SX244 +MKAJ0_SX334 +MKAJ0_SX424 +MKAJ0_SX64 +MKAM0_SA2 +MKAM0_SI1316 +MKAM0_SX236 +MKAM0_SX416 +MKDB0_SI2132 +MKDB0_SI588 +MKDB0_SI872 +MKDB0_SX242 +MKDB0_SX332 +MKDB0_SX422 +MKDB0_SX62 +MKDD0_SA1 +MKDD0_SX127 +MKDD0_SX217 +MKDD0_SX307 +MKDD0_SX37 +MKDD0_SX397 +MKDT0_SA1 +MKDT0_SA2 +MKDT0_SI2153 +MKDT0_SI893 +MKDT0_SX173 +MKDT0_SX263 +MKDT0_SX353 +MKDT0_SX443 +MKDT0_SX83 +MKES0_SA2 +MKES0_SX263 +MKES0_SX353 +MKES0_SX443 +MKES0_SX83 +MKJO0_SA1 +MKJO0_SA2 +MKJO0_SI2147 +MKJO0_SX167 +MKJO0_SX257 +MKJO0_SX424 +MKJO0_SX77 +MKLN0_SA1 +MKLN0_SA2 +MKLN0_SI1598 +MKLN0_SI2228 +MKLN0_SX158 +MKLN0_SX338 +MKLN0_SX428 +MKLN0_SX68 +MKLR0_SA1 +MKLR0_SI1059 +MKLR0_SI2319 +MKLR0_SX159 +MKLR0_SX249 +MKLR0_SX339 +MKLR0_SX429 +MKLR0_SX69 +MKLS0_SA2 +MKLS0_SI1533 +MKLS0_SX177 +MKLS0_SX267 +MKLS0_SX447 +MKLS1_SI1545 +MKLS1_SI2175 +MKLS1_SX105 +MKLS1_SX15 +MKLS1_SX195 +MKLS1_SX285 +MKLW0_SA2 +MKLW0_SI1844 +MKLW0_SI2201 +MKLW0_SX131 +MKLW0_SX221 +MKLW0_SX401 +MKLW0_SX41 +MKRG0_SA1 +MKRG0_SA2 +MKRG0_SI1491 +MKRG0_SI2121 +MKRG0_SX141 +MKRG0_SX231 +MKRG0_SX31 +MKRG0_SX51 +MKXL0_SA1 +MKXL0_SI1185 +MKXL0_SX105 +MKXL0_SX195 +MKXL0_SX285 +MLBC0_SA2 +MLBC0_SI609 +MLBC0_SX159 +MLBC0_SX339 +MLBC0_SX429 +MLBC0_SX69 +MLEL0_SI1876 +MLEL0_SX346 +MLEL0_SX76 +MLJC0_SA1 +MLJC0_SA2 +MLJC0_SI1855 +MLJC0_SI595 +MLJC0_SX235 +MLJC0_SX325 +MLJC0_SX55 +MLJH0_SI1324 +MLJH0_SX154 +MLJH0_SX334 +MLJH0_SX424 +MLNS0_SA1 +MLNS0_SA2 +MLNS0_SI1407 +MLNS0_SI777 +MLNS0_SX147 +MLNS0_SX237 +MLNS0_SX327 +MLNS0_SX417 +MLNS0_SX57 +MLSH0_SA1 +MLSH0_SA2 +MLSH0_SI2047 +MLSH0_SI787 +MLSH0_SX157 +MLSH0_SX337 +MLSH0_SX427 +MLSH0_SX67 +MMAA0_SI2105 +MMAA0_SX125 +MMAA0_SX215 +MMAA0_SX305 +MMAA0_SX395 +MMAB1_SA1 +MMAB1_SA2 +MMAB1_SI2124 +MMAB1_SX144 +MMAB1_SX414 +MMAB1_SX54 +MMAG0_SI496 +MMAG0_SX226 +MMAG0_SX406 +MMAG0_SX46 +MMAM0_SA1 +MMAM0_SA2 +MMAM0_SI1597 +MMAM0_SI1668 +MMAM0_SX247 +MMAM0_SX337 +MMAM0_SX67 +MMAR0_SA1 +MMAR0_SA2 +MMAR0_SI1336 +MMAR0_SI706 +MMAR0_SX436 +MMAR0_SX76 +MMBS0_SA1 +MMBS0_SA2 +MMBS0_SI1151 +MMBS0_SX251 +MMBS0_SX341 +MMBS0_SX431 +MMBS0_SX71 +MMCC0_SA1 +MMCC0_SI1968 +MMCC0_SI708 +MMCC0_SX168 +MMCC0_SX258 +MMCC0_SX348 +MMCC0_SX438 +MMCC0_SX78 +MMDB0_SA1 +MMDB0_SA2 +MMDB0_SI1358 +MMDB0_SI1617 +MMDB0_SX267 +MMDB0_SX357 +MMDB0_SX447 +MMDB0_SX87 +MMDG0_SI2035 +MMDG0_SX340 +MMDG0_SX430 +MMDG0_SX70 +MMDM0_SA1 +MMDM0_SA2 +MMDM0_SX231 +MMDM0_SX321 +MMDM0_SX411 +MMDM0_SX51 +MMDM1_SA1 +MMDM1_SI1650 +MMDM1_SI783 +MMDM1_SX243 +MMDS0_SA2 +MMDS0_SI1343 +MMDS0_SI1973 +MMDS0_SI713 +MMDS0_SX173 +MMDS0_SX263 +MMDS0_SX353 +MMDS0_SX443 +MMDS0_SX83 +MMEA0_SA2 +MMEA0_SI1388 +MMEA0_SI2018 +MMEA0_SI758 +MMEA0_SX218 +MMEA0_SX308 +MMEA0_SX38 +MMEB0_SA1 +MMEB0_SI1357 +MMEB0_SI1987 +MMEB0_SI727 +MMEB0_SX7 +MMEB0_SX97 +MMGC0_SA1 +MMGC0_SI1935 +MMGC0_SI2184 +MMGC0_SX315 +MMGC0_SX405 +MMGC0_SX45 +MMGG0_SA1 +MMGG0_SA2 +MMGG0_SI1709 +MMGG0_SI2339 +MMGG0_SX179 +MMGG0_SX359 +MMGG0_SX89 +MMGK0_SA1 +MMGK0_SA2 +MMGK0_SI1322 +MMGK0_SI1952 +MMGK0_SI692 +MMGK0_SX152 +MMGK0_SX242 +MMGK0_SX422 +MMJB1_SA1 +MMJB1_SI1408 +MMJB1_SI2038 +MMJB1_SI778 +MMJB1_SX148 +MMJB1_SX238 +MMJB1_SX328 +MMJB1_SX418 +MMJB1_SX58 +MMLM0_SA1 +MMLM0_SA2 +MMLM0_SI1527 +MMLM0_SI897 +MMLM0_SX177 +MMLM0_SX267 +MMLM0_SX357 +MMLM0_SX447 +MMLM0_SX87 +MMPM0_SA1 +MMPM0_SA2 +MMPM0_SI1061 +MMPM0_SI1691 +MMPM0_SI2321 +MMPM0_SX251 +MMPM0_SX341 +MMPM0_SX431 +MMPM0_SX71 +MMRP0_SA1 +MMRP0_SI2034 +MMRP0_SI717 +MMRP0_SI774 +MMRP0_SX234 +MMRP0_SX414 +MMRP0_SX54 +MMSM0_SA1 +MMSM0_SA2 +MMSM0_SI1736 +MMSM0_SX26 +MMSM0_SX296 +MMSM0_SX386 +MMVP0_SI1284 +MMVP0_SI1914 +MMVP0_SX114 +MMVP0_SX204 +MMVP0_SX294 +MMVP0_SX384 +MMWB0_SA2 +MMWB0_SI1619 +MMWB0_SX179 +MMWB0_SX269 +MMWS0_SA1 +MMWS0_SI1518 +MMWS0_SI559 +MMWS0_SI888 +MMWS0_SX258 +MMWS0_SX78 +MMWS1_SA1 +MMWS1_SA2 +MMWS1_SI1071 +MMWS1_SI2331 +MMWS1_SX261 +MMWS1_SX27 +MMWS1_SX351 +MMWS1_SX441 +MMWS1_SX81 +MMXS0_SA1 +MMXS0_SA2 +MMXS0_SI629 +MMXS0_SI876 +MMXS0_SX156 +MMXS0_SX336 +MMXS0_SX66 +MNET0_SA1 +MNET0_SA2 +MNET0_SI1446 +MNET0_SI2076 +MNET0_SX186 +MNET0_SX276 +MNET0_SX366 +MNET0_SX96 +MNTW0_SA1 +MNTW0_SI2328 +MNTW0_SX202 +MNTW0_SX258 +MNTW0_SX348 +MPAR0_SA1 +MPAR0_SA2 +MPAR0_SI1576 +MPAR0_SX226 +MPAR0_SX406 +MPAR0_SX46 +MPEB0_SA1 +MPEB0_SA2 +MPEB0_SX150 +MPEB0_SX420 +MPEB0_SX60 +MPFU0_SA1 +MPFU0_SA2 +MPFU0_SI1888 +MPFU0_SX178 +MPFU0_SX268 +MPFU0_SX358 +MPFU0_SX88 +MPGH0_SA1 +MPGH0_SA2 +MPGH0_SI1554 +MPGH0_SI924 +MPGH0_SX204 +MPGH0_SX294 +MPGH0_SX384 +MPGR0_SA1 +MPGR0_SA2 +MPGR0_SI2040 +MPGR0_SI780 +MPGR0_SX150 +MPGR0_SX420 +MPGR0_SX60 +MPGR1_SA1 +MPGR1_SA2 +MPGR1_SI1269 +MPGR1_SI2129 +MPGR1_SX239 +MPGR1_SX329 +MPGR1_SX419 +MPGR1_SX59 +MPMB0_SX241 +MPPC0_SA2 +MPPC0_SI2042 +MPPC0_SI782 +MPPC0_SX152 +MPPC0_SX242 +MPPC0_SX332 +MPPC0_SX422 +MPPC0_SX62 +MPRB0_SA1 +MPRB0_SA2 +MPRB0_SI1205 +MPRB0_SX125 +MPRB0_SX215 +MPRB0_SX305 +MPRB0_SX35 +MPRB0_SX395 +MPRD0_SA2 +MPRD0_SI1431 +MPRD0_SI2061 +MPRK0_SA2 +MPRK0_SX17 +MPRK0_SX197 +MPRT0_SA2 +MPRT0_SI1210 +MPRT0_SI495 +MPRT0_SI580 +MPRT0_SX130 +MPRT0_SX220 +MPRT0_SX40 +MPRT0_SX400 +MPSW0_SA1 +MPSW0_SA2 +MPSW0_SI1697 +MPSW0_SI2327 +MPSW0_SX24 +MPSW0_SX257 +MPSW0_SX77 +MRAB0_SA1 +MRAB0_SA2 +MRAB0_SI1224 +MRAB0_SI594 +MRAB0_SX144 +MRAB0_SX234 +MRAB0_SX324 +MRAB0_SX414 +MRAB0_SX54 +MRAB1_SA1 +MRAB1_SA2 +MRAB1_SI1478 +MRAB1_SI2108 +MRAB1_SX218 +MRAB1_SX38 +MRAB1_SX398 +MRAI0_SI1954 +MRAI0_SX162 +MRAI0_SX252 +MRAI0_SX342 +MRAM0_SI1275 +MRAM0_SI1905 +MRAM0_SX105 +MRAM0_SX195 +MRAM0_SX285 +MRAM0_SX375 +MRAV0_SA1 +MRAV0_SA2 +MRAV0_SI1008 +MRAV0_SI1638 +MRAV0_SI2268 +MRAV0_SX108 +MRAV0_SX18 +MRAV0_SX198 +MRAV0_SX288 +MRAV0_SX378 +MRBC0_SA1 +MRBC0_SA2 +MRBC0_SI1665 +MRBC0_SI599 +MRBC0_SX149 +MRBC0_SX239 +MRBC0_SX59 +MRCG0_SA1 +MRCG0_SI2058 +MRCG0_SX258 +MRCG0_SX78 +MRCW0_SA2 +MRCW0_SI1371 +MRCW0_SI2001 +MRCW0_SX111 +MRCW0_SX201 +MRCW0_SX21 +MRCW0_SX381 +MRDD0_SA1 +MRDD0_SA2 +MRDD0_SI1050 +MRDD0_SI2310 +MRDD0_SX240 +MRDD0_SX330 +MRDM0_SA1 +MRDM0_SA2 +MRDM0_SI965 +MRDM0_SX155 +MRDM0_SX245 +MRDM0_SX425 +MRDS0_SA2 +MRDS0_SI1167 +MRDS0_SI1797 +MRDS0_SI537 +MRDS0_SX177 +MRDS0_SX267 +MRDS0_SX357 +MRDS0_SX447 +MRDS0_SX87 +MREE0_SA1 +MREE0_SA2 +MREE0_SI1734 +MREE0_SX114 +MREE0_SX204 +MREE0_SX294 +MREE0_SX384 +MREH1_SA2 +MREH1_SI2229 +MREH1_SX159 +MREH1_SX339 +MREH1_SX429 +MREM0_SA1 +MREM0_SI1591 +MREM0_SI961 +MREM0_SX151 +MREM0_SX241 +MREM0_SX331 +MREM0_SX421 +MREM0_SX61 +MREW1_SA1 +MREW1_SA2 +MREW1_SI1500 +MREW1_SI2130 +MREW1_SX150 +MREW1_SX240 +MREW1_SX330 +MREW1_SX420 +MREW1_SX60 +MRFK0_SA1 +MRFK0_SA2 +MRFK0_SI1706 +MRFK0_SI2336 +MRFK0_SX176 +MRFK0_SX266 +MRFK0_SX356 +MRFK0_SX86 +MRFL0_SA2 +MRFL0_SI1786 +MRFL0_SX346 +MRGM0_SA1 +MRGM0_SI1162 +MRGM0_SI1792 +MRGM0_SX416 +MRGM0_SX82 +MRGS0_SA1 +MRGS0_SI1986 +MRGS0_SX276 +MRGS0_SX366 +MRGS0_SX96 +MRHL0_SA1 +MRHL0_SA2 +MRHL0_SI1515 +MRHL0_SI2145 +MRHL0_SX165 +MRHL0_SX255 +MRHL0_SX75 +MRJB1_SI1020 +MRJB1_SX300 +MRJH0_SA1 +MRJH0_SI914 +MRJH0_SX259 +MRJH0_SX439 +MRJM0_SA1 +MRJM0_SA2 +MRJM0_SI1095 +MRJM0_SI1228 +MRJM0_SI1858 +MRJM0_SX238 +MRJM0_SX328 +MRJM0_SX418 +MRJM0_SX58 +MRJM1_SA1 +MRJM1_SI668 +MRJM1_SX218 +MRJM1_SX308 +MRJM1_SX38 +MRJM1_SX398 +MRJT0_SA1 +MRJT0_SI1805 +MRJT0_SX148 +MRJT0_SX238 +MRKM0_SA1 +MRKM0_SX187 +MRKM0_SX277 +MRKM0_SX7 +MRKM0_SX97 +MRLD0_SA1 +MRLD0_SI1594 +MRLD0_SI964 +MRLD0_SX244 +MRLD0_SX334 +MRLD0_SX64 +MRLJ0_SA2 +MRLJ0_SI1420 +MRLJ0_SI2050 +MRLJ0_SX160 +MRLJ0_SX430 +MRLJ0_SX70 +MRLJ1_SI1671 +MRLJ1_SI2332 +MRLJ1_SX141 +MRLJ1_SX231 +MRLJ1_SX411 +MRLJ1_SX51 +MRLK0_SA1 +MRLK0_SA2 +MRLK0_SI2140 +MRLK0_SX303 +MRLK0_SX33 +MRLK0_SX393 +MRLR0_SA1 +MRLR0_SA2 +MRLR0_SI1826 +MRLR0_SI566 +MRLR0_SX116 +MRLR0_SX206 +MRLR0_SX26 +MRLR0_SX296 +MRLR0_SX386 +MRMB0_SA1 +MRMB0_SI2211 +MRMB0_SI951 +MRMB0_SX141 +MRMB0_SX231 +MRMB0_SX321 +MRMB0_SX51 +MRMG0_SA2 +MRMG0_SI1710 +MRMG0_SI2340 +MRMG0_SX180 +MRMG0_SX270 +MRMG0_SX360 +MRMG0_SX90 +MRMH0_SA1 +MRMH0_SA2 +MRMH0_SI1021 +MRMH0_SX211 +MRMH0_SX301 +MRMH0_SX31 +MRMH0_SX391 +MRML0_SI2051 +MRML0_SI791 +MRML0_SX431 +MRML0_SX71 +MRMS0_SA1 +MRMS0_SA2 +MRMS0_SI1113 +MRMS0_SI2100 +MRMS0_SX120 +MRMS0_SX210 +MRMS0_SX30 +MRMS0_SX300 +MRMS0_SX390 +MRPC1_SA1 +MRPC1_SA2 +MRPC1_SI1482 +MRPC1_SI2026 +MRPC1_SX132 +MRPC1_SX222 +MRPC1_SX312 +MRPC1_SX402 +MRPC1_SX42 +MRRE0_SI704 +MRRE0_SX254 +MRRE0_SX434 +MRSO0_SA1 +MRSO0_SA2 +MRSO0_SI1659 +MRSO0_SI2289 +MRSO0_SX219 +MRSO0_SX309 +MRSO0_SX399 +MRSP0_SA1 +MRSP0_SA2 +MRSP0_SI2059 +MRSP0_SI799 +MRSP0_SX169 +MRSP0_SX196 +MRSP0_SX439 +MRSP0_SX79 +MRTC0_SA1 +MRTC0_SA2 +MRTC0_SI2088 +MRTC0_SI828 +MRTC0_SX108 +MRTC0_SX18 +MRTC0_SX198 +MRTC0_SX288 +MRTJ0_SA2 +MRTJ0_SI1551 +MRTJ0_SI2032 +MRTJ0_SX322 +MRTJ0_SX412 +MRVG0_SA1 +MRVG0_SA2 +MRVG0_SI1770 +MRVG0_SI510 +MRVG0_SX150 +MRVG0_SX330 +MRVG0_SX420 +MRVG0_SX60 +MRWA0_SA1 +MRWA0_SA2 +MRWA0_SI1603 +MRWA0_SI2233 +MRWA0_SX253 +MRWA0_SX343 +MRWA0_SX433 +MRWS0_SA1 +MRWS0_SA2 +MRWS0_SX112 +MRWS0_SX202 +MRWS0_SX292 +MRXB0_SA1 +MRXB0_SI1585 +MRXB0_SX145 +MRXB0_SX235 +MRXB0_SX325 +MRXB0_SX55 +MSAH1_SA1 +MSAH1_SA2 +MSAH1_SI1049 +MSAH1_SI2309 +MSAH1_SX149 +MSAH1_SX239 +MSAH1_SX329 +MSAH1_SX419 +MSAH1_SX59 +MSAS0_SA1 +MSAS0_SA2 +MSAS0_SI2006 +MSAS0_SX26 +MSAS0_SX296 +MSAT0_SA2 +MSAT0_SI1526 +MSAT0_SI2156 +MSAT0_SI896 +MSAT0_SX176 +MSAT0_SX266 +MSAT0_SX356 +MSAT0_SX446 +MSAT0_SX86 +MSAT1_SA1 +MSAT1_SA2 +MSAT1_SI1073 +MSAT1_SI1703 +MSAT1_SI2333 +MSAT1_SX173 +MSAT1_SX353 +MSDB0_SA1 +MSDB0_SA2 +MSDB0_SI1007 +MSDB0_SI1637 +MSDB0_SI2267 +MSDB0_SX107 +MSDB0_SX17 +MSDH0_SA1 +MSDH0_SA2 +MSDH0_SI2113 +MSDH0_SX260 +MSDH0_SX350 +MSDS0_SA2 +MSDS0_SI1707 +MSDS0_SI2337 +MSDS0_SX177 +MSDS0_SX447 +MSDS0_SX87 +MSEM1_SA1 +MSEM1_SA2 +MSEM1_SX360 +MSEM1_SX450 +MSEM1_SX90 +MSES0_SA1 +MSES0_SA2 +MSES0_SI2216 +MSES0_SI2219 +MSES0_SX149 +MSES0_SX329 +MSES0_SX59 +MSFH0_SA2 +MSFH0_SI1216 +MSFH0_SI586 +MSFH0_SX226 +MSFH0_SX46 +MSFV0_SA1 +MSFV0_SA2 +MSFV0_SI1262 +MSFV0_SX182 +MSFV0_SX272 +MSFV0_SX452 +MSJK0_SA1 +MSJK0_SA2 +MSJK0_SI2226 +MSJK0_SI966 +MSJK0_SX156 +MSJK0_SX246 +MSJK0_SX426 +MSJK0_SX66 +MSMC0_SA1 +MSMC0_SA2 +MSMC0_SI1907 +MSMC0_SI647 +MSMC0_SX107 +MSMC0_SX17 +MSMC0_SX197 +MSMC0_SX287 +MSMC0_SX377 +MSMR0_SA1 +MSMR0_SA2 +MSMR0_SI1405 +MSMR0_SI775 +MSMR0_SX145 +MSMR0_SX235 +MSMR0_SX325 +MSMR0_SX55 +MSMS0_SA2 +MSMS0_SI2063 +MSMS0_SI803 +MSMS0_SX263 +MSMS0_SX353 +MSMS0_SX443 +MSRG0_SA2 +MSRG0_SI1851 +MSRG0_SI591 +MSRG0_SX141 +MSRG0_SX231 +MSRG0_SX321 +MSRG0_SX411 +MSRG0_SX51 +MSRR0_SA1 +MSRR0_SA2 +MSRR0_SI1131 +MSRR0_SX141 +MSRR0_SX231 +MSRR0_SX30 +MSRR0_SX411 +MSRR0_SX51 +MSTF0_SA1 +MSTF0_SA2 +MSTF0_SI1396 +MSTF0_SX136 +MSTF0_SX226 +MSTF0_SX406 +MSVS0_SA1 +MSVS0_SI1568 +MSVS0_SX128 +MSVS0_SX218 +MSVS0_SX38 +MTAB0_SA1 +MTAB0_SA2 +MTAB0_SI2202 +MTAB0_SI942 +MTAB0_SX132 +MTAB0_SX222 +MTAB0_SX402 +MTAB0_SX42 +MTAS0_SA1 +MTAS0_SA2 +MTAS0_SI1385 +MTAS0_SI2015 +MTAS0_SI755 +MTAS0_SX125 +MTAS0_SX305 +MTAT0_SA2 +MTAT0_SI1740 +MTAT0_SX120 +MTAT0_SX210 +MTAT0_SX30 +MTAT0_SX300 +MTAT1_SA1 +MTAT1_SA2 +MTAT1_SI1409 +MTAT1_SI1627 +MTAT1_SX239 +MTAT1_SX419 +MTBC0_SA1 +MTBC0_SA2 +MTBC0_SI1173 +MTBC0_SX183 +MTBC0_SX273 +MTBC0_SX347 +MTBC0_SX363 +MTBC0_SX93 +MTCS0_SA1 +MTCS0_SI1972 +MTCS0_SX172 +MTCS0_SX262 +MTCS0_SX352 +MTCS0_SX442 +MTDB0_SA1 +MTDB0_SA2 +MTDB0_SI2031 +MTDB0_SX141 +MTDB0_SX231 +MTDB0_SX321 +MTDB0_SX411 +MTDB0_SX51 +MTDP0_SI1274 +MTDP0_SI2151 +MTDP0_SX261 +MTDP0_SX441 +MTDP0_SX81 +MTER0_SI527 +MTER0_SX167 +MTER0_SX17 +MTER0_SX257 +MTER0_SX77 +MTJG0_SA2 +MTJG0_SI1520 +MTJG0_SI890 +MTJG0_SX350 +MTJG0_SX440 +MTJG0_SX80 +MTJM0_SA1 +MTJM0_SA2 +MTJM0_SI1226 +MTJM0_SI655 +MTJM0_SX236 +MTJM0_SX326 +MTJM0_SX416 +MTJM0_SX56 +MTJS0_SA1 +MTJS0_SI1192 +MTJS0_SX112 +MTJS0_SX202 +MTJS0_SX22 +MTJS0_SX292 +MTJU0_SA1 +MTJU0_SA2 +MTJU0_SI2269 +MTJU0_SI760 +MTJU0_SX220 +MTJU0_SX310 +MTJU0_SX40 +MTKD0_SA1 +MTKD0_SA2 +MTKD0_SI1187 +MTKD0_SI1817 +MTKD0_SX17 +MTKD0_SX197 +MTKD0_SX377 +MTKP0_SA1 +MTKP0_SA2 +MTKP0_SX123 +MTKP0_SX213 +MTKP0_SX303 +MTKP0_SX33 +MTKP0_SX393 +MTLB0_SA2 +MTLB0_SI1764 +MTLB0_SI504 +MTLB0_SX144 +MTLB0_SX414 +MTLB0_SX54 +MTLC0_SA2 +MTLC0_SI847 +MTLC0_SX127 +MTLC0_SX217 +MTLC0_SX307 +MTLC0_SX37 +MTLC0_SX397 +MTML0_SA1 +MTML0_SA2 +MTML0_SI1065 +MTML0_SI1695 +MTML0_SX255 +MTML0_SX345 +MTML0_SX75 +MTMN0_SA1 +MTMN0_SX164 +MTMN0_SX254 +MTMN0_SX344 +MTMN0_SX74 +MTMT0_SA1 +MTMT0_SI1118 +MTMT0_SX128 +MTMT0_SX218 +MTMT0_SX308 +MTMT0_SX38 +MTMT0_SX398 +MTPF0_SA1 +MTPF0_SA2 +MTPF0_SI1235 +MTPF0_SI1865 +MTPF0_SI605 +MTPF0_SX155 +MTPF0_SX245 +MTPF0_SX335 +MTPF0_SX425 +MTPG0_SA1 +MTPG0_SA2 +MTPG0_SI2013 +MTPG0_SX123 +MTPG0_SX213 +MTPG0_SX33 +MTPG0_SX393 +MTPP0_SA1 +MTPP0_SA2 +MTPP0_SI2138 +MTPP0_SI878 +MTPP0_SX158 +MTPP0_SX248 +MTPP0_SX428 +MTPP0_SX68 +MTPR0_SA1 +MTPR0_SA2 +MTPR0_SI1600 +MTPR0_SI506 +MTPR0_SX250 +MTPR0_SX70 +MTQC0_SA2 +MTQC0_SI2071 +MTQC0_SX271 +MTQC0_SX361 +MTRC0_SA1 +MTRC0_SA2 +MTRC0_SI1623 +MTRC0_SI993 +MTRC0_SX170 +MTRC0_SX183 +MTRC0_SX273 +MTRC0_SX363 +MTRC0_SX93 +MTRR0_SA1 +MTRR0_SA2 +MTRR0_SI1548 +MTRR0_SI2178 +MTRR0_SX108 +MTRR0_SX18 +MTRR0_SX378 +MTRT0_SA1 +MTRT0_SI1857 +MTRT0_SI597 +MTRT0_SX147 +MTRT0_SX237 +MTRT0_SX417 +MTWH1_SA1 +MTWH1_SA2 +MTWH1_SI1512 +MTWH1_SI2142 +MTWH1_SI882 +MTWH1_SX162 +MTWH1_SX252 +MTWH1_SX342 +MTWH1_SX432 +MTXS0_SI1690 +MTXS0_SX250 +MTXS0_SX340 +MTXS0_SX70 +MVJH0_SA1 +MVJH0_SA2 +MVJH0_SI2186 +MVJH0_SX116 +MVJH0_SX26 +MVJH0_SX386 +MVLO0_SA2 +MVLO0_SI1147 +MVLO0_SI1777 +MVLO0_SX157 +MVLO0_SX247 +MVLO0_SX337 +MVLO0_SX427 +MVLO0_SX67 +MVRW0_SA1 +MVRW0_SI1485 +MVRW0_SI2115 +MVRW0_SI855 +MVRW0_SX315 +MVRW0_SX405 +MVRW0_SX45 +MWAC0_SA1 +MWAC0_SI2231 +MWAC0_SI971 +MWAC0_SX71 +MWAD0_SA1 +MWAD0_SA2 +MWAD0_SI1062 +MWAD0_SI1749 +MWAD0_SI2322 +MWAD0_SX162 +MWAD0_SX252 +MWAD0_SX342 +MWAR0_SA2 +MWAR0_SI2305 +MWAR0_SX145 +MWAR0_SX235 +MWAR0_SX325 +MWAR0_SX415 +MWAR0_SX55 +MWCH0_SA1 +MWCH0_SA2 +MWCH0_SI1622 +MWCH0_SX272 +MWCH0_SX362 +MWCH0_SX92 +MWDK0_SX266 +MWDK0_SX356 +MWDK0_SX446 +MWEM0_SA1 +MWEM0_SI1950 +MWEM0_SX240 +MWEM0_SX330 +MWEM0_SX60 +MWGR0_SA1 +MWGR0_SA2 +MWGR0_SI1606 +MWGR0_SI2236 +MWGR0_SI976 +MWGR0_SX166 +MWGR0_SX256 +MWGR0_SX436 +MWGR0_SX76 +MWRE0_SA1 +MWRE0_SI1687 +MWRE0_SI2317 +MWRE0_SX157 +MWRP0_SA2 +MWRP0_SI1525 +MWRP0_SI2073 +MWRP0_SX183 +MWRP0_SX3 +MWRP0_SX93 +MWSB0_SA1 +MWSB0_SA2 +MWSB0_SI1626 +MWSB0_SI2256 +MWSB0_SX186 +MWSB0_SX366 +MWSB0_SX6 +MWSB0_SX96 +MWSH0_SA1 +MWSH0_SA2 +MWSH0_SI2266 +MWSH0_SX346 +MWSH0_SX436 +MZMB0_SA2 +MZMB0_SI1166 +MZMB0_SI1796 +MZMB0_SI536 +MZMB0_SX176 +MZMB0_SX266 +MZMB0_SX356 +MZMB0_SX446 +MZMB0_SX86 diff --git a/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid b/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid new file mode 100644 index 0000000000..0e0c2517c9 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid @@ -0,0 +1,1000 @@ +FAEM0_SI762 +FAEM0_SX42 +FAJW0_SA1 +FAJW0_SX3 +FAJW0_SX93 +FALK0_SX186 +FALK0_SX6 +FALR0_SI1325 +FBAS0_SA1 +FBAS0_SX217 +FBCG1_SA1 +FBCG1_SX172 +FBCG1_SX442 +FBCH0_SX236 +FBCH0_SX416 +FBLV0_SA1 +FBLV0_SI1058 +FBLV0_SX338 +FBLV0_SX68 +FBMH0_SA1 +FBMJ0_SI815 +FCAG0_SA1 +FCAG0_SX153 +FCAG0_SX243 +FCAJ0_SI1479 +FCAJ0_SX309 +FCDR1_SX106 +FCDR1_SX196 +FCEG0_SA2 +FCJF0_SA1 +FCJF0_SX127 +FCJS0_SI1607 +FCJS0_SI2237 +FCJS0_SX257 +FCKE0_SA2 +FCKE0_SX121 +FCLT0_SI2068 +FCLT0_SX448 +FCLT0_SX88 +FCMG0_SA2 +FCMG0_SI1872 +FCMG0_SX72 +FCMM0_SA1 +FCMM0_SA2 +FCMM0_SX183 +FCRZ0_SI2053 +FCRZ0_SX433 +FCYL0_SA1 +FCYL0_SX37 +FDAS1_SI2091 +FDAS1_SX201 +FDAS1_SX381 +FDAW0_SI1406 +FDFB0_SA1 +FDFB0_SA2 +FDFB0_SI2010 +FDFB0_SX58 +FDJH0_SX305 +FDML0_SA2 +FDML0_SX159 +FDML0_SX249 +FDML0_SX429 +FDMY0_SA2 +FDMY0_SX27 +FDNC0_SX198 +FDNC0_SX288 +FDTD0_SX211 +FDXW0_SA1 +FDXW0_SX251 +FDXW0_SX341 +FDXW0_SX71 +FEAC0_SX165 +FEAC0_SX75 +FEAR0_SI622 +FECD0_SX68 +FEEH0_SA1 +FEEH0_SI1742 +FEEH0_SI471 +FEEH0_SX122 +FEME0_SA1 +FEME0_SX155 +FEME0_SX65 +FETB0_SA1 +FETB0_SI1148 +FETB0_SX158 +FEXM0_SI1101 +FGCS0_SX136 +FGCS0_SX226 +FGCS0_SX316 +FGCS0_SX406 +FGDP0_SA1 +FGMB0_SI1775 +FGMB0_SX245 +FHLM0_SX390 +FHXS0_SA2 +FHXS0_SX445 +FJDM2_SA1 +FJDM2_SX232 +FJDM2_SX52 +FJHK0_SX302 +FJKL0_SX212 +FJKL0_SX392 +FJLG0_SI2306 +FJLR0_SA1 +FJRP1_SI2062 +FJRP1_SX82 +FJSK0_SA1 +FJSP0_SX264 +FJSP0_SX354 +FJSP0_SX444 +FJWB1_SA1 +FJWB1_SX345 +FJWB1_SX435 +FJXM0_SA1 +FJXM0_SI581 +FJXM0_SX401 +FJXP0_SA1 +FJXP0_SI1122 +FJXP0_SX132 +FKAA0_SX128 +FKAA0_SX398 +FKDE0_SA1 +FKDE0_SX151 +FKDE0_SX241 +FKDE0_SX421 +FKDE0_SX61 +FKDW0_SX397 +FKFB0_SA2 +FKFB0_SX348 +FKFB0_SX78 +FKKH0_SA1 +FKKH0_SA2 +FKKH0_SX120 +FKKH0_SX390 +FKLC0_SX355 +FKLC1_SI2308 +FKLC1_SX238 +FKLC1_SX328 +FKLC1_SX418 +FKLH0_SA2 +FKLH0_SX177 +FKSR0_SA1 +FKSR0_SA2 +FKSR0_SI1747 +FKSR0_SI487 +FKSR0_SX217 +FLAC0_SX451 +FLAG0_SA2 +FLAG0_SX114 +FLAG0_SX204 +FLAG0_SX24 +FLAG0_SX384 +FLEH0_SI1681 +FLEH0_SI2311 +FLEH0_SX331 +FLET0_SA1 +FLHD0_SI1827 +FLHD0_SX354 +FLJA0_SA1 +FLJA0_SI2338 +FLJD0_SI886 +FLJD0_SX76 +FLJG0_SA2 +FLKM0_SA2 +FLKM0_SI686 +FLKM0_SX260 +FLKM0_SX80 +FLMA0_SA1 +FLMA0_SI613 +FLMA0_SX433 +FLMA0_SX73 +FLMC0_SX22 +FLMK0_SI1035 +FLMK0_SX315 +FLMK0_SX405 +FLOD0_SI1917 +FLOD0_SX117 +FLOD0_SX171 +FLOD0_SX297 +FLTM0_SA1 +FLTM0_SI1070 +FLTM0_SI2330 +FMAH1_SA2 +FMAH1_SX159 +FMBG0_SA2 +FMBG0_SI2264 +FMEM0_SI747 +FMEM0_SX387 +FMJB0_SI547 +FMJB0_SX97 +FMJF0_SA2 +FMJU0_SX309 +FMJU0_SX399 +FMKC0_SI1702 +FMKC0_SX442 +FMKC0_SX82 +FMKF0_SX186 +FMPG0_SA2 +FNKL0_SI1522 +FNTB0_SI1203 +FNTB0_SI573 +FNTB0_SX303 +FPAB1_SI1471 +FPAB1_SX211 +FPAC0_SA2 +FPAD0_SA2 +FPAD0_SX356 +FPAD0_SX86 +FPAF0_SA2 +FPAF0_SX154 +FPAZ0_SA1 +FPAZ0_SA2 +FPAZ0_SX243 +FPJF0_SA1 +FPJF0_SX146 +FPJF0_SX56 +FPLS0_SI1590 +FPLS0_SX330 +FPMY0_SA1 +FPMY0_SX343 +FREH0_SA1 +FREH0_SA2 +FREH0_SX415 +FRJB0_SX347 +FRLL0_SX434 +FSAG0_SA1 +FSAG0_SX243 +FSAH0_SA1 +FSAH0_SA2 +FSAH0_SX164 +FSAH0_SX434 +FSBK0_SA2 +FSBK0_SI1069 +FSBK0_SX169 +FSCN0_SA2 +FSCN0_SI626 +FSCN0_SX266 +FSCN0_SX446 +FSCN0_SX86 +FSDC0_SA2 +FSDC0_SX142 +FSDC0_SX322 +FSDC0_SX52 +FSDJ0_SI485 +FSDJ0_SX215 +FSDJ0_SX305 +FSDJ0_SX395 +FSGF0_SX117 +FSJG0_SX130 +FSJK1_SA2 +FSJK1_SX125 +FSJK1_SX35 +FSJS0_SX181 +FSJW0_SI1963 +FSJW0_SX433 +FSKC0_SI1416 +FSKC0_SI786 +FSKC0_SX246 +FSKL0_SI1529 +FSKL0_SX449 +FSKP0_SA2 +FSLS0_SX156 +FSLS0_SX426 +FSMA0_SA2 +FSMA0_SX181 +FSMM0_SX144 +FSMM0_SX234 +FSMS1_SX244 +FSMS1_SX347 +FSPM0_SA2 +FSPM0_SX161 +FSPM0_SX71 +FSRH0_SI1931 +FSRH0_SI671 +FSRH0_SX221 +FSRH0_SX401 +FTAJ0_SI699 +FTAJ0_SX159 +FTAJ0_SX249 +FTAJ0_SX429 +FTBR0_SX21 +FTBW0_SA1 +FTMG0_SI1532 +FTMG0_SI2162 +FTMG0_SX452 +FVFB0_SA2 +FVFB0_SX132 +FVFB0_SX42 +FVKB0_SA1 +FVMH0_SA2 +FVMH0_SX116 +FVMH0_SX26 +MABC0_SI1620 +MABC0_SI2041 +MABC0_SI781 +MADC0_SX107 +MADC0_SX377 +MADD0_SA2 +MADD0_SI1295 +MADD0_SX178 +MADD0_SX268 +MADD0_SX88 +MAEB0_SX450 +MAEO0_SA1 +MAFM0_SI939 +MAFM0_SX129 +MAFM0_SX309 +MAJP0_SA2 +MAKB0_SI1646 +MAKB0_SX26 +MAKB0_SX386 +MAKR0_SX362 +MAKR0_SX92 +MAPV0_SX213 +MARC0_SA2 +MARC0_SX108 +MARC0_SX18 +MARC0_SX198 +MARW0_SI1906 +MBAR0_SA1 +MBAR0_SX419 +MBAR0_SX59 +MBBR0_SI2315 +MBBR0_SX65 +MBCG0_SA1 +MBCG0_SI486 +MBEF0_SI1281 +MBEF0_SI1911 +MBEF0_SI651 +MBEF0_SX21 +MBEF0_SX381 +MBGT0_SA2 +MBGT0_SX261 +MBGT0_SX351 +MBGT0_SX441 +MBJV0_SA1 +MBJV0_SI617 +MBJV0_SX347 +MBMA0_SI592 +MBMA0_SX232 +MBMA0_SX52 +MBMA1_SI2214 +MBMA1_SX54 +MBML0_SA2 +MBML0_SI1169 +MBML0_SX89 +MBOM0_SA2 +MBOM0_SI2274 +MBOM0_SX294 +MBSB0_SA1 +MBSB0_SX3 +MBTH0_SA2 +MBTH0_SX122 +MBTH0_SX32 +MCAE0_SX277 +MCAL0_SA2 +MCAL0_SI1768 +MCDC0_SA1 +MCDC0_SX212 +MCDD0_SA2 +MCDD0_SI883 +MCDD0_SX253 +MCDD0_SX433 +MCDR0_SI1154 +MCEF0_SX235 +MCEF0_SX415 +MCEW0_SA2 +MCHL0_SX87 +MCLK0_SX310 +MCLM0_SA1 +MCLM0_SI2086 +MCLM0_SI826 +MCPM0_SA1 +MCPM0_SX114 +MCPM0_SX294 +MCPM0_SX384 +MCSS0_SI750 +MCTH0_SA1 +MCTH0_SX39 +MCXM0_SX91 +MDAC0_SA1 +MDAC0_SX181 +MDAC0_SX361 +MDAS0_SX6 +MDBB1_SX106 +MDBB1_SX16 +MDBB1_SX376 +MDBP0_SX168 +MDCD0_SI1415 +MDCD0_SX245 +MDCD0_SX425 +MDCM0_SX40 +MDCM0_SX400 +MDDC0_SI2049 +MDDC0_SI789 +MDDC0_SX159 +MDDC0_SX69 +MDED0_SA1 +MDED0_SA2 +MDEF0_SX123 +MDEF0_SX303 +MDHL0_SI1439 +MDHL0_SX269 +MDHL0_SX449 +MDHS0_SA1 +MDHS0_SA2 +MDHS0_SI1530 +MDHS0_SI2160 +MDJM0_SX105 +MDJM0_SX15 +MDKS0_SX436 +MDLB0_SA2 +MDLC0_SX405 +MDLC1_SA2 +MDLC1_SI2065 +MDLC1_SI2144 +MDLC1_SX445 +MDLC2_SI2244 +MDLC2_SX354 +MDLH0_SA2 +MDLM0_SI1234 +MDLM0_SI1864 +MDLM0_SX154 +MDLM0_SX424 +MDLR0_SA1 +MDLR0_SA2 +MDLR0_SI1863 +MDLR0_SI603 +MDLR0_SX153 +MDLR1_SA1 +MDLR1_SA2 +MDMA0_SI1430 +MDMA0_SX260 +MDMA0_SX80 +MDMT0_SA1 +MDMT0_SA2 +MDMT0_SI1832 +MDMT0_SX122 +MDMT0_SX32 +MDNS0_SA2 +MDNS0_SI2271 +MDNS0_SX201 +MDNS0_SX21 +MDPB0_SX416 +MDPK0_SI1053 +MDPK0_SX333 +MDPK0_SX423 +MDPS0_SI719 +MDPS0_SX359 +MDRD0_SA1 +MDRD0_SX32 +MDSJ0_SI2092 +MDSS0_SA2 +MDSS0_SX441 +MDSS1_SA1 +MDSS1_SI1327 +MDSS1_SI697 +MDSS1_SX157 +MDSS1_SX67 +MDTB0_SI1200 +MDTB0_SI1830 +MDTB0_SX120 +MDWD0_SA2 +MDWD0_SX270 +MDWD0_SX90 +MDWH0_SX215 +MDWH0_SX305 +MDWM0_SA1 +MDWM0_SA2 +MDWM0_SX16 +MDWM0_SX286 +MEAL0_SA2 +MEAL0_SI2177 +MEAL0_SX107 +MEAL0_SX347 +MEDR0_SA1 +MEDR0_SA2 +MEDR0_SI1374 +MEFG0_SA1 +MEGJ0_SA2 +MEGJ0_SX257 +MEGJ0_SX3 +MEJL0_SA1 +MEJL0_SX152 +MEJL0_SX242 +MEJS0_SI610 +MEJS0_SX160 +MEJS0_SX340 +MESG0_SX432 +MESJ0_SX187 +MESJ0_SX97 +MEWM0_SI718 +MEWM0_SX178 +MEWM0_SX88 +MFER0_SI862 +MFER0_SX142 +MFRM0_SX345 +MFRM0_SX435 +MFWK0_SI1879 +MFWK0_SX169 +MFXS0_SX54 +MFXV0_SA2 +MFXV0_SX105 +MGAF0_SA1 +MGAF0_SX22 +MGAF0_SX382 +MGAG0_SA2 +MGAK0_SX226 +MGAK0_SX46 +MGAR0_SX132 +MGAW0_SI535 +MGAW0_SX175 +MGES0_SA1 +MGES0_SI2111 +MGES0_SI851 +MGJC0_SA2 +MGJC0_SX75 +MGRL0_SI2127 +MGRL0_SI867 +MGRL0_SX147 +MGRP0_SA2 +MGSH0_SA2 +MGSH0_SI1806 +MGSH0_SX127 +MGSH0_SX276 +MGSH0_SX6 +MGSL0_SA1 +MGSL0_SI534 +MGSL0_SX264 +MGXP0_SX187 +MGXP0_SX7 +MHBS0_SX315 +MHBS0_SX45 +MHIT0_SA1 +MHJB0_SA1 +MHJB0_SI1017 +MHMG0_SX195 +MHMR0_SA1 +MHMR0_SI489 +MHRM0_SA1 +MHRM0_SI958 +MHRM0_SX148 +MHRM0_SX58 +MHXL0_SI1772 +MHXL0_SX242 +MILB0_SA2 +MJAC0_SX307 +MJAC0_SX71 +MJAE0_SX174 +MJAI0_SA1 +MJAI0_SA2 +MJBG0_SX62 +MJDA0_SI1031 +MJDA0_SX311 +MJDE0_SI463 +MJDG0_SA2 +MJDG0_SI1042 +MJDG0_SI1705 +MJDM0_SA1 +MJDM0_SI974 +MJEB0_SI656 +MJEB0_SX296 +MJEB1_SA2 +MJEB1_SX207 +MJEB1_SX387 +MJEE0_SA1 +MJEE0_SX247 +MJEE0_SX337 +MJFH0_SA2 +MJFH0_SI1107 +MJFR0_SX75 +MJHI0_SA1 +MJHI0_SX158 +MJJB0_SA1 +MJJB0_SX239 +MJJJ0_SX443 +MJJM0_SA2 +MJJM0_SI827 +MJJM0_SX107 +MJKR0_SA1 +MJKR0_SI571 +MJLB0_SX176 +MJLG1_SX292 +MJLS0_SX106 +MJMA0_SA1 +MJMA0_SA2 +MJMD0_SA2 +MJMD0_SX308 +MJMD0_SX38 +MJMM0_SX85 +MJPG0_SI1191 +MJPG0_SX111 +MJPG0_SX201 +MJPG0_SX21 +MJPM0_SA2 +MJPM0_SX378 +MJPM1_SI2280 +MJPM1_SX401 +MJRA0_SA1 +MJRA0_SA2 +MJRA0_SI1236 +MJRA0_SI1866 +MJRA0_SX426 +MJRG0_SI1366 +MJRG0_SI1996 +MJRG0_SX376 +MJRH0_SX225 +MJRH1_SA1 +MJRH1_SI514 +MJRH1_SX154 +MJRH1_SX244 +MJRH1_SX424 +MJRK0_SA1 +MJRK0_SA2 +MJRK0_SI1662 +MJRK0_SX160 +MJRK0_SX250 +MJRK0_SX430 +MJRP0_SA1 +MJRP0_SA2 +MJRP0_SX225 +MJSR0_SA1 +MJSR0_SI1424 +MJSR0_SX344 +MJWG0_SA1 +MJWG0_SX265 +MJWS0_SI513 +MJWS0_SX153 +MJWS0_SX63 +MJWT0_SA1 +MJWT0_SX121 +MJWT0_SX211 +MJWT0_SX301 +MJWT0_SX31 +MJWT0_SX391 +MJXA0_SX427 +MJXL0_SI542 +MKAG0_SA1 +MKAG0_SX259 +MKAJ0_SA2 +MKAJ0_SX154 +MKAM0_SA1 +MKAM0_SX146 +MKAM0_SX326 +MKAM0_SX56 +MKDB0_SA1 +MKDB0_SA2 +MKDB0_SX152 +MKDD0_SA2 +MKES0_SA1 +MKES0_SI1253 +MKES0_SI1883 +MKES0_SX173 +MKJO0_SI1517 +MKJO0_SI887 +MKJO0_SX437 +MKLN0_SI968 +MKLN0_SX248 +MKLR0_SA2 +MKLR0_SI1689 +MKLS0_SA1 +MKLS0_SX357 +MKLS0_SX87 +MKLS1_SA1 +MKLS1_SA2 +MKLS1_SX375 +MKLW0_SA1 +MKRG0_SX411 +MKXL0_SA2 +MKXL0_SX15 +MKXL0_SX375 +MLBC0_SA1 +MLBC0_SI1869 +MLBC0_SX249 +MLEL0_SA1 +MLEL0_SA2 +MLEL0_SI1246 +MLEL0_SX256 +MLEL0_SX436 +MLJC0_SX145 +MLJC0_SX415 +MLJH0_SX64 +MLNS0_SI2037 +MMAA0_SA1 +MMAA0_SA2 +MMAA0_SX35 +MMAB1_SI1494 +MMAB1_SX234 +MMAG0_SA2 +MMAG0_SI1126 +MMAG0_SX316 +MMAM0_SI2227 +MMAM0_SX157 +MMAM0_SX427 +MMAR0_SX256 +MMBS0_SI1781 +MMCC0_SA2 +MMDB0_SX177 +MMDG0_SA1 +MMDG0_SA2 +MMDG0_SI520 +MMDG0_SX160 +MMDG0_SX250 +MMDM0_SI1941 +MMDM0_SI681 +MMDM0_SX141 +MMDM1_SA2 +MMDM1_SI2043 +MMDM1_SX423 +MMDM1_SX63 +MMDS0_SA1 +MMEA0_SA1 +MMEA0_SX128 +MMEA0_SX398 +MMEB0_SA2 +MMEB0_SX187 +MMEB0_SX367 +MMGC0_SA2 +MMGC0_SX135 +MMGC0_SX225 +MMGG0_SX269 +MMGK0_SX332 +MMGK0_SX62 +MMJB1_SA2 +MMRP0_SA2 +MMRP0_SX144 +MMSM0_SX116 +MMSM0_SX206 +MMVP0_SA1 +MMVP0_SA2 +MMWB0_SI989 +MMWB0_SX89 +MMWS0_SA2 +MMWS0_SX168 +MMWS0_SX348 +MMWS0_SX438 +MMWS1_SI1701 +MMXS0_SI2136 +MMXS0_SX246 +MMXS0_SX426 +MNET0_SI816 +MNET0_SX6 +MNTW0_SA2 +MNTW0_SX168 +MNTW0_SX78 +MPAR0_SI2206 +MPAR0_SI946 +MPAR0_SX136 +MPAR0_SX316 +MPEB0_SI1034 +MPEB0_SI1860 +MPEB0_SX240 +MPEB0_SX330 +MPFU0_SI628 +MPFU0_SX448 +MPGH0_SX114 +MPGH0_SX24 +MPGR0_SX240 +MPGR0_SX330 +MPGR1_SX149 +MPPC0_SA1 +MPRD0_SA1 +MPRD0_SX261 +MPRD0_SX351 +MPRD0_SX441 +MPRD0_SX81 +MPRK0_SI1727 +MPRK0_SX107 +MPRK0_SX377 +MPRT0_SA1 +MPRT0_SX310 +MPSW0_SI1067 +MPSW0_SX167 +MPSW0_SX437 +MRAB1_SX128 +MRAB1_SX308 +MRAI0_SA1 +MRAI0_SA2 +MRAI0_SX72 +MRAM0_SA1 +MRAM0_SA2 +MRAM0_SX15 +MRBC0_SI1859 +MRBC0_SX329 +MRBC0_SX419 +MRCG0_SI798 +MRCG0_SX168 +MRCW0_SA1 +MRCW0_SX291 +MRDD0_SI1680 +MRDD0_SX150 +MRDD0_SX277 +MRDD0_SX60 +MRDM0_SI1595 +MRDM0_SX65 +MRDS0_SA1 +MREE0_SX24 +MREH1_SX249 +MREH1_SX69 +MREM0_SA2 +MREW1_SI870 +MRFK0_SX446 +MRFL0_SA1 +MRFL0_SX256 +MRFL0_SX436 +MRFL0_SX76 +MRGM0_SA2 +MRGM0_SX262 +MRGS0_SA2 +MRGS0_SX186 +MRHL0_SI885 +MRHL0_SX345 +MRHL0_SX435 +MRJB1_SA1 +MRJB1_SA2 +MRJB1_SX210 +MRJB1_SX30 +MRJB1_SX390 +MRJH0_SA2 +MRJH0_SX307 +MRJH0_SX79 +MRJM0_SX148 +MRJM1_SA2 +MRJM1_SI1298 +MRJM1_SI1928 +MRJM1_SX128 +MRJT0_SA2 +MRJT0_SI1498 +MRJT0_SX328 +MRJT0_SX418 +MRKM0_SA2 +MRKM0_SX367 +MRLD0_SA2 +MRLD0_SI2224 +MRLD0_SX154 +MRLD0_SX424 +MRLJ0_SA1 +MRLJ0_SX250 +MRLJ0_SX340 +MRLJ1_SA1 +MRLJ1_SA2 +MRLJ1_SX321 +MRLK0_SI843 +MRLK0_SX123 +MRLK0_SX213 +MRMB0_SA2 +MRMB0_SI1581 +MRMB0_SX411 +MRMG0_SA1 +MRMG0_SI1080 +MRMG0_SX450 +MRMH0_SI1349 +MRMH0_SI2281 +MRMH0_SX121 +MRML0_SA2 +MRML0_SX341 +MRPC1_SI2112 +MRRE0_SA2 +MRRE0_SX164 +MRRE0_SX344 +MRRE0_SX74 +MRSO0_SX129 +MRSO0_SX39 +MRSP0_SX259 +MRTC0_SX378 +MRVG0_SI1140 +MRVG0_SX240 +MRWA0_SI973 +MRWA0_SX163 +MRWA0_SX73 +MRWS0_SI1732 +MRWS0_SI472 +MRWS0_SX22 +MRWS0_SX382 +MRXB0_SA2 +MRXB0_SX415 +MSAH1_SI1679 +MSAS0_SX116 +MSAS0_SX206 +MSAS0_SX386 +MSAT0_SA1 +MSAT1_SX263 +MSAT1_SX443 +MSAT1_SX83 +MSDB0_SX197 +MSDB0_SX287 +MSDB0_SX377 +MSDH0_SI2240 +MSDH0_SX440 +MSDH0_SX80 +MSDS0_SA1 +MSEM1_SI1440 +MSEM1_SX180 +MSEM1_SX270 +MSES0_SI1589 +MSES0_SX239 +MSES0_SX419 +MSFH0_SX316 +MSFV0_SI1892 +MSFV0_SX362 +MSFV0_SX92 +MSMR0_SX415 +MSMS0_SA1 +MSMS0_SX173 +MSMS0_SX83 +MSRG0_SA1 +MSRG0_SI1221 +MSTF0_SI766 +MSTF0_SX316 +MSTF0_SX46 +MSVS0_SA2 +MSVS0_SX308 +MTAS0_SX215 +MTAS0_SX35 +MTAS0_SX395 +MTAT0_SX390 +MTAT1_SX59 +MTBC0_SI1803 +MTCS0_SA2 +MTCS0_SI2265 +MTCS0_SX82 +MTDP0_SA2 +MTER0_SA2 +MTER0_SI1787 +MTJG0_SA1 +MTJG0_SI2157 +MTJG0_SX260 +MTJM0_SI1856 +MTJM0_SX146 +MTJU0_SX130 +MTJU0_SX400 +MTKD0_SX107 +MTKD0_SX287 +MTKP0_SI1023 +MTLB0_SA1 +MTLB0_SX234 +MTLC0_SA1 +MTML0_SI2325 +MTML0_SX165 +MTMN0_SA2 +MTMN0_SI1064 +MTMN0_SI2324 +MTMN0_SX434 +MTMT0_SA2 +MTMT0_SI1748 +MTPF0_SX65 +MTPG0_SI1383 +MTPG0_SI753 +MTPG0_SX303 +MTPP0_SX338 +MTPR0_SX340 +MTQC0_SI480 +MTQC0_SX91 +MTRR0_SX198 +MTRR0_SX288 +MTRT0_SA2 +MTRT0_SX254 +MTRT0_SX57 +MTWH1_SX72 +MTXS0_SA1 +MTXS0_SA2 +MVJH0_SI926 +MVJH0_SX206 +MVJH0_SX296 +MVLO0_SA1 +MVRW0_SA2 +MVRW0_SX135 +MVRW0_SX225 +MWAC0_SA2 +MWAC0_SX341 +MWAC0_SX431 +MWAD0_SX432 +MWAD0_SX72 +MWAR0_SA1 +MWAR0_SI1675 +MWCH0_SI1895 +MWCH0_SI2252 +MWCH0_SX182 +MWCH0_SX452 +MWDK0_SA1 +MWDK0_SA2 +MWDK0_SI2017 +MWDK0_SI806 +MWDK0_SX176 +MWDK0_SX86 +MWEM0_SA2 +MWEM0_SI1320 +MWEM0_SI1393 +MWEM0_SX150 +MWGR0_SX346 +MWRE0_SX247 +MWRE0_SX337 +MWRE0_SX427 +MWRP0_SA1 +MWRP0_SX273 +MWRP0_SX363 +MWSB0_SX276 +MWSH0_SX256 +MWSH0_SX76 +MZMB0_SA1 diff --git a/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid b/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid new file mode 100644 index 0000000000..e99edfe937 --- /dev/null +++ b/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid @@ -0,0 +1,620 @@ +FAEM0_SI1392 +FAJW0_SI1263 +FAJW0_SI633 +FALK0_SI658 +FALR0_SX335 +FAPB0_SI1063 +FAPB0_SI2323 +FAPB0_SX433 +FBAS0_SI1472 +FBAS0_SI2066 +FBCG1_SX352 +FBCH0_SI959 +FBJL0_SI922 +FBLV0_SI1688 +FBMH0_SI1136 +FBMH0_SI970 +FBMJ0_SA1 +FBMJ0_SI1776 +FBMJ0_SI516 +FBMJ0_SX336 +FCDR1_SI1186 +FCDR1_SI1816 +FCDR1_SI556 +FCDR1_SX286 +FCKE0_SI1741 +FCKE0_SI481 +FCLT0_SI808 +FCMG0_SI1142 +FCMG0_SX432 +FCMM0_SI1957 +FCMM0_SX420 +FCYL0_SI667 +FCYL0_SX349 +FDAS1_SI1461 +FDAS1_SI831 +FDAW0_SI1271 +FDAW0_SI2036 +FDJH0_SI935 +FDKN0_SI1202 +FDKN0_SX181 +FDKN0_SX451 +FDMY0_SA1 +FDMY0_SI567 +FDMY0_SI714 +FDMY0_SX387 +FDNC0_SI1278 +FDNC0_SI1908 +FDTD0_SA1 +FDTD0_SX321 +FEAC0_SI615 +FEAR0_SX352 +FECD0_SA1 +FECD0_SI1418 +FECD0_SI788 +FEME0_SI875 +FEME0_SX335 +FEXM0_SA1 +FEXM0_SI482 +FEXM0_SX366 +FGDP0_SI988 +FGDP0_SX88 +FGMB0_SI1145 +FGMB0_SX335 +FGRW0_SA1 +FGRW0_SI1152 +FGRW0_SX162 +FGRW0_SX432 +FHLM0_SX120 +FHLM0_SX349 +FHXS0_SA1 +FHXS0_SI1075 +FHXS0_SI2302 +FHXS0_SX175 +FJDM2_SA2 +FJDM2_SX142 +FJEN0_SA1 +FJEN0_SX327 +FJEN0_SX417 +FJHK0_SI2282 +FJKL0_SI932 +FJLG0_SI1889 +FJLR0_SI1231 +FJRB0_SX402 +FJRP1_SA1 +FJRP1_SI1432 +FJRP1_SX262 +FJRP1_SX352 +FJSK0_SI1052 +FJSP0_SI1434 +FJWB1_SI748 +FJXM0_SX311 +FJXM0_SX41 +FJXP0_SI1752 +FKAA0_SA1 +FKDE0_SI1141 +FKDE0_SI1771 +FKDW0_SI1207 +FKDW0_SI1891 +FKFB0_SI1608 +FKFB0_SX438 +FKKH0_SI1290 +FKKH0_SI1920 +FKLC0_SI985 +FKLC0_SX175 +FKLC1_SI1048 +FKLH0_SI1257 +FKSR0_SX366 +FLAC0_SI1339 +FLAG0_SI1464 +FLAG0_SI834 +FLEH0_SI1051 +FLET0_SI507 +FLJA0_SI1078 +FLJA0_SX178 +FLJD0_SI1516 +FLJG0_SI981 +FLJG0_SX171 +FLJG0_SX351 +FLKM0_SA1 +FLKM0_SI620 +FLKM0_SX350 +FLKM0_SX440 +FLMC0_SI1372 +FLMK0_SA1 +FLMK0_SI1229 +FLTM0_SX170 +FLTM0_SX350 +FLTM0_SX440 +FMAH1_SI879 +FMBG0_SI1160 +FMEM0_SA1 +FMEM0_SX333 +FMJB0_SI1177 +FMJF0_SI624 +FMJF0_SX174 +FMJF0_SX84 +FMJU0_SI1389 +FMKC0_SI1041 +FMKF0_SI1018 +FMPG0_SA1 +FMPG0_SI972 +FMPG0_SX162 +FMPG0_SX342 +FMPG0_SX432 +FNKL0_SI892 +FNTB0_SI679 +FPAB1_SA1 +FPAB1_SI2101 +FPAB1_SI841 +FPAC0_SI1921 +FPAC0_SI661 +FPAD0_SI716 +FPAD0_SX176 +FPAF0_SA1 +FPAF0_SI1054 +FPAZ0_SI2223 +FPAZ0_SI963 +FPJF0_SI1259 +FPJF0_SX352 +FPLS0_SI960 +FPMY0_SI1153 +FPMY0_SI523 +FREH0_SI1945 +FRLL0_SI805 +FSAG0_SI1323 +FSAG0_SX153 +FSAG0_SX333 +FSAG0_SX423 +FSAH0_SI614 +FSAH0_SX327 +FSAK0_SI1300 +FSBK0_SX349 +FSCN0_SA1 +FSCN0_SI705 +FSCN0_SX176 +FSDC0_SI1312 +FSDJ0_SI1115 +FSGF0_SI2187 +FSGF0_SI927 +FSJG0_SA1 +FSJG0_SA2 +FSJG0_SI940 +FSJG0_SX220 +FSJG0_SX40 +FSJG0_SX400 +FSJS0_SA1 +FSJS0_SX451 +FSJW0_SI1333 +FSKP0_SI1098 +FSMA0_SI991 +FSMA0_SX451 +FSMM0_SX324 +FSPM0_SI1241 +FSPM0_SX251 +FSRH0_SX311 +FSSB0_SI1712 +FSSB0_SX362 +FTBR0_SI1402 +FTBR0_SI921 +FTBW0_SI715 +FTBW0_SX175 +FTLG0_SI1743 +FTLG0_SI483 +FTMG0_SI902 +FVFB0_SI1510 +FVKB0_SX349 +FVMH0_SI1466 +FVMH0_SI836 +MADC0_SI1367 +MADC0_SI737 +MAEB0_SI1411 +MAEO0_SI1326 +MAJP0_SI1704 +MAJP0_SX174 +MAKB0_SA2 +MAKB0_SI1016 +MAKB0_SI2276 +MAKB0_SX116 +MAPV0_SI1293 +MAPV0_SI663 +MARW0_SX286 +MARW0_SX349 +MBBR0_SI1055 +MBBR0_SX335 +MBCG0_SI957 +MBCG0_SX327 +MBGT0_SI1841 +MBGT0_SX171 +MBMA0_SI1222 +MBMA1_SI954 +MBMA1_SX324 +MBTH0_SI2102 +MBWP0_SX349 +MCAE0_SI1447 +MCAE0_SI2077 +MCAE0_SI817 +MCAL0_SI1138 +MCDR0_SI1784 +MCDR0_SI524 +MCEF0_SI842 +MCEW0_SA1 +MCEW0_SI2072 +MCEW0_SI812 +MCEW0_SX362 +MCEW0_SX452 +MCHL0_SI1347 +MCHL0_SI1404 +MCLK0_SI2290 +MCLK0_SI650 +MCPM0_SI1824 +MCSS0_SI1380 +MCSS0_SI688 +MCTM0_SI1350 +MCTM0_SI1980 +MDAC0_SI631 +MDAS0_SI1896 +MDAS0_SI636 +MDBP0_SI528 +MDBP0_SX438 +MDCD0_SI785 +MDCD0_SX335 +MDCM0_SI1480 +MDDC0_SI1419 +MDED0_SI540 +MDEF0_SI1123 +MDEM0_SA1 +MDEM0_SI608 +MDEM0_SI800 +MDEM0_SX428 +MDHS0_SI900 +MDJM0_SI1455 +MDKS0_SX166 +MDKS0_SX346 +MDLB0_SI1306 +MDLB0_SX136 +MDLB0_SX406 +MDLC0_SI1395 +MDLC0_SI2025 +MDLC1_SI1435 +MDLH0_SX160 +MDLH0_SX430 +MDLM0_SI604 +MDLR0_SX333 +MDLR1_SI669 +MDMA0_SX170 +MDMA0_SX350 +MDMA0_SX440 +MDNS0_SI1011 +MDNS0_SI873 +MDPB0_SI1760 +MDPB0_SI866 +MDRD0_SI752 +MDSJ0_SI1462 +MDSJ0_SX438 +MDWD0_SI1260 +MDWH0_SA1 +MDWH0_SI1168 +MDWH0_SI665 +MDWM0_SI916 +MEDR0_SI2004 +MEFG0_SI491 +MEFG0_SI598 +MEGJ0_SA1 +MEGJ0_SI1337 +MEGJ0_SI707 +MEGJ0_SX167 +MEJS0_SI1240 +MESG0_SI702 +MESJ0_SI2039 +MFWK0_SX349 +MFXS0_SX324 +MFXV0_SI1005 +MFXV0_SI1342 +MGAF0_SI1282 +MGAG0_SI691 +MGAK0_SI1036 +MGAK0_SX136 +MGAR0_SX312 +MGAW0_SI1165 +MGES0_SX311 +MGJC0_SX435 +MGRL0_SX327 +MGRP0_SI1317 +MGRP0_SX327 +MGSH0_SI1176 +MGSH0_SI546 +MGSL0_SI797 +MGXP0_SI1087 +MGXP0_SI525 +MHBS0_SI945 +MHIT0_SI983 +MHMG0_SI735 +MHMR0_SI1692 +MILB0_SI903 +MJAC0_SI701 +MJAC0_SX251 +MJAE0_SX84 +MJAI0_SI682 +MJAI0_SI710 +MJDC0_SI531 +MJDE0_SA1 +MJDE0_SI1120 +MJDE0_SI490 +MJDE0_SX220 +MJDM0_SI1340 +MJDM0_SX170 +MJDM0_SX350 +MJEB0_SX170 +MJEB1_SI1467 +MJEB1_SI837 +MJFR0_SA1 +MJFR0_SX435 +MJHI0_SI1328 +MJJJ0_SI1163 +MJJM0_SI1251 +MJLB0_SI1616 +MJLS0_SI1726 +MJMA0_SI2125 +MJMD0_SI2288 +MJMM0_SI1255 +MJMM0_SX175 +MJPG0_SI1821 +MJPM0_SI1368 +MJPM1_SX311 +MJRA0_SX336 +MJRG0_SI736 +MJRG0_SX352 +MJRH0_SI1840 +MJRH1_SI1558 +MJRK0_SI880 +MJRP0_SI1845 +MJSR0_SI2054 +MJSR0_SI794 +MJWG0_SI813 +MJWG0_SI895 +MJWG0_SX175 +MJWS0_SX333 +MJWT0_SI1291 +MJWT0_SI1381 +MJXL0_SI1172 +MKAG0_SI979 +MKAH0_SX178 +MKAM0_SI1250 +MKAM0_SI1465 +MKDD0_SI1567 +MKDD0_SI2197 +MKDD0_SI937 +MKDT0_SI814 +MKES0_SI623 +MKLS0_SI1437 +MKLS0_SI2067 +MKLS1_SI915 +MKLW0_SI1571 +MKLW0_SX311 +MKRG0_SI861 +MKXL0_SI1815 +MKXL0_SI1958 +MLBC0_SI1239 +MLEL0_SI616 +MLEL0_SX166 +MLJC0_SI1225 +MLJH0_SA1 +MLJH0_SA2 +MLJH0_SI1422 +MLJH0_SI694 +MLJH0_SX244 +MLSH0_SI1417 +MLSH0_SX247 +MMAA0_SI1588 +MMAA0_SI845 +MMAB1_SI864 +MMAB1_SX324 +MMAG0_SA1 +MMAG0_SI1756 +MMAG0_SX136 +MMAR0_SI1966 +MMAR0_SX166 +MMAR0_SX346 +MMBS0_SI521 +MMBS0_SX161 +MMCC0_SI1338 +MMDB0_SI987 +MMDG0_SI1780 +MMDM0_SI1311 +MMDM1_SX153 +MMDM1_SX333 +MMEB0_SX327 +MMGC0_SI1305 +MMGG0_SI1079 +MMGG0_SX449 +MMLM0_SI2150 +MMPM0_SX161 +MMRP0_SX324 +MMSM0_SI1106 +MMSM0_SI476 +MMVP0_SI654 +MMVP0_SX347 +MMWB0_SA1 +MMWB0_SI2249 +MMWB0_SX359 +MMWB0_SX449 +MNTW0_SI1068 +MNTW0_SI1698 +MPEB0_SI600 +MPFU0_SI1258 +MPGH0_SI675 +MPGR0_SI1410 +MPGR1_SI1499 +MPMB0_SA1 +MPMB0_SA2 +MPMB0_SI1501 +MPMB0_SI2131 +MPMB0_SI871 +MPMB0_SX151 +MPMB0_SX331 +MPMB0_SX421 +MPMB0_SX61 +MPPC0_SI1412 +MPRB0_SI1215 +MPRB0_SI575 +MPRD0_SI801 +MPRD0_SX171 +MPRK0_SA1 +MPRK0_SI1097 +MPRK0_SI467 +MPRK0_SX287 +MRAB0_SI1854 +MRAB1_SI848 +MRAI0_SI2052 +MRAI0_SI792 +MRAI0_SX432 +MRAM0_SI1951 +MRCG0_SA2 +MRCG0_SI1428 +MRCG0_SX348 +MRCG0_SX438 +MRCW0_SI741 +MRDM0_SI1044 +MRDM0_SX335 +MREE0_SI1104 +MREE0_SI1959 +MREH1_SA1 +MREH1_SI1599 +MREH1_SI969 +MREM0_SI511 +MRFK0_SI1076 +MRFL0_SI1156 +MRFL0_SI526 +MRFL0_SX166 +MRGM0_SI532 +MRGM0_SX172 +MRGM0_SX442 +MRGS0_SI1356 +MRGS0_SI726 +MRGS0_SX6 +MRJB1_SI1413 +MRJB1_SI2021 +MRJB1_SX120 +MRJH0_SI1519 +MRJH0_SI889 +MRJH0_SX169 +MRJT0_SI868 +MRJT0_SX58 +MRKM0_SI1267 +MRKM0_SI1391 +MRKM0_SI637 +MRLJ0_SI790 +MRLJ1_SI2301 +MRLK0_SI1468 +MRLR0_SI1196 +MRML0_SA1 +MRML0_SI1421 +MRML0_SX161 +MRML0_SX251 +MRMS0_SI2057 +MRRE0_SA1 +MRRE0_SI1334 +MRRE0_SI952 +MRSO0_SI1206 +MRSP0_SI1429 +MRTC0_SI1458 +MRTJ0_SA1 +MRTJ0_SI772 +MRTJ0_SX142 +MRTJ0_SX232 +MRTJ0_SX52 +MRWS0_SI1102 +MRXB0_SI2215 +MRXB0_SI955 +MSAS0_SI1376 +MSAS0_SI746 +MSDH0_SI980 +MSDH0_SX170 +MSDS0_SI1077 +MSDS0_SX267 +MSDS0_SX357 +MSEM1_SI2070 +MSEM1_SI810 +MSFH0_SA1 +MSFH0_SI1738 +MSFH0_SX136 +MSFH0_SX406 +MSFV0_SI632 +MSJK0_SI1596 +MSJK0_SX336 +MSMC0_SI509 +MSMR0_SI1150 +MSMS0_SI1433 +MSRR0_SI1761 +MSRR0_SI501 +MSTF0_SI852 +MSVS0_SI2198 +MSVS0_SI938 +MSVS0_SX398 +MTAB0_SI1572 +MTAB0_SX312 +MTAT0_SA1 +MTAT0_SI1110 +MTAT0_SI811 +MTAT1_SI779 +MTAT1_SX149 +MTAT1_SX329 +MTBC0_SI543 +MTCS0_SI712 +MTDB0_SI1401 +MTDB0_SI771 +MTDP0_SA1 +MTDP0_SI1521 +MTDP0_SX171 +MTDP0_SX351 +MTER0_SA1 +MTER0_SI1157 +MTER0_SX437 +MTJG0_SX170 +MTJS0_SA2 +MTJS0_SI1822 +MTJS0_SI562 +MTJS0_SX382 +MTJU0_SI2020 +MTKD0_SI630 +MTKP0_SI2283 +MTKP0_SI454 +MTLB0_SI1134 +MTLB0_SX324 +MTLC0_SI1313 +MTLC0_SI1477 +MTML0_SX435 +MTMN0_SI582 +MTMT0_SI488 +MTPP0_SI1508 +MTPR0_SI2230 +MTPR0_SX160 +MTPR0_SX430 +MTQC0_SA1 +MTQC0_SI1441 +MTQC0_SX181 +MTQC0_SX451 +MTRC0_SI589 +MTRR0_SI918 +MTRT0_SI1227 +MTXS0_SI1060 +MTXS0_SI2320 +MTXS0_SX160 +MTXS0_SX430 +MVJH0_SI1556 +MVLO0_SI517 +MWAC0_SI1601 +MWAC0_SX161 +MWAC0_SX251 +MWAR0_SI1045 +MWDK0_SI1436 +MWEM0_SX420 +MWRE0_SA2 +MWRE0_SI1057 +MWRE0_SX67 +MWRP0_SI1443 +MWSB0_SI996 +MWSH0_SI1426 +MWSH0_SI796 +MWSH0_SX166 diff --git a/examples/wav2vec/unsupervised/scripts/prepare_timit.sh b/examples/wav2vec/unsupervised/scripts/prepare_timit.sh new file mode 100644 index 0000000000..d8f5d596b4 --- /dev/null +++ b/examples/wav2vec/unsupervised/scripts/prepare_timit.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +timit_root=$1 # assume it is the upper-cased version +tgt_dir=$2 +model=$3 + +set -eu + +setups="matched unmatched" +splits="test valid train train_text" + +tgt_dir=$(realpath $tgt_dir) +sph2wav=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +wav_dir=$tgt_dir/wav + + +mkdir -p $tgt_dir $wav_dir +find $timit_root/{TRAIN,TEST} -iname "*.WAV" > $tgt_dir/all_sph.flist +cat $tgt_dir/all_sph.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).WAV#\1_\2#g' > $tgt_dir/all.uid +paste -d' ' $tgt_dir/{all_sph.flist,all.uid} | \ + awk -v sph2wav=$sph2wav -v wav_dir=$wav_dir '{print sph2wav " -f wav " $1 " > " wav_dir "/" $2 ".wav"}' \ + > $tgt_dir/sph2wav.sh +bash $tgt_dir/sph2wav.sh +cat $tgt_dir/all.uid | awk -v wav_dir=$(pwd)/$wav_dir '{print $1" "wav_dir"/"$1".wav"}' | sort > $tgt_dir/all_wav.scp +cut -d' ' -f2 $tgt_dir/all_wav.scp | xargs -I{} soxi -s {} > $tgt_dir/all.dur +paste -d' ' $tgt_dir/{all_wav.scp,all.dur} > $tgt_dir/all_wav_dur.scp +rm $tgt_dir/{all.uid,all_sph.flist,sph2wav.sh} + +find $timit_root/{TRAIN,TEST} -iname "*.PHN" > $tgt_dir/all_phn60.flist +while read line; do + if [ ! -f $line ]; then + >&2 echo "Cannot find transcription file '$line'" && exit 1; + fi + cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' +done < $tgt_dir/all_phn60.flist > $tgt_dir/all.phn60 +cat $tgt_dir/all_phn60.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).PHN#\1_\2#g' | \ + paste -d' ' - $tgt_dir/all.phn60 | \ + $KALDI_ROOT/egs/timit/s5/local/timit_norm_trans.pl -i - -m $KALDI_ROOT/egs/timit/s5/conf/phones.60-48-39.map -to 39 | \ + sort > $tgt_dir/all.phn +echo "done preparing wav and 39-phone transcripts" + + +for s in $setups; do + mkdir -p $tgt_dir/$s + for x in $splits; do + uid_path=config/timit_${s}/${x}.uid + grep -w -f $uid_path $tgt_dir/all.phn | cut -d' ' -f2- > $tgt_dir/$s/$x.phn + ln -sf $(realpath $tgt_dir/$s/$x.phn) $tgt_dir/$s/$x.wrd + + echo "/" > $tgt_dir/$s/$x.tsv && grep -w -f $uid_path $tgt_dir/all_wav_dur.scp | cut -d' ' -f2- | sed 's# #\t#' >> $tgt_dir/$s/$x.tsv + done + + for x in $splits; do + cat $tgt_dir/$s/$x.phn + done | tr ' ' '\n' | sort -u | awk '{print $1" "1}' > $tgt_dir/$s/dict.phn.txt + ln -sf $(realpath $tgt_dir/$s/dict.phn.txt) $tgt_dir/$s/dict.wrd.txt +done +echo "done preparing unmatched and matched setups for TIMIT" + + +for s in $setups; do + zsh scripts/prepare_audio.sh $tgt_dir/$s $tgt_dir/$s/feat $model + + lm_dir=$tgt_dir/$s/phones + fst_dir=$tgt_dir/$s/fst/phn_to_phn + + python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $tgt_dir/$s/train_text.phn --workers 10 --only-source --destdir $lm_dir --srcdict $tgt_dir/$s/dict.phn.txt + $KENLM_ROOT/lmplz -o 3 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.03.arpa + $KENLM_ROOT/build_binary $lm_dir/train_text_phn.03.arpa $lm_dir/train_text_phn.03.bin + $KENLM_ROOT/lmplz -o 4 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.04.arpa + $KENLM_ROOT/build_binary $lm_dir/train_text_phn.04.arpa $lm_dir/train_text_phn.04.bin + + python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$fst_dir lm_arpa=$lm_dir/train_text_phn.03.arpa data_dir=$tgt_dir/$s in_labels=phn +done +echo "done preprocessing audio and text for wav2vec-U" From 3c4a8e41559fa50b6c907fbefa1dab55d57bda5c Mon Sep 17 00:00:00 2001 From: Nithin-Holla <nithin.holla7@gmail.com> Date: Mon, 21 Jun 2021 20:16:08 -0700 Subject: [PATCH 622/707] Enabling word-level timestamps for Wav2Vec 2.0 (#3627) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3371. Currently, the output from Wav2Vec 2.0 decoding does not contain word-level start/end times, which can be useful for certain applications of ASR. Based on the discussion [here](https://github.com/flashlight/flashlight/issues/618), they could be computed based on the output from the Flashlight decoder. For the KenLM decoder, we could first obtain the frame number corresponding to each non-blank token. Next, the timestamp of each character could be computed as `segment_start + frame_no/total_frames * segment_duration`. Finally, the start and end time of each word could be calculated based on the timestamp of the word boundary characters. In order to enable this, the frame number of each non-blank character is returned as a result of KenLM decoding. This is similar to the `timesteps` output from the [ctcdecode](https://github.com/parlance/ctcdecode#outputs-from-the-decode-method) library. ## PR review alexeib Pull Request resolved: https://github.com/pytorch/fairseq/pull/3627 Reviewed By: michaelauli Differential Revision: D29282488 Pulled By: alexeib fbshipit-source-id: b5fe64bf50abd7ef8e9539f4e338937c866eb0ca --- .../new/decoders/flashlight_decoder.py | 22 +++++++++++++++++++ examples/speech_recognition/w2l_decoder.py | 22 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/examples/speech_recognition/new/decoders/flashlight_decoder.py b/examples/speech_recognition/new/decoders/flashlight_decoder.py index 8a548bdf66..38c7ac492f 100644 --- a/examples/speech_recognition/new/decoders/flashlight_decoder.py +++ b/examples/speech_recognition/new/decoders/flashlight_decoder.py @@ -118,6 +118,27 @@ def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: self.decoder_opts, self.lm, self.silence, self.blank, [] ) + def get_timesteps(self, token_idxs: List[int]) -> List[int]: + """Returns frame numbers corresponding to every non-blank token. + + Parameters + ---------- + token_idxs : List[int] + IDs of decoded tokens. + + Returns + ------- + List[int] + Frame numbers corresponding to every non-blank token. + """ + timesteps = [] + for i, token_idx in enumerate(token_idxs): + if token_idx == self.blank: + continue + if i == 0 or token_idx != token_idxs[i-1]: + timesteps.append(i) + return timesteps + def decode( self, emissions: torch.FloatTensor, @@ -134,6 +155,7 @@ def decode( { "tokens": self.get_tokens(result.tokens), "score": result.score, + "timesteps": self.get_timesteps(result.tokens), "words": [ self.word_dict.get_entry(x) for x in result.words if x >= 0 ], diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index aef4481593..fbf2d3524e 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -12,6 +12,7 @@ import gc import itertools as it import os.path as osp +from typing import List import warnings from collections import deque, namedtuple @@ -194,6 +195,26 @@ def __init__(self, args, tgt_dict): self.decoder_opts, self.lm, self.silence, self.blank, [] ) + def get_timesteps(self, token_idxs: List[int]) -> List[int]: + """Returns frame numbers corresponding to every non-blank token. + + Parameters + ---------- + token_idxs : List[int] + IDs of decoded tokens. + + Returns + ------- + List[int] + Frame numbers corresponding to every non-blank token. + """ + timesteps = [] + for i, token_idx in enumerate(token_idxs): + if token_idx == self.blank: + continue + if i == 0 or token_idx != token_idxs[i-1]: + timesteps.append(i) + return timesteps def decode(self, emissions): B, T, N = emissions.size() @@ -208,6 +229,7 @@ def decode(self, emissions): { "tokens": self.get_tokens(result.tokens), "score": result.score, + "timesteps": self.get_timesteps(result.tokens), "words": [ self.word_dict.get_entry(x) for x in result.words if x >= 0 ], From 7ca8bc12c09d91187d95117094f6b31b3342cd17 Mon Sep 17 00:00:00 2001 From: Eduardo Romero <eduardoromero@fb.com> Date: Tue, 22 Jun 2021 09:11:13 -0700 Subject: [PATCH 623/707] KMeans Attention Summary: KMeans attention main file Reviewed By: yiq-liu Differential Revision: D28478149 fbshipit-source-id: 97ef1408cfa239bdf13ee5d54d5d31b61a7f2236 --- fairseq/modules/kmeans_attention.py | 609 ++++++++++++++++++++++++++++ 1 file changed, 609 insertions(+) create mode 100644 fairseq/modules/kmeans_attention.py diff --git a/fairseq/modules/kmeans_attention.py b/fairseq/modules/kmeans_attention.py new file mode 100644 index 0000000000..11a7debcf2 --- /dev/null +++ b/fairseq/modules/kmeans_attention.py @@ -0,0 +1,609 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from inspect import isfunction +from operator import mul +from functools import reduce, wraps + +from aml.multimodal_video.utils.einops.lib import rearrange, repeat +from aml.multimodal_video.utils.einops.lib.layers.torch import Rearrange + +from fairseq.modules.local_attention import LocalAttention + +# constants + +TOKEN_SELF_ATTN_VALUE = -5e4 +KMEAN_INIT_ITERS = 10 + +# helper functions + + +def exists(val): + return val is not None + + +def identity(x, *args, **kwargs): + return x + + +def default(x, d): + if not exists(x): + return d if not isfunction(d) else d() + return x + + +def cast_tuple(x): + return x if isinstance(x, tuple) else (x,) + + +def cache_fn(f): + cache = None + + @wraps(f) + def cached_fn(*args, **kwargs): + nonlocal cache + if exists(cache): + return cache + cache = f(*args, **kwargs) + return cache + return cached_fn + + +def to(t): + return {'device': t.device, 'dtype': t.dtype} + + +def find_modules(nn_module, type): + return [module for module in nn_module.modules() if isinstance(module, type)] + + +def is_empty(t): + return t.nelement() == 0 + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def batched_index_select(values, indices): + last_dim = values.shape[-1] + return values.gather(2, expand_dim(indices, -1, last_dim)) + + +def merge_dims(ind_from, ind_to, tensor): + shape = list(tensor.shape) + arr_slice = slice(ind_from, ind_to + 1) + shape[arr_slice] = [reduce(mul, shape[arr_slice])] + return tensor.reshape(*shape) + + +def expand_dim(t, dim, k): + t = t.unsqueeze(dim) + expand_shape = [-1] * len(t.shape) + expand_shape[dim] = k + return t.expand(*expand_shape) + + +def scatter_mean(src, t, index, dim, eps=1e-5): + numer = src.scatter_add(dim, index, t) + denom = src.scatter_add(dim, index, torch.ones_like(t)) + return numer / (denom + eps) + + +def split_at_index(dim, index, t): + pre_slices = (slice(None),) * dim + l = (*pre_slices, slice(None, index)) + r = (*pre_slices, slice(index, None)) + return t[l], t[r] + + +def reshape_dim(t, dim, split_dims): + shape = list(t.shape) + num_dims = len(shape) + dim = (dim + num_dims) % num_dims + shape[dim:dim+1] = split_dims + return t.reshape(shape) + + +def ema(old, new, decay): + if not exists(old): + return new + return old * decay + new * (1 - decay) + + +def ema_inplace(moving_avg, new, decay): + if is_empty(moving_avg): + moving_avg.data.copy_(new) + return + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + +# helper classes + + +def map_first_tuple_or_el(x, fn): + if isinstance(x, tuple): + return (fn(x[0]),) + x[1:] + return fn(x) + + +class Chunk(nn.Module): + def __init__(self, chunks, fn, along_dim=-1): + super().__init__() + self.dim = along_dim + self.chunks = chunks + self.fn = fn + + def forward(self, x, **kwargs): + if self.chunks <= 1: + return self.fn(x, **kwargs) + chunks = x.chunk(self.chunks, dim=self.dim) + return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim) + + +class PreNorm(nn.ModuleList): + def __init__(self, norm_class, dim, fn): + super().__init__() + self.norm = norm_class(dim) + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class ReZero(nn.Module): + def __init__(self, fn): + super().__init__() + self.residual_weight = nn.Parameter(torch.zeros(1)) + self.fn = fn + + def forward(self, x, **kwargs): + x = self.fn(x, **kwargs) + return map_first_tuple_or_el(x, lambda t: t * self.residual_weight) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.g = nn.Parameter(torch.ones(1)) + self.eps = eps + + def forward(self, x): + def norm(t): + n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps) + return t / n * self.g + return map_first_tuple_or_el(x, norm) + + +class ProjectInOut(nn.Module): + def __init__(self, fn, dim_in, dim_out, project_out=True): + super().__init__() + self.fn = fn + self.project_in = nn.Linear(dim_in, dim_out) + self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity + + def forward(self, x, **kwargs): + x = self.project_in(x) + x, loss = self.fn(x, **kwargs) + x = self.project_out(x) + return x, loss + + +class MatrixMultiply(nn.Module): + def __init__(self, tensor, transpose=False): + super().__init__() + self.tensor = tensor + self.transpose = transpose + + def forward(self, x): + tensor = self.tensor + if self.transpose: + tensor = tensor.t() + return x @ tensor + +# positional embeddings + + +class DepthWiseConv1d(nn.Module): + def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False): + super().__init__() + self.padding = ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) + + self.net = nn.Sequential( + nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, groups=dim_in, stride=stride, bias=bias), + nn.Conv1d(dim_in, dim_out, 1, bias=bias) + ) + + def forward(self, x): + x = F.pad(x, self.padding, value=0.) + return self.net(x) + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + position = torch.arange(0, max_seq_len, dtype=torch.float) + sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + self.register_buffer('emb', emb) + + def forward(self, x): + return self.emb[None, :x.shape[1], :].to(x) + + +def rotate_every_two(x): + x = rearrange(x, '... (d j) -> ... d j', j=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(q, k, sinu_pos): + sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2) + sin, cos = sinu_pos.unbind(dim=-2) + sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos)) + q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + return q, k + +# kmeans related function and class + + +def update_kmeans_on_backwards(module): + module.kmean_modules = find_modules(module, Kmeans) + + def hook(_, grad_in, grad_out): + for m in module.kmean_modules: + m.update() + + return module.register_backward_hook(hook) + + +def similarity(x, means): + return torch.einsum('bhld,hcd->bhlc', x, means) + + +def dists_and_buckets(x, means): + dists = similarity(x, means) + _, buckets = torch.max(dists, dim=-1) + return dists, buckets + + +def batched_bincount(index, num_classes, dim=-1): + shape = list(index.shape) + shape[dim] = num_classes + out = index.new_zeros(shape) + out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype)) + return out + + +def kmeans_iter(x, means, buckets=None): + b, h, _, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1] + + if not exists(buckets): + _, buckets = dists_and_buckets(x, means) + + bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True) + zero_mask = bins.long() == 0 + + means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype) + means_.scatter_add_(-2, expand_dim(buckets, -1, d), x) + means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype) + + means = torch.where(zero_mask.unsqueeze(-1), means, means_) + means = means.squeeze(0) + return means + + +def distribution(dists, window_size): + _, topk_indices = dists.topk(k=window_size, dim=-2) + indices = topk_indices.transpose(-2, -1) + return indices.reshape(*indices.size()[:2], -1) + + +class Kmeans(nn.Module): + def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4): + super().__init__() + self.commitment = commitment + self.ema_decay = ema_decay + + self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim)) + self.register_buffer('initted', torch.tensor(False)) + self.num_new_means = 0 + self.new_means = None + + @torch.no_grad() + def init(self, x): + if self.initted: + return + _, h, _, d, device, _ = *x.shape, x.device, x.dtype + + num_clusters = self.means.shape[1] + + means = x.transpose(0, 1).contiguous().view(h, -1, d) + num_samples = means.shape[1] + + if num_samples >= num_clusters: + indices = torch.randperm(num_samples, device=device)[:num_clusters] + else: + indices = torch.randint(0, num_samples, (num_clusters,), device=device) + + means = means[:, indices] + + for _ in range(KMEAN_INIT_ITERS): + means = kmeans_iter(x, means) + + self.num_new_means = 0 + self.means.data.copy_(means) + self.initted.data.copy_(torch.tensor(True)) + + @torch.no_grad() + def update(self, new_means=None): + new_means = default(new_means, self.new_means) + assert exists(new_means), 'new kmeans has not been supplied' + ema_inplace(self.means, new_means, self.ema_decay) + + del self.new_means + self.new_means = None + self.num_new_means = 0 + + def forward(self, x, update_means=False): + self.init(x) + + b, dtype = x.shape[0], x.dtype + means = self.means.type(dtype) + x = F.normalize(x, 2, dim=-1).type(dtype) + + with torch.no_grad(): + dists, buckets = dists_and_buckets(x, means) + + routed_means = batched_index_select(expand_dim(means, 0, b), buckets) + loss = F.mse_loss(x, routed_means) * self.commitment + + if update_means: + with torch.no_grad(): + means = kmeans_iter(x, means, buckets) + self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1)) + self.num_new_means += 1 + + return dists, loss + +# kmeans attention class + + +class KmeansAttention(nn.Module): + def __init__(self, num_clusters, window_size, num_heads, head_dim, causal=False, dropout=0., ema_decay=0.999, commitment=1e-4, context_window_size=None, receives_context=False, num_mem_kv=0, shared_qk=False): + super().__init__() + self.num_heads = num_heads + self.num_clusters = num_clusters + self.head_dim = head_dim + + self.window_size = window_size + self.context_window_size = default(context_window_size, window_size) + self.causal = causal + + self.shared_qk = shared_qk + self.receives_context = receives_context + self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment) + self.dropout = nn.Dropout(dropout) + + self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0) + self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) + self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) + + def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): + b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype + is_reverse = kwargs.pop('_reverse', False) + + out = torch.zeros_like(q, dtype=dtype) + + update_kmeans = self.training and not is_reverse + + key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask + kv_wsz = wsz if not self.receives_context else c_wsz + + wsz = min(wsz, t) + kv_wsz = min(kv_wsz, kv_t) + + if not self.shared_qk or self.receives_context: + dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans) + q_dists, k_dists = split_at_index(2, t, dists) + indices = distribution(q_dists, wsz) + kv_indices = distribution(k_dists, kv_wsz) + else: + dists, aux_loss = self.kmeans(q, update_kmeans) + k = F.normalize(k, dim=-1).to(q) + indices = distribution(dists, wsz) + kv_indices = indices + + q = batched_index_select(q, indices) + k = batched_index_select(k, kv_indices) + v = batched_index_select(v, kv_indices) + + reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) + q, k, v = map(reshape_with_window, (q, k, v)) + + m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)) + k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) + + dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5) + + mask_value = max_neg_value(dots) + + if exists(query_mask) or exists(key_mask): + query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool()) + key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool()) + + q_mask = expand_dim(query_mask, 1, h).gather(2, indices) + kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) + q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask)) + mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=1) + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) + mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=1) + dots.masked_fill_(~mask, mask_value) + del mask + + if self.shared_qk: + q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) + mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=0) + dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) + del mask + + dots = dots.softmax(dim=-1) + dots = self.dropout(dots) + + bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v) + so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) + out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) + return out, aux_loss + +# feedforward + + +class GELU_(nn.Module): + def forward(self, x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_ + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False): + super().__init__() + activation = default(activation, GELU) + + self.glu = glu + self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) + self.act = activation() + self.dropout = nn.Dropout(dropout) + self.w2 = nn.Linear(dim * mult, dim) + + def forward(self, x, **kwargs): + if not self.glu: + x = self.w1(x) + x = self.act(x) + else: + x, v = self.w1(x).chunk(2, dim=-1) + x = self.act(x) * v + + x = self.dropout(x) + x = self.w2(x) + return x + +# self attention + + +class SelfAttention(nn.Module): + def __init__(self, dim, max_seq_len, heads, local_attn_heads, window_size, dim_head=None, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, attn_dropout=0., dropout=0., kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, num_mem_kv=0, shared_qk=False, conv_query_kernel=9): + super().__init__() + assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads' + assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size' + assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads' + assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context' + assert not (receives_context and causal), 'contextual attention layer cannot be causal' + + local_attn_window_size = default(local_attn_window_size, window_size) + context_window_size = default(context_window_size, window_size) + + self.shared_qk = shared_qk + self.receives_context = receives_context + self.heads = heads + self.local_attn_heads = local_attn_heads + self.global_attn_heads = heads - local_attn_heads + + self.causal = causal + self.window_size = window_size + + dim_head = default(dim_head, dim // heads) + dim_heads = dim_head * heads + self.dim_head = dim_head + + num_clusters = max_seq_len // window_size + + # local + + local_dim_heads = dim_head * self.local_attn_heads + + if self.local_attn_heads > 0: + rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None + self.local_attn = LocalAttention(dim=dim_head, window_size=local_attn_window_size, causal=causal, dropout=attn_dropout, rel_pos_emb_config=rel_pos_emb_config, look_backward=local_attn_radius_blocks, look_forward=0 if causal else local_attn_radius_blocks) + self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) + + # global + + global_dim_heads = dim_head * self.global_attn_heads + + if self.global_attn_heads > 0: + self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal=causal, dropout=attn_dropout, ema_decay=kmeans_ema_decay, commitment=commitment_factor, receives_context=receives_context, num_mem_kv=num_mem_kv, shared_qk=shared_qk) + + self.to_q = nn.Sequential( + Rearrange('b n c -> b c n'), + DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal), + Rearrange('b c n -> b n c') + ) + + self.to_v = nn.Linear(dim, global_dim_heads, bias=False) + + if not self.shared_qk: + self.to_k = nn.Linear(dim, global_dim_heads, bias=False) + + # out + + self.to_out = nn.Linear(dim_heads, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, query, key, value, context=None, key_padding_mask=None, context_mask=None, pos_emb=None, **kwargs): + assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context' + input_mask = key_padding_mask + x = query.transpose(0, 1) + b, t, _, h, dh = *x.shape, self.heads, self.dim_head + has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)) + + split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() + + if has_local: + local_qkv = self.local_to_qkv(x).chunk(3, dim=-1) + lq, lk, lv = map(split_heads, local_qkv) + + if has_global: + kv_input = x if not self.receives_context else context + + q, v = self.to_q(x), self.to_v(kv_input) + + if not self.shared_qk: + k = self.to_k(kv_input) + else: + k = self.to_q(kv_input) if self.receives_context else q + + q, k, v = map(split_heads, (q, k, v)) + + out = [] + total_loss = torch.tensor(0., requires_grad=True, **to(x)) + + if has_local: + local_out = self.local_attn(lq, lk, lv, input_mask=input_mask) + out.append(local_out) + + if has_global: + if not self.receives_context and exists(pos_emb): + q, k = apply_rotary_pos_emb(q, k, pos_emb) + + global_out, loss = self.global_attn(q, k, v, query_mask=input_mask, key_mask=context_mask) + total_loss = total_loss + loss + + out.append(global_out) + + out = torch.cat(out, dim=1) + out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1) + out = self.dropout(out.transpose(0, 1)) + # out = self.to_out(out) + return out, total_loss From 7818f6148da4ea04f0b4b3a2df780004c3580dad Mon Sep 17 00:00:00 2001 From: Ashwyn Sharma <ashwynsharma@fb.com> Date: Wed, 23 Jun 2021 11:13:05 -0700 Subject: [PATCH 624/707] Tuna integration and model packaging Reviewed By: sravyapopuri388 Differential Revision: D29118016 fbshipit-source-id: d183c821e5d8eb1b37dda48ded9e24e5efc65dc7 --- .../models/transformer_monotonic_attention.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index 1062e9b955..77c0350d2d 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -34,6 +34,7 @@ "TransformerMonotonicDecoderOut", [ ("action", int), + ("p_choose", Optional[Tensor]), ("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]), ("step_list", Optional[List[Optional[Tensor]]]), ("encoder_out", Optional[Dict[str, List[Tensor]]]), @@ -150,12 +151,16 @@ def pre_attention( x = x.transpose(0, 1) encoder_out = encoder_out_dict["encoder_out"][0] - encoder_padding_mask = ( - encoder_out_dict["encoder_padding_mask"][0] - if encoder_out_dict["encoder_padding_mask"] - and len(encoder_out_dict["encoder_padding_mask"]) > 0 - else None - ) + + if "encoder_padding_mask" in encoder_out_dict: + encoder_padding_mask = ( + encoder_out_dict["encoder_padding_mask"][0] + if encoder_out_dict["encoder_padding_mask"] + and len(encoder_out_dict["encoder_padding_mask"]) > 0 + else None + ) + else: + encoder_padding_mask = None return x, encoder_out, encoder_padding_mask @@ -215,6 +220,8 @@ def extract_features( attn_list: List[Optional[Dict[str, Tensor]]] = [] step_list: List[Optional[Tensor]] = [] + p_choose = torch.tensor([1.0]) + for i, layer in enumerate(self.layers): x, attn, _ = layer( @@ -255,6 +262,7 @@ def extract_features( return x, TransformerMonotonicDecoderOut( action=0, + p_choose=p_choose, attn_list=None, step_list=None, encoder_out=None, @@ -265,6 +273,7 @@ def extract_features( return x, TransformerMonotonicDecoderOut( action=1, + p_choose=p_choose, attn_list=attn_list, step_list=step_list, encoder_out=encoder_out, From 520d9d3ba68d06e56f0b9e1d331ed444f48755b2 Mon Sep 17 00:00:00 2001 From: Alex Liu <alexliu36@gmail.com> Date: Thu, 24 Jun 2021 13:59:58 -0700 Subject: [PATCH 625/707] remove debug code from w2vu gen (#1997) Summary: see title Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1997 Reviewed By: wnhsu, Alexander-H-Liu Differential Revision: D29371459 Pulled By: alexeib fbshipit-source-id: 874e36462f919aa4ba698a0dd49531c89f7e27cf --- examples/wav2vec/unsupervised/w2vu_generate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py index 2bad873616..6177239dc7 100644 --- a/examples/wav2vec/unsupervised/w2vu_generate.py +++ b/examples/wav2vec/unsupervised/w2vu_generate.py @@ -428,8 +428,6 @@ def build_generator(cfg: UnsupGenerateConfig): ) hypo_futures.append(hypos) samples.append(sample) - if cfg.debug: - break itr = list(zip(hypo_futures, samples)) start = 0 end = len(itr) From 81046fc13ef05c7b9bbfb7a4cd66e59033918dc3 Mon Sep 17 00:00:00 2001 From: Shiyan Deng <dsy842974287@fb.com> Date: Thu, 24 Jun 2021 14:01:26 -0700 Subject: [PATCH 626/707] Add decoder and decoding wrapper for nmt Summary: Add a decoder class `FairSeqNVFasterTransformerDecoder` that could replace `TransformerDecoder` in nmt. Add a decoding class `FairSeqNVFasterTransformerDecoding` that does `decoding + beam serach`. We can't use `FairSeqNVFasterTransformerDecoding` right now in nmt because nmt ensembles decoders and calculate avg probabilities across those decoders. Follow ups: 1. Currently `FairSeqNVFasterTransformerDecoder` doesn't produce "attn" https://fburl.com/code/pom5vhr5. 2. Move mem_cache and cache to incremental_state 2. Benchmark fairseq ft encoder decoder. 2. E2e tests stucks at somewhere. Differential Revision: D29166310 fbshipit-source-id: 36360cfff1d22ed4f12f89068ee30dec835d2141 --- fairseq/models/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f4f6bea27b..c2726af34d 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -541,13 +541,14 @@ def forward_scriptable( # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. + src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], - "src_lengths": [], + "src_lengths": [src_lengths], } @torch.jit.export From f8871521f7b2496bbfce58ff72ea611c4f6ec244 Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Sat, 26 Jun 2021 08:59:02 -0700 Subject: [PATCH 627/707] Load dict from pretrained hubert model in HubertEncoder (#1999) Summary: ## What does this PR do? Load dict from pretrained hubert model in HubertEncoder so that the dictionary is not constructed for the labels. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1999 Test Plan: Tested with the cmdline below. ASR training progresses as expected without any exception. ``` PYTHONPATH=. HYDRA_FULL_ERROR=1 python fairseq_cli/hydra_train.py -m \ --config-dir examples/hubert/config/finetune \ --config-name base_10h \ dataset.num_workers=0 \ task.data=/checkpoint/kushall/data/librispeech/10h/raw \ task.label_dir=/checkpoint/kushall/data/librispeech/10h/raw \ model.w2v_path=/checkpoint/kushall/final_model_checkpoints/hubert/hubert_base_ls960_updated.pt \ hydra.sweep.dir=/checkpoint/kushall/experiments/hubert_test/base_asr_10h ``` Reviewed By: Abdel-rahmanMohamed Differential Revision: D29405491 Pulled By: hikushalhere fbshipit-source-id: be168a0ce27f8fcfea3dc980a192ba43fdf23871 --- examples/hubert/README.md | 5 ++--- fairseq/models/hubert/hubert_asr.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/hubert/README.md b/examples/hubert/README.md index 3254b754f0..b501a6eb2a 100644 --- a/examples/hubert/README.md +++ b/examples/hubert/README.md @@ -9,13 +9,12 @@ HuBERT Extra Large (~1B params) | [Libri-Light](https://github.com/facebookresea HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt) HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt) -## Load a pretrained model +## Load a model ``` ckpt_path = "/path/to/the/checkpoint.pt" -models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], strict=False) +models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) model = models[0] ``` -** We will follow-up with a patch such that you wouldn't need to pass `strict=False` for loading the checkpoint in future. ## Train a new model diff --git a/fairseq/models/hubert/hubert_asr.py b/fairseq/models/hubert/hubert_asr.py index 4cb3fb7153..dce899c9de 100644 --- a/fairseq/models/hubert/hubert_asr.py +++ b/fairseq/models/hubert/hubert_asr.py @@ -281,6 +281,9 @@ def __init__(self, cfg: HubertAsrConfig, tgt_dict=None): w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) + if state is not None and "task_state" in state: + # This will load the stored "dictionaries" object + task.load_state_dict(state["task_state"]) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: From 53bf2b12934aa5d38ff2d700221457ca34b55cab Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Mon, 28 Jun 2021 01:45:15 -0700 Subject: [PATCH 628/707] Extract File Chunking to its own utils (#1955) Summary: ## What does this PR do? there are a few places where we do file chunking for multiprocessing a single file. However, the code is partly in Binarizer and partly just duplicated here and there. This PR extracts the file chunking/reading logic. The multiprocessing logic could probably be extracted too, but I haven't found a good abstraction yet. # Testing Added testing for this reading logic + maybe fixed a bug where the last part of a file might get dropped (even if it's unclear with the current stopping logic) Tested by running the preprocessing script as follow: ``` python -m fairseq_cli.preprocess --source-lang de --target-lang en --trainpref ...train.spm.clean.de_en --srcdict ...fairseq.dict --tgtdict .../fairseq.dict --destdir ... --workers 60 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1955 Reviewed By: myleott Differential Revision: D29065473 Pulled By: Mortimerp9 fbshipit-source-id: c60843de8cfd45a63b3dbb8290f57ef3df3bf983 --- fairseq/binarizer.py | 56 +++++---------------- fairseq/data/dictionary.py | 55 +++++++++++---------- fairseq/file_chunker_utils.py | 84 ++++++++++++++++++++++++++++++++ fairseq_cli/preprocess.py | 38 +++++++++------ tests/test_dictionary.py | 29 +++++++++++ tests/test_file_chunker_utils.py | 63 ++++++++++++++++++++++++ 6 files changed, 239 insertions(+), 86 deletions(-) create mode 100644 fairseq/file_chunker_utils.py create mode 100644 tests/test_file_chunker_utils.py diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 18ae67bf25..ae4d02a6db 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -3,23 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os from collections import Counter +from typing import Dict import torch + +from fairseq.file_chunker_utils import Chunker from fairseq.file_io import PathManager from fairseq.tokenizer import tokenize_line -from typing import List, Dict - - -def safe_readline(f): - pos = f.tell() - while True: - try: - return f.readline() - except UnicodeDecodeError: - pos -= 1 - f.seek(pos) # search where this character begins class Binarizer: @@ -42,19 +33,10 @@ def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) - with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: - f.seek(offset) - # next(f) breaks f.tell(), hence readline() must be used - line = safe_readline(f) - while line: - # f.tell() does not always give the byte position in the file - # sometimes it skips to a very large number - # it is unlikely that through a normal read we go from - # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely - # that the procedure breaks by the undeterministic behavior of - # f.tell() - if end > 0 and f.tell() > end and f.tell() < end + 2 ** 32: - break + with Chunker( + PathManager.get_local_path(filename), offset, end + ) as line_iterator: + for line in line_iterator: if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] @@ -75,7 +57,6 @@ def replaced_consumer(word, idx): nseq += 1 ntok += len(ids) consumer(ids) - line = f.readline() return { "nseq": nseq, "nunk": sum(replaced.values()), @@ -89,26 +70,11 @@ def binarize_alignments( ) -> Dict[str, int]: nseq = 0 - with open(PathManager.get_local_path(filename), "r") as f: - f.seek(offset) - line = safe_readline(f) - while line: - if end > 0 and f.tell() > end: - break + with Chunker( + PathManager.get_local_path(filename), offset, end + ) as line_iterator: + for line in line_iterator: ids = alignment_parser(line) nseq += 1 consumer(ids) - line = f.readline() return {"nseq": nseq} - - @staticmethod - def find_offsets(filename, num_chunks) -> List[int]: - with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: - size = os.fstat(f.fileno()).st_size - chunk_size = size // num_chunks - offsets = [0 for _ in range(num_chunks + 1)] - for i in range(1, num_chunks): - f.seek(chunk_size * i) - safe_readline(f) - offsets[i] = f.tell() - return offsets diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 0d8308a811..6876b461d7 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -9,8 +9,8 @@ import torch from fairseq import utils -from fairseq.binarizer import safe_readline from fairseq.data import data_utils +from fairseq.file_chunker_utils import Chunker, find_offsets from fairseq.file_io import PathManager from fairseq.tokenizer import tokenize_line @@ -48,6 +48,9 @@ def __getitem__(self, idx): return self.symbols[idx] return self.unk_word + def get_count(self, idx): + return self.count[idx] + def __len__(self): """Returns the number of symbols in the dictionary""" return len(self.symbols) @@ -78,7 +81,13 @@ def string( """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( - self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore, include_eos=include_eos) + self.string( + t, + bpe_symbol, + escape_unk, + extra_symbols_to_ignore, + include_eos=include_eos, + ) for t in tensor ) @@ -320,31 +329,18 @@ def encode_line( @staticmethod def _add_file_to_dictionary_single_worker( - filename, tokenize, eos_word, worker_id=0, num_workers=1 + filename, + tokenize, + eos_word, + start_offset, + end_offset, ): counter = Counter() - with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: - size = os.fstat(f.fileno()).st_size - chunk_size = size // num_workers - offset = worker_id * chunk_size - end = offset + chunk_size - f.seek(offset) - if offset > 0: - safe_readline(f) # drop first incomplete line - line = f.readline() - while line: + with Chunker(filename, start_offset, end_offset) as line_iterator: + for line in line_iterator: for word in tokenize(line): counter.update([word]) counter.update([eos_word]) - # f.tell() returns only an opaque number which can - # return to the position in the file via f.seek() - # and does not necessarily represent a byte position - # in the file. However, f.tell() is faithful to the - # byte position _most of the time_. Thus we can just - # check against the file size to prevent early exit. - if f.tell() > end and f.tell() < size: - break - line = f.readline() return counter @staticmethod @@ -353,14 +349,23 @@ def merge_result(counter): for w, c in sorted(counter.items()): dict.add_symbol(w, c) + local_file = PathManager.get_local_path(filename) + offsets = find_offsets(local_file, num_workers) if num_workers > 1: + chunks = zip(offsets, offsets[1:]) pool = Pool(processes=num_workers) results = [] - for worker_id in range(num_workers): + for (start_offset, end_offset) in chunks: results.append( pool.apply_async( Dictionary._add_file_to_dictionary_single_worker, - (filename, tokenize, dict.eos_word, worker_id, num_workers), + ( + local_file, + tokenize, + dict.eos_word, + start_offset, + end_offset, + ), ) ) pool.close() @@ -370,7 +375,7 @@ def merge_result(counter): else: merge_result( Dictionary._add_file_to_dictionary_single_worker( - filename, tokenize, dict.eos_word + local_file, tokenize, dict.eos_word, offsets[0], offsets[1] ) ) diff --git a/fairseq/file_chunker_utils.py b/fairseq/file_chunker_utils.py new file mode 100644 index 0000000000..443100c61a --- /dev/null +++ b/fairseq/file_chunker_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import typing as tp + + +def _safe_readline(fd) -> str: + pos = fd.tell() + while True: + try: + return fd.readline() + except UnicodeDecodeError: + pos -= 1 + fd.seek(pos) # search where this character begins + + +def find_offsets(filename: str, num_chunks: int) -> tp.List[int]: + """ + given a file and a number of chuncks, find the offsets in the file + to be able to chunk around full lines. + """ + with open(filename, "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + _safe_readline(f) + offsets[i] = f.tell() + offsets[-1] = size + return offsets + + +class ChunkLineIterator: + """ + Iterator to properly iterate over lines of a file chunck. + """ + + def __init__(self, fd, start_offset: int, end_offset: int): + self._fd = fd + self._start_offset = start_offset + self._end_offset = end_offset + + def __iter__(self) -> tp.Iterable[str]: + self._fd.seek(self._start_offset) + # next(f) breaks f.tell(), hence readline() must be used + line = _safe_readline(self._fd) + while line: + pos = self._fd.tell() + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if ( + self._end_offset > 0 + and pos > self._end_offset + and pos < self._end_offset + 2 ** 32 + ): + break + yield line + line = self._fd.readline() + + +class Chunker: + """ + contextmanager to read a chunck of a file line by line. + """ + + def __init__(self, path: str, start_offset: int, end_offset: int): + self.path = path + self.start_offset = start_offset + self.end_offset = end_offset + + def __enter__(self) -> ChunkLineIterator: + self.fd = open(self.path, "r", encoding="utf-8") + return ChunkLineIterator(self.fd, self.start_offset, self.end_offset) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.fd.close() diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index b788900d30..f7170eb00f 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -18,7 +18,7 @@ from fairseq import options, tasks, utils from fairseq.binarizer import Binarizer from fairseq.data import indexed_dataset - +from fairseq.file_chunker_utils import find_offsets logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -133,11 +133,14 @@ def merge_result(worker_result): input_file = "{}{}".format( input_prefix, ("." + lang) if lang is not None else "" ) - offsets = Binarizer.find_offsets(input_file, num_workers) + offsets = find_offsets(input_file, num_workers) + (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) - for worker_id in range(1, num_workers): + for worker_id, (start_offset, end_offset) in enumerate( + more_chunks, start=1 + ): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize, @@ -147,8 +150,8 @@ def merge_result(worker_result): vocab, prefix, lang, - offsets[worker_id], - offsets[worker_id + 1], + start_offset, + end_offset, ), callback=merge_result, ) @@ -161,7 +164,11 @@ def merge_result(worker_result): ) merge_result( Binarizer.binarize( - input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1] + input_file, + vocab, + lambda t: ds.add_item(t), + offset=first_chunk[0], + end=first_chunk[1], ) ) if num_workers > 1: @@ -193,11 +200,14 @@ def merge_result(worker_result): nseq[0] += worker_result["nseq"] input_file = input_prefix - offsets = Binarizer.find_offsets(input_file, num_workers) + offsets = find_offsets(input_file, num_workers) + (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) - for worker_id in range(1, num_workers): + for worker_id, (start_offset, end_offset) in enumerate( + more_chunks, start=1 + ): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize_alignments, @@ -206,8 +216,8 @@ def merge_result(worker_result): input_file, utils.parse_alignment, prefix, - offsets[worker_id], - offsets[worker_id + 1], + start_offset, + end_offset, ), callback=merge_result, ) @@ -222,8 +232,8 @@ def merge_result(worker_result): input_file, utils.parse_alignment, lambda t: ds.add_item(t), - offset=0, - end=offsets[1], + offset=first_chunk[0], + end=first_chunk[1], ) ) if num_workers > 1: @@ -387,10 +397,6 @@ def dataset_dest_file(args, output_prefix, lang, extension): return "{}.{}".format(base, extension) -def get_offsets(input_file, num_workers): - return Binarizer.find_offsets(input_file, num_workers) - - def cli_main(): parser = options.get_preprocessing_parser() args = parser.parse_args() diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index 81ce102f4f..dc9d71b3c7 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -4,10 +4,13 @@ # LICENSE file in the root directory of this source tree. import io +import os +import string import tempfile import unittest import torch +from fairseq import tokenizer from fairseq.data import Dictionary @@ -111,6 +114,32 @@ def test_space(self): self.assertEqual(d.index("a"), 5) self.assertEqual(d.index("b"), 6) + def test_add_file_to_dict(self): + counts = {} + num_lines = 100 + per_line = 10 + with tempfile.TemporaryDirectory("test_sampling") as data_dir: + filename = os.path.join(data_dir, "dummy.txt") + with open(filename, "w", encoding="utf-8") as data: + for c in string.ascii_letters: + line = f"{c} " * per_line + for _ in range(num_lines): + data.write(f"{line}\n") + counts[c] = per_line * num_lines + per_line += 5 + + dict = Dictionary() + Dictionary.add_file_to_dictionary( + filename, dict, tokenizer.tokenize_line, 10 + ) + dict.finalize(threshold=0, nwords=-1, padding_factor=8) + + for c in string.ascii_letters: + count = dict.get_count(dict.index(c)) + self.assertEqual( + counts[c], count, f"{c} count is {count} but should be {counts[c]}" + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_chunker_utils.py b/tests/test_file_chunker_utils.py new file mode 100644 index 0000000000..5cded04572 --- /dev/null +++ b/tests/test_file_chunker_utils.py @@ -0,0 +1,63 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import unittest +from typing import Optional + + +class TestFileChunker(unittest.TestCase): + _tmpdir: Optional[str] = None + _tmpfile: Optional[str] = None + _line_content = "Hello, World\n" + _num_bytes = None + _num_lines = 200 + _num_splits = 20 + + @classmethod + def setUpClass(cls) -> None: + cls._num_bytes = len(cls._line_content.encode("utf-8")) + cls._tmpdir = tempfile.mkdtemp() + with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: + cls._tmpfile = f.name + for _i in range(cls._num_lines): + f.write(cls._line_content) + f.flush() + + @classmethod + def tearDownClass(cls) -> None: + # Cleanup temp working dir. + if cls._tmpdir is not None: + shutil.rmtree(cls._tmpdir) # type: ignore + + def test_find_offsets(self): + from fairseq.file_chunker_utils import find_offsets + + offsets = find_offsets(self._tmpfile, self._num_splits) + self.assertEqual(len(offsets), self._num_splits + 1) + (zero, *real_offsets, last) = offsets + self.assertEqual(zero, 0) + for i, o in enumerate(real_offsets): + self.assertEqual( + o, + self._num_bytes + + ((i + 1) * self._num_bytes * self._num_lines / self._num_splits), + ) + self.assertEqual(last, self._num_bytes * self._num_lines) + + def test_readchunks(self): + from fairseq.file_chunker_utils import Chunker, find_offsets + + offsets = find_offsets(self._tmpfile, self._num_splits) + for start, end in zip(offsets, offsets[1:]): + with Chunker(self._tmpfile, start, end) as lines: + all_lines = list(lines) + num_lines = self._num_lines / self._num_splits + self.assertAlmostEqual( + len(all_lines), num_lines, delta=1 + ) # because we split on the bites, we might end up with one more/less line in a chunk + self.assertListEqual( + all_lines, [self._line_content for _ in range(len(all_lines))] + ) From 0972dde844e39540faf53b6d9afe76b38c7e2fd6 Mon Sep 17 00:00:00 2001 From: Liang Luo <liangluo@fb.com> Date: Tue, 29 Jun 2021 00:02:06 -0700 Subject: [PATCH 629/707] apply nonblocking H/D transfer optimizations Summary: merge D27701492 + D27701493 * make checkpoint activation cpu offloading nonblocking * make gradient cpu offloading nonblocking * synchronize cpu/gpu stream before applying optimizer update Reviewed By: myleott Differential Revision: D28047171 fbshipit-source-id: f862eca64049acc045026aa4f5e6dbe8d0f03244 --- fairseq/modules/checkpoint_activations.py | 4 ++-- fairseq/optim/cpu_adam.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index b44fc346ce..7489e09eb7 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -163,7 +163,7 @@ def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) - tensor_inputs = tuple(x.cpu() for x in tensor_inputs) + tensor_inputs = tuple(x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None @@ -196,7 +196,7 @@ def backward(ctx, *args): tensor_inputs = checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = [ - t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs) + t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs) ] for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index e36bccf123..211c376756 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -127,6 +127,8 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + torch.cuda.synchronize() + for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group["params"]): if p.grad is None: From 0794f9ae21bb50f940ff9b1bc3f28be08dfa7b76 Mon Sep 17 00:00:00 2001 From: Edan Tessel Sneh <edan@fb.com> Date: Tue, 29 Jun 2021 15:09:13 -0700 Subject: [PATCH 630/707] Back out "Adding FBSequenceGenerator" Summary: Original commit changeset: b7a83bbc719d Reverts commit D26228721 (https://github.com/pytorch/fairseq/commit/6381aa2bb24f125d271e241c726a2fea581bc3c4) Reviewed By: theweiho Differential Revision: D29369494 fbshipit-source-id: 9e745b11bc532ca8ced2816326aa94afbb46ba2d --- fairseq/tasks/fairseq_task.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index fbec9bb2a5..99bf2c3fe9 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -379,10 +379,6 @@ def build_generator( SequenceGenerator, SequenceGeneratorWithAlignment, ) - try: - from fairseq.fb_sequence_generator import FBSequenceGenerator - except ModuleNotFoundError: - pass # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) @@ -450,8 +446,6 @@ def build_generator( if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment extra_gen_cls_kwargs["print_alignment"] = args.print_alignment - elif getattr(args, "fb_seq_gen", False): - seq_gen_cls = FBSequenceGenerator else: seq_gen_cls = SequenceGenerator From 9bee82e4a7b73249a88f2e2d286e991493ef13c2 Mon Sep 17 00:00:00 2001 From: Omry Yadan <omry@fb.com> Date: Thu, 1 Jul 2021 06:37:02 -0700 Subject: [PATCH 631/707] Hydra 1.1 compatibility: Use an explicit schema for the primary config (#3659) Summary: ## What does this PR do? Fixes compatibility with Hydra 1.1. The result is compatible with both Hydra 1.0 and Hydra 1.1, and will allow a smoother migration to Hydra 1.1. At this point I am not yet removing the restriction on the Hydra version from setup.py: 1. It depends on some Hydra 1.1 changes that are not yet released (It will be compatible with 1.1.1). 2. Upgrading will result in deprecation warnings, and fixing them will break compatibility with Hydra 1.0. There will be some followup to make the code fully compatible with 1.1 once Hydra 1.1 is the default version in fbcode. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3659 Reviewed By: omry Differential Revision: D29498036 Pulled By: lematt1991 fbshipit-source-id: 96999cde5daad6749ef4d3ddf6a36a1e984ff201 --- fairseq/checkpoint_utils.py | 2 +- fairseq/config/config.yaml | 1 + fairseq/dataclass/initialize.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 627f14160d..d22d987020 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if cfg.finetune_from_model is not None and first_launch: diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index e20d914b9b..087083e88a 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -5,6 +5,7 @@ hydra: dir: . defaults: + - config_schema - task: null - model: null - criterion: cross_entropy diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 479aeb8b16..1d5c90eefd 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -16,7 +16,7 @@ def hydra_init(cfg_name="config") -> None: cs = ConfigStore.instance() - cs.store(name=cfg_name, node=FairseqConfig) + cs.store(name=f"{cfg_name}_schema", node=FairseqConfig) for k in FairseqConfig.__dataclass_fields__: v = FairseqConfig.__dataclass_fields__[k].default From 096f492a224e14ed6628d96700f1ae8b534d86a8 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 1 Jul 2021 08:36:02 -0700 Subject: [PATCH 632/707] fix xlsr checkpoint finetuning saving issues (#2013) Summary: fixes an issue with some old checkpoints that had deep nested namespaces containing choices enum - most prominently xlsr 53 checkpoint fixes #3634 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2013 Reviewed By: xuqiantong Differential Revision: D29511325 Pulled By: alexeib fbshipit-source-id: 79df978afa7482b4ce3aaf7396e193626181aa17 --- fairseq/dataclass/utils.py | 82 +++++++++++++++----------- fairseq/models/wav2vec/wav2vec2_asr.py | 2 + 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 89206125d1..1ed28b7ccc 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -17,7 +17,7 @@ from fairseq.dataclass.configs import FairseqConfig from hydra.core.global_hydra import GlobalHydra from hydra.experimental import compose, initialize -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, OmegaConf, open_dict, _utils logger = logging.getLogger(__name__) @@ -341,6 +341,17 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: return overrides, deletes +class omegaconf_no_object_check: + def __init__(self): + self.old_is_primitive = _utils.is_primitive_type + + def __enter__(self): + _utils.is_primitive_type = lambda _: True + + def __exit__(self, type, value, traceback): + _utils.is_primitive_type = self.old_is_primitive + + def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: """Convert a flat argparse.Namespace to a structured DictConfig.""" @@ -370,41 +381,40 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: # omegaconf version that supports object flags, or when we migrate all existing models from omegaconf import _utils - old_primitive = _utils.is_primitive_type - _utils.is_primitive_type = lambda _: True - - if cfg.task is None and getattr(args, "task", None): - cfg.task = Namespace(**vars(args)) - from fairseq.tasks import TASK_REGISTRY - - _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) - cfg.task._name = args.task - if cfg.model is None and getattr(args, "arch", None): - cfg.model = Namespace(**vars(args)) - from fairseq.models import ARCH_MODEL_REGISTRY - - _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) - cfg.model._name = args.arch - if cfg.optimizer is None and getattr(args, "optimizer", None): - cfg.optimizer = Namespace(**vars(args)) - from fairseq.optim import OPTIMIZER_REGISTRY - - _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) - cfg.optimizer._name = args.optimizer - if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): - cfg.lr_scheduler = Namespace(**vars(args)) - from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY - - _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) - cfg.lr_scheduler._name = args.lr_scheduler - if cfg.criterion is None and getattr(args, "criterion", None): - cfg.criterion = Namespace(**vars(args)) - from fairseq.criterions import CRITERION_REGISTRY - - _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) - cfg.criterion._name = args.criterion - - _utils.is_primitive_type = old_primitive + with omegaconf_no_object_check(): + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults( + cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] + ) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + OmegaConf.set_struct(cfg, True) return cfg diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 405d1e613a..04307e8771 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -337,6 +337,8 @@ def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None): w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) + w2v_args.criterion = None + w2v_args.lr_scheduler = None cfg.w2v_args = w2v_args else: state = None From cdc1a553eb2af4fac720880aff4ee2566a28ad21 Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <wnhsu@csail.mit.edu> Date: Thu, 1 Jul 2021 13:11:40 -0700 Subject: [PATCH 633/707] query tgt_dict after loading task_state (#2019) Summary: # Before submitting `self.task.target_dictionary` is queried before `task_state` is loaded (in `self.load_model_ensemble()`). ## What does this PR do? Fix the bug above Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2019 Reviewed By: alexeib Differential Revision: D29523921 Pulled By: wnhsu fbshipit-source-id: 763b504dc1b4899e623eaa5c19972cec9d0a8985 --- examples/speech_recognition/new/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_recognition/new/infer.py b/examples/speech_recognition/new/infer.py index 79afbc426d..3fb67151e0 100644 --- a/examples/speech_recognition/new/infer.py +++ b/examples/speech_recognition/new/infer.py @@ -99,11 +99,11 @@ class InferenceProcessor: def __init__(self, cfg: InferConfig) -> None: self.cfg = cfg self.task = tasks.setup_task(cfg.task) - self.tgt_dict = self.task.target_dictionary models, saved_cfg = self.load_model_ensemble() self.models = models self.saved_cfg = saved_cfg + self.tgt_dict = self.task.target_dictionary self.task.load_dataset( self.cfg.dataset.gen_subset, From dd106d9534b22e7db859a6b87ffd7780c38341f8 Mon Sep 17 00:00:00 2001 From: Omry Yadan <omry@fb.com> Date: Tue, 6 Jul 2021 15:06:07 -0700 Subject: [PATCH 634/707] fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675) Summary: ## What does this PR do? Some downstream users reported that errors when passing Namespace to load_checkpoint(). A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function. This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675 Reviewed By: omry Differential Revision: D29564805 Pulled By: lematt1991 fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62 --- fairseq/checkpoint_utils.py | 2 +- tests/test_train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index d22d987020..627f14160d 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if cfg.finetune_from_model is not None and first_launch: diff --git a/tests/test_train.py b/tests/test_train.py index 65f4683bc6..02ef94cc5b 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -61,6 +61,7 @@ def get_mock_cfg(finetune_from_model): cfg_mock = OmegaConf.create( { "checkpoint": { + "save_dir": None, "optimizer_overrides": "{}", "reset_dataloader": False, "reset_meters": False, From 01576be513488601e936df67e65e2817d18f576e Mon Sep 17 00:00:00 2001 From: Henry Hu <henryhu6@fb.com> Date: Wed, 7 Jul 2021 15:41:37 -0700 Subject: [PATCH 635/707] Add tracing annotations Summary: Add profile record function at several locations in decoderlib and fariseq to annotate tracing. Most of the code changes are due to indentation and auto format. Reviewed By: mikekgfb Differential Revision: D29531358 fbshipit-source-id: 59934079c0ddc75b5b97922f585f4863680f1041 --- fairseq/sequence_generator.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 8a3858563e..ac04dc7db8 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -250,7 +250,8 @@ def _generate( self.min_len <= max_len ), "min_len cannot be larger than max_len, please adjust these!" # compute the encoder output for each beam - encoder_outs = self.model.forward_encoder(net_input) + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) @@ -327,13 +328,13 @@ def _generate( encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) - - lprobs, avg_attn_scores = self.model.forward_decoder( - tokens[:, : step + 1], - encoder_outs, - incremental_states, - self.temperature, - ) + with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"): + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) if self.lm_model is not None: lm_out = self.lm_model(tokens[:, : step + 1]) From 7b710acc9e0106e0359b809f285efc45be6fbfd1 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 8 Jul 2021 15:14:45 -0700 Subject: [PATCH 636/707] Fix static container (#2036) Summary: fixes StatefulContainer being static which is a problem when you load a checkpoint with task that already has the same keys in the container also print full path to checkpoints when saving (useful with hydra) and crash if repeatedly failing to save a checkpoint Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2036 Reviewed By: arbabu123 Differential Revision: D29608430 Pulled By: alexeib fbshipit-source-id: 1b65c8f839e02de9110af3ec53f1e7d48a4908f7 --- fairseq/checkpoint_utils.py | 12 ++++++++---- fairseq/tasks/fairseq_task.py | 11 ++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 627f14160d..8ec967397f 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -92,8 +92,9 @@ def is_better(a, b): # add random digits to resolve ties rand_sfx = randint(0, cfg.keep_best_checkpoints) checkpoint_conds[ - "checkpoint.best_{}_{:.3f}{}.pt".format(cfg.best_checkpoint_metric, - val_loss, rand_sfx) + "checkpoint.best_{}_{:.3f}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx + ) ] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) @@ -104,7 +105,7 @@ def is_better(a, b): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ - os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.abspath(os.path.join(cfg.save_dir, fn)) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) @@ -452,7 +453,9 @@ def load_model_ensemble_and_task( state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st - logger.info(f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard") + logger.info( + f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" + ) # build model for ensemble ensemble.append(model) @@ -508,6 +511,7 @@ def _torch_persistent_save(obj, f): except Exception: if i == 2: logger.error(traceback.format_exc()) + raise def _upgrade_state_dict(state): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 99bf2c3fe9..8148c77fe1 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -23,8 +23,9 @@ class StatefulContainer(object): - _state: Dict[str, Any] = dict() - _factories: Dict[str, Callable[[], Any]] = dict() + def __init__(self): + self._state = dict() + self._factories = dict() def add_factory(self, name, factory: Callable[[], Any]): self._factories[name] = factory @@ -78,11 +79,6 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() - cfg: FairseqDataclass - datasets: Dict[str, FairseqDataset] - dataset_to_epoch_iter: Dict[FairseqDataset, Any] - state: StatefulContainer = None - def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg self.datasets = dict() @@ -622,6 +618,7 @@ def get_interactive_tokens_and_lengths(self, lines, encode_fn): class LegacyFairseqTask(FairseqTask): def __init__(self, args: Namespace): + super().__init__(None) self.args = args self.datasets = {} self.dataset_to_epoch_iter = {} From 58201a15cceccd9b8f8c6463d7338e4252963b33 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 8 Jul 2021 16:56:58 -0700 Subject: [PATCH 637/707] migrate roberta glue finetuning to hydra (#2035) Summary: this allows roberta finetuning on different tasks using yaml config files + hydra entry point Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2035 Reviewed By: Mortimerp9 Differential Revision: D29601732 Pulled By: alexeib fbshipit-source-id: 774ef974b4b40ad0ced76874c62047d0c46520e7 --- examples/roberta/README.glue.md | 45 +---- examples/roberta/config/finetuning/cola.yaml | 56 ++++++ examples/roberta/config/finetuning/mnli.yaml | 56 ++++++ examples/roberta/config/finetuning/mrpc.yaml | 56 ++++++ examples/roberta/config/finetuning/qnli.yaml | 56 ++++++ examples/roberta/config/finetuning/qqp.yaml | 56 ++++++ examples/roberta/config/finetuning/rte.yaml | 56 ++++++ examples/roberta/config/finetuning/sst_2.yaml | 56 ++++++ examples/roberta/config/finetuning/sts_b.yaml | 56 ++++++ .../rxf/rxf_src/sentence_prediction_r3f.py | 1 + fairseq/criterions/sentence_prediction.py | 29 +-- fairseq/models/roberta/model.py | 91 +++++---- fairseq/tasks/__init__.py | 2 +- fairseq/tasks/sentence_prediction.py | 180 +++++++++--------- 14 files changed, 617 insertions(+), 179 deletions(-) create mode 100644 examples/roberta/config/finetuning/cola.yaml create mode 100644 examples/roberta/config/finetuning/mnli.yaml create mode 100644 examples/roberta/config/finetuning/mrpc.yaml create mode 100644 examples/roberta/config/finetuning/qnli.yaml create mode 100644 examples/roberta/config/finetuning/qqp.yaml create mode 100644 examples/roberta/config/finetuning/rte.yaml create mode 100644 examples/roberta/config/finetuning/sst_2.yaml create mode 100644 examples/roberta/config/finetuning/sts_b.yaml diff --git a/examples/roberta/README.glue.md b/examples/roberta/README.glue.md index 77015d2e2f..4f596d55af 100644 --- a/examples/roberta/README.glue.md +++ b/examples/roberta/README.glue.md @@ -17,54 +17,19 @@ Use `ALL` for preprocessing all the glue tasks. ### 3) Fine-tuning on GLUE task: Example fine-tuning cmd for `RTE` task ```bash -TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16 -WARMUP_UPDATES=122 # 6 percent of the number of updates -LR=2e-05 # Peak LR for polynomial LR scheduler. -NUM_CLASSES=2 -MAX_SENTENCES=16 # Batch size. ROBERTA_PATH=/path/to/roberta/model.pt -CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \ - --restore-file $ROBERTA_PATH \ - --max-positions 512 \ - --batch-size $MAX_SENTENCES \ - --max-tokens 4400 \ - --task sentence_prediction \ - --reset-optimizer --reset-dataloader --reset-meters \ - --required-batch-size-multiple 1 \ - --init-token 0 --separator-token 2 \ - --arch roberta_large \ - --criterion sentence_prediction \ - --num-classes $NUM_CLASSES \ - --dropout 0.1 --attention-dropout 0.1 \ - --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ - --clip-norm 0.0 \ - --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ - --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ - --max-epoch 10 \ - --find-unused-parameters \ - --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; +CUDA_VISIBLE_DEVICES=0 fairseq-hydra-train -config-dir examples/roberta/config/finetuning --config-name rte \ +task.data=RTE-bin checkpoint.restore_file=$ROBERTA_PATH ``` -For each of the GLUE task, you will need to use following cmd-line arguments: - -Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B ----|---|---|---|---|---|---|---|--- -`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 -`--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5 -`--batch-size` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16 -`--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598 -`--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214 - -For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`. +There are additional config files for each of the GLUE tasks in the examples/roberta/config/finetuning directory. **Note:** -a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=16/32` depending on the task. - -b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`. +a) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`. -c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. +b) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. ### Inference on GLUE task After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: diff --git a/examples/roberta/config/finetuning/cola.yaml b/examples/roberta/config/finetuning/cola.yaml new file mode 100644 index 0000000000..717069d407 --- /dev/null +++ b/examples/roberta/config/finetuning/cola.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 320 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 5336 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/mnli.yaml b/examples/roberta/config/finetuning/mnli.yaml new file mode 100644 index 0000000000..4bfc02bed9 --- /dev/null +++ b/examples/roberta/config/finetuning/mnli.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 3 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 7432 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 123873 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/mrpc.yaml b/examples/roberta/config/finetuning/mrpc.yaml new file mode 100644 index 0000000000..907b4639c1 --- /dev/null +++ b/examples/roberta/config/finetuning/mrpc.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 137 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 2296 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/qnli.yaml b/examples/roberta/config/finetuning/qnli.yaml new file mode 100644 index 0000000000..00aea91e56 --- /dev/null +++ b/examples/roberta/config/finetuning/qnli.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 1986 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 33112 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/qqp.yaml b/examples/roberta/config/finetuning/qqp.yaml new file mode 100644 index 0000000000..dc0296d26e --- /dev/null +++ b/examples/roberta/config/finetuning/qqp.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 28318 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 113272 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/rte.yaml b/examples/roberta/config/finetuning/rte.yaml new file mode 100644 index 0000000000..40dfd76169 --- /dev/null +++ b/examples/roberta/config/finetuning/rte.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 122 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 2036 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/sst_2.yaml b/examples/roberta/config/finetuning/sst_2.yaml new file mode 100644 index 0000000000..b808a850cb --- /dev/null +++ b/examples/roberta/config/finetuning/sst_2.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 1256 + +optimization: + clip_norm: 0.0 + lr: [1e-05] + max_update: 20935 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/sts_b.yaml b/examples/roberta/config/finetuning/sts_b.yaml new file mode 100644 index 0000000000..d354bb97dd --- /dev/null +++ b/examples/roberta/config/finetuning/sts_b.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 1 + max_positions: 512 + +checkpoint: + restore_file: ??? + reset_optimizer: true + reset_dataloader: true + reset_meters: true + best_checkpoint_metric: accuracy + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + +criterion: + _name: sentence_prediction + regression_target: true + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 214 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 3598 + max_epoch: 10 + +model: + _name: roberta_large + dropout: 0.1 + attention_dropout: 0.1 diff --git a/examples/rxf/rxf_src/sentence_prediction_r3f.py b/examples/rxf/rxf_src/sentence_prediction_r3f.py index 62dd63390c..6ecffd6b14 100644 --- a/examples/rxf/rxf_src/sentence_prediction_r3f.py +++ b/examples/rxf/rxf_src/sentence_prediction_r3f.py @@ -52,6 +52,7 @@ def add_args(parser): parser.add_argument('--classification-head-name', default='sentence_classification_head', help='name of the classification head to use') + parser.add_argument('--regression-target', action='store_true') # fmt: on def _get_symm_kl(self, noised_logits, input_logits): diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py index 9519fdc56d..482b97985a 100644 --- a/fairseq/criterions/sentence_prediction.py +++ b/fairseq/criterions/sentence_prediction.py @@ -4,27 +4,32 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field import torch import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass -@register_criterion("sentence_prediction") +@dataclass +class SentencePredictionConfig(FairseqDataclass): + classification_head_name: str = field( + default="sentence_classification_head", + metadata={"help": "name of the classification head to use"}, + ) + regression_target: bool = field( + default=False, + ) + + +@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig) class SentencePredictionCriterion(FairseqCriterion): - def __init__(self, task, classification_head_name, regression_target): + def __init__(self, cfg: SentencePredictionConfig, task): super().__init__(task) - self.classification_head_name = classification_head_name - self.regression_target = regression_target - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--classification-head-name', - default='sentence_classification_head', - help='name of the classification head to use') - # fmt: on + self.classification_head_name = cfg.classification_head_name + self.regression_target = cfg.regression_target def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index d9d0f324cf..39a1cdd951 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -189,13 +189,24 @@ def add_args(parser): def build_model(cls, args, task): """Build a new model instance.""" + from omegaconf import OmegaConf + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, False) + # make sure all arguments are present base_architecture(args) if not hasattr(args, "max_positions"): + if not hasattr(args, "tokens_per_sample"): + args.tokens_per_sample = task.max_positions() args.max_positions = args.tokens_per_sample encoder = RobertaEncoder(args, task.source_dictionary) + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, True) + return cls(args, encoder) def forward( @@ -508,54 +519,62 @@ def max_positions(self): return self.args.max_positions +def safe_getattr(obj, k, default=None): + from omegaconf import OmegaConf + + if OmegaConf.is_config(obj): + return obj.k if k in obj and obj.k is not None else default + + return getattr(obj, k, default) + @register_model_architecture("roberta", "roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - - args.max_source_positions = getattr(args, "max_positions", 512) - args.no_token_positional_embeddings = getattr( + args.encoder_layers = safe_getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 12) + + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_dropout = safe_getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = safe_getattr(args, "pooler_dropout", 0.0) + + args.max_source_positions = safe_getattr(args, "max_positions", 512) + args.no_token_positional_embeddings = safe_getattr( args, "no_token_positional_embeddings", False ) # BERT has a few structural differences compared to the original Transformer - args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) - args.layernorm_embedding = getattr(args, "layernorm_embedding", True) - args.no_scale_embedding = getattr(args, "no_scale_embedding", True) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) + args.encoder_learned_pos = safe_getattr(args, "encoder_learned_pos", True) + args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True) + args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh") + args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False) # Adaptive input config - args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input = safe_getattr(args, "adaptive_input", False) # LayerDrop config - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = safe_getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layers_to_keep = safe_getattr(args, "encoder_layers_to_keep", None) # Quantization noise config - args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) - args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) - args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0) # R4F config - args.spectral_norm_classification_head = getattr( + args.spectral_norm_classification_head = safe_getattr( args, "spectral_norm_classification_head", False ) @register_model_architecture("roberta", "roberta_prenorm") def roberta_prenorm_architecture(args): - args.layernorm_embedding = getattr(args, "layernorm_embedding", False) - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", True) base_architecture(args) @@ -566,17 +585,17 @@ def roberta_base_architecture(args): @register_model_architecture("roberta", "roberta_large") def roberta_large_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 24) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_layers = safe_getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16) base_architecture(args) @register_model_architecture("roberta", "xlm") def xlm_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 16) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_layers = safe_getattr(args, "encoder_layers", 16) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1280) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 1280 * 4) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16) base_architecture(args) diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 79dde74057..28305aa247 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -41,7 +41,7 @@ def setup_task(cfg: FairseqDataclass, **kwargs): assert ( task is not None - ), f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" + ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" return task.setup_task(cfg, **kwargs) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 6732728de9..d5f9302c10 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -6,8 +6,12 @@ import logging import os +import contextlib +from dataclasses import dataclass, field +from typing import Optional +from omegaconf import MISSING, II, open_dict, OmegaConf + import numpy as np -from fairseq import utils from fairseq.data import ( ConcatSentencesDataset, Dictionary, @@ -25,14 +29,63 @@ data_utils, ) from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.tasks import FairseqDataclass, FairseqTask, register_task +from fairseq.dataclass import ChoiceEnum logger = logging.getLogger(__name__) - - -@register_task("sentence_prediction") -class SentencePredictionTask(LegacyFairseqTask): +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SentencePredictionConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + num_classes: int = field( + default=-1, + metadata={"help": "number of classes or regression targets"}, + ) + init_token: Optional[int] = field( + default=None, + metadata={"help": "add token at the beginning of each batch item"}, + ) + separator_token: Optional[int] = field( + default=None, + metadata={"help": "add separator token between inputs"}, + ) + no_shuffle: bool = field( + default=False, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed tokens_per_sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + add_prev_output_tokens: bool = field( + default=False, + metadata={ + "help": "add prev_output_tokens to sample, used for encoder-decoder arch" + }, + ) + max_positions: int = field( + default=512, + metadata={"help": "max tokens per example"}, + ) + + regression_target: bool = II("criterion.regression_target") + classification_head_name: str = II("criterion.classification_head_name") + seed: int = II("common.seed") + + +@register_task("sentence_prediction", dataclass=SentencePredictionConfig) +class SentencePredictionTask(FairseqTask): """ Sentence (or sentence pair) prediction (classification or regression) task. @@ -40,64 +93,13 @@ class SentencePredictionTask(LegacyFairseqTask): dictionary (Dictionary): the dictionary for the input of the task """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("data", metavar="FILE", help="file prefix for data") - parser.add_argument( - "--num-classes", - type=int, - default=-1, - help="number of classes or regression targets", - ) - parser.add_argument( - "--init-token", - type=int, - default=None, - help="add token at the beginning of each batch item", - ) - parser.add_argument( - "--separator-token", - type=int, - default=None, - help="add separator token between inputs", - ) - parser.add_argument("--regression-target", action="store_true", default=False) - parser.add_argument("--no-shuffle", action="store_true", default=False) - parser.add_argument( - "--shorten-method", - default="none", - choices=["none", "truncate", "random_crop"], - help="if not none, shorten sequences that exceed --tokens-per-sample", - ) - parser.add_argument( - "--shorten-data-split-list", - default="", - help="comma-separated list of dataset splits to apply shortening to, " - 'e.g., "train,valid" (default: all dataset splits)', - ) - parser.add_argument( - "--add-prev-output-tokens", - action="store_true", - default=False, - help="add prev_output_tokens to sample, used for encoder-decoder arch", - ) - - def __init__(self, args, data_dictionary, label_dictionary): - super().__init__(args) + def __init__(self, cfg, data_dictionary, label_dictionary): + super().__init__(cfg) self.dictionary = data_dictionary self._label_dictionary = label_dictionary - if not hasattr(args, "max_positions"): - self._max_positions = ( - args.max_source_positions, - args.max_target_positions, - ) - else: - self._max_positions = args.max_positions - args.tokens_per_sample = self._max_positions @classmethod - def load_dictionary(cls, args, filename, source=True): + def load_dictionary(cls, filename): """Load the dictionary from the filename Args: @@ -108,34 +110,30 @@ def load_dictionary(cls, args, filename, source=True): return dictionary @classmethod - def setup_task(cls, args, **kwargs): - assert args.num_classes > 0, "Must set --num-classes" + def setup_task(cls, cfg, **kwargs): + assert cfg.num_classes > 0, "Must set task.num_classes" # load data dictionary data_dict = cls.load_dictionary( - args, - os.path.join(args.data, "input0", "dict.txt"), - source=True, + os.path.join(cfg.data, "input0", "dict.txt"), ) logger.info("[input] dictionary: {} types".format(len(data_dict))) # load label dictionary - if not args.regression_target: + if not cfg.regression_target: label_dict = cls.load_dictionary( - args, - os.path.join(args.data, "label", "dict.txt"), - source=False, + os.path.join(cfg.data, "label", "dict.txt"), ) logger.info("[label] dictionary: {} types".format(len(label_dict))) else: label_dict = data_dict - return cls(args, data_dict, label_dict) + return cls(cfg, data_dict, label_dict) def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): - return os.path.join(self.args.data, key, split) + return os.path.join(self.cfg.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) @@ -144,7 +142,6 @@ def make_dataset(key, dictionary): dataset = data_utils.load_indexed_dataset( split_path, dictionary, - self.args.dataset_impl, combine=combine, ) except Exception as e: @@ -161,27 +158,27 @@ def make_dataset(key, dictionary): ) input1 = make_dataset("input1", self.source_dictionary) - if self.args.init_token is not None: - input0 = PrependTokenDataset(input0, self.args.init_token) + if self.cfg.init_token is not None: + input0 = PrependTokenDataset(input0, self.cfg.init_token) if input1 is None: src_tokens = input0 else: - if self.args.separator_token is not None: - input1 = PrependTokenDataset(input1, self.args.separator_token) + if self.cfg.separator_token is not None: + input1 = PrependTokenDataset(input1, self.cfg.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) - with data_utils.numpy_seed(self.args.seed): + with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, - self.args.shorten_data_split_list, - self.args.shorten_method, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, self.max_positions(), - self.args.seed, + self.cfg.seed, ) dataset = { @@ -197,7 +194,7 @@ def make_dataset(key, dictionary): "ntokens": NumelDataset(src_tokens, reduce=True), } - if self.args.add_prev_output_tokens: + if self.cfg.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), @@ -206,7 +203,7 @@ def make_dataset(key, dictionary): prev_output_tokens=prev_tokens_dataset, ) - if not self.args.regression_target: + if not self.cfg.regression_target: label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: dataset.update( @@ -225,8 +222,8 @@ def make_dataset(key, dictionary): def parse_regression_target(i, line): values = line.split() assert ( - len(values) == self.args.num_classes - ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' + len(values) == self.cfg.num_classes + ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] with open(label_path) as h: @@ -244,7 +241,7 @@ def parse_regression_target(i, line): sizes=[src_tokens.sizes], ) - if self.args.no_shuffle: + if self.cfg.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( @@ -258,20 +255,23 @@ def parse_regression_target(i, line): self.datasets[split] = dataset return self.datasets[split] - def build_model(self, args): + def build_model(self, cfg): from fairseq import models - model = models.build_model(args, self) + with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): + cfg.max_positions = self.cfg.max_positions + + model = models.build_model(cfg, self) model.register_classification_head( - getattr(args, "classification_head_name", "sentence_classification_head"), - num_classes=self.args.num_classes, + self.cfg.classification_head_name, + num_classes=self.cfg.num_classes, ) return model def max_positions(self): - return self._max_positions + return self.cfg.max_positions @property def source_dictionary(self): From 605e1ceaa8be442726b0df9351dd28b979761d6e Mon Sep 17 00:00:00 2001 From: Pierce Chuang <pichuang@fb.com> Date: Thu, 8 Jul 2021 17:07:51 -0700 Subject: [PATCH 638/707] change denoising setup_task so that it can read from multiple shards Summary: Follow Roberta data handling to support | based data separation Reviewed By: myleott Differential Revision: D29619263 fbshipit-source-id: 6912df178965c1b0f859604c8f5ad8aff443198a --- fairseq/tasks/denoising.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index cbf01e14df..d1dff26c36 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -147,7 +147,9 @@ def __init__(self, args, dictionary): @classmethod def setup_task(cls, args, **kwargs): """Setup the task.""" - dictionary = Dictionary.load(os.path.join(args.data, "dict.txt")) + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) if not hasattr(args, "shuffle_instance"): args.shuffle_instance = False @@ -196,6 +198,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): break_mode=self.args.sample_break_mode, document_sep_len=0, ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) From d18e44a28994a1558e9f8dc988a23bd6f77a55d5 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Thu, 8 Jul 2021 19:09:58 -0700 Subject: [PATCH 639/707] add robust w2v model (#2046) Summary: add robust wav2vec model Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2046 Reviewed By: wnhsu Differential Revision: D29628639 Pulled By: alexeib fbshipit-source-id: 296cd2da579a969a71a0f9ffe1062002b73a8d86 --- README.md | 3 +++ examples/wav2vec/README.md | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 82b6ba7cd8..147714bae5 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ We provide reference implementations of various sequence modeling papers: + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) + + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) * **Non-autoregressive Transformers** + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -61,6 +63,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) * February 2021 [Added LASER training code](examples/laser/README.md) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 238639a9ba..c543b6b97b 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -6,6 +6,8 @@ We learned speech representations in multiple languages as well in [Unsupervised We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). +We combined speech data from multiple domains in [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + ## Pre-trained models Model | Finetuning split | Dataset | Model @@ -25,8 +27,10 @@ Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebo Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv.pt) -\* updated (Oct. 24, 2020) +\* updated (Oct. 24, 2020)\ +** updated (Jul. 8, 2021) We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: From cffc057b58d9d4b36260d30861f86b2ab8817ac1 Mon Sep 17 00:00:00 2001 From: Wei Ho <weiho@fb.com> Date: Fri, 9 Jul 2021 16:13:51 -0700 Subject: [PATCH 640/707] Roll back os.path.abspath change Reviewed By: donhusa Differential Revision: D29641968 fbshipit-source-id: eca379158055f3e38e9c053b06db56842265e53a --- fairseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 8ec967397f..35cce7fda7 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -105,7 +105,7 @@ def is_better(a, b): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ - os.path.abspath(os.path.join(cfg.save_dir, fn)) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) From 7f2fb5caa872ce06f7cc5d95956f4eca2a6211fe Mon Sep 17 00:00:00 2001 From: Ann Lee <an918tw@users.noreply.github.com> Date: Fri, 9 Jul 2021 16:45:40 -0700 Subject: [PATCH 641/707] Release code for the paper "Discriminative Reranking for Neural Machine Translation" (#2044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Release the code for the paper "Discriminative Reranking for Neural Machine Translation" ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2044 Reviewed By: michaelauli Differential Revision: D29628590 Pulled By: an918tw fbshipit-source-id: 7a52602d495b736573187cc721829aa545d24770 --- README.md | 1 + .../discriminative_reranking_nmt/README.md | 200 ++++++++ .../discriminative_reranking_nmt/__init__.py | 1 + .../config/deen.yaml | 56 +++ .../criterions/__init__.py | 6 + .../discriminative_reranking_criterion.py | 138 +++++ .../drnmt_rerank.py | 364 ++++++++++++++ .../models/__init__.py | 6 + .../models/discriminative_reranking_model.py | 365 ++++++++++++++ .../scripts/prep_data.py | 136 +++++ .../tasks/__init__.py | 6 + .../tasks/discriminative_reranking_task.py | 475 ++++++++++++++++++ 12 files changed, 1754 insertions(+) create mode 100644 examples/discriminative_reranking_nmt/README.md create mode 100644 examples/discriminative_reranking_nmt/__init__.py create mode 100644 examples/discriminative_reranking_nmt/config/deen.yaml create mode 100644 examples/discriminative_reranking_nmt/criterions/__init__.py create mode 100644 examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py create mode 100644 examples/discriminative_reranking_nmt/drnmt_rerank.py create mode 100644 examples/discriminative_reranking_nmt/models/__init__.py create mode 100644 examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py create mode 100755 examples/discriminative_reranking_nmt/scripts/prep_data.py create mode 100644 examples/discriminative_reranking_nmt/tasks/__init__.py create mode 100644 examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py diff --git a/README.md b/README.md index 147714bae5..460f3439fb 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) diff --git a/examples/discriminative_reranking_nmt/README.md b/examples/discriminative_reranking_nmt/README.md new file mode 100644 index 0000000000..aba0090370 --- /dev/null +++ b/examples/discriminative_reranking_nmt/README.md @@ -0,0 +1,200 @@ +# Discriminative Reranking for Neural Machine Translation +This folder contains source code for training DrNMT, a discriminatively trained reranker for neural machine translation. + +## Data preparation +1. Follow the instructions under `examples/translation` to build a base MT model. Prepare three files, one with source sentences, one with ground truth target sentences, and one with hypotheses generated from the base MT model. Each line in the file contains one sentence in raw text (i.e. no sentencepiece, etc.). Below is an example of the files with _N_ hypotheses for each source sentence. + +``` +# Example of the source sentence file: (The file should contain L lines.) + +source_sentence_1 +source_sentence_2 +source_sentence_3 +... +source_sentence_L + +# Example of the target sentence file: (The file should contain L lines.) + +target_sentence_1 +target_sentence_2 +target_sentence_3 +... +target_sentence_L + +# Example of the hypotheses file: (The file should contain L*N lines.) + +source_sentence_1_hypo_1 +source_sentence_1_hypo_2 +... +source_sentence_1_hypo_N +source_sentence_2_hypo_1 +... +source_sentence_2_hypo_N +... +source_sentence_L_hypo_1 +... +source_sentence_L_hypo_N +``` + +2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/master/examples/xlmr#pre-trained-models). +``` +wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz +tar zxvf xlmr.base.tar.gz + +# The folder should contain dict.txt, model.pt and sentencepiece.bpe.model. +``` + +3. Prepare scores and BPE data. +* `N`: Number of hypotheses per each source sentence. We use 50 in the paper. +* `SPLIT`: Name of the data split, i.e. train, valid, test. Use split_name, split_name1, split_name2, ..., if there are multiple datasets for a split, e.g. train, train1, valid, valid1. +* `NUM_SHARDS`: Number of shards. Set this to 1 for non-train splits. +* `METRIC`: The metric for DrNMT to optimize for. We support either `bleu` or `ter`. +``` +# For each data split, e.g. train, valid, test, etc., run the following: + +SOURCE_FILE=/path/to/source_sentence_file +TARGET_FILE=/path/to/target_sentence_file +HYPO_FILE=/path/to/hypo_file +XLMR_DIR=/path/to/xlmr +OUTPUT_DIR=/path/to/output + +python scripts/prep_data.py \ + --input-source ${SOURCE_FILE} \ + --input-target ${TARGET_FILE} \ + --input-hypo ${HYPO_FILE} \ + --output-dir ${OUTPUT_DIR} \ + --split $SPLIT + --beam $N \ + --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \ + --metric $METRIC \ + --num-shards ${NUM_SHARDS} + +# The script will create ${OUTPUT_DIR}/$METRIC with ${NUM_SHARDS} splits. +# Under split*/input_src, split*/input_tgt and split*/$METRIC, there will be $SPLIT.bpe and $SPLIT.$METRIC files, respectively. + +``` + +4. Pre-process the data into fairseq format. +``` +# use comma to separate if there are more than one train or valid set +for suffix in src tgt ; do + fairseq-preprocess --only-source \ + --trainpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/train.bpe \ + --validpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid.bpe \ + --destdir ${OUTPUT_DIR}/$METRIC/split1/input_${suffix} \ + --workers 60 \ + --srcdict ${XLMR_DIR}/dict.txt +done + +for i in `seq 2 ${NUM_SHARDS}`; do + for suffix in src tgt ; do + fairseq-preprocess --only-source \ + --trainpref ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/train.bpe \ + --destdir ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix} \ + --workers 60 \ + --srcdict ${XLMR_DIR}/dict.txt + + ln -s ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid* ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/. + done + + ln -s ${OUTPUT_DIR}/$METRIC/split1/$METRIC/valid* ${OUTPUT_DIR}/$METRIC/split${i}/$METRIC/. +done +``` + +## Training + +``` +EXP_DIR=/path/to/exp + +# An example of training the model with the config for De-En experiment in the paper. +# The config uses 16 GPUs and 50 hypotheses. +# For training with fewer number of GPUs, set +# distributed_training.distributed_world_size=k +optimization.update_freq='[x]' where x = 16/k +# For training with fewer number of hypotheses, set +# task.mt_beam=N dataset.batch_size=N dataset.required_batch_size_multiple=N + +fairseq-hydra-train -m \ + --config-dir config/ --config-name deen \ + task.data=${OUTPUT_DIR}/$METRIC/split1/ \ + task.num_data_splits=${NUM_SHARDS} \ + model.pretrained_model=${XLMR_DIR}/model.pt \ + common.user_dir=${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \ + checkpoint.save_dir=${EXP_DIR} + +``` + +## Inference & scoring +Perform DrNMT reranking (fw + reranker score) +1. Tune weights on valid sets. +``` +# genrate N hypotheses with the base MT model (fw score) +VALID_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model +VALID_TARGET_FILE=/path/to/target_sentences # one sentence per line in raw text, i.e. no sentencepiece and tokenization +MT_MODEL=/path/to/mt_model +MT_DATA_PATH=/path/to/mt_data + +cat ${VALID_SOURCE_FILE} | \ + fairseq-interactive ${MT_DATA_PATH} \ + --max-tokens 4000 --buffer-size 16 \ + --num-workers 32 --path ${MT_MODEL} \ + --beam $N --nbest $N \ + --post-process sentencepiece &> valid-hypo.out + +# replace "bleu" with "ter" to optimize for TER +python drnmt_rerank.py \ + ${OUTPUT_DIR}/$METRIC/split1/ \ + --path ${EXP_DIR}/checkpoint_best.pt \ + --in-text valid-hypo.out \ + --results-path ${EXP_DIR} \ + --gen-subset valid \ + --target-text ${VALID_TARGET_FILE} \ + --user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \ + --bpe sentencepiece \ + --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \ + --beam $N \ + --batch-size $N \ + --metric bleu \ + --tune + +``` + +2. Apply best weights on test sets +``` +# genrate N hypotheses with the base MT model (fw score) +TEST_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model + +cat ${TEST_SOURCE_FILE} | \ + fairseq-interactive ${MT_DATA_PATH} \ + --max-tokens 4000 --buffer-size 16 \ + --num-workers 32 --path ${MT_MODEL} \ + --beam $N --nbest $N \ + --post-process sentencepiece &> test-hypo.out + +# replace "bleu" with "ter" to evaluate TER +# Add --target-text for evaluating BLEU/TER, +# otherwise the script will only generate the hypotheses with the highest scores only. +python drnmt_rerank.py \ + ${OUTPUT_DIR}/$METRIC/split1/ \ + --path ${EXP_DIR}/checkpoint_best.pt \ + --in-text test-hypo.out \ + --results-path ${EXP_DIR} \ + --gen-subset test \ + --user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \ + --bpe sentencepiece \ + --sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \ + --beam $N \ + --batch-size $N \ + --metric bleu \ + --fw-weight ${BEST_FW_WEIGHT} \ + --lenpen ${BEST_LENPEN} +``` + +## Citation +```bibtex +@inproceedings{lee2021discriminative, + title={Discriminative Reranking for Neural Machine Translation}, + author={Lee, Ann and Auli, Michael and Ranzato, Marc'Aurelio}, + booktitle={ACL}, + year={2021} +} +``` diff --git a/examples/discriminative_reranking_nmt/__init__.py b/examples/discriminative_reranking_nmt/__init__.py new file mode 100644 index 0000000000..0278f6a273 --- /dev/null +++ b/examples/discriminative_reranking_nmt/__init__.py @@ -0,0 +1 @@ +from . import criterions, models, tasks # noqa diff --git a/examples/discriminative_reranking_nmt/config/deen.yaml b/examples/discriminative_reranking_nmt/config/deen.yaml new file mode 100644 index 0000000000..3fc2d5fcf5 --- /dev/null +++ b/examples/discriminative_reranking_nmt/config/deen.yaml @@ -0,0 +1,56 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 50 + seed: 2 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: bleu + maximize_best_checkpoint_metric: true + +task: + _name: discriminative_reranking_nmt + data: ??? + num_data_splits: ??? + include_src: true + mt_beam: 50 + eval_target_metric: true + target_metric: bleu + +dataset: + batch_size: 50 + num_workers: 6 + required_batch_size_multiple: 50 + valid_subset: ??? + +criterion: + _name: kl_divergence_rereanking + target_dist_norm: minmax + temperature: 0.5 + +optimization: + max_epoch: 200 + lr: [0.00005] + update_freq: [32] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 8000 + total_num_update: 320000 + +model: + _name: discriminative_nmt_reranker + pretrained_model: ??? + classifier_dropout: 0.2 + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 16 diff --git a/examples/discriminative_reranking_nmt/criterions/__init__.py b/examples/discriminative_reranking_nmt/criterions/__init__.py new file mode 100644 index 0000000000..7c257c2700 --- /dev/null +++ b/examples/discriminative_reranking_nmt/criterions/__init__.py @@ -0,0 +1,6 @@ +from .discriminative_reranking_criterion import KLDivergenceRerankingCriterion + + +__all__ = [ + "KLDivergenceRerankingCriterion", +] diff --git a/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py b/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py new file mode 100644 index 0000000000..0b02ce1877 --- /dev/null +++ b/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import ChoiceEnum, FairseqDataclass + + +_EPSILON = torch.finfo(torch.float32).eps +TARGET_DIST_NORM_CHOICES = ChoiceEnum(["none", "minmax"]) + + +@dataclass +class KLDivergenceRerankingCriterionConfig(FairseqDataclass): + target_dist_norm: TARGET_DIST_NORM_CHOICES = field( + default="none", + metadata={"help": "method to normalize the range of target scores"}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature in softmax for target distributions"}, + ) + forward_batch_size: int = field( + default=32, + metadata={ + "help": "number of hypotheses per batch for model forward (set a value smaller than --mt-beam to avoid OOM when training with a large beam size)" + }, + ) + + +@register_criterion( + "kl_divergence_rereanking", dataclass=KLDivergenceRerankingCriterionConfig +) +class KLDivergenceRerankingCriterion(FairseqCriterion): + def __init__( + self, task, target_dist_norm, temperature, forward_batch_size, + ): + super().__init__(task) + self.target_dist_norm = target_dist_norm + self.temperature = temperature + self.forward_batch_size = forward_batch_size + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + sample_size = sample["id"].numel() + assert sample_size % self.task.cfg.mt_beam == 0, ( + f"sample_size ({sample_size}) cannot be divided by beam size ({self.task.cfg.mt_beam})." + f"Please set --required-batch-size-multiple={self.task.cfg.mt_beam}." + ) + + # split into smaller batches for model forward + batch_out = [] + for i in range(0, sample_size, self.forward_batch_size): + j = min(i + self.forward_batch_size, sample_size) + + out = model( + src_tokens=sample["net_input"]["src_tokens"][i:j, :], + src_lengths=sample["net_input"]["src_lengths"][i:j], + ) + + batch_out.append( + model.sentence_forward(out, sample["net_input"]["src_tokens"][i:j, :]) + ) + + batch_out = torch.cat(batch_out, dim=0).view( + self.task.cfg.mt_beam, sample_size // self.task.cfg.mt_beam, -1 + ) # T x B x C + if model.joint_classification == "sent": + batch_out = model.joint_forward(batch_out) + scores = model.classification_forward(batch_out.view(sample_size, 1, -1)).view( + -1, self.task.cfg.mt_beam + ) # input: B x T x C + + loss = self.compute_kl_loss( + scores, sample["target"][:, 0].view(-1, self.task.cfg.mt_beam) + ) + + sample_size = sample_size // self.task.cfg.mt_beam + + logging_output = { + "loss": loss.detach(), + "ntokens": sample["ntokens"], + "nsentences": sample_size * self.task.cfg.mt_beam, + "sample_size": sample_size, + "scores": scores.detach(), + } + + return loss, sample_size, logging_output + + def compute_kl_loss(self, logits, target): + norm_target = target + if self.target_dist_norm == "minmax": + min_v = torch.min(target, 1, keepdim=True).values + max_v = torch.max(target, 1, keepdim=True).values + norm_target = (target - min_v) / (max_v - min_v + _EPSILON) + + target_dist = F.softmax( + norm_target / self.temperature, dim=-1, dtype=torch.float32 + ) + model_dist = F.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = -(target_dist * model_dist - target_dist * target_dist.log()).sum() + return loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + loss = loss_sum / sample_size / math.log(2) + metrics.log_scalar("loss", loss, sample_size, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/examples/discriminative_reranking_nmt/drnmt_rerank.py b/examples/discriminative_reranking_nmt/drnmt_rerank.py new file mode 100644 index 0000000000..2e0fc2bd29 --- /dev/null +++ b/examples/discriminative_reranking_nmt/drnmt_rerank.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Score raw text with a trained model. +""" + +from collections import namedtuple +import logging +from multiprocessing import Pool +import sys +import os +import random + +import numpy as np +import sacrebleu +import torch + +from fairseq import checkpoint_utils, options, utils + + +logger = logging.getLogger("fairseq_cli.drnmt_rerank") +logger.setLevel(logging.INFO) + +Batch = namedtuple("Batch", "ids src_tokens src_lengths") + + +pool_init_variables = {} + + +def init_loaded_scores(mt_scores, model_scores, hyp, ref): + global pool_init_variables + pool_init_variables["mt_scores"] = mt_scores + pool_init_variables["model_scores"] = model_scores + pool_init_variables["hyp"] = hyp + pool_init_variables["ref"] = ref + + +def parse_fairseq_gen(filename, task): + source = {} + hypos = {} + scores = {} + with open(filename, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("S-"): # source + uid, text = line.split("\t", 1) + uid = int(uid[2:]) + source[uid] = text + elif line.startswith("D-"): # hypo + uid, score, text = line.split("\t", 2) + uid = int(uid[2:]) + if uid not in hypos: + hypos[uid] = [] + scores[uid] = [] + hypos[uid].append(text) + scores[uid].append(float(score)) + else: + continue + + source_out = [source[i] for i in range(len(hypos))] + hypos_out = [h for i in range(len(hypos)) for h in hypos[i]] + scores_out = [s for i in range(len(scores)) for s in scores[i]] + + return source_out, hypos_out, scores_out + + +def read_target(filename): + with open(filename, "r", encoding="utf-8") as f: + output = [line.strip() for line in f] + return output + + +def make_batches(args, src, hyp, task, max_positions, encode_fn): + assert len(src) * args.beam == len( + hyp + ), f"Expect {len(src) * args.beam} hypotheses for {len(src)} source sentences with beam size {args.beam}. Got {len(hyp)} hypotheses intead." + hyp_encode = [ + task.source_dictionary.encode_line(encode_fn(h), add_if_not_exist=False).long() + for h in hyp + ] + if task.cfg.include_src: + src_encode = [ + task.source_dictionary.encode_line( + encode_fn(s), add_if_not_exist=False + ).long() + for s in src + ] + tokens = [(src_encode[i // args.beam], h) for i, h in enumerate(hyp_encode)] + lengths = [(t1.numel(), t2.numel()) for t1, t2 in tokens] + else: + tokens = [(h,) for h in hyp_encode] + lengths = [(h.numel(),) for h in hyp_encode] + + itr = task.get_batch_iterator( + dataset=task.build_dataset_for_inference(tokens, lengths), + max_tokens=args.max_tokens, + max_sentences=args.batch_size, + max_positions=max_positions, + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + ).next_epoch_itr(shuffle=False) + + for batch in itr: + yield Batch( + ids=batch["id"], + src_tokens=batch["net_input"]["src_tokens"], + src_lengths=batch["net_input"]["src_lengths"], + ) + + +def decode_rerank_scores(args): + if args.max_tokens is None and args.batch_size is None: + args.batch_size = 1 + + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load ensemble + logger.info("loading model(s) from {}".format(args.path)) + models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [args.path], arg_overrides=eval(args.model_overrides), + ) + + for model in models: + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + # Initialize generator + generator = task.build_generator(args) + + # Handle tokenization and BPE + tokenizer = task.build_tokenizer(args) + bpe = task.build_bpe(args) + + def encode_fn(x): + if tokenizer is not None: + x = tokenizer.encode(x) + if bpe is not None: + x = bpe.encode(x) + return x + + max_positions = utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ) + + src, hyp, mt_scores = parse_fairseq_gen(args.in_text, task) + model_scores = {} + logger.info("decode reranker score") + for batch in make_batches(args, src, hyp, task, max_positions, encode_fn): + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + if use_cuda: + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + + sample = { + "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}, + } + scores = task.inference_step(generator, models, sample) + + for id, sc in zip(batch.ids.tolist(), scores.tolist()): + model_scores[id] = sc[0] + + model_scores = [model_scores[i] for i in range(len(model_scores))] + + return src, hyp, mt_scores, model_scores + + +def get_score(mt_s, md_s, w1, lp, tgt_len): + return mt_s / (tgt_len ** lp) * w1 + md_s + + +def get_best_hyps(mt_scores, md_scores, hypos, fw_weight, lenpen, beam): + assert len(mt_scores) == len(md_scores) and len(mt_scores) == len(hypos) + hypo_scores = [] + best_hypos = [] + best_scores = [] + offset = 0 + for i in range(len(hypos)): + tgt_len = len(hypos[i].split()) + hypo_scores.append( + get_score(mt_scores[i], md_scores[i], fw_weight, lenpen, tgt_len) + ) + + if (i + 1) % beam == 0: + max_i = np.argmax(hypo_scores) + best_hypos.append(hypos[offset + max_i]) + best_scores.append(hypo_scores[max_i]) + hypo_scores = [] + offset += beam + return best_hypos, best_scores + + +def eval_metric(args, hypos, ref): + if args.metric == "bleu": + score = sacrebleu.corpus_bleu(hypos, [ref]).score + else: + score = sacrebleu.corpus_ter(hypos, [ref]).score + + return score + + +def score_target_hypo(args, fw_weight, lp): + mt_scores = pool_init_variables["mt_scores"] + model_scores = pool_init_variables["model_scores"] + hyp = pool_init_variables["hyp"] + ref = pool_init_variables["ref"] + best_hypos, _ = get_best_hyps( + mt_scores, model_scores, hyp, fw_weight, lp, args.beam + ) + rerank_eval = None + if ref: + rerank_eval = eval_metric(args, best_hypos, ref) + print(f"fw_weight {fw_weight}, lenpen {lp}, eval {rerank_eval}") + + return rerank_eval + + +def print_result(best_scores, best_hypos, output_file): + for i, (s, h) in enumerate(zip(best_scores, best_hypos)): + print(f"{i}\t{s}\t{h}", file=output_file) + + +def main(args): + utils.import_user_module(args) + + src, hyp, mt_scores, model_scores = decode_rerank_scores(args) + + assert ( + not args.tune or args.target_text is not None + ), "--target-text has to be set when tuning weights" + if args.target_text: + ref = read_target(args.target_text) + assert len(src) == len( + ref + ), f"different numbers of source and target sentences ({len(src)} vs. {len(ref)})" + + orig_best_hypos = [hyp[i] for i in range(0, len(hyp), args.beam)] + orig_eval = eval_metric(args, orig_best_hypos, ref) + + if args.tune: + logger.info("tune weights for reranking") + + random_params = np.array( + [ + [ + random.uniform( + args.lower_bound_fw_weight, args.upper_bound_fw_weight + ), + random.uniform(args.lower_bound_lenpen, args.upper_bound_lenpen), + ] + for k in range(args.num_trials) + ] + ) + + logger.info("launching pool") + with Pool( + 32, + initializer=init_loaded_scores, + initargs=(mt_scores, model_scores, hyp, ref), + ) as p: + rerank_scores = p.starmap( + score_target_hypo, + [ + (args, random_params[i][0], random_params[i][1],) + for i in range(args.num_trials) + ], + ) + if args.metric == "bleu": + best_index = np.argmax(rerank_scores) + else: + best_index = np.argmin(rerank_scores) + best_fw_weight = random_params[best_index][0] + best_lenpen = random_params[best_index][1] + else: + assert ( + args.lenpen is not None and args.fw_weight is not None + ), "--lenpen and --fw-weight should be set" + best_fw_weight, best_lenpen = args.fw_weight, args.lenpen + + best_hypos, best_scores = get_best_hyps( + mt_scores, model_scores, hyp, best_fw_weight, best_lenpen, args.beam + ) + + if args.results_path is not None: + os.makedirs(args.results_path, exist_ok=True) + output_path = os.path.join( + args.results_path, "generate-{}.txt".format(args.gen_subset), + ) + with open(output_path, "w", buffering=1, encoding="utf-8") as o: + print_result(best_scores, best_hypos, o) + else: + print_result(best_scores, best_hypos, sys.stdout) + + if args.target_text: + rerank_eval = eval_metric(args, best_hypos, ref) + print(f"before reranking, {args.metric.upper()}:", orig_eval) + print( + f"after reranking with fw_weight={best_fw_weight}, lenpen={best_lenpen}, {args.metric.upper()}:", + rerank_eval, + ) + + +def cli_main(): + parser = options.get_generation_parser(interactive=True) + + parser.add_argument( + "--in-text", + default=None, + required=True, + help="text from fairseq-interactive output, containing source sentences and hypotheses", + ) + parser.add_argument("--target-text", default=None, help="reference text") + parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu") + parser.add_argument( + "--tune", + action="store_true", + help="if set, tune weights on fw scores and lenpen instead of applying fixed weights for reranking", + ) + parser.add_argument( + "--lower-bound-fw-weight", + default=0.0, + type=float, + help="lower bound of search space", + ) + parser.add_argument( + "--upper-bound-fw-weight", + default=3, + type=float, + help="upper bound of search space", + ) + parser.add_argument( + "--lower-bound-lenpen", + default=0.0, + type=float, + help="lower bound of search space", + ) + parser.add_argument( + "--upper-bound-lenpen", + default=3, + type=float, + help="upper bound of search space", + ) + parser.add_argument( + "--fw-weight", type=float, default=None, help="weight on the fw model score" + ) + parser.add_argument( + "--num-trials", + default=1000, + type=int, + help="number of trials to do for random search", + ) + + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/discriminative_reranking_nmt/models/__init__.py b/examples/discriminative_reranking_nmt/models/__init__.py new file mode 100644 index 0000000000..c593ea5f18 --- /dev/null +++ b/examples/discriminative_reranking_nmt/models/__init__.py @@ -0,0 +1,6 @@ +from .discriminative_reranking_model import DiscriminativeNMTReranker + + +__all__ = [ + "DiscriminativeNMTReranker", +] diff --git a/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py b/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py new file mode 100644 index 0000000000..e4b5887f82 --- /dev/null +++ b/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py @@ -0,0 +1,365 @@ +from dataclasses import dataclass, field +import os + +import torch +import torch.nn as nn + +from fairseq import utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import ( + BaseFairseqModel, + register_model, +) + +from fairseq.models.roberta.model import RobertaClassificationHead + +from fairseq.modules import ( + LayerNorm, + TransformerSentenceEncoder, + TransformerSentenceEncoderLayer, +) + + +ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns()) +JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"]) +SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"]) + + +def update_init_roberta_model_state(state): + """ + update the state_dict of a Roberta model for initializing + weights of the BertRanker + """ + for k in list(state.keys()): + if ".lm_head." in k or "version" in k: + del state[k] + continue + # remove 'encoder/decoder.sentence_encoder.' from the key + assert k.startswith("encoder.sentence_encoder.") or k.startswith( + "decoder.sentence_encoder." + ), f"Cannot recognize parameter name {k}" + if "layernorm_embedding" in k: + new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.") + state[new_k[25:]] = state[k] + else: + state[k[25:]] = state[k] + del state[k] + + +class BaseRanker(nn.Module): + def __init__(self, args, task): + super().__init__() + + self.separator_token = task.dictionary.eos() + self.padding_idx = task.dictionary.pad() + + def forward(self, src_tokens): + raise NotImplementedError + + def get_segment_labels(self, src_tokens): + segment_boundary = (src_tokens == self.separator_token).long() + segment_labels = ( + segment_boundary.cumsum(dim=1) + - segment_boundary + - (src_tokens == self.padding_idx).long() + ) + + return segment_labels + + def get_positions(self, src_tokens, segment_labels): + segment_positions = ( + torch.arange(src_tokens.shape[1]) + .to(src_tokens.device) + .repeat(src_tokens.shape[0], 1) + ) + segment_boundary = (src_tokens == self.separator_token).long() + _, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True) + col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx]) + offset = torch.cat( + [ + torch.zeros(1).type_as(segment_boundary), + segment_boundary.sum(dim=1).cumsum(dim=0)[:-1], + ] + ) + segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * ( + segment_labels != 0 + ) + + padding_mask = src_tokens.ne(self.padding_idx) + segment_positions = (segment_positions + 1) * padding_mask.type_as( + segment_positions + ) + self.padding_idx + + return segment_positions + + +class BertRanker(BaseRanker): + def __init__(self, args, task): + super(BertRanker, self).__init__(args, task) + + init_model = getattr(args, "pretrained_model", "") + self.joint_layers = nn.ModuleList() + if os.path.isfile(init_model): + print(f"initialize weight from {init_model}") + + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + os.path.dirname(init_model), + checkpoint_file=os.path.basename(init_model), + ) + + in_state_dict = x["models"][0].state_dict() + init_args = x["args"].model + + num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1 + + # follow the setup in roberta + self.model = TransformerSentenceEncoder( + padding_idx=task.dictionary.pad(), + vocab_size=len(task.dictionary), + num_encoder_layers=getattr( + args, "encoder_layers", init_args.encoder_layers + ), + embedding_dim=init_args.encoder_embed_dim, + ffn_embedding_dim=init_args.encoder_ffn_embed_dim, + num_attention_heads=init_args.encoder_attention_heads, + dropout=init_args.dropout, + attention_dropout=init_args.attention_dropout, + activation_dropout=init_args.activation_dropout, + num_segments=2, # add language embeddings + max_seq_len=num_positional_emb, + offset_positions_by_padding=False, + encoder_normalize_before=True, + apply_bert_init=True, + activation_fn=init_args.activation_fn, + freeze_embeddings=args.freeze_embeddings, + n_trans_layers_to_freeze=args.n_trans_layers_to_freeze, + ) + + # still need to learn segment embeddings as we added a second language embedding + if args.freeze_embeddings: + for p in self.model.segment_embeddings.parameters(): + p.requires_grad = False + + update_init_roberta_model_state(in_state_dict) + print("loading weights from the pretrained model") + self.model.load_state_dict( + in_state_dict, strict=False + ) # ignore mismatch in language embeddings + + ffn_embedding_dim = init_args.encoder_ffn_embed_dim + num_attention_heads = init_args.encoder_attention_heads + dropout = init_args.dropout + attention_dropout = init_args.attention_dropout + activation_dropout = init_args.activation_dropout + activation_fn = init_args.activation_fn + + classifier_embed_dim = getattr( + args, "embed_dim", init_args.encoder_embed_dim + ) + if classifier_embed_dim != init_args.encoder_embed_dim: + self.transform_layer = nn.Linear( + init_args.encoder_embed_dim, classifier_embed_dim + ) + else: + self.model = TransformerSentenceEncoder( + padding_idx=task.dictionary.pad(), + vocab_size=len(task.dictionary), + num_encoder_layers=args.encoder_layers, + embedding_dim=args.embed_dim, + ffn_embedding_dim=args.ffn_embed_dim, + num_attention_heads=args.attention_heads, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + max_seq_len=task.max_positions() + if task.max_positions() + else args.tokens_per_sample, + num_segments=2, + offset_positions_by_padding=False, + encoder_normalize_before=args.encoder_normalize_before, + apply_bert_init=args.apply_bert_init, + activation_fn=args.activation_fn, + ) + + classifier_embed_dim = args.embed_dim + ffn_embedding_dim = args.ffn_embed_dim + num_attention_heads = args.attention_heads + dropout = args.dropout + attention_dropout = args.attention_dropout + activation_dropout = args.activation_dropout + activation_fn = args.activation_fn + + self.joint_classification = args.joint_classification + if args.joint_classification == "sent": + if args.joint_normalize_before: + self.joint_layer_norm = LayerNorm(classifier_embed_dim) + else: + self.joint_layer_norm = None + + self.joint_layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=classifier_embed_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + ) + for _ in range(args.num_joint_layers) + ] + ) + + self.classifier = RobertaClassificationHead( + classifier_embed_dim, + classifier_embed_dim, + 1, # num_classes + "tanh", + args.classifier_dropout, + ) + + def forward(self, src_tokens, src_lengths): + segment_labels = self.get_segment_labels(src_tokens) + positions = self.get_positions(src_tokens, segment_labels) + + inner_states, _ = self.model( + tokens=src_tokens, + segment_labels=segment_labels, + last_state_only=True, + positions=positions, + ) + + return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C + + def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"): + # encoder_out: B x T x C + if sentence_rep == "head": + x = encoder_out[:, :1, :] + else: # 'meanpool', 'maxpool' + assert src_tokens is not None, "meanpool requires src_tokens input" + segment_labels = self.get_segment_labels(src_tokens) + padding_mask = src_tokens.ne(self.padding_idx) + encoder_mask = segment_labels * padding_mask.type_as(segment_labels) + + if sentence_rep == "meanpool": + ntokens = torch.sum(encoder_mask, dim=1, keepdim=True) + x = torch.sum( + encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True + ) / ntokens.unsqueeze(2).type_as(encoder_out) + else: # 'maxpool' + encoder_out[ + (encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1]) + ] = -float("inf") + x, _ = torch.max(encoder_out, dim=1, keepdim=True) + + if hasattr(self, "transform_layer"): + x = self.transform_layer(x) + + return x # B x 1 x C + + def joint_forward(self, x): + # x: T x B x C + if self.joint_layer_norm: + x = self.joint_layer_norm(x.transpose(0, 1)) + x = x.transpose(0, 1) + + for layer in self.joint_layers: + x, _ = layer(x, self_attn_padding_mask=None) + return x + + def classification_forward(self, x): + # x: B x T x C + return self.classifier(x) + + +@dataclass +class DiscriminativeNMTRerankerConfig(FairseqDataclass): + pretrained_model: str = field( + default="", metadata={"help": "pretrained model to load"} + ) + sentence_rep: SENTENCE_REP_CHOICES = field( + default="head", + metadata={ + "help": "method to transform the output of the transformer stack to a sentence-level representation" + }, + ) + + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN"} + ) + classifier_dropout: float = field( + default=0.0, metadata={"help": "classifier dropout probability"} + ) + embed_dim: int = field(default=768, metadata={"help": "embedding dimension"}) + ffn_embed_dim: int = field( + default=2048, metadata={"help": "embedding dimension for FFN"} + ) + encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"}) + attention_heads: int = field(default=8, metadata={"help": "num attention heads"}) + encoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each encoder block"} + ) + apply_bert_init: bool = field( + default=False, metadata={"help": "use custom param initialization for BERT"} + ) + activation_fn: ACTIVATION_FN_CHOICES = field( + default="relu", metadata={"help": "activation function to use"} + ) + freeze_embeddings: bool = field( + default=False, metadata={"help": "freeze embeddings in the pretrained model"} + ) + n_trans_layers_to_freeze: int = field( + default=0, + metadata={ + "help": "number of layers to freeze in the pretrained transformer model" + }, + ) + + # joint classfication + joint_classification: JOINT_CLASSIFICATION_CHOICES = field( + default="none", + metadata={"help": "method to compute joint features for classification"}, + ) + num_joint_layers: int = field( + default=1, metadata={"help": "number of joint layers"} + ) + joint_normalize_before: bool = field( + default=False, + metadata={"help": "apply layer norm on the input to the joint layer"}, + ) + + +@register_model( + "discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig +) +class DiscriminativeNMTReranker(BaseFairseqModel): + @classmethod + def build_model(cls, args, task): + model = BertRanker(args, task) + return DiscriminativeNMTReranker(args, model) + + def __init__(self, args, model): + super().__init__() + + self.model = model + self.sentence_rep = args.sentence_rep + self.joint_classification = args.joint_classification + + def forward(self, src_tokens, src_lengths, **kwargs): + return self.model(src_tokens, src_lengths) + + def sentence_forward(self, encoder_out, src_tokens): + return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep) + + def joint_forward(self, x): + return self.model.joint_forward(x) + + def classification_forward(self, x): + return self.model.classification_forward(x) diff --git a/examples/discriminative_reranking_nmt/scripts/prep_data.py b/examples/discriminative_reranking_nmt/scripts/prep_data.py new file mode 100755 index 0000000000..7aa7d37edc --- /dev/null +++ b/examples/discriminative_reranking_nmt/scripts/prep_data.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +import argparse +from multiprocessing import Pool +from pathlib import Path + +import sacrebleu +import sentencepiece as spm + + +def read_text_file(filename): + with open(filename, "r") as f: + output = [line.strip() for line in f] + + return output + + +def get_bleu(in_sent, target_sent): + bleu = sacrebleu.corpus_bleu([in_sent], [[target_sent]]) + out = " ".join( + map(str, [bleu.score, bleu.sys_len, bleu.ref_len] + bleu.counts + bleu.totals) + ) + return out + + +def get_ter(in_sent, target_sent): + ter = sacrebleu.corpus_ter([in_sent], [[target_sent]]) + out = " ".join(map(str, [ter.score, ter.num_edits, ter.ref_length])) + return out + + +def init(sp_model): + global sp + sp = spm.SentencePieceProcessor() + sp.Load(sp_model) + + +def process(source_sent, target_sent, hypo_sent, metric): + source_bpe = " ".join(sp.EncodeAsPieces(source_sent)) + hypo_bpe = [" ".join(sp.EncodeAsPieces(h)) for h in hypo_sent] + + if metric == "bleu": + score_str = [get_bleu(h, target_sent) for h in hypo_sent] + else: # ter + score_str = [get_ter(h, target_sent) for h in hypo_sent] + + return source_bpe, hypo_bpe, score_str + + +def main(args): + assert ( + args.split.startswith("train") or args.num_shards == 1 + ), "--num-shards should be set to 1 for valid and test sets" + assert ( + args.split.startswith("train") + or args.split.startswith("valid") + or args.split.startswith("test") + ), "--split should be set to train[n]/valid[n]/test[n]" + + source_sents = read_text_file(args.input_source) + target_sents = read_text_file(args.input_target) + + num_sents = len(source_sents) + assert num_sents == len( + target_sents + ), f"{args.input_source} and {args.input_target} should have the same number of sentences." + + hypo_sents = read_text_file(args.input_hypo) + assert ( + len(hypo_sents) % args.beam == 0 + ), f"Number of hypotheses ({len(hypo_sents)}) cannot be divided by beam size ({args.beam})." + + hypo_sents = [ + hypo_sents[i : i + args.beam] for i in range(0, len(hypo_sents), args.beam) + ] + assert num_sents == len( + hypo_sents + ), f"{args.input_hypo} should contain {num_sents * args.beam} hypotheses but only has {len(hypo_sents) * args.beam}. (--beam={args.beam})" + + output_dir = args.output_dir / args.metric + for ns in range(args.num_shards): + print(f"processing shard {ns+1}/{args.num_shards}") + shard_output_dir = output_dir / f"split{ns+1}" + source_output_dir = shard_output_dir / "input_src" + hypo_output_dir = shard_output_dir / "input_tgt" + metric_output_dir = shard_output_dir / args.metric + + source_output_dir.mkdir(parents=True, exist_ok=True) + hypo_output_dir.mkdir(parents=True, exist_ok=True) + metric_output_dir.mkdir(parents=True, exist_ok=True) + + if args.n_proc > 1: + with Pool( + args.n_proc, initializer=init, initargs=(args.sentencepiece_model,) + ) as p: + output = p.starmap( + process, + [ + (source_sents[i], target_sents[i], hypo_sents[i], args.metric) + for i in range(ns, num_sents, args.num_shards) + ], + ) + else: + init(args.sentencepiece_model) + output = [ + process(source_sents[i], target_sents[i], hypo_sents[i], args.metric) + for i in range(ns, num_sents, args.num_shards) + ] + + with open(source_output_dir / f"{args.split}.bpe", "w") as s_o, open( + hypo_output_dir / f"{args.split}.bpe", "w" + ) as h_o, open(metric_output_dir / f"{args.split}.{args.metric}", "w") as m_o: + for source_bpe, hypo_bpe, score_str in output: + assert len(hypo_bpe) == len(score_str) + for h, m in zip(hypo_bpe, score_str): + s_o.write(f"{source_bpe}\n") + h_o.write(f"{h}\n") + m_o.write(f"{m}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-source", type=Path, required=True) + parser.add_argument("--input-target", type=Path, required=True) + parser.add_argument("--input-hypo", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--split", type=str, required=True) + parser.add_argument("--beam", type=int, required=True) + parser.add_argument("--sentencepiece-model", type=str, required=True) + parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu") + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--n-proc", type=int, default=8) + + args = parser.parse_args() + + main(args) diff --git a/examples/discriminative_reranking_nmt/tasks/__init__.py b/examples/discriminative_reranking_nmt/tasks/__init__.py new file mode 100644 index 0000000000..2d78ca9870 --- /dev/null +++ b/examples/discriminative_reranking_nmt/tasks/__init__.py @@ -0,0 +1,6 @@ +from .discriminative_reranking_task import DiscriminativeRerankingNMTTask + + +__all__ = [ + "DiscriminativeRerankingNMTTask", +] diff --git a/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py b/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py new file mode 100644 index 0000000000..0e7fbba888 --- /dev/null +++ b/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py @@ -0,0 +1,475 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import itertools +import logging +import os + +import numpy as np +import torch + +from fairseq import metrics +from fairseq.data import ( + ConcatDataset, + ConcatSentencesDataset, + data_utils, + Dictionary, + IdDataset, + indexed_dataset, + NestedDictionaryDataset, + NumSamplesDataset, + NumelDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + SortDataset, + TruncateDataset, + TokenBlockDataset, +) +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II, MISSING + + +EVAL_BLEU_ORDER = 4 +TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"]) + +logger = logging.getLogger(__name__) + + +@dataclass +class DiscriminativeRerankingNMTConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + num_data_splits: int = field( + default=1, metadata={"help": "total number of data splits"} + ) + no_shuffle: bool = field( + default=False, metadata={"help": "do not shuffle training data"} + ) + max_positions: int = field( + default=512, metadata={"help": "number of positional embeddings to learn"} + ) + include_src: bool = field( + default=False, metadata={"help": "include source sentence"} + ) + mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"}) + eval_target_metric: bool = field( + default=False, + metadata={"help": "evaluation with the target metric during validation"}, + ) + target_metric: TARGET_METRIC_CHOICES = field( + default="bleu", metadata={"help": "name of the target metric to optimize for"} + ) + train_subset: str = field( + default=II("dataset.train_subset"), + metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, + ) + seed: int = field( + default=II("common.seed"), + metadata={"help": "pseudo random number generator seed"}, + ) + + +class RerankerScorer(object): + """Scores the target for a given (source (optional), target) input.""" + + def __init__(self, args, mt_beam): + self.mt_beam = mt_beam + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + """Score a batch of translations.""" + net_input = sample["net_input"] + + assert len(models) == 1, "does not support model ensemble" + model = models[0] + + bs = net_input["src_tokens"].shape[0] + assert ( + model.joint_classification == "none" or bs % self.mt_beam == 0 + ), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})" + + model.eval() + logits = model(**net_input) + + batch_out = model.sentence_forward(logits, net_input["src_tokens"]) + if model.joint_classification == "sent": + batch_out = model.joint_forward( + batch_out.view(self.mt_beam, bs // self.mt_beam, -1) + ) + scores = model.classification_forward( + batch_out.view(bs, 1, -1) + ) # input: B x T x C + + return scores + + +@register_task( + "discriminative_reranking_nmt", dataclass=DiscriminativeRerankingNMTConfig +) +class DiscriminativeRerankingNMTTask(FairseqTask): + """ + Translation rerank task. + The input can be either (src, tgt) sentence pairs or tgt sentence only. + """ + + cfg: DiscriminativeRerankingNMTConfig + + def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None): + super().__init__(cfg) + self.dictionary = data_dictionary + self._max_positions = cfg.max_positions + # args.tokens_per_sample = self._max_positions + # self.num_classes = 1 # for model + + @classmethod + def load_dictionary(cls, cfg, filename): + """Load the dictionary from the filename""" + dictionary = Dictionary.load(filename) + dictionary.add_symbol("<mask>") # for loading pretrained XLMR model + + return dictionary + + @classmethod + def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs): + # load data dictionary (assume joint dictionary) + data_path = cfg.data + data_dict = cls.load_dictionary( + cfg, os.path.join(data_path, "input_src/dict.txt") + ) + + logger.info("[input] src dictionary: {} types".format(len(data_dict))) + + return DiscriminativeRerankingNMTTask(cfg, data_dict) + + def load_dataset(self, split, epoch=0, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + if self.cfg.data.endswith("1"): + data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 + data_path = self.cfg.data[:-1] + str(data_shard) + else: + data_path = self.cfg.data + + def get_path(type, data_split): + return os.path.join(data_path, str(type), data_split) + + def make_dataset(type, dictionary, data_split, combine): + split_path = get_path(type, data_split) + + dataset = data_utils.load_indexed_dataset( + split_path, dictionary, combine=combine, + ) + return dataset + + def load_split(data_split, metric): + input_src = None + if self.cfg.include_src: + input_src = make_dataset( + "input_src", self.dictionary, data_split, combine=False + ) + assert input_src is not None, "could not find dataset: {}".format( + get_path("input_src", data_split) + ) + + input_tgt = make_dataset( + "input_tgt", self.dictionary, data_split, combine=False + ) + assert input_tgt is not None, "could not find dataset: {}".format( + get_path("input_tgt", data_split) + ) + + label_path = f"{get_path(metric, data_split)}.{metric}" + assert os.path.exists(label_path), f"could not find dataset: {label_path}" + + np_labels = np.loadtxt(label_path) + if self.cfg.target_metric == "ter": + np_labels = -np_labels + label = RawLabelDataset(np_labels) + + return input_src, input_tgt, label + + src_datasets = [] + tgt_datasets = [] + label_datasets = [] + + if split == self.cfg.train_subset: + for k in itertools.count(): + split_k = "train" + (str(k) if k > 0 else "") + prefix = os.path.join(data_path, "input_tgt", split_k) + if not indexed_dataset.dataset_exists(prefix, impl=None): + if k > 0: + break + else: + raise FileNotFoundError(f"Dataset not found: {prefix}") + input_src, input_tgt, label = load_split( + split_k, self.cfg.target_metric + ) + src_datasets.append(input_src) + tgt_datasets.append(input_tgt) + label_datasets.append(label) + else: + input_src, input_tgt, label = load_split(split, self.cfg.target_metric) + src_datasets.append(input_src) + tgt_datasets.append(input_tgt) + label_datasets.append(label) + + if len(tgt_datasets) == 1: + input_tgt, label = tgt_datasets[0], label_datasets[0] + if self.cfg.include_src: + input_src = src_datasets[0] + else: + input_tgt = ConcatDataset(tgt_datasets) + label = ConcatDataset(label_datasets) + if self.cfg.include_src: + input_src = ConcatDataset(src_datasets) + + input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) + if self.cfg.include_src: + input_src = PrependTokenDataset(input_src, self.dictionary.bos()) + input_src = TruncateDataset(input_src, self.cfg.max_positions) + src_lengths = NumelDataset(input_src, reduce=False) + src_tokens = ConcatSentencesDataset(input_src, input_tgt) + else: + src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) + src_lengths = NumelDataset(src_tokens, reduce=False) + + dataset = { + "id": IdDataset(), + "net_input": { + "src_tokens": RightPadDataset( + src_tokens, pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": src_lengths, + }, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), + "target": label, + } + + dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) + + assert len(dataset) % self.cfg.mt_beam == 0, ( + "dataset size (%d) is not a multiple of beam size (%d)" + % (len(dataset), self.cfg.mt_beam) + ) + + # no need to shuffle valid/test sets + if not self.cfg.no_shuffle and split == self.cfg.train_subset: + + # need to keep all hypothese together + start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) + with data_utils.numpy_seed(self.cfg.seed + epoch): + np.random.shuffle(start_idx) + + idx = np.arange(0, self.cfg.mt_beam) + shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( + start_idx, (self.cfg.mt_beam, 1) + ).transpose().reshape(-1) + + dataset = SortDataset(dataset, sort_order=[shuffle],) + + logger.info(f"Loaded {split} with #samples: {len(dataset)}") + + self.datasets[split] = dataset + return self.datasets[split] + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + assert not self.cfg.include_src or len(src_tokens[0]) == 2 + input_src = None + if self.cfg.include_src: + input_src = TokenBlockDataset( + [t[0] for t in src_tokens], + [l[0] for l in src_lengths], + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ) + input_src = PrependTokenDataset(input_src, self.dictionary.bos()) + input_src = TruncateDataset(input_src, self.cfg.max_positions) + + input_tgt = TokenBlockDataset( + [t[-1] for t in src_tokens], + [l[-1] for l in src_lengths], + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ) + input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) + if self.cfg.include_src: + src_tokens = ConcatSentencesDataset(input_src, input_tgt) + src_lengths = NumelDataset(input_src, reduce=False) + else: + input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos()) + src_tokens = input_tgt + src_lengths = NumelDataset(src_tokens, reduce=False) + + dataset = { + "id": IdDataset(), + "net_input": { + "src_tokens": RightPadDataset( + src_tokens, pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": src_lengths, + }, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), + } + + return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) + + def build_model(self, cfg: FairseqDataclass): + return super().build_model(cfg) + + def build_generator(self, args): + return RerankerScorer(args, mt_beam=self.cfg.mt_beam) + + def max_positions(self): + return self._max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + def create_dummy_batch(self, device): + dummy_target = ( + torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device) + if not self.cfg.eval_ter + else torch.zeros(self.cfg.mt_beam, 3).long().to(device) + ) + + return { + "id": torch.zeros(self.cfg.mt_beam, 1).long().to(device), + "net_input": { + "src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device), + "src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device), + }, + "nsentences": 0, + "ntokens": 0, + "target": dummy_target, + } + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + if ignore_grad and sample is None: + sample = self.create_dummy_batch(model.device) + + return super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad + ) + + def valid_step(self, sample, model, criterion): + if sample is None: + sample = self.create_dummy_batch(model.device) + + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if not self.cfg.eval_target_metric: + return loss, sample_size, logging_output + + scores = logging_output["scores"] + + if self.cfg.target_metric == "bleu": + assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, ( + "target does not contain enough information (" + + str(sample["target"].shape[1]) + + "for evaluating BLEU" + ) + + max_id = torch.argmax(scores, dim=1) + select_id = max_id + torch.arange( + 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam + ).to(max_id.device) + bleu_data = sample["target"][select_id, 1:].sum(0).data + + logging_output["_bleu_sys_len"] = bleu_data[0] + logging_output["_bleu_ref_len"] = bleu_data[1] + + for i in range(EVAL_BLEU_ORDER): + logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i] + logging_output["_bleu_totals_" + str(i)] = bleu_data[ + 2 + EVAL_BLEU_ORDER + i + ] + + elif self.cfg.target_metric == "ter": + assert sample["target"].shape[1] == 3, ( + "target does not contain enough information (" + + str(sample["target"].shape[1]) + + "for evaluating TER" + ) + + max_id = torch.argmax(scores, dim=1) + select_id = max_id + torch.arange( + 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam + ).to(max_id.device) + ter_data = sample["target"][select_id, 1:].sum(0).data + + logging_output["_ter_num_edits"] = -ter_data[0] + logging_output["_ter_ref_len"] = -ter_data[1] + + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if not self.cfg.eval_target_metric: + return + + def sum_logs(key): + return sum(log.get(key, 0) for log in logging_outputs) + + if self.cfg.target_metric == "bleu": + counts, totals = [], [] + for i in range(EVAL_BLEU_ORDER): + counts.append(sum_logs("_bleu_counts_" + str(i))) + totals.append(sum_logs("_bleu_totals_" + str(i))) + + if max(totals) > 0: + # log counts as numpy arrays -- log_scalar will sum them correctly + metrics.log_scalar("_bleu_counts", np.array(counts)) + metrics.log_scalar("_bleu_totals", np.array(totals)) + metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) + metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) + + def compute_bleu(meters): + import inspect + import sacrebleu + + fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] + if "smooth_method" in fn_sig: + smooth = {"smooth_method": "exp"} + else: + smooth = {"smooth": "exp"} + bleu = sacrebleu.compute_bleu( + correct=meters["_bleu_counts"].sum, + total=meters["_bleu_totals"].sum, + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + **smooth, + ) + return round(bleu.score, 2) + + metrics.log_derived("bleu", compute_bleu) + elif self.cfg.target_metric == "ter": + num_edits = sum_logs("_ter_num_edits") + ref_len = sum_logs("_ter_ref_len") + + if ref_len > 0: + metrics.log_scalar("_ter_num_edits", num_edits) + metrics.log_scalar("_ter_ref_len", ref_len) + + def compute_ter(meters): + score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum + return round(score.item(), 2) + + metrics.log_derived("ter", compute_ter) From 4497897fa506187db96cd9f00614f6dd7080b550 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sun, 11 Jul 2021 15:21:06 -0700 Subject: [PATCH 642/707] respect roberta model name to init arch (#2049) Summary: previous PR was always applying base architecture when finetuning roberta Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2049 Reviewed By: arbabu123 Differential Revision: D29642089 Pulled By: alexeib fbshipit-source-id: e8c6ec6e1fa68bef233f7221e0e02787d7ad2c06 --- examples/roberta/config/finetuning/cola.yaml | 5 ++++- examples/roberta/config/finetuning/mnli.yaml | 5 ++++- examples/roberta/config/finetuning/mrpc.yaml | 5 ++++- examples/roberta/config/finetuning/qnli.yaml | 5 ++++- examples/roberta/config/finetuning/qqp.yaml | 5 ++++- examples/roberta/config/finetuning/rte.yaml | 5 ++++- examples/roberta/config/finetuning/sst_2.yaml | 5 ++++- examples/roberta/config/finetuning/sts_b.yaml | 6 ++++-- fairseq/models/__init__.py | 11 +++++++++++ fairseq/models/roberta/model.py | 2 +- 10 files changed, 44 insertions(+), 10 deletions(-) diff --git a/examples/roberta/config/finetuning/cola.yaml b/examples/roberta/config/finetuning/cola.yaml index 717069d407..ac76611201 100644 --- a/examples/roberta/config/finetuning/cola.yaml +++ b/examples/roberta/config/finetuning/cola.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/mnli.yaml b/examples/roberta/config/finetuning/mnli.yaml index 4bfc02bed9..5be10c362f 100644 --- a/examples/roberta/config/finetuning/mnli.yaml +++ b/examples/roberta/config/finetuning/mnli.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/mrpc.yaml b/examples/roberta/config/finetuning/mrpc.yaml index 907b4639c1..aa8b7db393 100644 --- a/examples/roberta/config/finetuning/mrpc.yaml +++ b/examples/roberta/config/finetuning/mrpc.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/qnli.yaml b/examples/roberta/config/finetuning/qnli.yaml index 00aea91e56..b4595b090e 100644 --- a/examples/roberta/config/finetuning/qnli.yaml +++ b/examples/roberta/config/finetuning/qnli.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/qqp.yaml b/examples/roberta/config/finetuning/qqp.yaml index dc0296d26e..5a2b2ed743 100644 --- a/examples/roberta/config/finetuning/qqp.yaml +++ b/examples/roberta/config/finetuning/qqp.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/rte.yaml b/examples/roberta/config/finetuning/rte.yaml index 40dfd76169..7318465011 100644 --- a/examples/roberta/config/finetuning/rte.yaml +++ b/examples/roberta/config/finetuning/rte.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/sst_2.yaml b/examples/roberta/config/finetuning/sst_2.yaml index b808a850cb..a93ad2f22c 100644 --- a/examples/roberta/config/finetuning/sst_2.yaml +++ b/examples/roberta/config/finetuning/sst_2.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -21,6 +23,7 @@ checkpoint: reset_meters: true best_checkpoint_metric: accuracy maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +54,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/examples/roberta/config/finetuning/sts_b.yaml b/examples/roberta/config/finetuning/sts_b.yaml index d354bb97dd..2d495221ad 100644 --- a/examples/roberta/config/finetuning/sts_b.yaml +++ b/examples/roberta/config/finetuning/sts_b.yaml @@ -5,6 +5,8 @@ common: fp16_init_scale: 4 threshold_loss_scale: 1 fp16_scale_window: 128 + log_format: json + log_interval: 200 task: _name: sentence_prediction @@ -19,7 +21,7 @@ checkpoint: reset_optimizer: true reset_dataloader: true reset_meters: true - best_checkpoint_metric: accuracy + no_epoch_checkpoints: true distributed_training: find_unused_parameters: true @@ -51,6 +53,6 @@ optimization: max_epoch: 10 model: - _name: roberta_large + _name: roberta dropout: 0.1 attention_dropout: 0.1 diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 61425c8ef5..05d2a9d087 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -7,10 +7,12 @@ import argparse import importlib import os +from contextlib import ExitStack from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import merge_with_parent, populate_dataclass from hydra.core.config_store import ConfigStore +from omegaconf import open_dict, OmegaConf from .composite_encoder import CompositeEncoder from .distributed_fairseq_model import DistributedFairseqModel @@ -80,10 +82,19 @@ def build_model(cfg: FairseqDataclass, task): if model_type in MODEL_DATACLASS_REGISTRY: # set defaults from dataclass. note that arch name and model name can be the same dc = MODEL_DATACLASS_REGISTRY[model_type] + if isinstance(cfg, argparse.Namespace): cfg = populate_dataclass(dc(), cfg) else: cfg = merge_with_parent(dc(), cfg) + else: + if model_type in ARCH_CONFIG_REGISTRY: + with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): + # this calls the different "arch" functions (like base_architecture()) that you indicate + # if you specify --arch on the command line. this is only applicable to the old argparse based models + # hydra models should expose different architectures via different config files + # it will modify the cfg object and default parameters according to the arch + ARCH_CONFIG_REGISTRY[model_type](cfg) assert model is not None, ( f"Could not infer model type from {cfg}. " diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 39a1cdd951..3337616be6 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -523,7 +523,7 @@ def safe_getattr(obj, k, default=None): from omegaconf import OmegaConf if OmegaConf.is_config(obj): - return obj.k if k in obj and obj.k is not None else default + return obj[k] if k in obj and obj[k] is not None else default return getattr(obj, k, default) From 1a1380e5a8b0cce49090676d95044626f208c48c Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Sun, 11 Jul 2021 21:25:45 -0700 Subject: [PATCH 643/707] fix hydra upgrade (#2051) Summary: commit 72c0e4f36d150b8244260fba6228b94ff95914bf added "config_schema" to config and initialize.py which did not exist, and caused crashes for any interpreted keys (e.g. "common.tpu" in task) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2051 Reviewed By: arbabu123 Differential Revision: D29654730 Pulled By: alexeib fbshipit-source-id: bac18e3d2011de7c822ffcd7e6d4d4364a8edc52 --- fairseq/config/config.yaml | 1 - fairseq/dataclass/initialize.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index 087083e88a..e20d914b9b 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -5,7 +5,6 @@ hydra: dir: . defaults: - - config_schema - task: null - model: null - criterion: cross_entropy diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 1d5c90eefd..e43b31790e 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -16,7 +16,7 @@ def hydra_init(cfg_name="config") -> None: cs = ConfigStore.instance() - cs.store(name=f"{cfg_name}_schema", node=FairseqConfig) + cs.store(name=f"{cfg_name}", node=FairseqConfig) for k in FairseqConfig.__dataclass_fields__: v = FairseqConfig.__dataclass_fields__[k].default From d26dcedaf97e556da8dcab10b5de08cea72c3459 Mon Sep 17 00:00:00 2001 From: Jialu Joann Li <jialuli3@fb.com> Date: Mon, 12 Jul 2021 13:55:29 -0700 Subject: [PATCH 644/707] pyspeech embedding extractor Summary: Transformer to extract embeddings from pyspeech encoder models Reviewed By: vimalmanohar Differential Revision: D29492082 fbshipit-source-id: abf2432f43c75b93dcccb1a389039e5628f95e92 --- fairseq/models/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 05d2a9d087..c5a4bbc831 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -98,10 +98,10 @@ def build_model(cfg: FairseqDataclass, task): assert model is not None, ( f"Could not infer model type from {cfg}. " - f"Available models: " - + str(MODEL_DATACLASS_REGISTRY.keys()) - + " Requested model type: " - + model_type + "Available models: {}".format( + MODEL_DATACLASS_REGISTRY.keys() + ) + + f" Requested model type: {model_type}" ) return model.build_model(cfg, task) From 313ff0581561c7725ea9430321d6af2901573dfb Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Mon, 12 Jul 2021 17:58:00 -0700 Subject: [PATCH 645/707] migrate masked lm task and criterion (#2050) Summary: migrates masked lm task and criterion. old command line flags still continue to work as before Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2050 Reviewed By: arbabu123 Differential Revision: D29648056 Pulled By: alexeib fbshipit-source-id: 33079a461826ab6d1736d6016ee5e3522bcc5427 --- examples/roberta/README.pretraining.md | 28 +-- examples/roberta/config/pretraining/base.yaml | 42 ++++ fairseq/criterions/masked_lm.py | 15 +- fairseq/tasks/masked_lm.py | 201 +++++++++--------- 4 files changed, 159 insertions(+), 127 deletions(-) create mode 100644 examples/roberta/config/pretraining/base.yaml diff --git a/examples/roberta/README.pretraining.md b/examples/roberta/README.pretraining.md index 8b6e10c08c..a4e7453529 100644 --- a/examples/roberta/README.pretraining.md +++ b/examples/roberta/README.pretraining.md @@ -48,35 +48,21 @@ fairseq-preprocess \ ### 2) Train RoBERTa base ```bash -TOTAL_UPDATES=125000 # Total number of training steps -WARMUP_UPDATES=10000 # Warmup the learning rate over this many updates -PEAK_LR=0.0005 # Peak learning rate, adjust as needed -TOKENS_PER_SAMPLE=512 # Max sequence length -MAX_POSITIONS=512 # Num. positional embeddings (usually same as above) -MAX_SENTENCES=16 # Number of sequences per batch (batch size) -UPDATE_FREQ=16 # Increase the batch size 16x - DATA_DIR=data-bin/wikitext-103 -fairseq-train --fp16 $DATA_DIR \ - --task masked_lm --criterion masked_lm \ - --arch roberta_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \ - --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \ - --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ - --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ - --batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \ - --max-update $TOTAL_UPDATES --log-format simple --log-interval 1 +fairseq-hydra-train -m --config-dir examples/roberta/config/pretraining \ +--config-name base task.data=$DATA_DIR ``` **Note:** You can optionally resume training the released RoBERTa base model by -adding `--restore-file /path/to/roberta.base/model.pt`. +adding `checkpoint.restore_file=/path/to/roberta.base/model.pt`. **Note:** The above command assumes training on 8x32GB V100 GPUs. Each GPU uses -a batch size of 16 sequences (`$MAX_SENTENCES`) and accumulates gradients to -further increase the batch size by 16x (`$UPDATE_FREQ`), for a total batch size +a batch size of 16 sequences (`dataset.batch_size`) and accumulates gradients to +further increase the batch size by 16x (`optimization.update_freq`), for a total batch size of 2048 sequences. If you have fewer GPUs or GPUs with less memory you may need -to reduce `$MAX_SENTENCES` and increase `$UPDATE_FREQ` to compensate. -Alternatively if you have more GPUs you can decrease `$UPDATE_FREQ` accordingly +to reduce `dataset.batch_size` and increase dataset.update_freq to compensate. +Alternatively if you have more GPUs you can decrease `dataset.update_freq` accordingly to increase training speed. **Note:** The learning rate and batch size are tightly connected and need to be diff --git a/examples/roberta/config/pretraining/base.yaml b/examples/roberta/config/pretraining/base.yaml new file mode 100644 index 0000000000..97829908f7 --- /dev/null +++ b/examples/roberta/config/pretraining/base.yaml @@ -0,0 +1,42 @@ +# @package _group_ +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + +task: + _name: masked_lm + data: ??? + sample_break_mode: complete + tokens_per_sample: 512 + +criterion: masked_lm + +dataset: + batch_size: 16 + ignore_unused_valid_subsets: true + +optimizer: + _name: adam + weight_decay: 0.01 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 10000 + +optimization: + clip_norm: 0 + lr: [0.0005] + max_update: 125000 + update_freq: [16] + +model: + _name: roberta + max_positions: 512 + dropout: 0.1 + attention_dropout: 0.1 diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index b04cfbff6d..279458f317 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -3,23 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass import math +from omegaconf import II import torch -import torch.nn.functional as F from fairseq import metrics, modules, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass -@register_criterion("masked_lm") +@dataclass +class MaskedLmConfig(FairseqDataclass): + tpu: bool = II("common.tpu") + + +@register_criterion("masked_lm", dataclass=MaskedLmConfig) class MaskedLmLoss(FairseqCriterion): """ Implementation for the loss used in masked language model (MLM) training. """ - def __init__(self, task, tpu=False): + def __init__(self, cfg: MaskedLmConfig, task): super().__init__(task) - self.tpu = tpu + self.tpu = cfg.tpu def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index fd2ea6ade1..0c08132fb7 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -3,9 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field import logging import os +from omegaconf import MISSING, II, OmegaConf + import numpy as np from fairseq import utils from fairseq.data import ( @@ -23,108 +26,103 @@ ) from fairseq.data.encoders.utils import get_whole_word_mask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES logger = logging.getLogger(__name__) -@register_task("masked_lm") -class MaskedLMTask(LegacyFairseqTask): - """Task for training masked language models (e.g., BERT, RoBERTa).""" - - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument( - "data", - help="colon separated path to data directories list, \ - will be iterated upon during epochs in round-robin manner", - ) - parser.add_argument( - "--sample-break-mode", - default="complete", - choices=["none", "complete", "complete_doc", "eos"], - help='If omitted or "none", fills each sample with tokens-per-sample ' +@dataclass +class MaskedLMConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' 'tokens. If set to "complete", splits samples only at the end ' "of sentence, but may include multiple sentences per sample. " '"complete_doc" is similar but respects doc boundaries. ' - 'If set to "eos", includes only one sentence per sample.', - ) - parser.add_argument( - "--tokens-per-sample", - default=512, - type=int, - help="max number of total tokens over all segments " - "per sample for BERT dataset", - ) - parser.add_argument( - "--mask-prob", - default=0.15, - type=float, - help="probability of replacing a token with mask", - ) - parser.add_argument( - "--leave-unmasked-prob", - default=0.1, - type=float, - help="probability that a masked token is unmasked", - ) - parser.add_argument( - "--random-token-prob", - default=0.1, - type=float, - help="probability of replacing a token with a random token", - ) - parser.add_argument( - "--freq-weighted-replacement", - default=False, - action="store_true", - help="sample random replacement words based on word frequencies", - ) - parser.add_argument( - "--mask-whole-words", - default=False, - action="store_true", - help="mask whole words; you may also want to set --bpe", - ) - parser.add_argument( - "--mask-multiple-length", - default=1, - type=int, - help="repeat the mask indices multiple times", - ) - parser.add_argument( - "--mask-stdev", default=0.0, type=float, help="stdev of the mask length" - ) - parser.add_argument( - "--shorten-method", - default="none", - choices=["none", "truncate", "random_crop"], - help="if not none, shorten sequences that exceed --tokens-per-sample", - ) - parser.add_argument( - "--shorten-data-split-list", - default="", - help="comma-separated list of dataset splits to apply shortening to, " - 'e.g., "train,valid" (default: all dataset splits)', - ) + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + mask_prob: float = field( + default=0.15, + metadata={"help": "probability of replacing a token with mask"}, + ) + leave_unmasked_prob: float = field( + default=0.1, + metadata={"help": "probability that a masked token is unmasked"}, + ) + random_token_prob: float = field( + default=0.1, + metadata={"help": "probability of replacing a token with a random token"}, + ) + freq_weighted_replacement: bool = field( + default=False, + metadata={"help": "sample random replacement words based on word frequencies"}, + ) + mask_whole_words: bool = field( + default=False, + metadata={"help": "mask whole words; you may also want to set --bpe"}, + ) + mask_multiple_length: int = field( + default=1, + metadata={"help": "repeat the mask indices multiple times"}, + ) + mask_stdev: float = field( + default=0.0, + metadata={"help": "stdev of the mask length"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + + +@register_task("masked_lm", dataclass=MaskedLMConfig) +class MaskedLMTask(FairseqTask): + + cfg: MaskedLMConfig + + """Task for training masked language models (e.g., BERT, RoBERTa).""" - def __init__(self, args, dictionary): - super().__init__(args) + def __init__(self, cfg: MaskedLMConfig, dictionary): + super().__init__(cfg) self.dictionary = dictionary - self.seed = args.seed # add mask token self.mask_idx = dictionary.add_symbol("<mask>") @classmethod - def setup_task(cls, args, **kwargs): - paths = utils.split_paths(args.data) + def setup_task(cls, cfg: MaskedLMConfig, **kwargs): + paths = utils.split_paths(cfg.data) assert len(paths) > 0 dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) - return cls(args, dictionary) + return cls(cfg, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -132,7 +130,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) @@ -140,7 +138,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, - self.args.dataset_impl, combine=combine, ) if dataset is None: @@ -151,20 +148,20 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): dataset = maybe_shorten_dataset( dataset, split, - self.args.shorten_data_split_list, - self.args.shorten_method, - self.args.tokens_per_sample, - self.args.seed, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, - self.args.tokens_per_sample - 1, # one less for <s> + self.cfg.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), - break_mode=self.args.sample_break_mode, + break_mode=self.cfg.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) @@ -174,7 +171,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # create masked input and targets mask_whole_words = ( get_whole_word_mask(self.args, self.source_dictionary) - if self.args.mask_whole_words + if self.cfg.mask_whole_words else None ) @@ -183,17 +180,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, - seed=self.args.seed, - mask_prob=self.args.mask_prob, - leave_unmasked_prob=self.args.leave_unmasked_prob, - random_token_prob=self.args.random_token_prob, - freq_weighted_replacement=self.args.freq_weighted_replacement, + seed=self.cfg.seed, + mask_prob=self.cfg.mask_prob, + leave_unmasked_prob=self.cfg.leave_unmasked_prob, + random_token_prob=self.cfg.random_token_prob, + freq_weighted_replacement=self.cfg.freq_weighted_replacement, mask_whole_words=mask_whole_words, - mask_multiple_length=self.args.mask_multiple_length, - mask_stdev=self.args.mask_stdev, + mask_multiple_length=self.cfg.mask_multiple_length, + mask_stdev=self.cfg.mask_stdev, ) - with data_utils.numpy_seed(self.args.seed): + with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( @@ -227,7 +224,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): TokenBlockDataset( src_tokens, src_lengths, - self.args.tokens_per_sample - 1, # one less for <s> + self.cfg.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", From f2146bdc7abf293186de9449bfa2272775e39e1d Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <wnhsu@csail.mit.edu> Date: Mon, 12 Jul 2021 20:28:51 -0700 Subject: [PATCH 646/707] refactor clustering in hubert example (#2018) Summary: ## What does this PR do? - Fix edge cases of sharded feature extraction in example/hubert/simple_kmeans, where some shard has 0 samples. - Refactor `example/hubert/simple_kmeans/dump_*_features.py` ## Test Tested on 2 iterations of clustering from MFCC and HUBERT features Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2018 Reviewed By: hikushalhere Differential Revision: D29523927 Pulled By: wnhsu fbshipit-source-id: b089f3d38bd464a940954d358cd5ef091e4892a2 --- .../simple_kmeans/dump_hubert_feature.py | 54 ++--------- .../simple_kmeans/dump_hubert_feature_s2t.py | 54 ++--------- .../hubert/simple_kmeans/dump_mfcc_feature.py | 48 +--------- .../hubert/simple_kmeans/dump_w2v2_feature.py | 95 +++++++++++++++++++ .../hubert/simple_kmeans/feature_utils.py | 66 +++++++++++++ 5 files changed, 183 insertions(+), 134 deletions(-) create mode 100644 examples/hubert/simple_kmeans/dump_w2v2_feature.py create mode 100644 examples/hubert/simple_kmeans/feature_utils.py diff --git a/examples/hubert/simple_kmeans/dump_hubert_feature.py b/examples/hubert/simple_kmeans/dump_hubert_feature.py index cd242890e5..5c7b67f8b1 100644 --- a/examples/hubert/simple_kmeans/dump_hubert_feature.py +++ b/examples/hubert/simple_kmeans/dump_hubert_feature.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import logging -import math import os import sys @@ -12,8 +11,9 @@ import soundfile as sf import torch import torch.nn.functional as F -import tqdm -from npy_append_array import NpyAppendArray + +from feature_utils import get_path_iterator, dump_feature + logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -66,53 +66,13 @@ def get_feats(self, path, ref_len=None): output_layer=self.layer, ) feat.append(feat_chunk) - return torch.cat(feat, 1).squeeze(0) - - -def get_path_iterator(tsv, nshard, rank): - with open(tsv, "r") as f: - root = f.readline().rstrip() - lines = [line.rstrip() for line in f] - tot = len(lines) - shard_size = math.ceil(tot / nshard) - start, end = rank * shard_size, min((rank + 1) * shard_size, tot) - assert start < end, "start={start}, end={end}" - logger.info( - f"rank {rank} of {nshard}, process {end-start} " - f"({start}-{end}) out of {tot}" - ) - - lines = lines[start:end] + return torch.cat(feat, 1).squeeze(0) - def iterate(): - for line in lines: - subpath, nsample = line.split("\t") - yield f"{root}/{subpath}", int(nsample) - return iterate, len(lines) - - -def dump_feature( - tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk -): +def main(tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk): reader = HubertFeatureReader(ckpt_path, layer, max_chunk) generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) - iterator = generator() - - feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" - leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" - - os.makedirs(feat_dir, exist_ok=True) - if os.path.exists(feat_path): - os.remove(feat_path) - - feat_f = NpyAppendArray(feat_path) - with open(leng_path, "w") as leng_f: - for path, nsample in tqdm.tqdm(iterator, total=num): - feat = reader.get_feats(path, nsample) - feat_f.append(feat.cpu().numpy()) - leng_f.write(f"{len(feat)}\n") - logger.info("finished successfully") + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) if __name__ == "__main__": @@ -130,4 +90,4 @@ def dump_feature( args = parser.parse_args() logger.info(args) - dump_feature(**vars(args)) + main(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py b/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py index 7ec8a7311b..6fff4faf44 100644 --- a/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py +++ b/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py @@ -6,18 +6,17 @@ import csv import io import logging -import math import os import os.path as op import sys -import tqdm from dump_hubert_feature import HubertFeatureReader +from feature_utils import get_shard_range, dump_feature from fairseq.data.audio.audio_utils import get_waveform from fairseq.data.audio.speech_to_text_dataset import ( read_from_uncompressed_zip, ) -from npy_append_array import NpyAppendArray + logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -57,54 +56,21 @@ def get_path_iterator(root, tsv, nshard, rank): quoting=csv.QUOTE_NONE, ) subpaths = [op.join(root, e["audio"]) for e in reader] - - tot = len(subpaths) - shard_size = math.ceil(tot / nshard) - start, end = rank * shard_size, min((rank + 1) * shard_size, tot) - assert start < end, "start={start}, end={end}" - logger.info( - f"rank {rank} of {nshard}, process {end-start} " - f"({start}-{end}) out of {tot}" - ) - + start, end = get_shard_range(len(subpaths), nshard, rank) subpaths = subpaths[start:end] - def iterate(): for subpath in subpaths: - yield op.join(root, subpath) - - return iterate, len(subpaths) + yield op.join(root, subpath), None + return iterate, len(subpaths) -def dump_feature( - root, - tsv_path, - ckpt_path, - layer, - nshard, - rank, - feat_dir, - feat_name, - max_chunk, +def main( + root, tsv_path, ckpt_path, layer, nshard, rank, feat_dir, split, max_chunk ): reader = HubertFeatureReaderS2T(ckpt_path, layer, max_chunk) generator, num = get_path_iterator(root, tsv_path, nshard, rank) - iterator = generator() - - feat_path = f"{feat_dir}/{feat_name}_{rank}_{nshard}.npy" - leng_path = f"{feat_dir}/{feat_name}_{rank}_{nshard}.len" - - os.makedirs(feat_dir, exist_ok=True) - if op.exists(feat_path): - os.remove(feat_path) + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) - feat_f = NpyAppendArray(feat_path) - with open(leng_path, "w") as leng_f: - for path in tqdm.tqdm(iterator, total=num): - feat = reader.get_feats(path) - feat_f.append(feat.cpu().numpy()) - leng_f.write(f"{len(feat)}\n") - logger.info("finished successfully") if __name__ == "__main__": @@ -118,9 +84,9 @@ def dump_feature( parser.add_argument("nshard", type=int) parser.add_argument("rank", type=int) parser.add_argument("feat_dir") - parser.add_argument("feat_name") + parser.add_argument("split") parser.add_argument("--max_chunk", type=int, default=1600000) args = parser.parse_args() logger.info(args) - dump_feature(**vars(args)) + main(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_mfcc_feature.py b/examples/hubert/simple_kmeans/dump_mfcc_feature.py index a36fa643bd..70d0016663 100644 --- a/examples/hubert/simple_kmeans/dump_mfcc_feature.py +++ b/examples/hubert/simple_kmeans/dump_mfcc_feature.py @@ -4,15 +4,14 @@ # LICENSE file in the root directory of this source tree. import logging -import math import os import sys import soundfile as sf import torch import torchaudio -import tqdm -from npy_append_array import NpyAppendArray + +from feature_utils import get_path_iterator, dump_feature logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -56,48 +55,11 @@ def get_feats(self, path, ref_len=None): return concat -def get_path_iterator(tsv, nshard, rank): - with open(tsv, "r") as f: - root = f.readline().rstrip() - lines = [line.rstrip() for line in f] - tot = len(lines) - shard_size = math.ceil(tot / nshard) - start, end = rank * shard_size, min((rank + 1) * shard_size, tot) - assert start < end, "start={start}, end={end}" - logger.info( - f"rank {rank} of {nshard}, process {end-start} " - f"({start}-{end}) out of {tot}" - ) - - lines = lines[start:end] - - def iterate(): - for line in lines: - subpath, nsample = line.split("\t") - yield f"{root}/{subpath}", int(nsample) - - return iterate, len(lines) - - -def dump_feature(tsv_dir, split, sample_rate, nshard, rank, feat_dir): +def main(tsv_dir, split, nshard, rank, feat_dir, sample_rate): reader = MfccFeatureReader(sample_rate) generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) - iterator = generator() - - feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" - leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" - - os.makedirs(feat_dir, exist_ok=True) - if os.path.exists(feat_path): - os.remove(feat_path) + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) - feat_f = NpyAppendArray(feat_path) - with open(leng_path, "w") as leng_f: - for path, nsample in tqdm.tqdm(iterator, total=num): - feat = reader.get_feats(path, nsample) - feat_f.append(feat.cpu().numpy()) - leng_f.write(f"{len(feat)}\n") - logger.info("finished successfully") if __name__ == "__main__": @@ -113,4 +75,4 @@ def dump_feature(tsv_dir, split, sample_rate, nshard, rank, feat_dir): args = parser.parse_args() logger.info(args) - dump_feature(**vars(args)) + main(**vars(args)) diff --git a/examples/hubert/simple_kmeans/dump_w2v2_feature.py b/examples/hubert/simple_kmeans/dump_w2v2_feature.py new file mode 100644 index 0000000000..a1f0d902ac --- /dev/null +++ b/examples/hubert/simple_kmeans/dump_w2v2_feature.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +import fairseq +import soundfile as sf +import torch +import torch.nn.functional as F + +from feature_utils import get_path_iterator, dump_feature + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("dump_w2v2_feature") + + +class Wav2Vec2FeatureReader(object): + def __init__(self, ckpt_path, layer, max_chunk=1600000): + ( + model, + cfg, + task, + ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) + self.model = model[0].eval().cuda() + self.task = task + self.layer = layer # assume this is 1-based like HuBERT + self.max_chunk = max_chunk + logger.info(f"TASK CONFIG:\n{self.task.cfg}") + logger.info(f" max_chunk = {self.max_chunk}") + logger.info(f" model:\n{self.model}") + + def read_audio(self, path, ref_len=None): + wav, sr = sf.read(path) + assert sr == self.task.cfg.sample_rate, sr + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, path, ref_len=None): + x = self.read_audio(path, ref_len) + with torch.no_grad(): + x = torch.from_numpy(x).float().cuda() + if self.task.cfg.normalize: + x = F.layer_norm(x, x.shape) + x = x.view(1, -1) + + feat = [] + for start in range(0, x.size(1), self.max_chunk): + x_chunk = x[:, start: start + self.max_chunk] + res = self.model.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + layer=self.layer - 1, + ) + feat_chunk = res["x"] + feat.append(feat_chunk) + return torch.cat(feat, 1).squeeze(0) + + +def main(tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk): + reader = Wav2Vec2FeatureReader(ckpt_path, layer, max_chunk) + generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("split") + parser.add_argument("ckpt_path") + parser.add_argument("layer", type=int) + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("--max_chunk", type=int, default=1600000) + args = parser.parse_args() + logger.info(args) + + main(**vars(args)) diff --git a/examples/hubert/simple_kmeans/feature_utils.py b/examples/hubert/simple_kmeans/feature_utils.py new file mode 100644 index 0000000000..f80bc45697 --- /dev/null +++ b/examples/hubert/simple_kmeans/feature_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +import tqdm +from npy_append_array import NpyAppendArray + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("feature_utils") + + +def get_shard_range(tot, nshard, rank): + assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}" + start = round(tot / nshard * rank) + end = round(tot / nshard * (rank + 1)) + assert start < end, f"start={start}, end={end}" + logger.info( + f"rank {rank} of {nshard}, process {end-start} " + f"({start}-{end}) out of {tot}" + ) + return start, end + + +def get_path_iterator(tsv, nshard, rank): + with open(tsv, "r") as f: + root = f.readline().rstrip() + lines = [line.rstrip() for line in f] + start, end = get_shard_range(len(lines), nshard, rank) + lines = lines[start:end] + def iterate(): + for line in lines: + subpath, nsample = line.split("\t") + yield f"{root}/{subpath}", int(nsample) + return iterate, len(lines) + + +def dump_feature(reader, generator, num, split, nshard, rank, feat_dir): + iterator = generator() + + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + + os.makedirs(feat_dir, exist_ok=True) + if os.path.exists(feat_path): + os.remove(feat_path) + + feat_f = NpyAppendArray(feat_path) + with open(leng_path, "w") as leng_f: + for path, nsample in tqdm.tqdm(iterator, total=num): + feat = reader.get_feats(path, nsample) + feat_f.append(feat.cpu().numpy()) + leng_f.write(f"{len(feat)}\n") + logger.info("finished successfully") + + From 597a8fc5e36110450c6a7251506d17b42b635643 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 647/707] Only import MultiheadAttention when doing typechecking Summary: This might create dependency loops in some cases and we don't use it otherwise. Reviewed By: myleott, dianaml0 Differential Revision: D29521387 fbshipit-source-id: b4c27426a5965cd864a0c7b353082756099fead5 --- fairseq/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fairseq/utils.py b/fairseq/utils.py index 4fe95b9e8b..d1ec9a274c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -12,13 +12,14 @@ import sys import warnings from itertools import accumulate -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING import torch import torch.nn.functional as F -from fairseq.modules.multihead_attention import MultiheadAttention from torch import Tensor +if TYPE_CHECKING: + from fairseq.modules.multihead_attention import MultiheadAttention try: from amp_C import multi_tensor_l2norm @@ -130,7 +131,7 @@ def _move_to_tpu(tensor): def get_incremental_state( - module: MultiheadAttention, + module: "MultiheadAttention", incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: @@ -139,7 +140,7 @@ def get_incremental_state( def set_incremental_state( - module: MultiheadAttention, + module: "MultiheadAttention", incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], From aa15dc9a1b4cd586dc832a971a5e31ad668f9e22 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 648/707] Transformer File Split Summary: transformer.py in fairseq is enormous. This doesn't change anything but just splits that file in the model file, the encoder file and the decoder file. We then adjust a few imports that were importing from the wrong modules or creating dependency loops. Reviewed By: myleott Differential Revision: D29533995 fbshipit-source-id: 776f6bbdadb08d729b0b521fd767b7ac6116a723 --- fairseq/models/nat/levenshtein_transformer.py | 3 +- fairseq/models/transformer.py | 1188 ----------------- fairseq/models/transformer/__init__.py | 44 + .../models/transformer/transformer_decoder.py | 452 +++++++ .../models/transformer/transformer_encoder.py | 322 +++++ .../models/transformer/transformer_model.py | 452 +++++++ 6 files changed, 1272 insertions(+), 1189 deletions(-) delete mode 100644 fairseq/models/transformer.py create mode 100644 fairseq/models/transformer/__init__.py create mode 100644 fairseq/models/transformer/transformer_decoder.py create mode 100644 fairseq/models/transformer/transformer_encoder.py create mode 100644 fairseq/models/transformer/transformer_model.py diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index 9377c3c7f5..d60d3c52d5 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -9,7 +9,8 @@ from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder -from fairseq.models.transformer import Embedding, TransformerDecoderLayer +from fairseq.models.transformer import Embedding +from fairseq.modules import TransformerDecoderLayer from fairseq.modules.transformer_sentence_encoder import init_bert_params from .levenshtein_utils import ( diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py deleted file mode 100644 index c2726af34d..0000000000 --- a/fairseq/models/transformer.py +++ /dev/null @@ -1,1188 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from fairseq import utils -from fairseq.distributed import fsdp_wrap -from fairseq.models import ( - FairseqEncoder, - FairseqEncoderDecoderModel, - FairseqIncrementalDecoder, - register_model, - register_model_architecture, -) -from fairseq.modules import ( - AdaptiveSoftmax, - BaseLayer, - FairseqDropout, - LayerDropModuleList, - LayerNorm, - PositionalEmbedding, - SinusoidalPositionalEmbedding, - TransformerDecoderLayer, - TransformerEncoderLayer, -) -from fairseq.modules.checkpoint_activations import checkpoint_wrapper -from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ -from torch import Tensor - - -DEFAULT_MAX_SOURCE_POSITIONS = 1024 -DEFAULT_MAX_TARGET_POSITIONS = 1024 - - -DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) - - -@register_model("transformer") -class TransformerModel(FairseqEncoderDecoderModel): - """ - Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) - <https://arxiv.org/abs/1706.03762>`_. - - Args: - encoder (TransformerEncoder): the encoder - decoder (TransformerDecoder): the decoder - - The Transformer model provides the following named architectures and - command-line arguments: - - .. argparse:: - :ref: fairseq.models.transformer_parser - :prog: - """ - - @classmethod - def hub_models(cls): - # fmt: off - - def moses_subword(path): - return { - 'path': path, - 'tokenizer': 'moses', - 'bpe': 'subword_nmt', - } - - def moses_fastbpe(path): - return { - 'path': path, - 'tokenizer': 'moses', - 'bpe': 'fastbpe', - } - - def spm(path): - return { - 'path': path, - 'bpe': 'sentencepiece', - 'tokenizer': 'space', - } - - return { - 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), - 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', - 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'), - 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'), - 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'), - 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'), - 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'), - 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'), - 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), - 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), - 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), - 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), - 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), - 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), - 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), - 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), - 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), - 'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'), - 'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'), - } - # fmt: on - - def __init__(self, args, encoder, decoder): - super().__init__(encoder, decoder) - self.args = args - self.supports_align_args = True - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', - help='dropout probability after activation in FFN.') - parser.add_argument('--encoder-embed-path', type=str, metavar='STR', - help='path to pre-trained encoder embedding') - parser.add_argument('--encoder-embed-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-layers', type=int, metavar='N', - help='num encoder layers') - parser.add_argument('--encoder-attention-heads', type=int, metavar='N', - help='num encoder attention heads') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') - parser.add_argument('--encoder-learned-pos', action='store_true', - help='use learned positional embeddings in the encoder') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') - parser.add_argument('--decoder-normalize-before', action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--decoder-output-dim', type=int, metavar='N', - help='decoder output dimension (extra linear layer ' - 'if different from decoder embed dim') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--share-all-embeddings', action='store_true', - help='share encoder, decoder and output embeddings' - ' (requires shared dictionary and embed dim)') - parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', - help='if set, disables positional embeddings (outside self attention)') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion'), - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - parser.add_argument('--layernorm-embedding', action='store_true', - help='add layernorm to embedding') - parser.add_argument('--no-scale-embedding', action='store_true', - help='if True, dont scale embeddings') - parser.add_argument('--checkpoint-activations', action='store_true', - help='checkpoint activations at each layer, which saves GPU ' - 'memory usage at the cost of some additional compute') - parser.add_argument('--offload-activations', action='store_true', - help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') - # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) - parser.add_argument('--no-cross-attention', default=False, action='store_true', - help='do not perform cross-attention') - parser.add_argument('--cross-self-attention', default=False, action='store_true', - help='perform cross+self-attention') - # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) - parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, - help='LayerDrop probability for encoder') - parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, - help='LayerDrop probability for decoder') - parser.add_argument('--encoder-layers-to-keep', default=None, - help='which layers to *keep* when pruning as a comma-separated list') - parser.add_argument('--decoder-layers-to-keep', default=None, - help='which layers to *keep* when pruning as a comma-separated list') - # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) - parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, - help='iterative PQ quantization noise at training time') - parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, - help='block size of quantization noise at training time') - parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, - help='scalar quantization noise and scalar quantization at training time') - # args for Fully Sharded Data Parallel (FSDP) training - parser.add_argument( - '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, - help=( - 'minimum number of params for a layer to be wrapped with FSDP() when ' - 'training with --ddp-backend=fully_sharded. Smaller values will ' - 'improve memory efficiency, but may make torch.distributed ' - 'communication less efficient due to smaller input sizes. This option ' - 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' - '--offload-activations are passed.' - ) - ) - # fmt: on - - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - - # make sure all arguments are present in older models - base_architecture(args) - - if args.encoder_layers_to_keep: - args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - if args.decoder_layers_to_keep: - args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) - - if getattr(args, "max_source_positions", None) is None: - args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS - if getattr(args, "max_target_positions", None) is None: - args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS - - src_dict, tgt_dict = task.source_dictionary, task.target_dictionary - - if args.share_all_embeddings: - if src_dict != tgt_dict: - raise ValueError("--share-all-embeddings requires a joined dictionary") - if args.encoder_embed_dim != args.decoder_embed_dim: - raise ValueError( - "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" - ) - if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path - ): - raise ValueError( - "--share-all-embeddings not compatible with --decoder-embed-path" - ) - encoder_embed_tokens = cls.build_embedding( - args, src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = encoder_embed_tokens - args.share_decoder_input_output_embed = True - else: - encoder_embed_tokens = cls.build_embedding( - args, src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = cls.build_embedding( - args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path - ) - if getattr(args, "offload_activations", False): - args.checkpoint_activations = True # offloading implies checkpointing - encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) - decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - if not args.share_all_embeddings: - min_params_to_wrap = getattr( - args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP - ) - # fsdp_wrap is a no-op when --ddp-backend != fully_sharded - encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) - decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) - return cls(args, encoder, decoder) - - @classmethod - def build_embedding(cls, args, dictionary, embed_dim, path=None): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - - emb = Embedding(num_embeddings, embed_dim, padding_idx) - # if provided, load from preloaded dictionaries - if path: - embed_dict = utils.parse_embedding(path) - utils.load_embedding(embed_dict, dictionary, emb) - return emb - - @classmethod - def build_encoder(cls, args, src_dict, embed_tokens): - return TransformerEncoder(args, src_dict, embed_tokens) - - @classmethod - def build_decoder(cls, args, tgt_dict, embed_tokens): - return TransformerDecoder( - args, - tgt_dict, - embed_tokens, - no_encoder_attn=getattr(args, "no_cross_attention", False), - ) - - # TorchScript doesn't support optional arguments with variable length (**kwargs). - # Current workaround is to add union of all arguments in child classes. - def forward( - self, - src_tokens, - src_lengths, - prev_output_tokens, - return_all_hiddens: bool = True, - features_only: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - ): - """ - Run the forward pass for an encoder-decoder model. - - Copied from the base class, but without ``**kwargs``, - which are not supported by TorchScript. - """ - encoder_out = self.encoder( - src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens - ) - decoder_out = self.decoder( - prev_output_tokens, - encoder_out=encoder_out, - features_only=features_only, - alignment_layer=alignment_layer, - alignment_heads=alignment_heads, - src_lengths=src_lengths, - return_all_hiddens=return_all_hiddens, - ) - return decoder_out - - # Since get_normalized_probs is in the Fairseq Model which is not scriptable, - # I rewrite the get_normalized_probs from Base Class to call the - # helper function in the Base Class. - @torch.jit.export - def get_normalized_probs( - self, - net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], - log_probs: bool, - sample: Optional[Dict[str, Tensor]] = None, - ): - """Get normalized probabilities (or log probs) from a net's output.""" - return self.get_normalized_probs_scriptable(net_output, log_probs, sample) - - -class TransformerEncoder(FairseqEncoder): - """ - Transformer encoder consisting of *args.encoder_layers* layers. Each layer - is a :class:`TransformerEncoderLayer`. - - Args: - args (argparse.Namespace): parsed command-line arguments - dictionary (~fairseq.data.Dictionary): encoding dictionary - embed_tokens (torch.nn.Embedding): input embedding - """ - - def __init__(self, args, dictionary, embed_tokens): - self.args = args - super().__init__(dictionary) - self.register_buffer("version", torch.Tensor([3])) - - self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ - ) - self.encoder_layerdrop = args.encoder_layerdrop - - embed_dim = embed_tokens.embedding_dim - self.padding_idx = embed_tokens.padding_idx - self.max_source_positions = args.max_source_positions - - self.embed_tokens = embed_tokens - - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) - - self.embed_positions = ( - PositionalEmbedding( - args.max_source_positions, - embed_dim, - self.padding_idx, - learned=args.encoder_learned_pos, - ) - if not args.no_token_positional_embeddings - else None - ) - export = getattr(args, "export", False) - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim, export=export) - else: - self.layernorm_embedding = None - - if not args.adaptive_input and args.quant_noise_pq > 0: - self.quant_noise = apply_quant_noise_( - nn.Linear(embed_dim, embed_dim, bias=False), - args.quant_noise_pq, - args.quant_noise_pq_block_size, - ) - else: - self.quant_noise = None - - if self.encoder_layerdrop > 0.0: - self.layers = LayerDropModuleList(p=self.encoder_layerdrop) - else: - self.layers = nn.ModuleList([]) - self.layers.extend( - [self.build_encoder_layer(args) for i in range(args.encoder_layers)] - ) - self.num_layers = len(self.layers) - - if args.encoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim, export=export) - else: - self.layer_norm = None - - def build_encoder_layer(self, args): - layer = TransformerEncoderLayer(args) - checkpoint = getattr(args, "checkpoint_activations", False) - if checkpoint: - offload_to_cpu = getattr(args, "offload_activations", False) - layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # if we are checkpointing, enforce that FSDP always wraps the - # checkpointed layer, regardless of layer size - min_params_to_wrap = ( - getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint - else 0 - ) - layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) - return layer - - def forward_embedding( - self, src_tokens, token_embedding: Optional[torch.Tensor] = None - ): - # embed tokens and positions - if token_embedding is None: - token_embedding = self.embed_tokens(src_tokens) - x = embed = self.embed_scale * token_embedding - if self.embed_positions is not None: - x = embed + self.embed_positions(src_tokens) - if self.layernorm_embedding is not None: - x = self.layernorm_embedding(x) - x = self.dropout_module(x) - if self.quant_noise is not None: - x = self.quant_noise(x) - return x, embed - - def forward( - self, - src_tokens, - src_lengths: Optional[torch.Tensor] = None, - return_all_hiddens: bool = False, - token_embeddings: Optional[torch.Tensor] = None, - ): - """ - Args: - src_tokens (LongTensor): tokens in the source language of shape - `(batch, src_len)` - src_lengths (torch.LongTensor): lengths of each source sentence of - shape `(batch)` - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - token_embeddings (torch.Tensor, optional): precomputed embeddings - default `None` will recompute embeddings - - Returns: - dict: - - **encoder_out** (Tensor): the last encoder layer's output of - shape `(src_len, batch, embed_dim)` - - **encoder_padding_mask** (ByteTensor): the positions of - padding elements of shape `(batch, src_len)` - - **encoder_embedding** (Tensor): the (scaled) embedding lookup - of shape `(batch, src_len, embed_dim)` - - **encoder_states** (List[Tensor]): all intermediate - hidden states of shape `(src_len, batch, embed_dim)`. - Only populated if *return_all_hiddens* is True. - """ - return self.forward_scriptable( - src_tokens, src_lengths, return_all_hiddens, token_embeddings - ) - - # TorchScript doesn't support super() method so that the scriptable Subclass - # can't access the base class model in Torchscript. - # Current workaround is to add a helper function with different name and - # call the helper function from scriptable Subclass. - def forward_scriptable( - self, - src_tokens, - src_lengths: Optional[torch.Tensor] = None, - return_all_hiddens: bool = False, - token_embeddings: Optional[torch.Tensor] = None, - ): - """ - Args: - src_tokens (LongTensor): tokens in the source language of shape - `(batch, src_len)` - src_lengths (torch.LongTensor): lengths of each source sentence of - shape `(batch)` - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - token_embeddings (torch.Tensor, optional): precomputed embeddings - default `None` will recompute embeddings - - Returns: - dict: - - **encoder_out** (Tensor): the last encoder layer's output of - shape `(src_len, batch, embed_dim)` - - **encoder_padding_mask** (ByteTensor): the positions of - padding elements of shape `(batch, src_len)` - - **encoder_embedding** (Tensor): the (scaled) embedding lookup - of shape `(batch, src_len, embed_dim)` - - **encoder_states** (List[Tensor]): all intermediate - hidden states of shape `(src_len, batch, embed_dim)`. - Only populated if *return_all_hiddens* is True. - """ - # compute padding mask - encoder_padding_mask = src_tokens.eq(self.padding_idx) - has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() - - x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) - - # account for padding while computing the representation - if has_pads: - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - encoder_states = [] - - if return_all_hiddens: - encoder_states.append(x) - - # encoder layers - for layer in self.layers: - x = layer( - x, encoder_padding_mask=encoder_padding_mask if has_pads else None - ) - if return_all_hiddens: - assert encoder_states is not None - encoder_states.append(x) - - if self.layer_norm is not None: - x = self.layer_norm(x) - - # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in - # `forward` so we use a dictionary instead. - # TorchScript does not support mixed values so the values are all lists. - # The empty list is equivalent to None. - src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() - return { - "encoder_out": [x], # T x B x C - "encoder_padding_mask": [encoder_padding_mask], # B x T - "encoder_embedding": [encoder_embedding], # B x T x C - "encoder_states": encoder_states, # List[T x B x C] - "src_tokens": [], - "src_lengths": [src_lengths], - } - - @torch.jit.export - def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): - """ - Reorder encoder output according to *new_order*. - - Args: - encoder_out: output from the ``forward()`` method - new_order (LongTensor): desired order - - Returns: - *encoder_out* rearranged according to *new_order* - """ - if len(encoder_out["encoder_out"]) == 0: - new_encoder_out = [] - else: - new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] - if len(encoder_out["encoder_padding_mask"]) == 0: - new_encoder_padding_mask = [] - else: - new_encoder_padding_mask = [ - encoder_out["encoder_padding_mask"][0].index_select(0, new_order) - ] - if len(encoder_out["encoder_embedding"]) == 0: - new_encoder_embedding = [] - else: - new_encoder_embedding = [ - encoder_out["encoder_embedding"][0].index_select(0, new_order) - ] - - if len(encoder_out["src_tokens"]) == 0: - src_tokens = [] - else: - src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] - - if len(encoder_out["src_lengths"]) == 0: - src_lengths = [] - else: - src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] - - encoder_states = encoder_out["encoder_states"] - if len(encoder_states) > 0: - for idx, state in enumerate(encoder_states): - encoder_states[idx] = state.index_select(1, new_order) - - return { - "encoder_out": new_encoder_out, # T x B x C - "encoder_padding_mask": new_encoder_padding_mask, # B x T - "encoder_embedding": new_encoder_embedding, # B x T x C - "encoder_states": encoder_states, # List[T x B x C] - "src_tokens": src_tokens, # B x T - "src_lengths": src_lengths, # B x 1 - } - - def max_positions(self): - """Maximum input length supported by the encoder.""" - if self.embed_positions is None: - return self.max_source_positions - return min(self.max_source_positions, self.embed_positions.max_positions) - - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = "{}.embed_positions.weights".format(name) - if weights_key in state_dict: - print("deleting {0}".format(weights_key)) - del state_dict[weights_key] - state_dict[ - "{}.embed_positions._float_tensor".format(name) - ] = torch.FloatTensor(1) - for i in range(self.num_layers): - # update layer norms - self.layers[i].upgrade_state_dict_named( - state_dict, "{}.layers.{}".format(name, i) - ) - - version_key = "{}.version".format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - return state_dict - - -class TransformerDecoder(FairseqIncrementalDecoder): - """ - Transformer decoder consisting of *args.decoder_layers* layers. Each layer - is a :class:`TransformerDecoderLayer`. - - Args: - args (argparse.Namespace): parsed command-line arguments - dictionary (~fairseq.data.Dictionary): decoding dictionary - embed_tokens (torch.nn.Embedding): output embedding - no_encoder_attn (bool, optional): whether to attend to encoder outputs - (default: False). - """ - - def __init__( - self, - args, - dictionary, - embed_tokens, - no_encoder_attn=False, - output_projection=None, - ): - self.args = args - super().__init__(dictionary) - self.register_buffer("version", torch.Tensor([3])) - self._future_mask = torch.empty(0) - - self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ - ) - self.decoder_layerdrop = args.decoder_layerdrop - self.share_input_output_embed = args.share_decoder_input_output_embed - - input_embed_dim = embed_tokens.embedding_dim - embed_dim = args.decoder_embed_dim - self.embed_dim = embed_dim - self.output_embed_dim = args.decoder_output_dim - - self.padding_idx = embed_tokens.padding_idx - self.max_target_positions = args.max_target_positions - - self.embed_tokens = embed_tokens - - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) - - if not args.adaptive_input and args.quant_noise_pq > 0: - self.quant_noise = apply_quant_noise_( - nn.Linear(embed_dim, embed_dim, bias=False), - args.quant_noise_pq, - args.quant_noise_pq_block_size, - ) - else: - self.quant_noise = None - - self.project_in_dim = ( - Linear(input_embed_dim, embed_dim, bias=False) - if embed_dim != input_embed_dim - else None - ) - self.embed_positions = ( - PositionalEmbedding( - self.max_target_positions, - embed_dim, - self.padding_idx, - learned=args.decoder_learned_pos, - ) - if not args.no_token_positional_embeddings - else None - ) - export = getattr(args, "export", False) - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim, export=export) - else: - self.layernorm_embedding = None - - self.cross_self_attention = getattr(args, "cross_self_attention", False) - - if self.decoder_layerdrop > 0.0: - self.layers = LayerDropModuleList(p=self.decoder_layerdrop) - else: - self.layers = nn.ModuleList([]) - self.layers.extend( - [ - self.build_decoder_layer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ] - ) - self.num_layers = len(self.layers) - - if args.decoder_normalize_before and not getattr( - args, "no_decoder_final_norm", False - ): - self.layer_norm = LayerNorm(embed_dim, export=export) - else: - self.layer_norm = None - - self.project_out_dim = ( - Linear(embed_dim, self.output_embed_dim, bias=False) - if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights - else None - ) - - self.adaptive_softmax = None - self.output_projection = output_projection - if self.output_projection is None: - self.build_output_projection(args, dictionary, embed_tokens) - - def build_output_projection(self, args, dictionary, embed_tokens): - if args.adaptive_softmax_cutoff is not None: - self.adaptive_softmax = AdaptiveSoftmax( - len(dictionary), - self.output_embed_dim, - utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), - dropout=args.adaptive_softmax_dropout, - adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, - factor=args.adaptive_softmax_factor, - tie_proj=args.tie_adaptive_proj, - ) - elif self.share_input_output_embed: - self.output_projection = nn.Linear( - self.embed_tokens.weight.shape[1], - self.embed_tokens.weight.shape[0], - bias=False, - ) - self.output_projection.weight = self.embed_tokens.weight - else: - self.output_projection = nn.Linear( - self.output_embed_dim, len(dictionary), bias=False - ) - nn.init.normal_( - self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 - ) - num_base_layers = getattr(args, "base_layers", 0) - for i in range(num_base_layers): - self.layers.insert( - ((i + 1) * args.decoder_layers) // (num_base_layers + 1), - BaseLayer(args), - ) - - def build_decoder_layer(self, args, no_encoder_attn=False): - layer = TransformerDecoderLayer(args, no_encoder_attn) - checkpoint = getattr(args, "checkpoint_activations", False) - if checkpoint: - offload_to_cpu = getattr(args, "offload_activations", False) - layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # if we are checkpointing, enforce that FSDP always wraps the - # checkpointed layer, regardless of layer size - min_params_to_wrap = ( - getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint - else 0 - ) - layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) - return layer - - def forward( - self, - prev_output_tokens, - encoder_out: Optional[Dict[str, List[Tensor]]] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - features_only: bool = False, - full_context_alignment: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - src_lengths: Optional[Any] = None, - return_all_hiddens: bool = False, - ): - """ - Args: - prev_output_tokens (LongTensor): previous decoder outputs of shape - `(batch, tgt_len)`, for teacher forcing - encoder_out (optional): output from the encoder, used for - encoder-side attention, should be of size T x B x C - incremental_state (dict): dictionary used for storing state during - :ref:`Incremental decoding` - features_only (bool, optional): only return features without - applying output layer (default: False). - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - - Returns: - tuple: - - the decoder's output of shape `(batch, tgt_len, vocab)` - - a dictionary with any model-specific outputs - """ - - x, extra = self.extract_features( - prev_output_tokens, - encoder_out=encoder_out, - incremental_state=incremental_state, - full_context_alignment=full_context_alignment, - alignment_layer=alignment_layer, - alignment_heads=alignment_heads, - ) - - if not features_only: - x = self.output_layer(x) - return x, extra - - def extract_features( - self, - prev_output_tokens, - encoder_out: Optional[Dict[str, List[Tensor]]], - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - full_context_alignment: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - ): - return self.extract_features_scriptable( - prev_output_tokens, - encoder_out, - incremental_state, - full_context_alignment, - alignment_layer, - alignment_heads, - ) - - """ - A scriptable subclass of this class has an extract_features method and calls - super().extract_features, but super() is not supported in torchscript. A copy of - this function is made to be used in the subclass instead. - """ - - def extract_features_scriptable( - self, - prev_output_tokens, - encoder_out: Optional[Dict[str, List[Tensor]]], - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - full_context_alignment: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - ): - """ - Similar to *forward* but only return features. - - Includes several features from "Jointly Learning to Align and - Translate with Transformer Models" (Garg et al., EMNLP 2019). - - Args: - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - alignment_layer (int, optional): return mean alignment over - heads at this layer (default: last layer). - alignment_heads (int, optional): only average alignment over - this many heads (default: all heads). - - Returns: - tuple: - - the decoder's features of shape `(batch, tgt_len, embed_dim)` - - a dictionary with any model-specific outputs - """ - bs, slen = prev_output_tokens.size() - if alignment_layer is None: - alignment_layer = self.num_layers - 1 - - enc: Optional[Tensor] = None - padding_mask: Optional[Tensor] = None - if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: - enc = encoder_out["encoder_out"][0] - assert ( - enc.size()[1] == bs - ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" - if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: - padding_mask = encoder_out["encoder_padding_mask"][0] - - # embed positions - positions = None - if self.embed_positions is not None: - positions = self.embed_positions( - prev_output_tokens, incremental_state=incremental_state - ) - - if incremental_state is not None: - prev_output_tokens = prev_output_tokens[:, -1:] - if positions is not None: - positions = positions[:, -1:] - - # embed tokens and positions - x = self.embed_scale * self.embed_tokens(prev_output_tokens) - - if self.quant_noise is not None: - x = self.quant_noise(x) - - if self.project_in_dim is not None: - x = self.project_in_dim(x) - - if positions is not None: - x += positions - - if self.layernorm_embedding is not None: - x = self.layernorm_embedding(x) - - x = self.dropout_module(x) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - self_attn_padding_mask: Optional[Tensor] = None - if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): - self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) - - # decoder layers - attn: Optional[Tensor] = None - inner_states: List[Optional[Tensor]] = [x] - for idx, layer in enumerate(self.layers): - if incremental_state is None and not full_context_alignment: - self_attn_mask = self.buffered_future_mask(x) - else: - self_attn_mask = None - - x, layer_attn, _ = layer( - x, - enc, - padding_mask, - incremental_state, - self_attn_mask=self_attn_mask, - self_attn_padding_mask=self_attn_padding_mask, - need_attn=bool((idx == alignment_layer)), - need_head_weights=bool((idx == alignment_layer)), - ) - inner_states.append(x) - if layer_attn is not None and idx == alignment_layer: - attn = layer_attn.float().to(x) - - if attn is not None: - if alignment_heads is not None: - attn = attn[:alignment_heads] - - # average probabilities over heads - attn = attn.mean(dim=0) - - if self.layer_norm is not None: - x = self.layer_norm(x) - - # T x B x C -> B x T x C - x = x.transpose(0, 1) - - if self.project_out_dim is not None: - x = self.project_out_dim(x) - - return x, {"attn": [attn], "inner_states": inner_states} - - def output_layer(self, features): - """Project features to the vocabulary size.""" - if self.adaptive_softmax is None: - # project back to size of vocabulary - return self.output_projection(features) - else: - return features - - def max_positions(self): - """Maximum output length supported by the decoder.""" - if self.embed_positions is None: - return self.max_target_positions - return min(self.max_target_positions, self.embed_positions.max_positions) - - def buffered_future_mask(self, tensor): - dim = tensor.size(0) - # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. - if ( - self._future_mask.size(0) == 0 - or (not self._future_mask.device == tensor.device) - or self._future_mask.size(0) < dim - ): - self._future_mask = torch.triu( - utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 - ) - self._future_mask = self._future_mask.to(tensor) - return self._future_mask[:dim, :dim] - - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = "{}.embed_positions.weights".format(name) - if weights_key in state_dict: - del state_dict[weights_key] - state_dict[ - "{}.embed_positions._float_tensor".format(name) - ] = torch.FloatTensor(1) - - if f"{name}.output_projection.weight" not in state_dict: - if self.share_input_output_embed: - embed_out_key = f"{name}.embed_tokens.weight" - else: - embed_out_key = f"{name}.embed_out" - if embed_out_key in state_dict: - state_dict[f"{name}.output_projection.weight"] = state_dict[ - embed_out_key - ] - if not self.share_input_output_embed: - del state_dict[embed_out_key] - - for i in range(self.num_layers): - # update layer norms - layer_norm_map = { - "0": "self_attn_layer_norm", - "1": "encoder_attn_layer_norm", - "2": "final_layer_norm", - } - for old, new in layer_norm_map.items(): - for m in ("weight", "bias"): - k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) - if k in state_dict: - state_dict[ - "{}.layers.{}.{}.{}".format(name, i, new, m) - ] = state_dict[k] - del state_dict[k] - - version_key = "{}.version".format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - - return state_dict - - -def Embedding(num_embeddings, embedding_dim, padding_idx): - m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) - nn.init.constant_(m.weight[padding_idx], 0) - return m - - -def Linear(in_features, out_features, bias=True): - m = nn.Linear(in_features, out_features, bias) - nn.init.xavier_uniform_(m.weight) - if bias: - nn.init.constant_(m.bias, 0.0) - return m - - -@register_model_architecture("transformer", "transformer_tiny") -def tiny_architecture(args): - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64) - args.encoder_layers = getattr(args, "encoder_layers", 2) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) - args.decoder_layers = getattr(args, "decoder_layers", 2) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) - return base_architecture(args) - - -@register_model_architecture("transformer", "transformer") -def base_architecture(args): - args.encoder_embed_path = getattr(args, "encoder_embed_path", None) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) - args.encoder_layers = getattr(args, "encoder_layers", 6) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) - args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) - args.decoder_embed_path = getattr(args, "decoder_embed_path", None) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) - args.decoder_ffn_embed_dim = getattr( - args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim - ) - args.decoder_layers = getattr(args, "decoder_layers", 6) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.attention_dropout = getattr(args, "attention_dropout", 0.0) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.activation_fn = getattr(args, "activation_fn", "relu") - args.dropout = getattr(args, "dropout", 0.1) - args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) - args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) - args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", False - ) - args.share_all_embeddings = getattr(args, "share_all_embeddings", False) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.adaptive_input = getattr(args, "adaptive_input", False) - args.no_cross_attention = getattr(args, "no_cross_attention", False) - args.cross_self_attention = getattr(args, "cross_self_attention", False) - - args.decoder_output_dim = getattr( - args, "decoder_output_dim", args.decoder_embed_dim - ) - args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) - - args.no_scale_embedding = getattr(args, "no_scale_embedding", False) - args.layernorm_embedding = getattr(args, "layernorm_embedding", False) - args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) - args.checkpoint_activations = getattr(args, "checkpoint_activations", False) - args.offload_activations = getattr(args, "offload_activations", False) - if args.offload_activations: - args.checkpoint_activations = True - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) - args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) - args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) - args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) - args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) - - -@register_model_architecture("transformer", "transformer_iwslt_de_en") -def transformer_iwslt_de_en(args): - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) - args.encoder_layers = getattr(args, "encoder_layers", 6) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) - args.decoder_layers = getattr(args, "decoder_layers", 6) - base_architecture(args) - - -@register_model_architecture("transformer", "transformer_wmt_en_de") -def transformer_wmt_en_de(args): - base_architecture(args) - - -# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) -@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big") -def transformer_vaswani_wmt_en_de_big(args): - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) - args.dropout = getattr(args, "dropout", 0.3) - base_architecture(args) - - -@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big") -def transformer_vaswani_wmt_en_fr_big(args): - args.dropout = getattr(args, "dropout", 0.1) - transformer_vaswani_wmt_en_de_big(args) - - -@register_model_architecture("transformer", "transformer_wmt_en_de_big") -def transformer_wmt_en_de_big(args): - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - transformer_vaswani_wmt_en_de_big(args) - - -# default parameters used in tensor2tensor implementation -@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t") -def transformer_wmt_en_de_big_t2t(args): - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.1) - transformer_vaswani_wmt_en_de_big(args) diff --git a/fairseq/models/transformer/__init__.py b/fairseq/models/transformer/__init__.py new file mode 100644 index 0000000000..6809adeab7 --- /dev/null +++ b/fairseq/models/transformer/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .transformer_decoder import TransformerDecoder, Linear +from .transformer_encoder import TransformerEncoder +from .transformer_model import ( + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, + DEFAULT_MIN_PARAMS_TO_WRAP, + TransformerModel, + base_architecture, + tiny_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de, + transformer_vaswani_wmt_en_de_big, + transformer_vaswani_wmt_en_fr_big, + transformer_wmt_en_de_big, + transformer_wmt_en_de_big_t2t, + Embedding, +) + + +__all__ = [ + "TransformerConfig", + "TransformerDecoder", + "TransformerEncoder", + "TransformerModel", + "Embedding", + "Linear", + "base_architecture", + "tiny_architecture", + "transformer_iwslt_de_en", + "transformer_wmt_en_de", + "transformer_vaswani_wmt_en_de_big", + "transformer_vaswani_wmt_en_fr_big", + "transformer_wmt_en_de_big", + "transformer_wmt_en_de_big_t2t", + "DEFAULT_MAX_SOURCE_POSITIONS", + "DEFAULT_MAX_TARGET_POSITIONS", + "DEFAULT_MIN_PARAMS_TO_WRAP", +] diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py new file mode 100644 index 0000000000..ca9e737e60 --- /dev/null +++ b/fairseq/models/transformer/transformer_decoder.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqIncrementalDecoder +from fairseq.models.transformer import transformer_model +from fairseq.modules import ( + AdaptiveSoftmax, + BaseLayer, + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + SinusoidalPositionalEmbedding, +) +from fairseq.modules import transformer_layer +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from torch import Tensor + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.args = args + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.decoder_layerdrop = args.decoder_layerdrop + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = args.decoder_output_dim + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + PositionalEmbedding( + self.max_target_positions, + embed_dim, + self.padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + export = getattr(args, "export", False) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim, export=export) + else: + self.layernorm_embedding = None + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + self.num_layers = len(self.layers) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim, export=export) + else: + self.layer_norm = None + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + + self.adaptive_softmax = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(args, dictionary, embed_tokens) + + def build_output_projection(self, args, dictionary, embed_tokens): + if args.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 + ) + num_base_layers = getattr(args, "base_layers", 0) + for i in range(num_base_layers): + self.layers.insert( + ((i + 1) * args.decoder_layers) // (num_base_layers + 1), + BaseLayer(args), + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + layer = transformer_layer.TransformerDecoderLayer(args, no_encoder_attn) + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", transformer_model.DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint + else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + assert ( + enc.size()[1] == bs + ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] + if not self.share_input_output_embed: + del state_dict[embed_out_key] + + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py new file mode 100644 index 0000000000..6e57af5433 --- /dev/null +++ b/fairseq/models/transformer/transformer_encoder.py @@ -0,0 +1,322 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqEncoder +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + SinusoidalPositionalEmbedding, +) +from fairseq.modules.transformer_layer import TransformerEncoderLayer +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from torch import Tensor +from fairseq.models.transformer import transformer_model + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + self.args = args + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.encoder_layerdrop = args.encoder_layerdrop + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = args.max_source_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + export = getattr(args, "export", False) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim, export=export) + else: + self.layernorm_embedding = None + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) + self.num_layers = len(self.layers) + + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim, export=export) + else: + self.layer_norm = None + + def build_encoder_layer(self, args): + layer = TransformerEncoderLayer(args) + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", transformer_model.DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint + else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None + ): + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + return x, embed + + def forward( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable( + src_tokens, src_lengths, return_all_hiddens, token_embeddings + ) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() + + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + encoder_states = [] + + if return_all_hiddens: + encoder_states.append(x) + + # encoder layers + for layer in self.layers: + x = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None + ) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] + + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] + + if len(encoder_out["src_lengths"]) == 0: + src_lengths = [] + else: + src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "src_lengths": src_lengths, # B x 1 + } + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + print("deleting {0}".format(weights_key)) + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + for i in range(self.num_layers): + # update layer norms + self.layers[i].upgrade_state_dict_named( + state_dict, "{}.layers.{}".format(name, i) + ) + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict diff --git a/fairseq/models/transformer/transformer_model.py b/fairseq/models/transformer/transformer_model.py new file mode 100644 index 0000000000..7cc5b64ad1 --- /dev/null +++ b/fairseq/models/transformer/transformer_model.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch.nn as nn +import torch +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from torch import Tensor +from fairseq.models import transformer + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + +@register_model("transformer") +class TransformerModel(FairseqEncoderDecoderModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) + <https://arxiv.org/abs/1706.03762>`_. + + Args: + encoder (TransformerEncoder): the encoder + decoder (TransformerDecoder): the decoder + + The Transformer model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_parser + :prog: + """ + + @classmethod + def hub_models(cls): + # fmt: off + + def moses_subword(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'subword_nmt', + } + + def moses_fastbpe(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'fastbpe', + } + + def spm(path): + return { + 'path': path, + 'bpe': 'sentencepiece', + 'tokenizer': 'space', + } + + return { + 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), + 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', + 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'), + 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'), + 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'), + 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), + 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), + 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), + 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), + 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), + 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), + 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), + 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), + 'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'), + 'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'), + } + # fmt: on + + def __init__(self, args, encoder, decoder): + super().__init__(encoder, decoder) + self.args = args + self.supports_align_args = True + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--decoder-output-dim', type=int, metavar='N', + help='decoder output dimension (extra linear layer ' + 'if different from decoder embed dim') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--layernorm-embedding', action='store_true', + help='add layernorm to embedding') + parser.add_argument('--no-scale-embedding', action='store_true', + help='if True, dont scale embeddings') + parser.add_argument('--checkpoint-activations', action='store_true', + help='checkpoint activations at each layer, which saves GPU ' + 'memory usage at the cost of some additional compute') + parser.add_argument('--offload-activations', action='store_true', + help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') + # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) + parser.add_argument('--no-cross-attention', default=False, action='store_true', + help='do not perform cross-attention') + parser.add_argument('--cross-self-attention', default=False, action='store_true', + help='perform cross+self-attention') + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for encoder') + parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for decoder') + parser.add_argument('--encoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument('--decoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, + help='iterative PQ quantization noise at training time') + parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, + help='block size of quantization noise at training time') + parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, + help='scalar quantization noise and scalar quantization at training time') + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + 'minimum number of params for a layer to be wrapped with FSDP() when ' + 'training with --ddp-backend=fully_sharded. Smaller values will ' + 'improve memory efficiency, but may make torch.distributed ' + 'communication less efficient due to smaller input sizes. This option ' + 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' + '--offload-activations are passed.' + ) + ) + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = cls.build_embedding( + args, src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = cls.build_embedding( + args, src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = cls.build_embedding( + args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + if getattr(args, "offload_activations", False): + args.checkpoint_activations = True # offloading implies checkpointing + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: + min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded + encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) + return cls(args, encoder, decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return transformer.TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return transformer.TransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + + # TorchScript doesn't support optional arguments with variable length (**kwargs). + # Current workaround is to add union of all arguments in child classes. + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens: bool = True, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Run the forward pass for an encoder-decoder model. + + Copied from the base class, but without ``**kwargs``, + which are not supported by TorchScript. + """ + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return decoder_out + + # Since get_normalized_probs is in the Fairseq Model which is not scriptable, + # I rewrite the get_normalized_probs from Base Class to call the + # helper function in the Base Class. + @torch.jit.export + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + +@register_model_architecture("transformer", "transformer_tiny") +def tiny_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64) + args.encoder_layers = getattr(args, "encoder_layers", 2) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) + return base_architecture(args) + + +@register_model_architecture("transformer", "transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.offload_activations = getattr(args, "offload_activations", False) + if args.offload_activations: + args.checkpoint_activations = True + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + +@register_model_architecture("transformer", "transformer_iwslt_de_en") +def transformer_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de") +def transformer_wmt_en_de(args): + base_architecture(args) + + +# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big") +def transformer_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big") +def transformer_vaswani_wmt_en_fr_big(args): + args.dropout = getattr(args, "dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de_big") +def transformer_wmt_en_de_big(args): + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t") +def transformer_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m From 7ebdc24909eaa1478e3423a942f5e41c35b9614c Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 649/707] delegate namespace conversion to DC Summary: `populate_dataclass` is very basic in how it populates the dataclass. We might want more specific behaviour for some config dataclasses (like the hierarchical behaviour in TransformerConfig, see rest of stack). This diff move the populate logic to a `from_namespace` method in `FairseqDataclass` so that the a specific Dataclass can reimplement it. Reviewed By: myleott Differential Revision: D29521388 fbshipit-source-id: f3a6dc80e4ddfc9563c6e85c37c563173f193f4d --- fairseq/dataclass/configs.py | 16 ++++++++++++++++ fairseq/dataclass/utils.py | 14 -------------- fairseq/models/__init__.py | 4 ++-- fairseq/registry.py | 4 ++-- fairseq/tasks/__init__.py | 4 ++-- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index b0146fa4c7..6a86ea0192 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -79,6 +79,22 @@ def _get_argparse_alias(self, attribute_name: str) -> Any: def _get_choices(self, attribute_name: str) -> Any: return self._get_meta(attribute_name, "choices") + @classmethod + def from_namespace(cls, args): + if isinstance(args, cls): + return args + else: + config = cls() + for k in config.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(config, k, getattr(args, k)) + + return config + + @dataclass class CommonConfig(FairseqDataclass): diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 1ed28b7ccc..5c25a2b3d4 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -419,20 +419,6 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: return cfg -def populate_dataclass( - dataclass: FairseqDataclass, - args: Namespace, -) -> FairseqDataclass: - for k in dataclass.__dataclass_fields__.keys(): - if k.startswith("_"): - # private member, skip - continue - if hasattr(args, k): - setattr(dataclass, k, getattr(args, k)) - - return dataclass - - def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): # this will be deprecated when we get rid of argparse and model_overrides logic diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index c5a4bbc831..337c77ac7b 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -10,7 +10,7 @@ from contextlib import ExitStack from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent, populate_dataclass +from fairseq.dataclass.utils import merge_with_parent from hydra.core.config_store import ConfigStore from omegaconf import open_dict, OmegaConf @@ -84,7 +84,7 @@ def build_model(cfg: FairseqDataclass, task): dc = MODEL_DATACLASS_REGISTRY[model_type] if isinstance(cfg, argparse.Namespace): - cfg = populate_dataclass(dc(), cfg) + cfg = dc.from_namespace(cfg) else: cfg = merge_with_parent(dc(), cfg) else: diff --git a/fairseq/registry.py b/fairseq/registry.py index 3fbaeac301..f3b9406043 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -7,7 +7,7 @@ from typing import Union from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import populate_dataclass, merge_with_parent +from fairseq.dataclass.utils import merge_with_parent from hydra.core.config_store import ConfigStore from omegaconf import DictConfig @@ -45,7 +45,7 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs) else: choice = getattr(cfg, registry_name, None) if choice in DATACLASS_REGISTRY: - cfg = populate_dataclass(DATACLASS_REGISTRY[choice](), cfg) + cfg = DATACLASS_REGISTRY[choice].from_namespace(cfg) if choice is None: if required: diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 28305aa247..9a46b012c5 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -9,7 +9,7 @@ import os from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent, populate_dataclass +from fairseq.dataclass.utils import merge_with_parent from hydra.core.config_store import ConfigStore from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa @@ -30,7 +30,7 @@ def setup_task(cfg: FairseqDataclass, **kwargs): task = TASK_REGISTRY[task_name] if task_name in TASK_DATACLASS_REGISTRY: dc = TASK_DATACLASS_REGISTRY[task_name] - cfg = populate_dataclass(dc(), cfg) + cfg = dc.from_namespace(cfg) else: task_name = getattr(cfg, "_name", None) From bc1504d4d709fd2157b6ba15f754c0307eb734f3 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 650/707] Hierarchical Configs Summary: This is a precursor to D29232595 The current behaviour to convert a dataclass to a namespace is that all the fields from all DCs in the field hierarchy are flattened at the top. This is also the legacy behaviour with `add_args`. This is kind of cumbersome to build reusable Dataclasses as we need to make sure that each field has a unique name. In the case of Transformer for instance, we have a Decoder and Encoder config that share a large part of their fields (embed_dim, layers, etc.). We can build a single dataclass for this that can be reused and extended in other implementations. To be then able to have a flat namespace, instead of adding all subfields as is to the root namespace, we introduce the name of the field as prefix to the arg in the namespace. So: `model.decoder.embed_dim` becomes `decoder_embed_dim` and `model.encoder.embed_dim` becomes `encoder_embed_dim`. Reviewed By: myleott, dianaml0 Differential Revision: D29521386 fbshipit-source-id: f4bef036f0eeb620c6d8709ce97f96ae288848ef --- fairseq/dataclass/utils.py | 31 +++++++++++-- tests/test_dataclass_utils.py | 87 +++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 tests/test_dataclass_utils.py diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 5c25a2b3d4..1320ec4737 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -54,17 +54,27 @@ def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, + with_prefix: Optional[str] = None, ) -> None: - """convert a dataclass instance to tailing parser arguments""" + """ + convert a dataclass instance to tailing parser arguments. + + If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are + building a flat namespace from a structured dataclass (see transformer_config.py for example). + """ def argparse_name(name: str): - if name == "data": - # normally data is positional args + if name == "data" and (with_prefix is None or with_prefix == ''): + # normally data is positional args, so we don't add the -- nor the prefix return name if name == "_name": # private member, skip return None - return "--" + name.replace("_", "-") + full_name = "--" + name.replace("_", "-") + if with_prefix is not None and with_prefix != '': + # if a prefix is specified, construct the prefixed arg name + full_name = with_prefix + "-" + full_name[2:] # strip -- when composing + return full_name def get_kwargs_from_dc( dataclass_instance: FairseqDataclass, k: str @@ -132,6 +142,10 @@ def get_kwargs_from_dc( if field_default is not MISSING: kwargs["default"] = field_default + # build the help with the hierarchical prefix + if with_prefix is not None and with_prefix != '' and field_help is not None: + field_help = with_prefix[2:] + ': ' + field_help + kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const @@ -145,7 +159,14 @@ def get_kwargs_from_dc( if field_name is None: continue elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): - gen_parser_from_dataclass(parser, field_type(), delete_default) + # for fields that are of type FairseqDataclass, we can recursively + # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) + prefix = None + if with_prefix is not None: + # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace + # but we prefix them with the name of the current field. + prefix = field_name + gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) continue kwargs = get_kwargs_from_dc(dataclass_instance, k) diff --git a/tests/test_dataclass_utils.py b/tests/test_dataclass_utils.py new file mode 100644 index 0000000000..45fc391a97 --- /dev/null +++ b/tests/test_dataclass_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from argparse import ArgumentParser +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass + + +@dataclass +class A(FairseqDataclass): + data: str = field(default="test", metadata={"help": "the data input"}) + num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) + + +@dataclass +class B(FairseqDataclass): + bar: A = field(default=A()) + foo: int = field(default=0, metadata={"help": "not a bar"}) + + +@dataclass +class D(FairseqDataclass): + arch: A = field(default=A()) + foo: int = field(default=0, metadata={"help": "not a bar"}) + + +@dataclass +class C(FairseqDataclass): + data: str = field(default="test", metadata={"help": "root level data input"}) + encoder: D = field(default=D()) + decoder: A = field(default=A()) + lr: int = field(default=0, metadata={"help": "learning rate"}) + + +class TestDataclassUtils(unittest.TestCase): + def test_argparse_convert_basic(self): + parser = ArgumentParser() + gen_parser_from_dataclass(parser, A(), True) + args = parser.parse_args(["--num-layers", '10', "the/data/path"]) + self.assertEqual(args.num_layers, 10) + self.assertEqual(args.data, "the/data/path") + + def test_argparse_recursive(self): + parser = ArgumentParser() + gen_parser_from_dataclass(parser, B(), True) + args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) + self.assertEqual(args.num_layers, 10) + self.assertEqual(args.foo, 10) + self.assertEqual(args.data, "the/data/path") + + def test_argparse_recursive_prefixing(self): + self.maxDiff = None + parser = ArgumentParser() + gen_parser_from_dataclass(parser, C(), True, "") + args = parser.parse_args( + [ + "--encoder-arch-data", + "ENCODER_ARCH_DATA", + "--encoder-arch-num-layers", + "10", + "--encoder-foo", + "10", + "--decoder-data", + "DECODER_DATA", + "--decoder-num-layers", + "10", + "--lr", + "10", + "the/data/path", + ] + ) + self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") + self.assertEqual(args.encoder_arch_num_layers, 10) + self.assertEqual(args.encoder_foo, 10) + self.assertEqual(args.decoder_data, "DECODER_DATA") + self.assertEqual(args.decoder_num_layers, 10) + self.assertEqual(args.lr, 10) + self.assertEqual(args.data, "the/data/path") + + +if __name__ == "__main__": + unittest.main() From 059187f5abbbe8a8118f01daaa4a1feb3a827249 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 651/707] ArgumentError when double declaration Summary: A few subclasses of TransformerModel try to add the same params that have already been added. A workaround is to catch the argparse exception and just ignore it as we know the field was already added if it's thrown. Reviewed By: myleott, dianaml0 Differential Revision: D29521389 fbshipit-source-id: 7912a260c4f55fbfe486794c9726d6289370400e --- examples/laser/laser_src/laser_task.py | 33 +++++++++++-------- .../multilingual/multilingual_data_manager.py | 33 +++++++++++-------- fairseq/tasks/multilingual_translation.py | 13 +++++--- fairseq/tasks/online_backtranslation.py | 13 +++++--- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py index c8ac805f54..e4152fde68 100644 --- a/examples/laser/laser_src/laser_task.py +++ b/examples/laser/laser_src/laser_task.py @@ -8,6 +8,7 @@ import json import os import logging +from argparse import ArgumentError from fairseq import options, models from fairseq.data import ( @@ -59,20 +60,24 @@ def add_args(parser): metavar="BOOL", help="pad the target on the left (default: False)", ) - parser.add_argument( - "--max-source-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the source sequence", - ) - parser.add_argument( - "--max-target-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the target sequence", - ) + try: + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): super().__init__(args) diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index a2fae5bf52..137481b449 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -9,6 +9,7 @@ import math import os from collections import OrderedDict, defaultdict +from argparse import ArgumentError from fairseq import utils from fairseq.data import ( @@ -141,20 +142,24 @@ def add_args(parser): metavar="BOOL", help="pad the target on the left", ) - parser.add_argument( - "--max-source-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the source sequence", - ) - parser.add_argument( - "--max-target-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the target sequence", - ) + try: + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass parser.add_argument( "--upsample-primary", default=1, diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 26e0b529d5..4f85ab4832 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -7,6 +7,7 @@ import logging import os from collections import OrderedDict +from argparse import ArgumentError import torch from fairseq import metrics, options, utils @@ -77,10 +78,14 @@ def add_args(parser): help='pad the source on the left (default: True)') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', help='pad the target on the left (default: False)') - parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass parser.add_argument('--upsample-primary', default=1, type=int, help='amount to upsample primary dataset') parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], diff --git a/fairseq/tasks/online_backtranslation.py b/fairseq/tasks/online_backtranslation.py index 2545624cd4..2e27ca237c 100644 --- a/fairseq/tasks/online_backtranslation.py +++ b/fairseq/tasks/online_backtranslation.py @@ -12,6 +12,7 @@ from collections import OrderedDict, defaultdict from pathlib import Path from typing import Dict, Sequence, Tuple +from argparse import ArgumentError import numpy as np import torch @@ -110,10 +111,14 @@ def add_args(parser): help='pad the target on the left') parser.add_argument('--upsample-primary', default=1, type=int, help='amount to upsample primary dataset') - parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass parser.add_argument('--truncate-source', action='store_true', default=False, help='truncate source to max-source-positions') parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', From 129d8594ccdc6644be84dc249e16489e049f4bfd Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 04:54:46 -0700 Subject: [PATCH 652/707] Transformer Hydration pt1. (#1984) Summary: ## What does this PR do? In https://github.com/fairinternal/fairseq-py/tree/hydra-transformer I tried to convert TransformerModel to hydra directly, but then had to deal with upgrading a lot of downstream classes and this got out of hand. I am trying a different approach here, this is my strategy in this PR: 0- make the argparse backward converter support hierarchical configs. This way, I can clean up the config and split it in "sub-configs" for Encoder/Decoder. This simplifies the config object and allows for code reusability. In the future, should simplify creating more specific configs for Enc/Dec. 1- Have a base classe that is hydrated (but not registered as model). This also mean hydrating the Encoder/Decoder/Layer classes. 2- first hide this behind a legacy model that is still called TransformerModel (to not break imports, etc. etc.). This legacy model transforms the argparse namespace in the dataclass config when it needs to. 3- make the dataclass look like the argparse namespace so that it can be used without subclassing/hydrating all the downstream classes under TransformerModel (see other branch to see what this involves) test_binaries seems to run fine but for one state loading issue. I am still digging into that but wanted to make sure this was a good approach. (I also decided to split the `transformer.py` file as it was getting a bit too large and I like making merges with master miserable) ## Next Steps - Separate PR to create a registered Hydra model, ideally renaming TransformerModel to TransformerModelLegacy and codemoding everything to inherit from TransformerModelLegacy. - add equivalent test_binaries test that run with hydra main (already done in other branch) - start converting some downstream models to hydra where needed. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1984 Reviewed By: dianaml0 Differential Revision: D29232595 Pulled By: Mortimerp9 fbshipit-source-id: f12eadfa9ccff29f28c67b527f9996311f791359 --- fairseq/models/transformer/__init__.py | 14 +- .../models/transformer/transformer_base.py | 169 ++++++++++ .../models/transformer/transformer_config.py | 318 ++++++++++++++++++ .../models/transformer/transformer_decoder.py | 126 ++++--- .../models/transformer/transformer_encoder.py | 83 +++-- ...sformer_model.py => transformer_legacy.py} | 245 ++------------ fairseq/modules/multihead_attention.py | 2 +- fairseq/modules/transformer_layer.py | 153 +++++---- 8 files changed, 751 insertions(+), 359 deletions(-) create mode 100644 fairseq/models/transformer/transformer_base.py create mode 100644 fairseq/models/transformer/transformer_config.py rename fairseq/models/transformer/{transformer_model.py => transformer_legacy.py} (50%) diff --git a/fairseq/models/transformer/__init__.py b/fairseq/models/transformer/__init__.py index 6809adeab7..681fca3d45 100644 --- a/fairseq/models/transformer/__init__.py +++ b/fairseq/models/transformer/__init__.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. """isort:skip_file""" -from .transformer_decoder import TransformerDecoder, Linear -from .transformer_encoder import TransformerEncoder -from .transformer_model import ( +from .transformer_config import ( + TransformerConfig, DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS, DEFAULT_MIN_PARAMS_TO_WRAP, +) +from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear +from .transformer_encoder import TransformerEncoder, TransformerEncoderBase +from .transformer_legacy import ( TransformerModel, base_architecture, tiny_architecture, @@ -19,14 +22,17 @@ transformer_vaswani_wmt_en_fr_big, transformer_wmt_en_de_big, transformer_wmt_en_de_big_t2t, - Embedding, ) +from .transformer_base import TransformerModelBase, Embedding __all__ = [ + "TransformerModelBase", "TransformerConfig", "TransformerDecoder", + "TransformerDecoderBase", "TransformerEncoder", + "TransformerEncoderBase", "TransformerModel", "Embedding", "Linear", diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py new file mode 100644 index 0000000000..e3ceb3c317 --- /dev/null +++ b/fairseq/models/transformer/transformer_base.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqEncoderDecoderModel +from torch import Tensor +from fairseq.models.transformer import (TransformerEncoderBase, TransformerDecoderBase, TransformerConfig) + + +class TransformerModelBase(FairseqEncoderDecoderModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) + <https://arxiv.org/abs/1706.03762>`_. + + Args: + encoder (TransformerEncoder): the encoder + decoder (TransformerDecoder): the decoder + + The Transformer model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_parser + :prog: + """ + + def __init__(self, cfg, encoder, decoder): + super().__init__(encoder, decoder) + self.cfg = cfg + self.supports_align_args = True + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # we want to build the args recursively in this case. + gen_parser_from_dataclass( + parser, TransformerConfig(), delete_default=False, with_prefix="" + ) + + @classmethod + def build_model(cls, cfg, task): + """Build a new model instance.""" + + if cfg.encoder.layers_to_keep: + cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(',')) + if cfg.decoder.layers_to_keep: + cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(',')) + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + if cfg.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if cfg.encoder.embed_dim != cfg.decoder.embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if cfg.decoder.embed_path and ( + cfg.decoder.embed_path != cfg.encoder.embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = cls.build_embedding( + cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + cfg.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = cls.build_embedding( + cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path + ) + decoder_embed_tokens = cls.build_embedding( + cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path + ) + if cfg.offload_activations: + cfg.checkpoint_activations = True # offloading implies checkpointing + encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) + if not cfg.share_all_embeddings: + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded + encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap) + return cls(cfg, encoder, decoder) + + @classmethod + def build_embedding(cls, cfg, dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + @classmethod + def build_encoder(cls, cfg, src_dict, embed_tokens): + return TransformerEncoderBase(cfg, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, cfg, tgt_dict, embed_tokens): + return TransformerDecoderBase( + cfg, + tgt_dict, + embed_tokens, + no_encoder_attn=cfg.no_cross_attention, + ) + + # TorchScript doesn't support optional arguments with variable length (**kwargs). + # Current workaround is to add union of all arguments in child classes. + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens: bool = True, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Run the forward pass for an encoder-decoder model. + + Copied from the base class, but without ``**kwargs``, + which are not supported by TorchScript. + """ + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return decoder_out + + # Since get_normalized_probs is in the Fairseq Model which is not scriptable, + # I rewrite the get_normalized_probs from Base Class to call the + # helper function in the Base Class. + @torch.jit.export + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py new file mode 100644 index 0000000000..2580d20aac --- /dev/null +++ b/fairseq/models/transformer/transformer_config.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from dataclasses import dataclass, field, fields +from typing import List, Optional + +from fairseq import utils +from fairseq.dataclass import FairseqDataclass, ChoiceEnum +from omegaconf import II + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +_NAME_PARSER = r"(decoder|encoder|quant_noise)_(.*)" + + +@dataclass +class EncDecBaseConfig(FairseqDataclass): + embed_path: Optional[str] = field( + default=None, metadata={"help": "path to pre-trained embedding"} + ) + embed_dim: Optional[int] = field( + default=512, metadata={"help": "embedding dimension"} + ) + ffn_embed_dim: int = field( + default=2048, metadata={"help": "embedding dimension for FFN"} + ) + layers: int = field(default=6, metadata={"help": "number of layers"}) + attention_heads: int = field( + default=8, metadata={"help": "number of attention heads"} + ) + normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each block"} + ) + learned_pos: bool = field( + default=False, metadata={"help": "use learned positional embeddings"} + ) + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + layerdrop: float = field(default=0, metadata={"help": "LayerDrop probability"}) + layers_to_keep: Optional[List[int]] = field( + default=None, metadata={"help": "which layers to *keep* when pruning"} + ) + + +@dataclass +class DecoderConfig(EncDecBaseConfig): + input_dim: int = II("model.decoder.embed_dim") + output_dim: int = field( + default=II("model.decoder.embed_dim"), + metadata={ + "help": "decoder output dimension (extra linear layer if different from decoder embed dim)" + }, + ) + + def __post_init__(self): + # II doesn't work if we are just creating the object outside of hydra so fix that + if self.input_dim == II("model.decoder.embed_dim"): + self.input_dim = self.embed_dim + if self.output_dim == II("model.decoder.embed_dim"): + self.output_dim = self.embed_dim + + +@dataclass +class QuantNoiseConfig(FairseqDataclass): + pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + + +@dataclass +class TransformerConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", + metadata={"help": "activation function to use"}, + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN.", + "alias": "--relu-dropout", + }, + ) + adaptive_input: bool = False + encoder: EncDecBaseConfig = EncDecBaseConfig() + # TODO should really be in the encoder config + max_source_positions: int = field( + default=DEFAULT_MAX_SOURCE_POSITIONS, + metadata={"help": "Maximum input length supported by the encoder"}, + ) + decoder: DecoderConfig = DecoderConfig() + # TODO should really be in the decoder config + max_target_positions: int = field( + default=DEFAULT_MAX_TARGET_POSITIONS, + metadata={"help": "Maximum output length supported by the decoder"}, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + share_all_embeddings: bool = field( + default=False, + metadata={ + "help": "share encoder, decoder and output embeddings (requires shared dictionary and embed dim)" + }, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if True, disables positional embeddings (outside self attention)" + }, + ) + adaptive_softmax_cutoff: Optional[List[int]] = field( + default=None, + metadata={ + "help": "list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion" + }, + ) + adaptive_softmax_dropout: float = field( + default=0.0, + metadata={"help": "sets adaptive softmax dropout for the tail projections"}, + ) + adaptive_softmax_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + tie_adaptive_weights: bool = field( + default=False, + metadata={ + "help": "if set, ties the weights of adaptive softmax and adaptive input" + }, + ) + tie_adaptive_proj: bool = field( + default=False, + metadata={ + "help": "if set, ties the projection weights of adaptive softmax and adaptive input" + }, + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + checkpoint_activations: bool = field( + default=False, + metadata={ + "help": "checkpoint activations at each layer, which saves GPU memory usage at the cost of some additional compute" + }, + ) + offload_activations: bool = field( + default=False, + metadata={ + "help": "checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations." + }, + ) + # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) + no_cross_attention: bool = field( + default=False, metadata={"help": "do not perform cross-attention"} + ) + cross_self_attention: bool = field( + default=False, metadata={"help": "perform cross+self-attention"} + ) + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig()) + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + }, + ) + # DEPRECATED field, but some old checkpoints might have it + char_inputs: bool = field( + default=False, metadata={"help": "if set, model takes character ids as input"} + ) + relu_dropout: float = 0.0 + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, + metadata={"help": "shuffle tokens between workers before computing assignment"}, + ) + + export: bool = field( + default=False, + metadata={"help": "make the layernorm exportable with torchscript."}, + ) + + # copied from transformer_lm but expected in transformer_decoder: + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + + # We need to make this hierarchical dataclass like the flat namespace + # __getattr__ and __setattr__ here allow backward compatibility + # for subclasses of Transformer(Legacy) that depend on read/write on + # the flat namespace. + + def __getattr__(self, name): + match = re.match(_NAME_PARSER, name) + if match: + sub = getattr(self, match[1]) + return getattr(sub, match[2]) + raise AttributeError(f"invalid argument {name}.") + + def __setattr__(self, name, value): + match = re.match(_NAME_PARSER, name) + if match: + sub = getattr(self, match[1]) + setattr(sub, match[2], value) + else: + super().__setattr__(name, value) + + @staticmethod + def _copy_keys(args, cls, prefix, seen): + """ + copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim + """ + cfg = cls() + for fld in fields(cls): + # for all the fields in the DC, find the fields (e.g. embed_dim) + # in the namespace with the prefix (e.g. decoder) + # and set it on the dc. + args_key = f"{prefix}_{fld.name}" + if hasattr(args, args_key): + seen.add(args_key) + setattr(cfg, fld.name, getattr(args, args_key)) + if hasattr(args, fld.name): + seen.add(fld.name) + setattr(cfg, fld.name, getattr(args, fld.name)) + return cfg + + @classmethod + def from_namespace(cls, args): + if args is None: + return None + if not isinstance(args, cls): + seen = set() + config = cls() + # currently, we can go generically from DC fields to args hierarchically + # but we can't easily deconstruct a flat namespace to a hierarchical + # DC. Mostly because we could have a sub-dc called `decoder-foo` that should not + # go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple + # for now. + for fld in fields(cls): + # concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields + # and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()` + if fld.name == "decoder": + if hasattr(args, "decoder"): + # in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC + seen.add("decoder") + config.decoder = DecoderConfig(**args.decoder) + else: + config.decoder = cls._copy_keys( + args, DecoderConfig, "decoder", seen + ) + elif fld.name == "encoder": + # same but for encoder + if hasattr(args, "encoder"): + seen.add("encoder") + config.encoder = EncDecBaseConfig(**args.encoder) + else: + config.encoder = cls._copy_keys( + args, EncDecBaseConfig, "encoder", seen + ) + elif fld.name == "quant_noise": + # same but for quant_noise + if hasattr(args, "quant_noise"): + seen.add("quant_noise") + config.quant_noise = QuantNoiseConfig(**args.quant_noise) + else: + config.quant_noise = cls._copy_keys( + args, QuantNoiseConfig, "quant_noise", seen + ) + elif hasattr(args, fld.name): + # if it's not a structure field, it's just a normal field, copy it over + seen.add(fld.name) + setattr(config, fld.name, getattr(args, fld.name)) + # we got all the fields defined in the dataclass, but + # the argparse namespace might have extra args for two reasons: + # - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this + # - some places expect args to be there but never define them + args_dict = args._asdict() if hasattr(args, '_asdict') else vars(args) if hasattr(args, '__dict__') else {} # namedtupled doesn't have __dict__ :-/ + for key, value in args_dict.items(): + if key not in seen: + setattr(config, key, value) + return config + else: + return args diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index ca9e737e60..49e37917cc 100644 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -11,7 +11,7 @@ from fairseq import utils from fairseq.distributed import fsdp_wrap from fairseq.models import FairseqIncrementalDecoder -from fairseq.models.transformer import transformer_model +from fairseq.models.transformer import TransformerConfig from fairseq.modules import ( AdaptiveSoftmax, BaseLayer, @@ -27,9 +27,17 @@ from torch import Tensor -class TransformerDecoder(FairseqIncrementalDecoder): +# rewrite name for backward compatibility in `make_generation_fast_` +def module_name_fordropout(module_name: str) -> str: + if module_name == 'TransformerDecoderBase': + return 'TransformerDecoder' + else: + return module_name + + +class TransformerDecoderBase(FairseqIncrementalDecoder): """ - Transformer decoder consisting of *args.decoder_layers* layers. Each layer + Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: @@ -42,40 +50,40 @@ class TransformerDecoder(FairseqIncrementalDecoder): def __init__( self, - args, + cfg, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None, ): - self.args = args + self.cfg = cfg super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ + cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) ) - self.decoder_layerdrop = args.decoder_layerdrop - self.share_input_output_embed = args.share_decoder_input_output_embed + self.decoder_layerdrop = cfg.decoder.layerdrop + self.share_input_output_embed = cfg.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim - embed_dim = args.decoder_embed_dim + embed_dim = cfg.decoder.embed_dim self.embed_dim = embed_dim - self.output_embed_dim = args.decoder_output_dim + self.output_embed_dim = cfg.decoder.output_dim self.padding_idx = embed_tokens.padding_idx - self.max_target_positions = args.max_target_positions + self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) - if not args.adaptive_input and args.quant_noise_pq > 0: + if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), - args.quant_noise_pq, - args.quant_noise_pq_block_size, + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None @@ -90,18 +98,17 @@ def __init__( self.max_target_positions, embed_dim, self.padding_idx, - learned=args.decoder_learned_pos, + learned=cfg.decoder.learned_pos, ) - if not args.no_token_positional_embeddings + if not cfg.no_token_positional_embeddings else None ) - export = getattr(args, "export", False) - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim, export=export) + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None - self.cross_self_attention = getattr(args, "cross_self_attention", False) + self.cross_self_attention = cfg.cross_self_attention if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) @@ -109,40 +116,38 @@ def __init__( self.layers = nn.ModuleList([]) self.layers.extend( [ - self.build_decoder_layer(args, no_encoder_attn) - for _ in range(args.decoder_layers) + self.build_decoder_layer(cfg, no_encoder_attn) + for _ in range(cfg.decoder.layers) ] ) self.num_layers = len(self.layers) - if args.decoder_normalize_before and not getattr( - args, "no_decoder_final_norm", False - ): - self.layer_norm = LayerNorm(embed_dim, export=export) + if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None self.project_out_dim = ( Linear(embed_dim, self.output_embed_dim, bias=False) - if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights else None ) self.adaptive_softmax = None self.output_projection = output_projection if self.output_projection is None: - self.build_output_projection(args, dictionary, embed_tokens) + self.build_output_projection(cfg, dictionary, embed_tokens) - def build_output_projection(self, args, dictionary, embed_tokens): - if args.adaptive_softmax_cutoff is not None: + def build_output_projection(self, cfg, dictionary, embed_tokens): + if cfg.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, - utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), - dropout=args.adaptive_softmax_dropout, - adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, - factor=args.adaptive_softmax_factor, - tie_proj=args.tie_adaptive_proj, + utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int), + dropout=cfg.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None, + factor=cfg.adaptive_softmax_factor, + tie_proj=cfg.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( @@ -158,26 +163,22 @@ def build_output_projection(self, args, dictionary, embed_tokens): nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) - num_base_layers = getattr(args, "base_layers", 0) + num_base_layers = cfg.base_layers for i in range(num_base_layers): self.layers.insert( - ((i + 1) * args.decoder_layers) // (num_base_layers + 1), - BaseLayer(args), + ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1), + BaseLayer(cfg), ) - def build_decoder_layer(self, args, no_encoder_attn=False): - layer = transformer_layer.TransformerDecoderLayer(args, no_encoder_attn) - checkpoint = getattr(args, "checkpoint_activations", False) + def build_decoder_layer(self, cfg, no_encoder_attn=False): + layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn) + checkpoint = cfg.checkpoint_activations if checkpoint: - offload_to_cpu = getattr(args, "offload_activations", False) + offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size - min_params_to_wrap = ( - getattr(args, "min_params_to_wrap", transformer_model.DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint - else 0 - ) + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer @@ -450,3 +451,32 @@ def Linear(in_features, out_features, bias=True): if bias: nn.init.constant_(m.bias, 0.0) return m + + +class TransformerDecoder(TransformerDecoderBase): + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + no_encoder_attn=no_encoder_attn, + output_projection=output_projection, + ) + + def build_output_projection(self, args, dictionary, embed_tokens): + super().build_output_projection( + TransformerConfig.from_namespace(args), dictionary, embed_tokens + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + return super().build_decoder_layer( + TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn + ) diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index 6e57af5433..f007776a6f 100644 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import math from typing import Dict, List, Optional @@ -19,16 +18,26 @@ PositionalEmbedding, SinusoidalPositionalEmbedding, ) -from fairseq.modules.transformer_layer import TransformerEncoderLayer +from fairseq.modules import transformer_layer from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor -from fairseq.models.transformer import transformer_model +from fairseq.models.transformer import ( + TransformerConfig, +) + + +# rewrite name for backward compatibility in `make_generation_fast_` +def module_name_fordropout(module_name: str) -> str: + if module_name == 'TransformerEncoderBase': + return 'TransformerEncoder' + else: + return module_name -class TransformerEncoder(FairseqEncoder): +class TransformerEncoderBase(FairseqEncoder): """ - Transformer encoder consisting of *args.encoder_layers* layers. Each layer + Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: @@ -37,45 +46,44 @@ class TransformerEncoder(FairseqEncoder): embed_tokens (torch.nn.Embedding): input embedding """ - def __init__(self, args, dictionary, embed_tokens): - self.args = args + def __init__(self, cfg, dictionary, embed_tokens): + self.cfg = cfg super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ + cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) ) - self.encoder_layerdrop = args.encoder_layerdrop + self.encoder_layerdrop = cfg.encoder.layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx - self.max_source_positions = args.max_source_positions + self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( - args.max_source_positions, + cfg.max_source_positions, embed_dim, self.padding_idx, - learned=args.encoder_learned_pos, + learned=cfg.encoder.learned_pos, ) - if not args.no_token_positional_embeddings + if not cfg.no_token_positional_embeddings else None ) - export = getattr(args, "export", False) - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim, export=export) + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None - if not args.adaptive_input and args.quant_noise_pq > 0: + if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), - args.quant_noise_pq, - args.quant_noise_pq_block_size, + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None @@ -85,28 +93,24 @@ def __init__(self, args, dictionary, embed_tokens): else: self.layers = nn.ModuleList([]) self.layers.extend( - [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] ) self.num_layers = len(self.layers) - if args.encoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim, export=export) + if cfg.encoder.normalize_before: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None - def build_encoder_layer(self, args): - layer = TransformerEncoderLayer(args) - checkpoint = getattr(args, "checkpoint_activations", False) + def build_encoder_layer(self, cfg): + layer = transformer_layer.TransformerEncoderLayerBase(cfg) + checkpoint = cfg.checkpoint_activations if checkpoint: - offload_to_cpu = getattr(args, "offload_activations", False) + offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size - min_params_to_wrap = ( - getattr(args, "min_params_to_wrap", transformer_model.DEFAULT_MIN_PARAMS_TO_WRAP) - if not checkpoint - else 0 - ) + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer @@ -320,3 +324,18 @@ def upgrade_state_dict_named(self, state_dict, name): self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict + + +class TransformerEncoder(TransformerEncoderBase): + def __init__(self, args, dictionary, embed_tokens): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + ) + + def build_encoder_layer(self, args): + return super().build_encoder_layer( + TransformerConfig.from_namespace(args), + ) diff --git a/fairseq/models/transformer/transformer_model.py b/fairseq/models/transformer/transformer_legacy.py similarity index 50% rename from fairseq/models/transformer/transformer_model.py rename to fairseq/models/transformer/transformer_legacy.py index 7cc5b64ad1..9534e400b5 100644 --- a/fairseq/models/transformer/transformer_model.py +++ b/fairseq/models/transformer/transformer_legacy.py @@ -3,44 +3,26 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple - -import torch.nn as nn -import torch -from fairseq import utils -from fairseq.distributed import fsdp_wrap from fairseq.models import ( - FairseqEncoderDecoderModel, register_model, register_model_architecture, ) -from torch import Tensor -from fairseq.models import transformer - - -DEFAULT_MAX_SOURCE_POSITIONS = 1024 -DEFAULT_MAX_TARGET_POSITIONS = 1024 - - -DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) +from fairseq.models.transformer.transformer_config import ( + TransformerConfig, + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, + DEFAULT_MIN_PARAMS_TO_WRAP, +) +from fairseq.models.transformer.transformer_base import ( + TransformerModelBase, +) @register_model("transformer") -class TransformerModel(FairseqEncoderDecoderModel): +class TransformerModel(TransformerModelBase): """ - Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) - <https://arxiv.org/abs/1706.03762>`_. - - Args: - encoder (TransformerEncoder): the encoder - decoder (TransformerDecoder): the decoder - - The Transformer model provides the following named architectures and - command-line arguments: - - .. argparse:: - :ref: fairseq.models.transformer_parser - :prog: + This is the legacy implementation of the transformer model that + uses argparse for configuration. """ @classmethod @@ -92,109 +74,9 @@ def spm(path): # fmt: on def __init__(self, args, encoder, decoder): - super().__init__(encoder, decoder) + cfg = TransformerConfig.from_namespace(args) + super().__init__(cfg, encoder, decoder) self.args = args - self.supports_align_args = True - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', - help='dropout probability after activation in FFN.') - parser.add_argument('--encoder-embed-path', type=str, metavar='STR', - help='path to pre-trained encoder embedding') - parser.add_argument('--encoder-embed-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-layers', type=int, metavar='N', - help='num encoder layers') - parser.add_argument('--encoder-attention-heads', type=int, metavar='N', - help='num encoder attention heads') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') - parser.add_argument('--encoder-learned-pos', action='store_true', - help='use learned positional embeddings in the encoder') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') - parser.add_argument('--decoder-normalize-before', action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--decoder-output-dim', type=int, metavar='N', - help='decoder output dimension (extra linear layer ' - 'if different from decoder embed dim') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--share-all-embeddings', action='store_true', - help='share encoder, decoder and output embeddings' - ' (requires shared dictionary and embed dim)') - parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', - help='if set, disables positional embeddings (outside self attention)') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion'), - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - parser.add_argument('--layernorm-embedding', action='store_true', - help='add layernorm to embedding') - parser.add_argument('--no-scale-embedding', action='store_true', - help='if True, dont scale embeddings') - parser.add_argument('--checkpoint-activations', action='store_true', - help='checkpoint activations at each layer, which saves GPU ' - 'memory usage at the cost of some additional compute') - parser.add_argument('--offload-activations', action='store_true', - help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') - # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) - parser.add_argument('--no-cross-attention', default=False, action='store_true', - help='do not perform cross-attention') - parser.add_argument('--cross-self-attention', default=False, action='store_true', - help='perform cross+self-attention') - # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) - parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, - help='LayerDrop probability for encoder') - parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, - help='LayerDrop probability for decoder') - parser.add_argument('--encoder-layers-to-keep', default=None, - help='which layers to *keep* when pruning as a comma-separated list') - parser.add_argument('--decoder-layers-to-keep', default=None, - help='which layers to *keep* when pruning as a comma-separated list') - # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) - parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, - help='iterative PQ quantization noise at training time') - parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, - help='block size of quantization noise at training time') - parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, - help='scalar quantization noise and scalar quantization at training time') - # args for Fully Sharded Data Parallel (FSDP) training - parser.add_argument( - '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, - help=( - 'minimum number of params for a layer to be wrapped with FSDP() when ' - 'training with --ddp-backend=fully_sharded. Smaller values will ' - 'improve memory efficiency, but may make torch.distributed ' - 'communication less efficient due to smaller input sizes. This option ' - 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' - '--offload-activations are passed.' - ) - ) - # fmt: on @classmethod def build_model(cls, args, task): @@ -228,100 +110,38 @@ def build_model(cls, args, task): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) - encoder_embed_tokens = cls.build_embedding( - args, src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True - else: - encoder_embed_tokens = cls.build_embedding( - args, src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = cls.build_embedding( - args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path - ) + if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing - encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) - decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: - min_params_to_wrap = getattr( + args.min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) - # fsdp_wrap is a no-op when --ddp-backend != fully_sharded - encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) - decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) - return cls(args, encoder, decoder) + cfg = TransformerConfig.from_namespace(args) + return super().build_model(cfg, task) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - - emb = Embedding(num_embeddings, embed_dim, padding_idx) - # if provided, load from preloaded dictionaries - if path: - embed_dict = utils.parse_embedding(path) - utils.load_embedding(embed_dict, dictionary, emb) - return emb + return super().build_embedding( + TransformerConfig.from_namespace(args), dictionary, embed_dim, path + ) @classmethod def build_encoder(cls, args, src_dict, embed_tokens): - return transformer.TransformerEncoder(args, src_dict, embed_tokens) + return super().build_encoder( + TransformerConfig.from_namespace(args), src_dict, embed_tokens + ) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): - return transformer.TransformerDecoder( - args, - tgt_dict, - embed_tokens, - no_encoder_attn=getattr(args, "no_cross_attention", False), + return super().build_decoder( + TransformerConfig.from_namespace(args), tgt_dict, embed_tokens ) - # TorchScript doesn't support optional arguments with variable length (**kwargs). - # Current workaround is to add union of all arguments in child classes. - def forward( - self, - src_tokens, - src_lengths, - prev_output_tokens, - return_all_hiddens: bool = True, - features_only: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - ): - """ - Run the forward pass for an encoder-decoder model. - - Copied from the base class, but without ``**kwargs``, - which are not supported by TorchScript. - """ - encoder_out = self.encoder( - src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens - ) - decoder_out = self.decoder( - prev_output_tokens, - encoder_out=encoder_out, - features_only=features_only, - alignment_layer=alignment_layer, - alignment_heads=alignment_heads, - src_lengths=src_lengths, - return_all_hiddens=return_all_hiddens, - ) - return decoder_out - - # Since get_normalized_probs is in the Fairseq Model which is not scriptable, - # I rewrite the get_normalized_probs from Base Class to call the - # helper function in the Base Class. - @torch.jit.export - def get_normalized_probs( - self, - net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], - log_probs: bool, - sample: Optional[Dict[str, Tensor]] = None, - ): - """Get normalized probabilities (or log probs) from a net's output.""" - return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + +# architectures @register_model_architecture("transformer", "transformer_tiny") @@ -443,10 +263,3 @@ def transformer_wmt_en_de_big_t2t(args): args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_dropout = getattr(args, "activation_dropout", 0.1) transformer_vaswani_wmt_en_de_big(args) - - -def Embedding(num_embeddings, embedding_dim, padding_idx): - m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) - nn.init.constant_(m.weight[padding_idx], 0) - return m diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 9bdca0f6af..a251635611 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -148,7 +148,7 @@ def forward( tgt_len, bsz, embed_dim = query.size() src_len = tgt_len - assert embed_dim == self.embed_dim + assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: src_len, key_bsz, _ = key.size() diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index aa06a42935..de25de6564 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -12,9 +12,12 @@ from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise from torch import Tensor +from fairseq.models.transformer import ( + TransformerConfig, +) -class TransformerEncoderLayer(nn.Module): +class TransformerEncoderLayerBase(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is @@ -23,49 +26,46 @@ class TransformerEncoderLayer(nn.Module): preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting - *args.encoder_normalize_before* to ``True``. + *cfg.encoder.normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ - def __init__(self, args): + def __init__(self, cfg): super().__init__() - self.args = args - self.embed_dim = args.encoder_embed_dim - self.quant_noise = getattr(args, "quant_noise_pq", 0) - self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) or 8 - self.self_attn = self.build_self_attention(self.embed_dim, args) - export = getattr(args, "export", False) - self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + self.cfg = cfg + self.embed_dim = cfg.encoder.embed_dim + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size + self.self_attn = self.build_self_attention(self.embed_dim, cfg) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ - ) - self.activation_fn = utils.get_activation_fn( - activation=getattr(args, "activation_fn", "relu") or "relu" + cfg.dropout, module_name=self.__class__.__name__ ) - activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout if activation_dropout_p == 0: - # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) - self.normalize_before = args.encoder_normalize_before + self.normalize_before = cfg.encoder.normalize_before self.fc1 = self.build_fc1( self.embed_dim, - args.encoder_ffn_embed_dim, + cfg.encoder.ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( - args.encoder_ffn_embed_dim, + cfg.encoder.ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) - self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( @@ -77,11 +77,11 @@ def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size ) - def build_self_attention(self, embed_dim, args): + def build_self_attention(self, embed_dim, cfg): return MultiheadAttention( embed_dim, - args.encoder_attention_heads, - dropout=args.attention_dropout, + cfg.encoder.attention_heads, + dropout=cfg.attention_dropout, self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, @@ -162,7 +162,19 @@ def forward( return x -class TransformerDecoderLayer(nn.Module): +# backward compatible with the legacy argparse format +class TransformerEncoderLayer(TransformerEncoderLayerBase): + def __init__(self, args): + super().__init__(TransformerConfig.from_namespace(args)) + self.args = args + + def build_self_attention(self, embed_dim, args): + return super().build_self_attention( + embed_dim, TransformerConfig.from_namespace(args) + ) + + +class TransformerDecoderLayerBase(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder @@ -171,7 +183,7 @@ class TransformerDecoderLayer(nn.Module): robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting - *args.decoder_normalize_before* to ``True``. + *cfg.decoder.normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments @@ -180,63 +192,58 @@ class TransformerDecoderLayer(nn.Module): """ def __init__( - self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() - self.embed_dim = args.decoder_embed_dim + self.embed_dim = cfg.decoder.embed_dim self.dropout_module = FairseqDropout( - args.dropout, module_name=self.__class__.__name__ + cfg.dropout, module_name=self.__class__.__name__ ) - self.quant_noise = getattr(args, "quant_noise_pq", 0) - self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size - self.cross_self_attention = getattr(args, "cross_self_attention", False) + self.cross_self_attention = cfg.cross_self_attention self.self_attn = self.build_self_attention( self.embed_dim, - args, + cfg, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) - self.activation_fn = utils.get_activation_fn( - activation=str(args.activation_fn) - if getattr(args, "activation_fn", None) is not None - else "relu" - ) - activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout if activation_dropout_p == 0: - # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) - self.normalize_before = args.decoder_normalize_before + self.normalize_before = cfg.decoder.normalize_before - export = getattr(args, "export", False) - self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: - self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) - self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.fc1 = self.build_fc1( self.embed_dim, - args.decoder_ffn_embed_dim, + cfg.decoder.ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( - args.decoder_ffn_embed_dim, + cfg.decoder.ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) - self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.need_attn = True self.onnx_trace = False @@ -248,26 +255,26 @@ def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_self_attention( - self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False ): return MultiheadAttention( embed_dim, - args.decoder_attention_heads, - dropout=args.attention_dropout, + cfg.decoder.attention_heads, + dropout=cfg.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, - self_attention=not getattr(args, "cross_self_attention", False), + self_attention=not cfg.cross_self_attention, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, ) - def build_encoder_attention(self, embed_dim, args): + def build_encoder_attention(self, embed_dim, cfg): return MultiheadAttention( embed_dim, - args.decoder_attention_heads, - kdim=getattr(args, "encoder_embed_dim", None), - vdim=getattr(args, "encoder_embed_dim", None), - dropout=args.attention_dropout, + cfg.decoder.attention_heads, + kdim=cfg.encoder.embed_dim, + vdim=cfg.encoder.embed_dim, + dropout=cfg.attention_dropout, encoder_decoder_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, @@ -417,3 +424,33 @@ def forward( def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn + + +# backward compatible with the legacy argparse format +class TransformerDecoderLayer(TransformerDecoderLayerBase): + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__( + TransformerConfig.from_namespace(args), + no_encoder_attn=no_encoder_attn, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.args = args + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return super().build_self_attention( + embed_dim, + TransformerConfig.from_namespace(args), + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + def build_encoder_attention(self, embed_dim, args): + return super().build_encoder_attention( + embed_dim, + TransformerConfig.from_namespace(args), + ) From c1624b273b206cc7c0a1529be4d2f35b38607ec5 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Fri, 16 Jul 2021 05:45:38 -0700 Subject: [PATCH 653/707] Criterions to Hydra Summary: convert a couple of criterions to hydra Reviewed By: dianaml0 Differential Revision: D29585608 fbshipit-source-id: 7790b767ed55f58bbcc0c237cfa689684b9bf5e2 --- ...l_smoothed_cross_entropy_with_alignment.py | 33 +++++++++++-------- fairseq/criterions/nat_loss.py | 24 +++++++------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py index 73cfa05310..2ea37c16b4 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -8,10 +8,27 @@ from fairseq import metrics, utils from fairseq.criterions import register_criterion -from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from .label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) +from dataclasses import dataclass, field -@register_criterion("label_smoothed_cross_entropy_with_alignment") + +@dataclass +class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + alignment_lambda: float = field( + default=0.05, metadata={"help": "weight for the alignment loss"} + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_alignment", + dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig, +) class LabelSmoothedCrossEntropyCriterionWithAlignment( LabelSmoothedCrossEntropyCriterion ): @@ -19,18 +36,6 @@ def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): super().__init__(task, sentence_avg, label_smoothing) self.alignment_lambda = alignment_lambda - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - LabelSmoothedCrossEntropyCriterion.add_args(parser) - parser.add_argument( - "--alignment-lambda", - default=0.05, - type=float, - metavar="D", - help="weight for the alignment loss", - ) - def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py index cdc7da861d..7dac32fbaf 100644 --- a/fairseq/criterions/nat_loss.py +++ b/fairseq/criterions/nat_loss.py @@ -9,26 +9,26 @@ import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass from torch import Tensor +from dataclasses import dataclass, field -@register_criterion("nat_loss") + +@dataclass +class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + + +@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig) class LabelSmoothedDualImitationCriterion(FairseqCriterion): def __init__(self, task, label_smoothing): super().__init__(task) self.label_smoothing = label_smoothing - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - parser.add_argument( - "--label-smoothing", - default=0.0, - type=float, - metavar="D", - help="epsilon for label smoothing, 0 means no label smoothing", - ) - def _compute_loss( self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 ): From 72323586aeae75e2b704c1c936784471bfa75019 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Tue, 20 Jul 2021 11:30:28 -0700 Subject: [PATCH 654/707] Add warning when combining --ddp-backend=fully_sharded and --update-freq (#2076) Summary: Add warning when combining `--ddp-backend=fully_sharded` and `--update-freq`, since that will result in increased memory usage. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2076 Reviewed By: xianxl Differential Revision: D29791364 Pulled By: myleott fbshipit-source-id: 5748f20484840f61a16448f1287b0f2e3b3ce9d8 --- fairseq/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 1deb14326f..1602688671 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -72,6 +72,12 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "FullyShardedDataParallel is not compatible with --zero-sharding " "option (it's already built in)" ) + if self.cfg.optimization.update_freq[0] > 1: + logger.warning( + "Combining --update-freq with FullyShardedDataParallel will " + "result in increased memory usage, since full-sized gradients " + "will be accumulated on each GPU!" + ) else: if ( hasattr(self.cfg.distributed_training, "cpu_offload") From 804b49397606328221dd7296026c9bcc04a967d1 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Wed, 21 Jul 2021 13:16:15 -0700 Subject: [PATCH 655/707] Several updates on simultaneous translation (#1831) Summary: This pull request includes several updates and refactoring related to simultaneous translation 1. Add mixed precision training for simultaneous translation decoder (avoiding nan errors) 2. Add unit test for simultaneous decoders `cd examples/simultaneous_translation; python -m unittest test_text_models.py` 3. Simplify the inference code (simuleval only) 4. Reorganize code structure 5. Remove duplicated / deprecated code 6. Fixed a bug for waitk p_choose generation 7. Fixed a issue when using fixed_pre_decision + mma The update won't affect the current training. The old checkpoint can still be loaded and infered. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1831 Test Plan: Imported from GitHub, without a `Test Plan:` line. f286829279 f286829286 Reviewed By: sravyapopuri388 Differential Revision: D28082398 Pulled By: xutaima fbshipit-source-id: 882d077e7f1b94870f8328dba89660a4f3bd5d9c --- .../models/transformer_monotonic_attention.py | 98 +- .../modules/__init__.py | 5 +- .../modules/fixed_pre_decision.py | 90 +- .../modules/monotonic_multihead_attention.py | 1011 +++++------------ .../modules/monotonic_transformer_layer.py | 46 +- .../tests/test_text_models.py | 407 +++++++ .../utils/functions.py | 81 +- .../utils/monotonic_attention.py | 196 ++++ .../utils/p_choose_strategy.py | 82 +- 9 files changed, 1049 insertions(+), 967 deletions(-) create mode 100644 examples/simultaneous_translation/tests/test_text_models.py create mode 100644 examples/simultaneous_translation/utils/monotonic_attention.py diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index 77c0350d2d..b0cdc43483 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from examples.simultaneous_translation.modules.monotonic_transformer_layer import ( TransformerMonotonicDecoderLayer, TransformerMonotonicEncoderLayer, @@ -23,12 +22,14 @@ base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, - transformer_vaswani_wmt_en_fr_big, + tiny_architecture ) from torch import Tensor DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +READ_ACTION = 0 +WRITE_ACTION = 1 TransformerMonotonicDecoderOut = NamedTuple( "TransformerMonotonicDecoderOut", @@ -36,7 +37,6 @@ ("action", int), ("p_choose", Optional[Tensor]), ("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]), - ("step_list", Optional[List[Optional[Tensor]]]), ("encoder_out", Optional[Dict[str, List[Tensor]]]), ("encoder_padding_mask", Optional[Tensor]), ], @@ -60,26 +60,6 @@ def build_encoder(cls, args, src_dict, embed_tokens): def build_decoder(cls, args, tgt_dict, embed_tokens): return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) - def _indices_from_states(self, states): - if type(states["indices"]["src"]) == list: - if next(self.parameters()).is_cuda: - tensor = torch.cuda.LongTensor - else: - tensor = torch.LongTensor - - src_indices = tensor( - [states["indices"]["src"][: 1 + states["steps"]["src"]]] - ) - - tgt_indices = tensor( - [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]] - ) - else: - src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]] - tgt_indices = states["indices"]["tgt"] - - return src_indices, None, tgt_indices - class TransformerMonotonicEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): @@ -88,7 +68,10 @@ def __init__(self, args, dictionary, embed_tokens): self.dictionary = dictionary self.layers = nn.ModuleList([]) self.layers.extend( - [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)] + [ + TransformerMonotonicEncoderLayer(args) + for i in range(args.encoder_layers) + ] ) @@ -112,10 +95,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.layers = nn.ModuleList([]) self.layers.extend( [ - TransformerMonotonicDecoderLayer(args, no_encoder_attn) + TransformerMonotonicDecoderLayer(args) for _ in range(args.decoder_layers) ] ) + self.policy_criterion = getattr(args, "policy_criterion", "any") def pre_attention( self, @@ -176,14 +160,15 @@ def post_attention(self, x): return x - def clear_cache( + def clean_cache( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], end_id: Optional[int] = None, ): """ - Clear cache in the monotonic layers. - The cache is generated because of a forward pass of decode but no prediction. + Clean cache in the monotonic layers. + The cache is generated because of a forward pass of decoder has run but no prediction, + so that the self attention key value in decoder is written in the incremental state. end_id is the last idx of the layers """ if end_id is None: @@ -218,7 +203,6 @@ def extract_features( attn = None inner_states = [x] attn_list: List[Optional[Dict[str, Tensor]]] = [] - step_list: List[Optional[Tensor]] = [] p_choose = torch.tensor([1.0]) @@ -238,36 +222,28 @@ def extract_features( attn_list.append(attn) if incremental_state is not None: - curr_steps = layer.get_head_steps(incremental_state) - step_list.append(curr_steps) if_online = incremental_state["online"]["only"] assert if_online is not None if if_online.to(torch.bool): # Online indicates that the encoder states are still changing assert attn is not None - assert curr_steps is not None - p_choose = ( - attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) - ) - - new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) - src = incremental_state["steps"]["src"] - assert src is not None - - if (new_steps >= src).any(): - # We need to prune the last self_attn saved_state - # if model decide not to read - # otherwise there will be duplicated saved_state - self.clear_cache(incremental_state, i + 1) - - return x, TransformerMonotonicDecoderOut( - action=0, - p_choose=p_choose, - attn_list=None, - step_list=None, - encoder_out=None, - encoder_padding_mask=None, - ) + if self.policy_criterion == "any": + # Any head decide to read than read + head_read = layer.encoder_attn._get_monotonic_buffer(incremental_state)["head_read"] + assert head_read is not None + if head_read.any(): + # We need to prune the last self_attn saved_state + # if model decide not to read + # otherwise there will be duplicated saved_state + self.clean_cache(incremental_state, i + 1) + + return x, TransformerMonotonicDecoderOut( + action=0, + p_choose=p_choose, + attn_list=None, + encoder_out=None, + encoder_padding_mask=None, + ) x = self.post_attention(x) @@ -275,18 +251,10 @@ def extract_features( action=1, p_choose=p_choose, attn_list=attn_list, - step_list=step_list, encoder_out=encoder_out, encoder_padding_mask=encoder_padding_mask, ) - def reorder_incremental_state(self, incremental_state, new_order): - super().reorder_incremental_state(incremental_state, new_order) - if "fastest_step" in incremental_state: - incremental_state["fastest_step"] = incremental_state[ - "fastest_step" - ].index_select(0, new_order) - @register_model_architecture("transformer_monotonic", "transformer_monotonic") def base_monotonic_architecture(args): @@ -322,3 +290,9 @@ def transformer_monotonic_vaswani_wmt_en_fr_big(args): ) def transformer_unidirectional_iwslt_de_en(args): transformer_iwslt_de_en(args) + + +@register_model_architecture("transformer_monotonic", "transformer_monotonic_tiny") +def monotonic_tiny_architecture(args): + tiny_architecture(args) + base_monotonic_architecture(args) diff --git a/examples/simultaneous_translation/modules/__init__.py b/examples/simultaneous_translation/modules/__init__.py index c695850c04..f5ea180f9b 100644 --- a/examples/simultaneous_translation/modules/__init__.py +++ b/examples/simultaneous_translation/modules/__init__.py @@ -3,12 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import importlib -import os +import os +import importlib from fairseq import registry - ( build_monotonic_attention, register_monotonic_attention, diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index dd29c031b3..3991414aed 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -7,12 +7,12 @@ from . import register_monotonic_attention from .monotonic_multihead_attention import ( - MonotonicMultiheadAttentionWaitK, - MonotonicMultiheadAttentionHardAligned, - MonotonicMultiheadAttentionInfiniteLookback, + MonotonicAttention, + MonotonicInfiniteLookbackAttention, + WaitKAttention ) from typing import Dict, Optional -from examples.simultaneous_translation.utils import p_choose_strategy + def fixed_pooling_monotonic_attention(monotonic_attention): def create_model(monotonic_attention, klass): @@ -26,10 +26,7 @@ def __init__(self, args): self.pre_decision_type = args.fixed_pre_decision_type self.pre_decision_ratio = args.fixed_pre_decision_ratio self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold - if self.pre_decision_ratio == 1: - return - - self.strategy = args.simul_type + assert self.pre_decision_ratio > 1 if args.fixed_pre_decision_type == "average": self.pooling_layer = torch.nn.AvgPool1d( @@ -46,7 +43,7 @@ def last(key): k = key[ :, :, - self.pre_decision_ratio - 1 :: self.pre_decision_ratio, + self.pre_decision_ratio - 1:: self.pre_decision_ratio, ].contiguous() if key.size(-1) % self.pre_decision_ratio != 0: k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() @@ -97,45 +94,6 @@ def insert_zeros(self, x): ) return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) - def p_choose_waitk( - self, query, key, key_padding_mask: Optional[Tensor] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None - ): - """ - query: bsz, tgt_len - key: bsz, src_len - key_padding_mask: bsz, src_len - """ - if incremental_state is not None: - # Retrieve target length from incremental states - # For inference the length of query is always 1 - tgt = incremental_state["steps"]["tgt"] - assert tgt is not None - tgt_len = int(tgt) - else: - tgt_len, bsz, _ = query.size() - - src_len, bsz, _ = key.size() - - p_choose = torch.ones(bsz, tgt_len, src_len).to(query) - p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) - p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) - - if incremental_state is not None: - p_choose = p_choose[:, -1:] - tgt_len = 1 - - # Extend to each head - p_choose = ( - p_choose.contiguous() - .unsqueeze(1) - .expand(-1, self.num_heads, -1, -1) - .contiguous() - .view(-1, tgt_len, src_len) - ) - - return p_choose - def p_choose( self, query: Optional[Tensor], @@ -149,28 +107,6 @@ def p_choose( tgt_len = query.size(0) batch_size = query.size(1) - if self.pre_decision_ratio == 1: - if self.strategy == "waitk": - return p_choose_strategy.waitk( - query, - key, - self.waitk_lagging, - self.num_heads, - key_padding_mask, - incremental_state=incremental_state, - ) - else: # hard_aligned or infinite_lookback - q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") - attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) - return p_choose_strategy.hard_aligned( - q_proj, - k_proj, - attn_energy, - self.noise_mean, - self.noise_var, - self.training - ) - key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) if key_padding_mask is not None: @@ -194,7 +130,7 @@ def p_choose( if key_padding_mask_pool is not None: key_padding_mask_pool = key_padding_mask_pool[:-1] - p_choose_pooled = self.p_choose_waitk( + p_choose_pooled = self.p_choose_from_qk( query, key_pool, key_padding_mask_pool, @@ -237,18 +173,18 @@ def p_choose( @register_monotonic_attention("waitk_fixed_pre_decision") -@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionWaitK) -class MonotonicMultiheadAttentionWaitkFixedStride: +@fixed_pooling_monotonic_attention(WaitKAttention) +class WaitKAttentionFixedStride: pass @register_monotonic_attention("hard_aligned_fixed_pre_decision") -@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionHardAligned) -class MonotonicMultiheadAttentionHardFixedStride: +@fixed_pooling_monotonic_attention(MonotonicAttention) +class MonotonicAttentionFixedStride: pass @register_monotonic_attention("infinite_lookback_fixed_pre_decision") -@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionInfiniteLookback) -class MonotonicMultiheadAttentionInfiniteLookbackFixedStride: +@fixed_pooling_monotonic_attention(MonotonicInfiniteLookbackAttention) +class MonotonicInfiniteLookbackAttentionFixedStride: pass diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index f49b1daa2f..2b8a48b1de 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -9,27 +9,44 @@ from torch import Tensor import torch.nn as nn -from examples.simultaneous_translation.utils.functions import ( - exclusive_cumprod, - lengths_to_mask, +from examples.simultaneous_translation.utils.p_choose_strategy import ( + learnable_p_choose, + waitk_p_choose +) + +from examples.simultaneous_translation.utils.monotonic_attention import ( + expected_alignment_from_p_choose, + expected_soft_attention, + mass_preservation, ) -from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules import MultiheadAttention from . import register_monotonic_attention from typing import Dict, Optional -from examples.simultaneous_translation.utils import p_choose_strategy -@with_incremental_state -class MonotonicAttention(nn.Module): +@register_monotonic_attention("hard_aligned") +class MonotonicAttention(MultiheadAttention): """ Abstract class of monotonic attentions """ + k_in_proj: Dict[str, nn.Linear] + q_in_proj: Dict[str, nn.Linear] def __init__(self, args): - self.eps = args.attention_eps - self.mass_preservation = args.mass_preservation + super().__init__( + embed_dim=args.decoder_embed_dim, + num_heads=args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + + self.soft_attention = False + + self.eps = getattr(args, "attention_eps", True) + self.mass_preservation = getattr(args, "mass_preservation", True) self.noise_type = args.noise_type self.noise_mean = args.noise_mean @@ -42,6 +59,10 @@ def __init__(self, args): else 0 ) + self.k_in_proj = {"monotonic": self.k_proj} + self.q_in_proj = {"monotonic": self.q_proj} + self.chunk_size = None + @staticmethod def add_args(parser): # fmt: off @@ -66,567 +87,316 @@ def add_args(parser): parser.add_argument('--attention-eps', type=float, default=1e-6, help='Epsilon when calculating expected attention') - def p_choose(self, *args): - raise NotImplementedError - - def input_projections(self, *args): - raise NotImplementedError - - def attn_energy( - self, q_proj, k_proj, key_padding_mask=None, attn_mask=None + def energy_from_qk( + self, + query: Tensor, + key: Tensor, + energy_type: str, + key_padding_mask: Optional[Tensor] = None, + bias: int = 0 ): """ - Calculating monotonic energies - - ============================================================ - Expected input size - q_proj: bsz * num_heads, tgt_len, self.head_dim - k_proj: bsz * num_heads, src_len, self.head_dim - key_padding_mask: bsz, src_len - attn_mask: tgt_len, src_len + Compute energy from query and key + q_func_value is a tuple looks like + (q_proj_func, q_tensor) + q_tensor size: bsz, tgt_len, emb_dim + k_tensor size: bsz, src_len, emb_dim + key_padding_mask size: bsz, src_len + attn_mask: bsz, src_len """ - bsz, tgt_len, embed_dim = q_proj.size() - bsz = bsz // self.num_heads - src_len = k_proj.size(1) - attn_energy = ( - torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + length, bsz, _ = query.size() + q = self.q_in_proj[energy_type].forward(query) + q = ( + q.contiguous() + .view(length, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + q = q * self.scaling + length, bsz, _ = key.size() + k = self.k_in_proj[energy_type].forward(key) + k = ( + k.contiguous() + .view(length, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) ) - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_energy += attn_mask - - attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) + energy = torch.bmm(q, k.transpose(1, 2)) + bias if key_padding_mask is not None: - attn_energy = attn_energy.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), + energy = energy.masked_fill( + key_padding_mask.unsqueeze(1).to(torch.bool), + - float("inf") ) - return attn_energy - - def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]): - """ - Calculating expected alignment for MMA - Mask is not need because p_choose will be 0 if masked - - q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} - a_ij = p_ij q_ij - - Parallel solution: - ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) - - ============================================================ - Expected input size - p_choose: bsz * num_heads, tgt_len, src_len - """ - - # p_choose: bsz * num_heads, tgt_len, src_len - bsz_num_heads, tgt_len, src_len = p_choose.size() - - # cumprod_1mp : bsz * num_heads, tgt_len, src_len - cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) - cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) - - init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) - init_attention[:, :, 0] = 1.0 + return energy - previous_attn = [init_attention] + def p_choose_from_qk(self, query, key, key_padding_mask): + monotonic_energy = self.energy_from_qk( + query, + key, + "monotonic", + key_padding_mask=key_padding_mask, + bias=self.energy_bias, + ) - for i in range(tgt_len): - # p_choose: bsz * num_heads, tgt_len, src_len - # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len - # previous_attn[i]: bsz * num_heads, 1, src_len - # alpha_i: bsz * num_heads, src_len - alpha_i = ( - p_choose[:, i] - * cumprod_1mp[:, i] - * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) - ).clamp(0, 1.0) - previous_attn.append(alpha_i.unsqueeze(1)) + p_choose = learnable_p_choose( + monotonic_energy, + self.noise_mean, + self.noise_var, + self.training + ) + return p_choose - # alpha: bsz * num_heads, tgt_len, src_len - alpha = torch.cat(previous_attn[1:], dim=1) + def p_choose(self, query, key, key_padding_mask): + return self.p_choose_from_qk(self, query, key, key_padding_mask) - if self.mass_preservation: - # Last token has the residual probabilities - if key_padding_mask is not None and key_padding_mask[:, -1].any(): - # right padding - batch_size = key_padding_mask.size(0) - residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) - src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) - src_lens = src_lens.expand( - batch_size, self.num_heads - ).contiguous().view(-1, 1) - src_lens = src_lens.expand(-1, tgt_len).contiguous() - # add back the last value - residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) - alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) - else: - residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) - alpha[:, :, -1] = residuals - - if torch.isnan(alpha).any(): - # Something is wrong - raise RuntimeError("NaN in alpha.") - - return alpha - - def expected_alignment_infer( - self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + def monotonic_attention_process_infer( + self, + query: Optional[Tensor], + key: Optional[Tensor], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], ): - # TODO modify this function """ - Calculating mo alignment for MMA during inference time - - ============================================================ - Expected input size - p_choose: bsz * num_heads, tgt_len, src_len - incremental_state: dict - encodencoder_padding_mask: bsz * src_len + Monotonic attention at inference time + Notice that this function is designed for simuleval not sequence_generator """ - # p_choose: bsz * self.num_heads, src_len - bsz_num_heads, tgt_len, src_len = p_choose.size() - # One token at a time - assert tgt_len == 1 - p_choose = p_choose[:, 0, :] + assert query is not None + assert key is not None - monotonic_cache = self._get_monotonic_buffer(incremental_state) + if query.size(1) != 1: + raise RuntimeError( + "Simultaneous translation models don't support batch decoding." + ) + # 1. compute stepwise probability + p_choose = self.p_choose( + query, key, None, incremental_state + ).squeeze(1) - # prev_monotonic_step: bsz, num_heads - bsz = bsz_num_heads // self.num_heads - prev_monotonic_step = monotonic_cache.get( - "head_step", - p_choose.new_zeros([bsz, self.num_heads]).long() + # 2. Compute the alpha + src_len = key.size(0) + # Maximum steps allows in this iteration + max_steps = src_len - 1 if self.mass_preservation else src_len + monotonic_cache = self._get_monotonic_buffer(incremental_state) + # Step for each head + monotonic_step = monotonic_cache.get( + 'head_step', + p_choose.new_zeros(1, self.num_heads).long() ) - assert prev_monotonic_step is not None - bsz, num_heads = prev_monotonic_step.size() - assert num_heads == self.num_heads - assert bsz * num_heads == bsz_num_heads - - # p_choose: bsz, num_heads, src_len - p_choose = p_choose.view(bsz, num_heads, src_len) + assert monotonic_step is not None + finish_read = monotonic_step.eq(max_steps) + p_choose_i = torch.tensor(1) - if encoder_padding_mask is not None: - src_lengths = src_len - \ - encoder_padding_mask.sum(dim=1, keepdim=True).long() - else: - src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len - - # src_lengths: bsz, num_heads - src_lengths = src_lengths.expand_as(prev_monotonic_step) - # new_monotonic_step: bsz, num_heads - new_monotonic_step = prev_monotonic_step - - step_offset = 0 - if encoder_padding_mask is not None: - if encoder_padding_mask[:, 0].any(): - # left_pad_source = True: - step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) - - max_steps = src_lengths - 1 if self.mass_preservation else src_lengths - - # finish_read: bsz, num_heads - finish_read = new_monotonic_step.eq(max_steps) - p_choose_i = 1 - while finish_read.sum().item() < bsz * self.num_heads: - # p_choose: bsz * self.num_heads, src_len + while finish_read.sum().item() < self.num_heads: + # p_choose: self.num_heads, src_len # only choose the p at monotonic steps - # p_choose_i: bsz , self.num_heads + # p_choose_i: 1, self.num_heads p_choose_i = ( p_choose.gather( - 2, - (step_offset + new_monotonic_step) - .unsqueeze(2) + 1, + monotonic_step .clamp(0, src_len - 1), ) - ).squeeze(2) + ) - action = ( + read_one_step = ( (p_choose_i < 0.5) - .type_as(prev_monotonic_step) + .type_as(monotonic_step) .masked_fill(finish_read, 0) ) # 1 x bsz # sample actions on unfinished seq - # 1 means stay, finish reading - # 0 means leave, continue reading - # dist = torch.distributions.bernoulli.Bernoulli(p_choose) - # action = dist.sample().type_as(finish_read) * (1 - finish_read) + # 0 means stay, finish reading + # 1 means leave, continue reading + + monotonic_step += read_one_step - new_monotonic_step += action + finish_read = monotonic_step.eq(max_steps) | (read_one_step == 0) - finish_read = new_monotonic_step.eq(max_steps) | (action == 0) + # p_choose at last steps + p_choose_i = ( + p_choose.gather( + 1, + monotonic_step + .clamp(0, src_len - 1), + ) + ) - monotonic_cache["head_step"] = new_monotonic_step + monotonic_cache["head_step"] = monotonic_step # Whether a head is looking for new input monotonic_cache["head_read"] = ( - new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + monotonic_step.eq(max_steps) & (p_choose_i < 0.5) ) + self._set_monotonic_buffer(incremental_state, monotonic_cache) - # alpha: bsz * num_heads, 1, src_len - # new_monotonic_step: bsz, num_heads + # 2. Update alpha alpha = ( p_choose - .new_zeros([bsz * self.num_heads, src_len]) + .new_zeros([self.num_heads, src_len]) .scatter( 1, - (step_offset + new_monotonic_step) - .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), + (monotonic_step) + .view(self.num_heads, 1).clamp(0, src_len - 1), 1 ) ) if not self.mass_preservation: alpha = alpha.masked_fill( - (new_monotonic_step == max_steps) - .view(bsz * self.num_heads, 1), + (monotonic_step == max_steps) + .view(self.num_heads, 1), 0 ) - alpha = alpha.unsqueeze(1) - - self._set_monotonic_buffer(incremental_state, monotonic_cache) - - return alpha - - def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): - return self.get_incremental_state( - incremental_state, - 'monotonic', - ) or {} - - def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]): - self.set_incremental_state( - incremental_state, - 'monotonic', - buffer, - ) - - def v_proj_output(self, value): - raise NotImplementedError - - def forward( - self, query, key, value, - key_padding_mask=None, attn_mask=None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - need_weights=True, static_kv=False - ): - - tgt_len, bsz, embed_dim = query.size() - src_len = value.size(0) - - # stepwise prob - # p_choose: bsz * self.num_heads, tgt_len, src_len - p_choose = self.p_choose( - query, key, key_padding_mask, incremental_state, - ) - - # expected alignment alpha - # bsz * self.num_heads, tgt_len, src_len - if incremental_state is not None: - alpha = self.expected_alignment_infer( - p_choose, key_padding_mask, incremental_state) + # 4. Compute Beta + if self.soft_attention: + monotonic_step = monotonic_step.t() + beta_mask = torch.arange(src_len).expand_as(alpha).gt(monotonic_step).unsqueeze(1) + # If it's soft attention just do softmax on current context + soft_energy = self.energy_from_qk( + query, + key, + "soft" + ) + beta = torch.nn.functional.softmax( + soft_energy.masked_fill(beta_mask, -float("inf")), dim=-1 + ) + # It could happen that a head doesn't move at all + beta = beta.masked_fill(monotonic_step.eq(0).unsqueeze(1), 0) else: - alpha = self.expected_alignment_train( - p_choose, key_padding_mask) - - # expected attention beta - # bsz * self.num_heads, tgt_len, src_len - beta = self.expected_attention( - alpha, query, key, value, - key_padding_mask, attn_mask, - incremental_state - ) + # If it's hard attention just select the last state + beta = alpha - attn_weights = beta + return p_choose, alpha, beta - v_proj = self.v_proj_output(value) - - attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) - - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - - attn = self.out_proj(attn) - - beta = beta.view(bsz, self.num_heads, tgt_len, src_len) - alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) - p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) - - return attn, { - "alpha": alpha, - "beta": beta, - "p_choose": p_choose, - } - - -@register_monotonic_attention("hard_aligned") -class MonotonicMultiheadAttentionHardAligned( - MonotonicAttention, MultiheadAttention -): - def __init__(self, args): - MultiheadAttention.__init__( - self, - embed_dim=args.decoder_embed_dim, - num_heads=args.decoder_attention_heads, - kdim=getattr(args, "encoder_embed_dim", None), - vdim=getattr(args, "encoder_embed_dim", None), - dropout=args.attention_dropout, - encoder_decoder_attention=True, - ) - - MonotonicAttention.__init__(self, args) - - self.k_in_proj = {"monotonic": self.k_proj} - self.q_in_proj = {"monotonic": self.q_proj} - self.v_in_proj = {"output": self.v_proj} - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--no-mass-preservation', action="store_false", - dest="mass_preservation", - help='Do not stay on the last token when decoding') - parser.add_argument('--mass-preservation', action="store_true", - dest="mass_preservation", - help='Stay on the last token when decoding') - parser.set_defaults(mass_preservation=True) - parser.add_argument('--noise-var', type=float, default=1.0, - help='Variance of discretness noise') - parser.add_argument('--noise-mean', type=float, default=0.0, - help='Mean of discretness noise') - parser.add_argument('--noise-type', type=str, default="flat", - help='Type of discretness noise') - parser.add_argument('--energy-bias', action="store_true", - default=False, - help='Bias for energy') - parser.add_argument('--energy-bias-init', type=float, default=-2.0, - help='Initial value of the bias for energy') - parser.add_argument('--attention-eps', type=float, default=1e-6, - help='Epsilon when calculating expected attention') - - def attn_energy( - self, q_proj: Optional[Tensor], k_proj: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None + def monotonic_attention_process_train( + self, + query: Optional[Tensor], + key: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, ): """ - Calculating monotonic energies - - ============================================================ - Expected input size - q_proj: bsz * num_heads, tgt_len, self.head_dim - k_proj: bsz * num_heads, src_len, self.head_dim - key_padding_mask: bsz, src_len - attn_mask: tgt_len, src_len + Calculating monotonic attention process for training + Including: + stepwise probability: p_choose + expected hard alignment: alpha + expected soft attention: beta """ - assert q_proj is not None # Optional[Tensor] annotations in the signature above are to make the JIT compiler happy - assert k_proj is not None - bsz, tgt_len, embed_dim = q_proj.size() - bsz = bsz // self.num_heads - src_len = k_proj.size(1) - - attn_energy = ( - torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias - ) + assert query is not None + assert key is not None - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_energy += attn_mask + # 1. compute stepwise probability + p_choose = self.p_choose_from_qk(query, key, key_padding_mask) - attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) + # 2. compute expected_alignment + alpha = expected_alignment_from_p_choose( + p_choose, + key_padding_mask, + eps=self.eps, + ) - if key_padding_mask is not None: - attn_energy = attn_energy.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), + if self.mass_preservation: + alpha = mass_preservation( + alpha, key_padding_mask ) - return attn_energy - - def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]): - """ - Calculating expected alignment for MMA - Mask is not need because p_choose will be 0 if masked - - q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} - a_ij = p_ij q_ij - - Parallel solution: - ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) - - ============================================================ - Expected input size - p_choose: bsz * num_heads, tgt_len, src_len - """ - - # p_choose: bsz * num_heads, tgt_len, src_len - bsz_num_heads, tgt_len, src_len = p_choose.size() - - # cumprod_1mp : bsz * num_heads, tgt_len, src_len - cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) - cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) - - init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) - init_attention[:, :, 0] = 1.0 - - previous_attn = [init_attention] + # 3. compute expected soft attention (soft aligned model only) + if self.soft_attention: + soft_energy = self.energy_from_qk( + query, + key, + "soft", + key_padding_mask=None, + ) - for i in range(tgt_len): - # p_choose: bsz * num_heads, tgt_len, src_len - # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len - # previous_attn[i]: bsz * num_heads, 1, src_len - # alpha_i: bsz * num_heads, src_len - alpha_i = ( - p_choose[:, i] - * cumprod_1mp[:, i] - * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) - ).clamp(0, 1.0) - previous_attn.append(alpha_i.unsqueeze(1)) + beta = expected_soft_attention( + alpha, + soft_energy, + padding_mask=key_padding_mask, + chunk_size=self.chunk_size, + eps=self.eps, + ) + else: + beta = alpha + soft_energy = alpha - # alpha: bsz * num_heads, tgt_len, src_len - alpha = torch.cat(previous_attn[1:], dim=1) + return p_choose, alpha, beta, soft_energy - if self.mass_preservation: - # Last token has the residual probabilities - if key_padding_mask is not None and key_padding_mask[:, -1].any(): - # right padding - batch_size = key_padding_mask.size(0) - residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) - src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) - src_lens = src_lens.expand( - batch_size, self.num_heads - ).contiguous().view(-1, 1) - src_lens = src_lens.expand(-1, tgt_len).contiguous() - # add back the last value - residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) - alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) - else: - residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) - alpha[:, :, -1] = residuals - - if torch.isnan(alpha).any(): - # Something is wrong - raise RuntimeError("NaN in alpha.") - - return alpha - - def expected_alignment_infer( - self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + def forward( + self, + query: Optional[Tensor], + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False, ): - # TODO modify this function """ - Calculating mo alignment for MMA during inference time - - ============================================================ - Expected input size - p_choose: bsz * num_heads, tgt_len, src_len - incremental_state: dict - encodencoder_padding_mask: bsz * src_len + query: tgt_len, bsz, embed_dim + key: src_len, bsz, embed_dim + value: src_len, bsz, embed_dim """ - # p_choose: bsz * self.num_heads, src_len - bsz_num_heads, tgt_len, src_len = p_choose.size() - # One token at a time - assert tgt_len == 1 - p_choose = p_choose[:, 0, :] - monotonic_cache = self._get_monotonic_buffer(incremental_state) + assert attn_mask is None + assert query is not None + assert key is not None + assert value is not None - # prev_monotonic_step: bsz, num_heads - bsz = bsz_num_heads // self.num_heads - prev_monotonic_step = monotonic_cache.get( - "head_step", - p_choose.new_zeros([bsz, self.num_heads]).long() - ) - assert prev_monotonic_step is not None - bsz, num_heads = prev_monotonic_step.size() - assert num_heads == self.num_heads - assert bsz * num_heads == bsz_num_heads + tgt_len, bsz, embed_dim = query.size() + src_len = value.size(0) - # p_choose: bsz, num_heads, src_len - p_choose = p_choose.view(bsz, num_heads, src_len) + if key_padding_mask is not None: + assert not key_padding_mask[:, 0].any(), ( + "Only right padding is supported." + ) + key_padding_mask = ( + key_padding_mask + .unsqueeze(1) + .expand([bsz, self.num_heads, src_len]) + .contiguous() + .view(-1, src_len) + ) - if encoder_padding_mask is not None: - src_lengths = src_len - \ - encoder_padding_mask.sum(dim=1, keepdim=True).long() + if incremental_state is not None: + # Inference + ( + p_choose, alpha, beta + ) = self.monotonic_attention_process_infer( + query, key, incremental_state + ) + soft_energy = beta else: - src_lengths = torch.ones(bsz, 1).to(prev_monotonic_step) * src_len - - # src_lengths: bsz, num_heads - src_lengths = src_lengths.expand_as(prev_monotonic_step) - # new_monotonic_step: bsz, num_heads - new_monotonic_step = prev_monotonic_step - - step_offset = torch.tensor(0) - if encoder_padding_mask is not None: - if encoder_padding_mask[:, 0].any(): - # left_pad_source = True: - step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) - - max_steps = src_lengths - 1 if self.mass_preservation else src_lengths - - # finish_read: bsz, num_heads - finish_read = new_monotonic_step.eq(max_steps) - p_choose_i = torch.tensor(1) - while finish_read.sum().item() < bsz * self.num_heads: - # p_choose: bsz * self.num_heads, src_len - # only choose the p at monotonic steps - # p_choose_i: bsz , self.num_heads - p_choose_i = ( - p_choose.gather( - 2, - (step_offset + new_monotonic_step) - .unsqueeze(2) - .clamp(0, src_len - 1), - ) - ).squeeze(2) - - action = ( - (p_choose_i < 0.5) - .type_as(prev_monotonic_step) - .masked_fill(finish_read, 0) + # Train + ( + p_choose, alpha, beta, soft_energy + ) = self.monotonic_attention_process_train( + query, key, key_padding_mask ) - # 1 x bsz - # sample actions on unfinished seq - # 1 means stay, finish reading - # 0 means leave, continue reading - # dist = torch.distributions.bernoulli.Bernoulli(p_choose) - # action = dist.sample().type_as(finish_read) * (1 - finish_read) - - new_monotonic_step += action - - finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - monotonic_cache["head_step"] = new_monotonic_step - # Whether a head is looking for new input - monotonic_cache["head_read"] = ( - new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + v = self.v_proj(value) + length, bsz, _ = v.size() + v = ( + v.contiguous() + .view(length, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) ) - # alpha: bsz * num_heads, 1, src_len - # new_monotonic_step: bsz, num_heads - alpha = ( - p_choose - .new_zeros([bsz * self.num_heads, src_len]) - .scatter( - 1, - (step_offset + new_monotonic_step) - .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), - 1 - ) - ) + attn = torch.bmm(beta.type_as(v), v) - if not self.mass_preservation: - alpha = alpha.masked_fill( - (new_monotonic_step == max_steps) - .view(bsz * self.num_heads, 1), - 0 - ) + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - alpha = alpha.unsqueeze(1) + attn = self.out_proj(attn) - self._set_monotonic_buffer(incremental_state, monotonic_cache) + p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) + alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) + beta = beta.view(bsz, self.num_heads, tgt_len, src_len) - return alpha + return attn, { + "p_choose": p_choose, + "alpha": alpha, + "beta": beta, + } def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): maybe_incremental_state = self.get_incremental_state( @@ -646,147 +416,14 @@ def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, buffer, ) - def forward( - self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], - key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False, - ): - assert query is not None - assert value is not None - tgt_len, bsz, embed_dim = query.size() - src_len = value.size(0) - - # stepwise prob - # p_choose: bsz * self.num_heads, tgt_len, src_len - p_choose = self.p_choose( - query, key, key_padding_mask, incremental_state, - ) - - # expected alignment alpha - # bsz * self.num_heads, tgt_len, src_len - if incremental_state is not None: - alpha = self.expected_alignment_infer( - p_choose, key_padding_mask, incremental_state) - else: - alpha = self.expected_alignment_train( - p_choose, key_padding_mask) - - # expected attention beta - # bsz * self.num_heads, tgt_len, src_len - beta = self.expected_attention( - alpha, query, key, value, - key_padding_mask, attn_mask, - incremental_state - ) - - attn_weights = beta - - v_proj = self.v_proj_output(value) - assert v_proj is not None - - attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) - - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - - attn = self.out_proj(attn) - - beta = beta.view(bsz, self.num_heads, tgt_len, src_len) - alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) - p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) - - return attn, { - "alpha": alpha, - "beta": beta, - "p_choose": p_choose, - } - - def input_projections(self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], name: str): - """ - Prepare inputs for multihead attention - - ============================================================ - Expected input size - query: tgt_len, bsz, embed_dim - key: src_len, bsz, embed_dim - value: src_len, bsz, embed_dim - name: monotonic or soft - """ - - if query is not None: - bsz = query.size(1) - q = self.q_proj(query) - q *= self.scaling - q = q.contiguous().view( - -1, bsz * self.num_heads, self.head_dim - ).transpose(0, 1) - else: - q = None - - if key is not None: - bsz = key.size(1) - k = self.k_proj(key) - k = k.contiguous().view( - -1, bsz * self.num_heads, self.head_dim - ).transpose(0, 1) - else: - k = None - - if value is not None: - bsz = value.size(1) - v = self.v_proj(value) - v = v.contiguous().view( - -1, bsz * self.num_heads, self.head_dim - ).transpose(0, 1) - else: - v = None - - return q, k, v - - def p_choose( - self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None - ): - """ - Calculating step wise prob for reading and writing - 1 to read, 0 to write - - ============================================================ - Expected input size - query: bsz, tgt_len, embed_dim - key: bsz, src_len, embed_dim - value: bsz, src_len, embed_dim - key_padding_mask: bsz, src_len - attn_mask: bsz, src_len - query: bsz, tgt_len, embed_dim - """ - - # prepare inputs - q_proj, k_proj, _ = self.input_projections( - query, key, None, "monotonic" - ) - - # attention energy - attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) - - return p_choose_strategy.hard_aligned(q_proj, k_proj, attn_energy, self.noise_mean, self.noise_var, self.training) - - def expected_attention(self, alpha, *args): - """ - For MMA-H, beta = alpha - """ - return alpha - - def v_proj_output(self, value): - _, _, v_proj = self.input_projections(None, None, value, "output") - return v_proj - @register_monotonic_attention("infinite_lookback") -class MonotonicMultiheadAttentionInfiniteLookback( - MonotonicMultiheadAttentionHardAligned +class MonotonicInfiniteLookbackAttention( + MonotonicAttention ): def __init__(self, args): super().__init__(args) + self.soft_attention = True self.init_soft_attention() def init_soft_attention(self): @@ -808,80 +445,21 @@ def init_soft_attention(self): nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) - def expected_attention( - self, alpha, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], - key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] - ): - # monotonic attention, we will calculate milk here - bsz_x_num_heads, tgt_len, src_len = alpha.size() - bsz = int(bsz_x_num_heads / self.num_heads) - - q, k, _ = self.input_projections(query, key, None, "soft") - soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask) - - assert list(soft_energy.size()) == \ - [bsz, self.num_heads, tgt_len, src_len] - - soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) - - if incremental_state is not None: - monotonic_cache = self._get_monotonic_buffer(incremental_state) - head_step = monotonic_cache["head_step"] - assert head_step is not None - monotonic_length = head_step + 1 - step_offset = 0 - if key_padding_mask is not None: - if key_padding_mask[:, 0].any(): - # left_pad_source = True: - step_offset = key_padding_mask.sum(dim=-1, keepdim=True) - monotonic_length += step_offset - mask = lengths_to_mask( - monotonic_length.view(-1), - soft_energy.size(2), 1 - ).unsqueeze(1) - - soft_energy = soft_energy.masked_fill(~mask.to(torch.bool), float("-inf")) - soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] - exp_soft_energy = torch.exp(soft_energy) - exp_soft_energy_sum = exp_soft_energy.sum(dim=2) - beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) - - else: - soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] - exp_soft_energy = torch.exp(soft_energy) + self.eps - inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2)) - - beta = ( - exp_soft_energy - * torch.cumsum(inner_items.flip(dims=[2]), dim=2) - .flip(dims=[2]) - ) - - beta = beta.view(bsz, self.num_heads, tgt_len, src_len) - - if key_padding_mask is not None: - beta = beta.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0) - - beta = beta / beta.sum(dim=3, keepdim=True) - beta = beta.view(bsz * self.num_heads, tgt_len, src_len) - beta = self.dropout_module(beta) - - if torch.isnan(beta).any(): - # Something is wrong - raise RuntimeError("NaN in beta.") - - return beta - @register_monotonic_attention("waitk") -class MonotonicMultiheadAttentionWaitK( - MonotonicMultiheadAttentionInfiniteLookback +class WaitKAttention( + MonotonicInfiniteLookbackAttention ): + """ + STACL: Simultaneous Translation with Implicit Anticipation and + Controllable Latency using Prefix-to-Prefix Framework + https://www.aclweb.org/anthology/P19-1289/ + """ def __init__(self, args): super().__init__(args) self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"] + self.waitk_lagging = args.waitk_lagging assert self.waitk_lagging > 0, ( f"Lagging has to been larger than 0, get {self.waitk_lagging}." @@ -890,21 +468,52 @@ def __init__(self, args): @staticmethod def add_args(parser): super( - MonotonicMultiheadAttentionWaitK, - MonotonicMultiheadAttentionWaitK, + MonotonicInfiniteLookbackAttention, + MonotonicInfiniteLookbackAttention ).add_args(parser) parser.add_argument( "--waitk-lagging", type=int, required=True, help="Wait K lagging" ) - def p_choose( - self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, + def p_choose_from_qk( + self, + query: Optional[Tensor], + key: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): - """ - query: bsz, tgt_len - key: bsz, src_len - key_padding_mask: bsz, src_len - """ - return p_choose_strategy.waitk(query, key, self.waitk_lagging, self.num_heads, key_padding_mask, incremental_state) + assert query is not None + assert key is not None + + p_choose = waitk_p_choose( + tgt_len=query.size(0), + src_len=key.size(0), + bsz=query.size(1) * self.num_heads, + waitk_lagging=self.waitk_lagging, + key_padding_mask=key_padding_mask, + incremental_state=incremental_state, + ) + + return p_choose.to(query) + + +@register_monotonic_attention("chunkwise") +class ChunkwiseAttention( + MonotonicInfiniteLookbackAttention +): + def __init__(self, args): + super().__init__(args) + self.chunk_size = args.mocha_chunk_size + assert self.chunk_size > 1 + + @staticmethod + def add_args(parser): + super( + MonotonicInfiniteLookbackAttention + ).add_args(parser) + + parser.add_argument( + "--mocha-chunk-size", type=int, + required=True, help="Mocha chunk size" + ) diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index bcd45aa8a6..94bd71fb9c 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer +from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer from . import build_monotonic_attention -from typing import Dict, List, Optional +from typing import Dict, Optional, List -import torch from torch import Tensor +import torch class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): @@ -22,29 +22,16 @@ def forward(self, x, encoder_padding_mask): class TransformerMonotonicDecoderLayer(TransformerDecoderLayer): - def __init__( - self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False - ): - super().__init__( - args, - no_encoder_attn=True, - add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn, - ) + def __init__(self, args): + super().__init__(args) assert args.simul_type is not None, "A --simul-type is needed." - self.encoder_attn = build_monotonic_attention(args) - self.encoder_attn_layer_norm = LayerNorm( - self.embed_dim, export=getattr(args, "char_inputs", False) - ) - def get_head_steps(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): - return self.encoder_attn._get_monotonic_buffer(incremental_state).get( - "head_step" - ) - - def prune_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + def prune_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ): input_buffer = self.self_attn._get_input_buffer(incremental_state) for key in ["prev_key", "prev_value"]: input_buffer_key = input_buffer[key] @@ -58,19 +45,16 @@ def prune_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, input_buffer) - def get_steps(self, incremental_state): - return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0) - def forward( self, x, - encoder_out: Optional[torch.Tensor] = None, - encoder_padding_mask: Optional[torch.Tensor] = None, + encoder_out: Optional[Tensor] = None, + encoder_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - prev_self_attn_state: Optional[List[torch.Tensor]] = None, - prev_attn_state: Optional[List[torch.Tensor]] = None, - self_attn_mask: Optional[torch.Tensor] = None, - self_attn_padding_mask: Optional[torch.Tensor] = None, + prev_self_attn_state: Optional[List[Tensor]] = None, + prev_attn_state: Optional[List[Tensor]] = None, + self_attn_mask: Optional[Tensor] = None, + self_attn_padding_mask: Optional[Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): diff --git a/examples/simultaneous_translation/tests/test_text_models.py b/examples/simultaneous_translation/tests/test_text_models.py new file mode 100644 index 0000000000..127adfa633 --- /dev/null +++ b/examples/simultaneous_translation/tests/test_text_models.py @@ -0,0 +1,407 @@ +import argparse +import unittest +from typing import Any, Dict + +import torch +from examples.simultaneous_translation.models import ( + transformer_monotonic_attention +) + + +from tests.test_roberta import FakeTask + + +DEFAULT_CONFIG = { + "attention_eps": 1e-6, + "mass_preservation": True, + "noise_type": "flat", + "noise_mean": 0.0, + "noise_var": 1.0, + "energy_bias_init": -2, + "energy_bias": True +} + + +PAD_INDEX = 1 + + +def generate_config(overrides_kv): + new_dict = {key: value for key, value in DEFAULT_CONFIG.items()} + for key, value in overrides_kv.items(): + new_dict[key] = value + return new_dict + + +def make_sample_with_padding(longer_src=False) -> Dict[str, Any]: + tokens_1 = torch.LongTensor( + [ + [2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2], + [ + 2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2, + PAD_INDEX, PAD_INDEX + ], + ] + ) + tokens_2 = torch.LongTensor( + [ + [2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX], + [2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX] + ] + ) + if longer_src: + src_tokens = tokens_1[:, 1:] + prev_output_tokens = tokens_2 + else: + src_tokens = tokens_2[:, 1:8] + prev_output_tokens = tokens_1 + + src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long() + + sample = { + "net_input": { + "src_tokens": src_tokens, + "prev_output_tokens": prev_output_tokens, + "src_lengths": src_lengths, + }, + "target": prev_output_tokens[:, 1:], + } + return sample + + +def build_transformer_monotonic_attention(**extra_args: Any): + overrides = { + # Use characteristics dimensions + "encoder_embed_dim": 12, + "encoder_ffn_embed_dim": 14, + "decoder_embed_dim": 12, + "decoder_ffn_embed_dim": 14, + # Disable dropout so we have comparable tests. + "dropout": 0, + "attention_dropout": 0, + "activation_dropout": 0, + "encoder_layerdrop": 0, + } + overrides.update(extra_args) + # Overrides the defaults from the parser + args = argparse.Namespace(**overrides) + transformer_monotonic_attention.monotonic_tiny_architecture(args) + + torch.manual_seed(0) + task = FakeTask(args) + return ( + transformer_monotonic_attention + .TransformerModelSimulTrans + .build_model(args, task) + ) + + +def expected_alignment_formula( + p_choose, + mass_perservation=True, + padding_mask=None +): + # Online and Linear-Time Attention by Enforcing Monotonic Alignments + # https://arxiv.org/pdf/1704.00784.pdf + # Eq 18, 19 + bsz, tgt_len, src_len = p_choose.size() + alpha = torch.zeros_like(p_choose) + + if padding_mask is not None: + bsz_pad = padding_mask.size(0) + num_heads = int(bsz / bsz_pad) + padding_mask = ( + padding_mask + .unsqueeze(1) + .expand([bsz_pad, num_heads, src_len]) + .contiguous() + .view(-1, src_len) + ) + + p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0) + + for bsz_i in range(bsz): + for i in range(tgt_len): + for j in range(src_len): + if i == 0: + if j == 0: + # First source token + alpha[bsz_i, i, j] = p_choose[bsz_i, i, j] + else: + # First target token + alpha[bsz_i, i, j] = ( + p_choose[bsz_i, i, j] + * torch.prod( + 1 - p_choose[bsz_i, i, :j] + ) + ) + else: + alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j] + for k in range(j): + alpha[bsz_i, i, j] += ( + alpha[bsz_i, i - 1, k] + * torch.prod( + 1 - p_choose[bsz_i, i, k:j] + ) + ) + alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j] + + alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) + + if mass_perservation: + alpha = mass_perservation_formula(alpha, False, padding_mask) + + return alpha + + +def mass_perservation_formula(alpha, left_padding=False, padding_mask=None): + if padding_mask is None or alpha.size(-1) == 1: + if alpha.size(-1) > 1: + alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1) + return alpha + + src_lens = (padding_mask.logical_not()).sum(dim=1).long() + + bsz, tgt_len, src_len = alpha.size() + + assert ( + not left_padding + or (left_padding and (not padding_mask[:, 0].any())) + ) + + alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) + + for bsz_i in range(bsz): + if left_padding: + alpha[bsz_i, :, -1] = ( + 1 - alpha[bsz_i, :, :-1].sum(dim=-1) + ) + else: + alpha[bsz_i, :, src_lens[bsz_i] - 1] = ( + 1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1) + ) + + return alpha + + +def expected_soft_attention_formula( + alpha, + soft_energy, + padding_mask=None, + chunksize=1e10, +): + # Monotonic Infinite Lookback Attention for Simultaneous Machine Translation + # https://arxiv.org/pdf/1906.05218.pdf + # Eq 14 + + # Monotonic Chunkwise Attention + # https://arxiv.org/abs/1712.05382 + # Eq 17 + bsz, tgt_len, src_len = alpha.size() + beta = torch.zeros_like(alpha) + + if padding_mask is not None: + bsz_pad = padding_mask.size(0) + num_heads = int(bsz / bsz_pad) + # Expanding for potential head dimension + padding_mask = ( + padding_mask + .unsqueeze(1) + .expand([bsz_pad, num_heads, src_len]) + .contiguous() + .view(-1, src_len) + ) + soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf')) + + for bsz_i in range(bsz): + for i in range(tgt_len): + for j in range(src_len): + for k in range(j, min([src_len, j + chunksize])): + if not padding_mask[bsz_i, j]: + beta[bsz_i, i, j] += ( + alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j]) + / torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1])) + ) + return beta + + +class MonotonicAttentionTestAbstractClass(object): + def test_forward(self): + sample = make_sample_with_padding() + out, _ = self.model.forward(**sample["net_input"]) + loss = out.sum() + loss.backward() + + def test_p_choose(self): + sample = make_sample_with_padding() + _, extra_out = self.model.forward(**sample["net_input"]) + for item in extra_out.attn_list: + p_choose = item["p_choose"] + self.assertTrue(p_choose.le(1.0).all()) + self.assertTrue(p_choose.ge(0.0).all()) + + def test_expected_alignment(self): + for longer_src in [True, False]: + sample = make_sample_with_padding(longer_src) + _, extra_out = self.model.forward(**sample["net_input"]) + for item in extra_out.attn_list: + p_choose = item["p_choose"] + alpha_system = item["alpha"] + self.assertTrue(p_choose.size() == alpha_system.size()) + bsz, num_head, tgt_len, src_len = alpha_system.size() + alpha_system = alpha_system.view(-1, tgt_len, src_len) + p_choose = p_choose.view(-1, tgt_len, src_len) + + alpha_real = expected_alignment_formula( + p_choose, + self.model.decoder.layers[0].encoder_attn.mass_preservation, + sample["net_input"]["src_tokens"].eq(PAD_INDEX) + ) + + self.assertTrue( + torch.abs(alpha_system - alpha_real).le(5e-5).all(), + ) + + +class HardMonotonicAttentionTestCase( + unittest.TestCase, + MonotonicAttentionTestAbstractClass +): + def setUp(self): + self.model = build_transformer_monotonic_attention( + **generate_config({"simul_type": "hard_aligned"}) + ) + + +class InfiniteLookbackTestCase( + unittest.TestCase, + MonotonicAttentionTestAbstractClass +): + def setUp(self): + self.model = build_transformer_monotonic_attention( + **generate_config( + { + "simul_type": "infinite_lookback" + } + ) + ) + self.model.train() + + def test_fp16_for_long_input(self): + sample = { + "net_input": { + "src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), + "prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), + "src_lengths": torch.LongTensor([1000]).cuda(), + }, + "target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda() + } + self.model.cuda().half() + _, extra_out = self.model.forward(**sample["net_input"]) + for item in extra_out.attn_list: + for key in ["p_choose", "alpha", "beta", "soft_energy"]: + self.assertFalse(torch.isnan(item[key]).any()) + + def test_expected_attention(self): + for longer_src in [True, False]: + sample = make_sample_with_padding(longer_src) + _, extra_out = self.model.forward(**sample["net_input"]) + for item in extra_out.attn_list: + p_choose = item["p_choose"] + alpha_system = item["alpha"] + beta_system = item["beta"] + soft_energy_system = item["soft_energy"] + self.assertTrue(beta_system.size() == alpha_system.size()) + self.assertTrue(p_choose.size() == alpha_system.size()) + + bsz, num_head, tgt_len, src_len = alpha_system.size() + + alpha_system = alpha_system.view(-1, tgt_len, src_len) + beta_system = beta_system.view(-1, tgt_len, src_len) + p_choose = p_choose.view(-1, tgt_len, src_len) + soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len) + + alpha_real = expected_alignment_formula( + p_choose, + self.model.decoder.layers[0].encoder_attn.mass_preservation, + sample["net_input"]["src_tokens"].eq(PAD_INDEX) + ) + + beta_real = expected_soft_attention_formula( + alpha_real, + soft_energy_system, + sample["net_input"]["src_tokens"].eq(PAD_INDEX), + chunksize=getattr( + self.model.decoder.layers[0].encoder_attn, + "chunk_size", + int(1e10) + ) + ) + + self.assertTrue( + torch.abs(beta_system - beta_real).le(1e-5).all(), + ) + + +class ChunkwiswTestCase( + InfiniteLookbackTestCase +): + def setUp(self): + self.model = build_transformer_monotonic_attention( + **generate_config( + { + "simul_type": "chunkwise", + "mocha_chunk_size": 3 + } + ) + ) + + +class WaitkTestCase(InfiniteLookbackTestCase): + def setUp(self): + self.model = build_transformer_monotonic_attention( + **generate_config( + { + "simul_type": "waitk", + "waitk_lagging": 3, + } + ) + ) + + def check_waitk(self, p_choose, lagging, padding_mask): + bsz, tgt_len, src_len = p_choose.size() + for bsz_i in range(bsz): + for i in range(tgt_len): + for j in range(src_len): + if not padding_mask[bsz_i, j]: + if j - i == lagging - 1: + self.assertTrue(p_choose[bsz_i, i, j] == 1) + else: + self.assertTrue(p_choose[bsz_i, i, j] == 0) + + def test_waitk_p_choose(self): + for longer_src in [True, False]: + for k in [1, 3, 10, 20, 100]: + sample = make_sample_with_padding(longer_src) + model = build_transformer_monotonic_attention( + **generate_config( + { + "simul_type": "waitk", + "waitk_lagging": k, + } + ) + ) + model.train() + _, extra_out = model.forward(**sample["net_input"]) + for item in extra_out.attn_list: + p_choose = item["p_choose"] + bsz, num_heads, tgt_len, src_len = p_choose.size() + padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX) + padding_mask = ( + padding_mask + .unsqueeze(1) + .expand([bsz, num_heads, src_len]) + .contiguous() + .view(-1, src_len) + ) + p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len) + self.check_waitk(p_choose, k, padding_mask) diff --git a/examples/simultaneous_translation/utils/functions.py b/examples/simultaneous_translation/utils/functions.py index f795b5f31c..0ced35a9d5 100644 --- a/examples/simultaneous_translation/utils/functions.py +++ b/examples/simultaneous_translation/utils/functions.py @@ -6,12 +6,23 @@ import torch +def prob_check(tensor): + assert not torch.isnan(tensor).any(), ( + "Nan in a probability tensor." + ) + assert tensor.le(1.0).all() and tensor.ge(0.0).all(), ( + "Incorrect values in a probability tensor" + ", 0.0 <= tensor <= 1.0" + ) + + def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): """ Implementing exclusive cumprod. There is cumprod in pytorch, however there is no exclusive mode. cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] - exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] + exclusive means + cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] """ tensor_size = list(tensor.size()) tensor_size[dim] = 1 @@ -28,7 +39,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): elif dim == 2: return return_tensor[:, :, :-1] else: - raise RuntimeError("Cumprod on dimension 3 and more is not implemented") + raise RuntimeError( + "Cumprod on dimension 3 and more is not implemented" + ) def safe_cumprod(tensor, dim: int, eps: float = 1e-10): @@ -52,42 +65,6 @@ def safe_cumprod(tensor, dim: int, eps: float = 1e-10): return exp_cumsum_log_tensor -def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False): - """ - Convert a tensor of lengths to mask - For example, lengths = [[2, 3, 4]], max_len = 5 - mask = - [[1, 1, 1], - [1, 1, 1], - [0, 1, 1], - [0, 0, 1], - [0, 0, 0]] - """ - assert len(lengths.size()) <= 2 - if len(lengths) == 2: - if dim == 1: - lengths = lengths.t() - lengths = lengths - else: - lengths = lengths.unsqueeze(1) - - # lengths : batch_size, 1 - lengths = lengths.view(-1, 1) - - batch_size = lengths.size(0) - # batch_size, max_len - mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths - - if negative_mask: - mask = ~mask - - if dim == 0: - # max_len, batch_size - mask = mask.t() - - return mask - - def moving_sum(x, start_idx: int, end_idx: int): """ From MONOTONIC CHUNKWISE ATTENTION @@ -126,24 +103,22 @@ def moving_sum(x, start_idx: int, end_idx: int): [ 7, 17, 27], [ 4, 9, 14]] """ + # TODO: Make dimension configurable assert start_idx > 0 and end_idx > 0 - assert len(x.size()) == 2 - src_len, batch_size = x.size() + batch_size, tgt_len, src_len = x.size() + x = x.view(-1, src_len).unsqueeze(1) # batch_size, 1, src_len - x = x.t().unsqueeze(1) - # batch_size, 1, src_len - moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) + moving_sum_weight = torch.ones([1, 1, end_idx + start_idx - 1]).type_as(x) - moving_sum = ( - torch.nn.functional.conv1d( - x, moving_sum_weight, padding=start_idx + end_idx - 1 - ) - .squeeze(1) - .t() - ) - moving_sum = moving_sum[end_idx:-start_idx] + moving_sum = torch.nn.functional.conv1d( + x, moving_sum_weight, padding=start_idx + end_idx - 1 + ).squeeze(1) + + moving_sum = moving_sum[:, end_idx:-start_idx] + + assert src_len == moving_sum.size(1) + assert batch_size * tgt_len == moving_sum.size(0) - assert src_len == moving_sum.size(0) - assert batch_size == moving_sum.size(1) + moving_sum = moving_sum.view(batch_size, tgt_len, src_len) return moving_sum diff --git a/examples/simultaneous_translation/utils/monotonic_attention.py b/examples/simultaneous_translation/utils/monotonic_attention.py new file mode 100644 index 0000000000..fd45137735 --- /dev/null +++ b/examples/simultaneous_translation/utils/monotonic_attention.py @@ -0,0 +1,196 @@ +from typing import Optional +import torch +from torch import Tensor + +from examples.simultaneous_translation.utils.functions import ( + exclusive_cumprod, + prob_check, + moving_sum, +) + + +def expected_alignment_from_p_choose( + p_choose: Tensor, + padding_mask: Optional[Tensor] = None, + eps: float = 1e-6 +): + """ + Calculating expected alignment for from stepwise probability + + Reference: + Online and Linear-Time Attention by Enforcing Monotonic Alignments + https://arxiv.org/pdf/1704.00784.pdf + + q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} + a_ij = p_ij q_ij + + Parallel solution: + ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) + + ============================================================ + Expected input size + p_choose: bsz, tgt_len, src_len + """ + prob_check(p_choose) + + # p_choose: bsz, tgt_len, src_len + bsz, tgt_len, src_len = p_choose.size() + dtype = p_choose.dtype + + p_choose = p_choose.float() + + if padding_mask is not None: + p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0) + + # cumprod_1mp : bsz, tgt_len, src_len + cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) + cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) + + alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) + alpha_0[:, :, 0] = 1.0 + + previous_alpha = [alpha_0] + + for i in range(tgt_len): + # p_choose: bsz , tgt_len, src_len + # cumprod_1mp_clamp : bsz, tgt_len, src_len + # previous_alpha[i]: bsz, 1, src_len + # alpha_i: bsz, src_len + alpha_i = ( + p_choose[:, i] + * cumprod_1mp[:, i] + * torch.cumsum( + previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1 + ) + ).clamp(0, 1.0) + + previous_alpha.append(alpha_i.unsqueeze(1)) + + # alpha: bsz * num_heads, tgt_len, src_len + alpha = torch.cat(previous_alpha[1:], dim=1) + + # Mix precision to prevent overflow for fp16 + alpha = alpha.type(dtype) + + prob_check(alpha) + + return alpha + + +def expected_soft_attention( + alpha: Tensor, + soft_energy: Tensor, + padding_mask: Optional[Tensor] = None, + chunk_size: Optional[int] = None, + eps: float = 1e-10 +): + """ + Function to compute expected soft attention for + monotonic infinite lookback attention from + expected alignment and soft energy. + + Reference: + Monotonic Chunkwise Attention + https://arxiv.org/abs/1712.05382 + + Monotonic Infinite Lookback Attention for Simultaneous Machine Translation + https://arxiv.org/abs/1906.05218 + + alpha: bsz, tgt_len, src_len + soft_energy: bsz, tgt_len, src_len + padding_mask: bsz, src_len + left_padding: bool + """ + if padding_mask is not None: + alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) + soft_energy = soft_energy.masked_fill( + padding_mask.unsqueeze(1), -float("inf") + ) + + prob_check(alpha) + + dtype = alpha.dtype + + alpha = alpha.float() + soft_energy = soft_energy.float() + + soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] + exp_soft_energy = torch.exp(soft_energy) + eps + + if chunk_size is not None: + # Chunkwise + beta = ( + exp_soft_energy + * moving_sum( + alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)), + 1, chunk_size + ) + ) + else: + # Infinite lookback + # Notice that infinite lookback is a special case of chunkwise + # where chunksize = inf + inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2)) + + beta = ( + exp_soft_energy + * torch.cumsum(inner_items.flip(dims=[2]), dim=2) + .flip(dims=[2]) + ) + + if padding_mask is not None: + beta = beta.masked_fill( + padding_mask.unsqueeze(1).to(torch.bool), 0.0) + + # Mix precision to prevent overflow for fp16 + beta = beta.type(dtype) + + prob_check(beta) + + return beta + + +def mass_preservation( + alpha: Tensor, + padding_mask: Optional[Tensor] = None, + left_padding: bool = False +): + """ + Function to compute the mass perservation for alpha. + This means that the residual weights of alpha will be assigned + to the last token. + + Reference: + Monotonic Infinite Lookback Attention for Simultaneous Machine Translation + https://arxiv.org/abs/1906.05218 + + alpha: bsz, tgt_len, src_len + padding_mask: bsz, src_len + left_padding: bool + """ + + prob_check(alpha) + + if padding_mask is not None: + if not left_padding: + assert not padding_mask[:, 0].any(), ( + "Find padding on the beginning of the sequence." + ) + alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) + + if left_padding or padding_mask is None: + residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1) + alpha[:, :, -1] = residuals + else: + # right padding + _, tgt_len, src_len = alpha.size() + residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1) + src_lens = src_len - padding_mask.sum(dim=1, keepdim=True) + src_lens = src_lens.expand(-1, tgt_len).contiguous() + # add back the last value + residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1) + alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals) + + prob_check(alpha) + + return alpha diff --git a/examples/simultaneous_translation/utils/p_choose_strategy.py b/examples/simultaneous_translation/utils/p_choose_strategy.py index 308227ed96..724c6912a6 100644 --- a/examples/simultaneous_translation/utils/p_choose_strategy.py +++ b/examples/simultaneous_translation/utils/p_choose_strategy.py @@ -3,30 +3,34 @@ import torch -def waitk( - query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None, +def waitk_p_choose( + tgt_len: int, + src_len: int, + bsz: int, + waitk_lagging: int, + key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None ): + + max_src_len = src_len if incremental_state is not None: # Retrieve target length from incremental states # For inference the length of query is always 1 - tgt_len = incremental_state["steps"]["tgt"] - assert tgt_len is not None - tgt_len = int(tgt_len) + max_tgt_len = incremental_state["steps"]["tgt"] + assert max_tgt_len is not None + max_tgt_len = int(max_tgt_len) else: - tgt_len, bsz, _ = query.size() - - max_src_len, bsz, _ = key.size() + max_tgt_len = tgt_len if max_src_len < waitk_lagging: if incremental_state is not None: - tgt_len = 1 - return query.new_zeros( - bsz * num_heads, tgt_len, max_src_len + max_tgt_len = 1 + return torch.zeros( + bsz, max_tgt_len, max_src_len ) # Assuming the p_choose looks like this for wait k=3 - # src_len = 6, tgt_len = 5 + # src_len = 6, max_tgt_len = 5 # [0, 0, 1, 0, 0, 0, 0] # [0, 0, 0, 1, 0, 0, 0] # [0, 0, 0, 0, 1, 0, 0] @@ -39,21 +43,20 @@ def waitk( # 3 + 6 * 1 # ... # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 - # n from 0 to tgt_len - 1 + # n from 0 to max_tgt_len - 1 # - # First, generate the indices (activate_indices_offset: bsz, tgt_len) - # Second, scatter a zeros tensor (bsz, tgt_len * src_len) + # First, generate the indices (activate_indices_offset: bsz, max_tgt_len) + # Second, scatter a zeros tensor (bsz, max_tgt_len * src_len) # with activate_indices_offset - # Third, resize the tensor to (bsz, tgt_len, src_len) + # Third, resize the tensor to (bsz, max_tgt_len, src_len) activate_indices_offset = ( ( - torch.arange(tgt_len) * (max_src_len + 1) + torch.arange(max_tgt_len) * (max_src_len + 1) + waitk_lagging - 1 ) .unsqueeze(0) - .expand(bsz, tgt_len) - .to(query) + .expand(bsz, max_tgt_len) .long() ) @@ -71,54 +74,53 @@ def waitk( 0, min( [ - tgt_len, + max_tgt_len, max_src_len - waitk_lagging + 1 ] ) * max_src_len - 1 ) ) - p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) + p_choose = torch.zeros(bsz, max_tgt_len * max_src_len) p_choose = p_choose.scatter( 1, activate_indices_offset, 1.0 - ).view(bsz, tgt_len, max_src_len) + ).view(bsz, max_tgt_len, max_src_len) + + if key_padding_mask is not None: + p_choose = p_choose.to(key_padding_mask) + p_choose = p_choose.masked_fill(key_padding_mask.unsqueeze(1), 0) if incremental_state is not None: p_choose = p_choose[:, -1:] - tgt_len = 1 - - # Extend to each head - p_choose = ( - p_choose.contiguous() - .unsqueeze(1) - .expand(-1, num_heads, -1, -1) - .contiguous() - .view(-1, tgt_len, max_src_len) - ) - return p_choose + return p_choose.float() -def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True): +def learnable_p_choose( + energy, + noise_mean: float = 0.0, + noise_var: float = 0.0, + training: bool = True +): """ Calculating step wise prob for reading and writing 1 to read, 0 to write + energy: bsz, tgt_len, src_len """ noise = 0 if training: # add noise here to encourage discretness noise = ( - torch.normal(noise_mean, noise_var, attn_energy.size()) - .type_as(attn_energy) - .to(attn_energy.device) + torch.normal(noise_mean, noise_var, energy.size()) + .type_as(energy) + .to(energy.device) ) - p_choose = torch.sigmoid(attn_energy + noise) - _, _, tgt_len, src_len = p_choose.size() + p_choose = torch.sigmoid(energy + noise) # p_choose: bsz * self.num_heads, tgt_len, src_len - return p_choose.view(-1, tgt_len, src_len) + return p_choose From 95a9cb798dc8d6375428c6a9502234de8b4c3ec8 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Wed, 21 Jul 2021 13:29:09 -0700 Subject: [PATCH 656/707] add text compression to FileAudioDataset + speed up AddTargetDataset Summary: # Add text compression to FileAudioDataset The FileAudioDatasetstores the full manifest (mainly raw texts --- audio filenames + target texts for finetuning) in memory. This leads to large memory usage as data is duplicated in multiple workers. And it limits the scale of training data that we can use before having to switch to `BinarizedAudioDataset`. This diff aims at alleviating this limitation by in-memory text compression. This technique can be applied to any other dataset classes that store raw texts (e.g. `speech_to_text_dataset`). ### Implementation - `zlib` for low-level compression: built-in, relatively fast (as shown in the benchmarking below) - `unishox2` for high-level compression: optimized for short texts but relatively slower ### Benchmarking Tested with single process/thread. **On 3.6G TSV manifest (ASCII filenames) data** | Compression Level | CPU Mem Usage | Loading + Encoding | Decoding | |---|---|---|---| | No | 6782.18MB | 00:32 | - | | Low | 6450.04MB (95.10%) | 04:39 | 01:03 | | High | 4742.91MB (69.93%) | 08:49 | 01:31| **On 7.8G label (Arabic text) data** | Compression Level | CPU Mem Usage | Loading + Encoding | Decoding | |---|---|---|---| | No | 14623.57MB | 00:58 | - | | Low | 10773.65MB (73.67%) | 05:38 | 01:45 | | High | 7352.67MB (50.28%) | 25:03 | 04:31 | The difference on CPU memory usage will be enlarged when data is duplicated in multiple dataloading workers. # Speed up AddTargetDataset AddTargetDataset gets label length from tensorized data which is very slow --- leading to 6+hrs for batching ~8G text. Replaced it with a helper function that gets length from untokenized string data and reduced the time from 6+hrs to 10min. Reviewed By: cndn Differential Revision: D29093876 fbshipit-source-id: b6b9a8da61944771bb7c8cb57b207ed4d79d0764 --- fairseq/data/add_target_dataset.py | 21 ++++++---- fairseq/data/audio/raw_audio_dataset.py | 12 ++++-- fairseq/data/text_compressor.py | 56 +++++++++++++++++++++++++ fairseq/tasks/audio_pretraining.py | 26 +++++++++++- 4 files changed, 101 insertions(+), 14 deletions(-) create mode 100644 fairseq/data/text_compressor.py diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py index 9ef467058b..673963d0ed 100644 --- a/fairseq/data/add_target_dataset.py +++ b/fairseq/data/add_target_dataset.py @@ -6,6 +6,7 @@ import torch from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel class AddTargetDataset(BaseWrapperDataset): @@ -17,7 +18,9 @@ def __init__( eos, batch_targets, process_label=None, + label_len_fn=None, add_to_input=False, + text_compression_level=TextCompressionLevel.none ): super().__init__(dataset) self.labels = labels @@ -25,24 +28,24 @@ def __init__( self.pad = pad self.eos = eos self.process_label = process_label + self.label_len_fn = label_len_fn self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) - def get_label(self, index): - return ( - self.labels[index] - if self.process_label is None - else self.process_label(self.labels[index]) - ) + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) def __getitem__(self, index): item = self.dataset[index] - item["label"] = self.get_label(index) + item["label"] = self.get_label(index, process_fn=self.process_label) return item def size(self, index): sz = self.dataset.size(index) - own_sz = len(self.get_label(index)) - return (sz, own_sz) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz def collater(self, samples): collated = self.dataset.collater(samples) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 9ce3f7e39d..f4e965493c 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -20,6 +20,7 @@ read_from_stored_zip, is_sf_audio_data, ) +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel logger = logging.getLogger(__name__) @@ -256,6 +257,7 @@ def __init__( normalize=False, num_buckets=0, compute_mask_indices=False, + text_compression_level=TextCompressionLevel.none, **mask_compute_kwargs, ): super().__init__( @@ -269,6 +271,8 @@ def __init__( **mask_compute_kwargs, ) + self.text_compressor = TextCompressor(level=text_compression_level) + skipped = 0 self.fnames = [] sizes = [] @@ -284,7 +288,7 @@ def __init__( skipped += 1 self.skipped_indices.add(i) continue - self.fnames.append(items[0]) + self.fnames.append(self.text_compressor.compress(items[0])) sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") @@ -304,8 +308,10 @@ def __init__( def __getitem__(self, index): import soundfile as sf - - path_or_fp = os.path.join(self.root_dir, str(self.fnames[index])) + fn = self.fnames[index] + fn = fn if isinstance(self.fnames, list) else fn.as_py() + fn = self.text_compressor.decompress(fn) + path_or_fp = os.path.join(self.root_dir, fn) _path, slice_ptr = parse_path(path_or_fp) if len(slice_ptr) == 2: byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) diff --git a/fairseq/data/text_compressor.py b/fairseq/data/text_compressor.py new file mode 100644 index 0000000000..561e9ac89a --- /dev/null +++ b/fairseq/data/text_compressor.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +class TextCompressionLevel(Enum): + none = 0 + low = 1 + high = 2 + + +class TextCompressor(object): + def __init__( + self, level: TextCompressionLevel, + max_input_byte_length: int = 2 ** 16 + ): + self.level = level + self.max_input_length = max_input_byte_length + + def compress(self, text: str) -> bytes: + if self.level == TextCompressionLevel.low: + import zlib + # zlib: built-in, fast + return zlib.compress(text.encode(), level=0) + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + # unishox2: optimized for short text but slower + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + assert len(text.encode()) <= self.max_input_length + return unishox2.compress(text)[0] + else: + return text.encode() + + def decompress(self, compressed: bytes) -> str: + if self.level == TextCompressionLevel.low: + import zlib + return zlib.decompress(compressed).decode() + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + return unishox2.decompress(compressed, self.max_input_length) + else: + return compressed.decode() diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index c642ff5226..059e2d70c8 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -22,8 +22,9 @@ FileAudioDataset, encoders, ) -from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass import FairseqDataclass, ChoiceEnum from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel from . import FairseqTask, register_task from .. import utils @@ -43,6 +44,10 @@ def __call__(self, label): ) +def label_len_fn(label): + return len(label.split(" ")) + + @dataclass class InferredW2vConfig: # The following are needed to precompute mask and mask channel indices @@ -143,6 +148,13 @@ class AudioPretrainingConfig(FairseqDataclass): ) tpu: bool = II("common.tpu") + text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field( + default="none", + metadata={ + "help": "compression level for texts (e.g. audio filenames, " + "target texts): none/low/high (default: none). " + } + ) @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) @@ -198,6 +210,9 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == "ctc" + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) if getattr(task_cfg, "binarized_dataset", False): self.datasets[split] = BinarizedAudioDataset( data_path, @@ -223,6 +238,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), + text_compression_level=text_compression_level, **self._get_mask_precompute_kwargs(task_cfg), ) @@ -236,8 +252,12 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) with open(label_path, "r") as f: - labels = [line for i, line in enumerate(f) if i not in skipped_indices] + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) if i not in skipped_indices + ] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " @@ -253,7 +273,9 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, + label_len_fn=label_len_fn, add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level ) @property From 698961dc0d2eb17d15e9665f6f31fd4c1c4b58c7 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Wed, 21 Jul 2021 15:05:35 -0700 Subject: [PATCH 657/707] Make FSDP and --update-freq play nice (#3727) Summary: Previously combining FSDP with `--update-freq` would result in significant memory usage because full-size gradients would be accumulated on each GPU. We can instead skip the `no_sync` context manager in this case. The tradeoff is more communication (we do reduce-scatter on each backward), but the memory savings are likely to be worth it in most cases. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3727 Reviewed By: sshleifer Differential Revision: D29824021 Pulled By: myleott fbshipit-source-id: 2d942586cc1a9ac33fd34b8709df91dda870dd49 --- fairseq/trainer.py | 62 ++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 1602688671..d53e650b0a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -61,7 +61,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): else: self.device = torch.device("cpu") - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.is_fsdp: if self.cfg.common.bf16: raise ValueError( "FullyShardedDataParallel is not compatible with --bf16 or " @@ -72,12 +72,6 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "FullyShardedDataParallel is not compatible with --zero-sharding " "option (it's already built in)" ) - if self.cfg.optimization.update_freq[0] > 1: - logger.warning( - "Combining --update-freq with FullyShardedDataParallel will " - "result in increased memory usage, since full-sized gradients " - "will be accumulated on each GPU!" - ) else: if ( hasattr(self.cfg.distributed_training, "cpu_offload") @@ -88,7 +82,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.distributed_training.ddp_backend != "fully_sharded": + if not self.is_fsdp: if cfg.common.fp16: assert not cfg.common.amp, "Cannot use fp16 and AMP together" self._criterion = self._criterion.half() @@ -197,16 +191,14 @@ def use_distributed_wrapper(self) -> bool: return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf ) or ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.distributed_training.cpu_offload + self.is_fsdp and self.cfg.distributed_training.cpu_offload ) @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.distributed_training.use_sharded_state + self.is_fsdp and self.cfg.distributed_training.use_sharded_state ) or getattr(self.cfg.model, "base_layers", 0) > 0: return True else: @@ -214,10 +206,7 @@ def should_save_checkpoint_on_current_rank(self) -> bool: @property def always_call_state_dict_during_save_checkpoint(self) -> bool: - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and not self.cfg.distributed_training.use_sharded_state - ): + if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state: # FSDP calls communication collective when consolidating checkpoints return True else: @@ -226,10 +215,7 @@ def always_call_state_dict_during_save_checkpoint(self) -> bool: @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.distributed_training.use_sharded_state - ): + if self.is_fsdp and self.cfg.distributed_training.use_sharded_state: return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format( self.data_parallel_rank ) @@ -284,10 +270,7 @@ def _build_optimizer(self): ) ) - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.common.fp16 - ): + if self.is_fsdp and self.cfg.common.fp16: # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, # mostly for the grad scaling. But if we don't have the # --memory-efficient-fp16 flag set, then we're effectively doing @@ -319,7 +302,7 @@ def _build_optimizer(self): logger.info("NOTE: your device may support faster training with --fp16 or --amp") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.is_fsdp: assert ( not self.cfg.optimization.use_bmuf ), "--ddp-backend=fully_sharded is not compatible with BMUF" @@ -357,6 +340,10 @@ def _build_optimizer(self): ) self._lr_scheduler.step_update(0) + @property + def is_fsdp(self): + return self.cfg.distributed_training.ddp_backend == "fully_sharded" + def consolidate_optimizer(self): """For OSS, we need to consolidate the state dict.""" if self.cfg.checkpoint.no_save_optimizer_state: @@ -364,11 +351,7 @@ def consolidate_optimizer(self): self._gathered_optim_state = None if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() - - elif ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and not self.model.use_sharded_state - ): + elif self.is_fsdp and not self.model.use_sharded_state: st = self.model.gather_full_optim_state_dict( self.optimizer ) # only returns on rank 0 @@ -409,7 +392,7 @@ def state_dict(self): self._gathered_optim_state = None else: state_dict["last_optimizer_state"] = self.optimizer.state_dict() - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.is_fsdp: # save meta data for recombining checkpoint upon loading state_dict["fsdp_metadata"] = self.model.local_metadata_dict() return state_dict @@ -453,10 +436,7 @@ def load_checkpoint( # on every worker for now or self.tpu # FSDP requires loading checkpoint shards on all ranks - or ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.distributed_training.use_sharded_state - ) + or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state) or getattr(self.cfg.model, "base_layers", 0) > 0 ) @@ -527,10 +507,7 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and not self.model.use_sharded_state - ): + if self.is_fsdp and not self.model.use_sharded_state: # if use_sharded_state, the last_optim_state is already sharded, skip this last_optim_state = self.model.get_shard_from_optim_state_dict( last_optim_state @@ -702,6 +679,11 @@ def maybe_no_sync(): self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 + # The no_sync context manager results in increased memory + # usage with FSDP, since full-size gradients will be + # accumulated on each GPU. It's typically a better tradeoff + # to do the extra communication with FSDP. + and not self.is_fsdp ): return self.model.no_sync() else: @@ -1112,7 +1094,7 @@ def agg_norm_fn(total_norm): return total_norm ** 0.5 should_agg_norm = ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" + self.is_fsdp and ( self.data_parallel_process_group is not None or torch.distributed.is_initialized() From 5826a6855f05cab919f567a68b66d6d8c1817551 Mon Sep 17 00:00:00 2001 From: Ishani Karmarkar <ikarmarkar@fb.com> Date: Wed, 21 Jul 2021 23:11:59 -0700 Subject: [PATCH 658/707] Allow Moving Observer to Cuda Summary: Allow for moving observer to cuda in emulate_int8_{method} functions in order to make it compatible with pytext trainers. Reviewed By: huihuifan Differential Revision: D29686263 fbshipit-source-id: 7eddaadb17163bd4be33de1ae729621d043520d2 --- fairseq/modules/quantization/scalar/ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py index 2a855159be..cb0120fa81 100644 --- a/fairseq/modules/quantization/scalar/ops.py +++ b/fairseq/modules/quantization/scalar/ops.py @@ -20,6 +20,7 @@ def quantize(w, scale, zero_point): def emulate_int8_histogram(w, scale=None, zero_point=None): if scale is None: obs = torch.quantization.observer.HistogramObserver() + obs.to(device=w.device) _ = obs(w.float()) scale, zero_point = obs.calculate_qparams() scale = scale.cuda().type_as(w) @@ -32,6 +33,7 @@ def emulate_int8_channel(w, scale=None, zero_point=None): obs = torch.quantization.observer.PerChannelMinMaxObserver( ch_axis=-1, qscheme=torch.per_channel_symmetric ) + obs.to(device=w.device) _ = obs(w) scale, zero_point, ch_axis = obs.get_qparams() scale = scale.cuda().type_as(w) @@ -42,6 +44,7 @@ def emulate_int8_channel(w, scale=None, zero_point=None): def emulate_int8_tensor(w, scale=None, zero_point=None): if scale is None: obs = torch.quantization.observer.MinMaxObserver() + obs.to(device=w.device) _ = obs(w) scale, zero_point = obs.calculate_qparams() scale = scale.cuda().type_as(w) From 7feb8747b2f88752a5a5c5a5c85b0b3e428ca549 Mon Sep 17 00:00:00 2001 From: Ishani Karmarkar <ikarmarkar@fb.com> Date: Thu, 22 Jul 2021 05:00:38 -0700 Subject: [PATCH 659/707] Support Quantization to variable number of bits Summary: allow quantization using quant noise with variable number of bits Reviewed By: huihuifan Differential Revision: D29686262 fbshipit-source-id: 661e7d684f006af891191b32639651214df79bc4 --- fairseq/modules/quantization/scalar/ops.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py index cb0120fa81..c74f530380 100644 --- a/fairseq/modules/quantization/scalar/ops.py +++ b/fairseq/modules/quantization/scalar/ops.py @@ -7,17 +7,19 @@ def emulate_int(w, bits, method, scale=None, zero_point=None): - q = globals()[f"emulate_int{bits}_{method}"] - return q(w, scale=scale, zero_point=zero_point) + q = globals()[f"emulate_int8_{method}"] + return q(w, scale=scale, zero_point=zero_point, bits=bits) -def quantize(w, scale, zero_point): +def quantize(w, scale, zero_point, bits=8): + # In the default behavior, max_val = 255. + max_val = 2 ** bits - 1 return ( - torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point + torch.clamp(torch.round(w / scale + zero_point), 0, max_val) - zero_point ) * scale -def emulate_int8_histogram(w, scale=None, zero_point=None): +def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8): if scale is None: obs = torch.quantization.observer.HistogramObserver() obs.to(device=w.device) @@ -25,10 +27,10 @@ def emulate_int8_histogram(w, scale=None, zero_point=None): scale, zero_point = obs.calculate_qparams() scale = scale.cuda().type_as(w) zero_point = zero_point.cuda().type_as(w) - return quantize(w, scale, zero_point), scale, zero_point + return quantize(w, scale, zero_point, bits=bits), scale, zero_point -def emulate_int8_channel(w, scale=None, zero_point=None): +def emulate_int8_channel(w, scale=None, zero_point=None, bits=8): if scale is None: obs = torch.quantization.observer.PerChannelMinMaxObserver( ch_axis=-1, qscheme=torch.per_channel_symmetric @@ -38,10 +40,10 @@ def emulate_int8_channel(w, scale=None, zero_point=None): scale, zero_point, ch_axis = obs.get_qparams() scale = scale.cuda().type_as(w) zero_point = zero_point.cuda().type_as(w) - return quantize(w, scale, zero_point), scale, zero_point + return quantize(w, scale, zero_point, bits=bits), scale, zero_point -def emulate_int8_tensor(w, scale=None, zero_point=None): +def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8): if scale is None: obs = torch.quantization.observer.MinMaxObserver() obs.to(device=w.device) @@ -49,4 +51,4 @@ def emulate_int8_tensor(w, scale=None, zero_point=None): scale, zero_point = obs.calculate_qparams() scale = scale.cuda().type_as(w) zero_point = zero_point.cuda().type_as(w) - return quantize(w, scale, zero_point), scale, zero_point + return quantize(w, scale, zero_point, bits=bits), scale, zero_point From eff39d5d453497a5a6e5e998e2a920fb5f0618e1 Mon Sep 17 00:00:00 2001 From: Ishani Karmarkar <ikarmarkar@fb.com> Date: Thu, 22 Jul 2021 05:16:11 -0700 Subject: [PATCH 660/707] quantize_layers_ using variable methods Summary: allow for moving average per channel and per channel observer to be customized when calling quantize_model_ Reviewed By: huihuifan Differential Revision: D29686457 fbshipit-source-id: 37b833d27d643e6963c9ab7e0af9b8f889e1a2fc --- fairseq/modules/quantization/pq/utils.py | 6 +++++- fairseq/modules/quantization/scalar/utils.py | 10 ++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py index 03b15e4b1b..3c5ea4155d 100644 --- a/fairseq/modules/quantization/pq/utils.py +++ b/fairseq/modules/quantization/pq/utils.py @@ -152,7 +152,7 @@ def quantize_model_( return quantized_layers -def get_layers(model, filter_regexp): +def get_layers(model, filter_regexp, remove_weights=False): """ Filters out the layers according to a regexp. Note that we omit biases. @@ -181,6 +181,10 @@ def get_layers(model, filter_regexp): # remove .weight in all other names (or .weight_orig is spectral norm) all_layers = map(lambda x: x.replace(".weight_orig", ""), all_layers) + # remove weights indicates whether the weights extension should be removed, in addition to + # weight_orig and weight extension on names + if remove_weights: + all_layers = map(lambda x: x.replace(".weights", ""), all_layers) all_layers = map(lambda x: x.replace(".weight", ""), all_layers) # return filtered layers diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py index 32cf616568..76db40fec0 100644 --- a/fairseq/modules/quantization/scalar/utils.py +++ b/fairseq/modules/quantization/scalar/utils.py @@ -16,7 +16,7 @@ MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} -def quantize_model_(model, p=0.2, bits=8, update_step=3000): +def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False): """ Replaces all modules with their scalar quantized counterpart and registers hooks to quantize the post-ativations of those modules. @@ -29,7 +29,9 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000): """ # quantize all layers - quantized_layers = get_layers(model, "(.*?)") + # remove weights indicates whether the weights extension should be removed, in addition to + # weight_orig and weight extension on names + quantized_layers = get_layers(model, "(.*?)", remove_weights=remove_weights) for layer in quantized_layers: @@ -50,7 +52,7 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000): "p": p, "update_step": update_step, "bits": bits, - "method": "histogram", + "method": method, "counter": 0, } @@ -68,7 +70,7 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000): continue # activation quantization - a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram") + a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method=method) # replace layer by its quantized counterpart attrsetter(layer)(model, quantized_module) From bc3bd55ec98c39af45ff7323ae49bcbdf93acc36 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield <kheafield@fb.com> Date: Thu, 22 Jul 2021 10:38:47 -0700 Subject: [PATCH 661/707] Fix anchor link for inference & evaluation Summary: The link in the documentation was wrong due to & in the anchor. Reviewed By: xutaima Differential Revision: D29850940 fbshipit-source-id: 5024802e868f4a2f5440a35f1792bf5337fd3abf --- examples/speech_to_text/docs/simulst_mustc_example.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 52ca9ac062..f3b5a413a2 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -6,7 +6,7 @@ This is a tutorial of training and evaluating a transformer *wait-k* simultaneou ## Data Preparation This section introduces the data preparation for training and evaluation. -If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation) +If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference--evaluation) [Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with From e1adc9d388c1af907366c5f328bb7672f618a73c Mon Sep 17 00:00:00 2001 From: Edan Tessel Sneh <edan@fb.com> Date: Fri, 23 Jul 2021 10:43:24 -0700 Subject: [PATCH 662/707] hacky fix for gen_parser bug Summary: hacky fix for gen_parser bug to unblock fbtranslate Reviewed By: theweiho Differential Revision: D29865385 fbshipit-source-id: 0d92c8c67c465cec6eb309185087aec469ade713 --- fairseq/models/transformer/transformer_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py index e3ceb3c317..49fd30e502 100644 --- a/fairseq/models/transformer/transformer_base.py +++ b/fairseq/models/transformer/transformer_base.py @@ -49,6 +49,10 @@ def add_args(parser): def build_model(cls, cfg, task): """Build a new model instance.""" + # hacky fixes for issue with II + cfg.decoder.input_dim = int(cfg.decoder.input_dim) + cfg.decoder.output_dim = int(cfg.decoder.output_dim) + if cfg.encoder.layers_to_keep: cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(',')) if cfg.decoder.layers_to_keep: From 67ff6baa42c1208d0da85f5af2f01689034d1dfd Mon Sep 17 00:00:00 2001 From: Shashank Chaudhry <gandalf@fb.com> Date: Fri, 23 Jul 2021 11:07:47 -0700 Subject: [PATCH 663/707] Apply the CLANGFORMAT linter to fbcode/deeplearning/projects/fairseq-py/** Summary: Try to lint the folders in deeplearning/projects/**. Context: FBCode doesn't enable CLANGFORMAT linter over the entire repo because too many files currently don't conform to it. This is an attempt to migrate folder paths to use CLANGFORMAT conventions. Doc: https://fburl.com/clangformat-for-fbcode Actions: Please check the formatting recommendations and accept if they seem ok, or reply if there is a reason for concern. Command run: arc lint --take CLANGFORMAT -a --paths-cmd 'hg files deeplearning/projects/fairseq-py/**' Reviewed By: dianaml0 Differential Revision: D29838082 fbshipit-source-id: 3afbff42e239a6376543c4a849da4221dbf7eb32 --- .../kaldi/add-self-loop-simple.cc | 12 +- fairseq/clib/cuda/ngram_repeat_block_cuda.cpp | 32 +- .../cuda/ngram_repeat_block_cuda_kernel.cu | 32 +- fairseq/clib/libbase/balanced_assignment.cpp | 138 +++-- fairseq/clib/libbleu/libbleu.cpp | 86 +-- fairseq/clib/libbleu/module.cpp | 22 +- fairseq/clib/libnat/edit_dist.cpp | 4 +- fairseq/clib/libnat_cuda/binding.cpp | 69 ++- fairseq/clib/libnat_cuda/edit_dist.cu | 566 +++++++++--------- fairseq/clib/libnat_cuda/edit_dist.h | 16 +- fairseq/modules/cuda_utils.cu | 79 ++- .../dynamicconv_layer/dynamicconv_cuda.cpp | 45 +- .../dynamicconv_layer/dynamicconv_cuda.cuh | 29 +- .../dynamicconv_cuda_kernel.cu | 104 ++-- .../dynamicconv_layer/dynamiconv_cpu.cpp | 22 +- .../lightconv_layer/lightconv_cuda.cpp | 43 +- .../lightconv_layer/lightconv_cuda.cuh | 52 +- .../lightconv_layer/lightconv_cuda_kernel.cu | 173 +++--- 18 files changed, 798 insertions(+), 726 deletions(-) diff --git a/examples/speech_recognition/kaldi/add-self-loop-simple.cc b/examples/speech_recognition/kaldi/add-self-loop-simple.cc index 89754b925e..e18fb62df5 100644 --- a/examples/speech_recognition/kaldi/add-self-loop-simple.cc +++ b/examples/speech_recognition/kaldi/add-self-loop-simple.cc @@ -1,9 +1,9 @@ /* -* Copyright (c) Facebook, Inc. and its affiliates. -* -* This source code is licensed under the MIT license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ #include <iostream> #include "fstext/fstext-lib.h" // @manual @@ -91,4 +91,4 @@ int main(int argc, char** argv) { KALDI_LOG << "Writing FST to " << output << std::endl; delete fst; -} \ No newline at end of file +} diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp index 4199cd6ea8..707219105a 100644 --- a/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp @@ -11,10 +11,13 @@ CPP Binding for CUDA OP */ // CUDA forward declarations -torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, - torch::Tensor lprobs, int bsz, - int step, int beam_size, - int no_repeat_ngram_size); +torch::Tensor ngram_repeat_block_cuda_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size); #define CHECK_CUDA(x) \ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") @@ -26,10 +29,13 @@ torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, // Input check and call to CUDA OP // Backward method not required -torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, - torch::Tensor lprobs, int bsz, - int step, int beam_size, - int no_repeat_ngram_size) { +torch::Tensor ngram_repeat_block_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { CHECK_INPUT(tokens); CHECK_INPUT(lprobs); assert(bsz > 0); @@ -37,11 +43,13 @@ torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, assert(beam_size > 0); assert(no_repeat_ngram_size > 0); - return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, - no_repeat_ngram_size); + return ngram_repeat_block_cuda_forward( + tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &ngram_repeat_block_forward, - "No Repeat Ngram Block forward (CUDA)"); + m.def( + "forward", + &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); } diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu index b458b0916a..bd6106cba0 100644 --- a/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -14,10 +14,12 @@ Kernel implementation for blocking repeated n-grams. #include <vector> // Ban repeated ngrams of length = 'no_repeat_ngram_size' -__global__ void banRepeatedTokens(long* __restrict__ tokens, - float* __restrict__ lprobs, - int max_predict_len, int vocab_size, - int no_repeat_ngram_size) { +__global__ void banRepeatedTokens( + long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, + int vocab_size, + int no_repeat_ngram_size) { auto row = blockIdx.x; auto col = threadIdx.x; auto start = row * (max_predict_len) + col; @@ -30,10 +32,10 @@ __global__ void banRepeatedTokens(long* __restrict__ tokens, extern __shared__ long tokens_shm[]; tokens_shm[col] = tokens[start]; if (col == blockDim.x - 1) { - for (int i=1; i<no_repeat_ngram_size; i++){ - if (col+i < max_predict_len){ - tokens_shm[col + i] = tokens[start + i]; - } + for (int i = 1; i < no_repeat_ngram_size; i++) { + if (col + i < max_predict_len) { + tokens_shm[col + i] = tokens[start + i]; + } } } __syncthreads(); @@ -52,12 +54,16 @@ __global__ void banRepeatedTokens(long* __restrict__ tokens, // Allocate blocks and threads based on // batch size and sequence length and launch // kernel -torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens, - torch::Tensor lprobs, int bsz, - int step, int beam_size, - int no_repeat_ngram_size) { +torch::Tensor ngram_repeat_block_cuda_forward( + const torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { int threads = step - no_repeat_ngram_size + 2; - if (threads <= 0) return lprobs; + if (threads <= 0) + return lprobs; int max_predict_len = tokens.size(1); int vocab_size = lprobs.size(1); auto token_ptr = tokens.data_ptr<long>(); diff --git a/fairseq/clib/libbase/balanced_assignment.cpp b/fairseq/clib/libbase/balanced_assignment.cpp index 296f03b6ae..1a5a1061f3 100644 --- a/fairseq/clib/libbase/balanced_assignment.cpp +++ b/fairseq/clib/libbase/balanced_assignment.cpp @@ -8,86 +8,100 @@ /* C++ code for solving the linear assignment problem. -Based on the Auction Algorithm from https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the implementation from: -https://github.com/bkj/auction-lap -Adapted to be more efficient when each worker is looking for k jobs instead of 1. +Based on the Auction Algorithm from +https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the +implementation from: https://github.com/bkj/auction-lap Adapted to be more +efficient when each worker is looking for k jobs instead of 1. */ #include <torch/extension.h> #include <iostream> using namespace torch::indexing; torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) { - int max_iterations = 100; - torch::Tensor epsilon = (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; - epsilon.clamp_min_(1e-04); - torch::Tensor worker_and_job_to_score = job_and_worker_to_score.detach().transpose(0,1).contiguous(); - int num_workers = worker_and_job_to_score.size(0); - int num_jobs = worker_and_job_to_score.size(1); - auto device = worker_and_job_to_score.device(); - int jobs_per_worker = num_jobs / num_workers; - torch::Tensor value = worker_and_job_to_score.clone(); - int counter = 0; - torch::Tensor max_value = worker_and_job_to_score.max(); + int max_iterations = 100; + torch::Tensor epsilon = + (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; + epsilon.clamp_min_(1e-04); + torch::Tensor worker_and_job_to_score = + job_and_worker_to_score.detach().transpose(0, 1).contiguous(); + int num_workers = worker_and_job_to_score.size(0); + int num_jobs = worker_and_job_to_score.size(1); + auto device = worker_and_job_to_score.device(); + int jobs_per_worker = num_jobs / num_workers; + torch::Tensor value = worker_and_job_to_score.clone(); + int counter = 0; + torch::Tensor max_value = worker_and_job_to_score.max(); - torch::Tensor bid_indices; - torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); - torch::Tensor bids = worker_and_job_to_score.new_empty({num_workers, num_jobs}); - torch::Tensor bid_increments = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); - torch::Tensor top_values = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); - torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); + torch::Tensor bid_indices; + torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); + torch::Tensor bids = + worker_and_job_to_score.new_empty({num_workers, num_jobs}); + torch::Tensor bid_increments = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); + torch::Tensor top_values = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); + torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); - torch::Tensor top_index = top_values.to(torch::kLong); - torch::Tensor high_bidders = top_index.new_empty({num_jobs}); - torch::Tensor have_bids = high_bidders.to(torch::kBool); - torch::Tensor jobs_indices = torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); - torch::Tensor true_tensor = torch::ones({1}, torch::dtype(torch::kBool).device(device)); + torch::Tensor top_index = top_values.to(torch::kLong); + torch::Tensor high_bidders = top_index.new_empty({num_jobs}); + torch::Tensor have_bids = high_bidders.to(torch::kBool); + torch::Tensor jobs_indices = + torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); + torch::Tensor true_tensor = + torch::ones({1}, torch::dtype(torch::kBool).device(device)); - while (true) { - bids.zero_(); - torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); + while (true) { + bids.zero_(); + torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); - // Each worker bids the difference in value between that job and the k+1th job - torch::sub_out(bid_increments, - top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), - top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); + // Each worker bids the difference in value between that job and the k+1th + // job + torch::sub_out( + bid_increments, + top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), + top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); - bid_increments.add_(epsilon); - bids.scatter_(1, - top_index.index({Slice(None, None),Slice(0, jobs_per_worker)}), - bid_increments); + bid_increments.add_(epsilon); + bids.scatter_( + 1, + top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}), + bid_increments); - if (counter < max_iterations && counter > 0) { - // Put in a minimal bid to retain items from the last round if no-one else bids for them this round - bids.view(-1).index_put_({bid_indices}, epsilon); - } - - // Find the highest bidding worker per job - torch::max_out(high_bids, high_bidders, bids, 0); - torch::gt_out(have_bids, high_bids, 0); + if (counter < max_iterations && counter > 0) { + // Put in a minimal bid to retain items from the last round if no-one else + // bids for them this round + bids.view(-1).index_put_({bid_indices}, epsilon); + } - if (have_bids.all().item<bool>()) { - // All jobs were bid for - break; - } + // Find the highest bidding worker per job + torch::max_out(high_bids, high_bidders, bids, 0); + torch::gt_out(have_bids, high_bids, 0); - // Make popular items more expensive - cost.add_(high_bids); - torch::sub_out(value, worker_and_job_to_score, cost); + if (have_bids.all().item<bool>()) { + // All jobs were bid for + break; + } - bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); + // Make popular items more expensive + cost.add_(high_bids); + torch::sub_out(value, worker_and_job_to_score, cost); - if (counter < max_iterations) { - // Make sure that this item will be in the winning worker's top-k next time. - value.view(-1).index_put_({bid_indices}, max_value); - } - else { - // Suboptimal approximation that converges quickly from current solution - value.view(-1).index_put_({bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); - } + bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); - counter += 1; + if (counter < max_iterations) { + // Make sure that this item will be in the winning worker's top-k next + // time. + value.view(-1).index_put_({bid_indices}, max_value); + } else { + // Suboptimal approximation that converges quickly from current solution + value.view(-1).index_put_( + {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); } - return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}).reshape(-1); + counter += 1; + } + + return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}) + .reshape(-1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/fairseq/clib/libbleu/libbleu.cpp b/fairseq/clib/libbleu/libbleu.cpp index 3cf2d65b6d..939d9e1174 100644 --- a/fairseq/clib/libbleu/libbleu.cpp +++ b/fairseq/clib/libbleu/libbleu.cpp @@ -6,30 +6,32 @@ * LICENSE file in the root directory of this source tree. */ -#include <map> #include <array> -#include <cstring> #include <cstdio> +#include <cstring> +#include <map> -typedef struct -{ - size_t reflen; - size_t predlen; - size_t match1; - size_t count1; - size_t match2; - size_t count2; - size_t match3; - size_t count3; - size_t match4; - size_t count4; +// NOLINTNEXTLINE +typedef struct { + size_t reflen; + size_t predlen; + size_t match1; + size_t count1; + size_t match2; + size_t count2; + size_t match3; + size_t count3; + size_t match4; + size_t count4; } bleu_stat; // left trim (remove pad) void bleu_ltrim(size_t* len, int** sent, int pad) { size_t start = 0; - while(start < *len) { - if (*(*sent + start) != pad) { break; } + while (start < *len) { + if (*(*sent + start) != pad) { + break; + } start++; } *sent += start; @@ -40,7 +42,9 @@ void bleu_ltrim(size_t* len, int** sent, int pad) { void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { size_t end = *len - 1; while (end > 0) { - if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } + if (*(*sent + end) != eos && *(*sent + end) != pad) { + break; + } end--; } *len = end + 1; @@ -53,10 +57,10 @@ void bleu_trim(size_t* len, int** sent, int pad, int eos) { } size_t bleu_hash(int len, int* data) { - size_t h = 14695981039346656037ul; + size_t h = 14695981039346656037ul; size_t prime = 0x100000001b3; - char* b = (char*) data; - size_t blen = sizeof(int) * len; + char* b = (char*)data; + size_t blen = sizeof(int) * len; while (blen-- > 0) { h ^= *b++; @@ -67,15 +71,23 @@ size_t bleu_hash(int len, int* data) { } void bleu_addngram( - size_t *ntotal, size_t *nmatch, size_t n, - size_t reflen, int* ref, size_t predlen, int* pred) { - - if (predlen < n) { return; } + size_t* ntotal, + size_t* nmatch, + size_t n, + size_t reflen, + int* ref, + size_t predlen, + int* pred) { + if (predlen < n) { + return; + } predlen = predlen - n + 1; (*ntotal) += predlen; - if (reflen < n) { return; } + if (reflen < n) { + return; + } reflen = reflen - n + 1; @@ -90,7 +102,7 @@ void bleu_addngram( size_t w = bleu_hash(n, ref++); if (count[w] > 0) { (*nmatch)++; - count[w] -=1; + count[w] -= 1; } reflen--; } @@ -99,16 +111,16 @@ void bleu_addngram( extern "C" { #ifdef _WIN64 -__declspec(dllexport) +__declspec(dllexport) #endif -void bleu_zero_init(bleu_stat* stat) { + void bleu_zero_init(bleu_stat* stat) { std::memset(stat, 0, sizeof(bleu_stat)); } #ifdef _WIN64 -__declspec(dllexport) +__declspec(dllexport) #endif -void bleu_one_init(bleu_stat* stat) { + void bleu_one_init(bleu_stat* stat) { bleu_zero_init(stat); stat->count1 = 0; stat->count2 = 1; @@ -121,11 +133,16 @@ void bleu_one_init(bleu_stat* stat) { } #ifdef _WIN64 -__declspec(dllexport) +__declspec(dllexport) #endif -void bleu_add( - bleu_stat* stat, - size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { + void bleu_add( + bleu_stat* stat, + size_t reflen, + int* ref, + size_t predlen, + int* pred, + int pad, + int eos) { bleu_trim(&reflen, &ref, pad, eos); bleu_trim(&predlen, &pred, pad, eos); @@ -137,5 +154,4 @@ void bleu_add( bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); } - } diff --git a/fairseq/clib/libbleu/module.cpp b/fairseq/clib/libbleu/module.cpp index 8ed9a84b1c..35288b3177 100644 --- a/fairseq/clib/libbleu/module.cpp +++ b/fairseq/clib/libbleu/module.cpp @@ -8,20 +8,16 @@ #include <Python.h> - -static PyMethodDef method_def[] = { - {NULL, NULL, 0, NULL} -}; +static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - "libbleu", /* name of module */ - NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - method_def -}; - + PyModuleDef_HEAD_INIT, + "libbleu", /* name of module */ + // NOLINTNEXTLINE + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + method_def}; // NOLINT #if PY_MAJOR_VERSION == 2 PyMODINIT_FUNC init_libbleu() @@ -29,7 +25,7 @@ PyMODINIT_FUNC init_libbleu() PyMODINIT_FUNC PyInit_libbleu() #endif { - PyObject *m = PyModule_Create(&module_def); + PyObject* m = PyModule_Create(&module_def); if (!m) { return NULL; } diff --git a/fairseq/clib/libnat/edit_dist.cpp b/fairseq/clib/libnat/edit_dist.cpp index 6bc6a937d6..9ffb60569d 100644 --- a/fairseq/clib/libnat/edit_dist.cpp +++ b/fairseq/clib/libnat/edit_dist.cpp @@ -6,10 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include <torch/torch.h> // @manual=//caffe2:torch_extension #include <pybind11/detail/common.h> #include <pybind11/pybind11.h> -#include <vector> +#include <torch/torch.h> // @manual=//caffe2:torch_extension #include <algorithm> #include <cstdint> #include <iosfwd> @@ -17,6 +16,7 @@ #include <new> #include <string> #include <utility> +#include <vector> using namespace ::std; diff --git a/fairseq/clib/libnat_cuda/binding.cpp b/fairseq/clib/libnat_cuda/binding.cpp index aaa6244d5c..ced91c0d0a 100644 --- a/fairseq/clib/libnat_cuda/binding.cpp +++ b/fairseq/clib/libnat_cuda/binding.cpp @@ -7,54 +7,61 @@ */ /* - This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance + This code is partially adpoted from + https://github.com/1ytic/pytorch-edit-distance */ -#include "edit_dist.h" #include <torch/types.h> +#include "edit_dist.h" #ifndef TORCH_CHECK #define TORCH_CHECK AT_CHECK #endif -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) torch::Tensor LevenshteinDistance( - torch::Tensor source, - torch::Tensor target, - torch::Tensor source_length, - torch::Tensor target_length) { - - CHECK_INPUT(source); - CHECK_INPUT(target); - CHECK_INPUT(source_length); - CHECK_INPUT(target_length); - return LevenshteinDistanceCuda(source, target, source_length, target_length); + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + CHECK_INPUT(source); + CHECK_INPUT(target); + CHECK_INPUT(source_length); + CHECK_INPUT(target_length); + return LevenshteinDistanceCuda(source, target, source_length, target_length); } torch::Tensor GenerateDeletionLabel( - torch::Tensor source, - torch::Tensor operations) { - - CHECK_INPUT(source); - CHECK_INPUT(operations); - return GenerateDeletionLabelCuda(source, operations); + torch::Tensor source, + torch::Tensor operations) { + CHECK_INPUT(source); + CHECK_INPUT(operations); + return GenerateDeletionLabelCuda(source, operations); } std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel( - torch::Tensor target, - torch::Tensor operations) { - - CHECK_INPUT(target); - CHECK_INPUT(operations); - return GenerateInsertionLabelCuda(target, operations); + torch::Tensor target, + torch::Tensor operations) { + CHECK_INPUT(target); + CHECK_INPUT(operations); + return GenerateInsertionLabelCuda(target, operations); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); - m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); - m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); + m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); + m.def( + "generate_deletion_labels", + &GenerateDeletionLabel, + "Generate Deletion Label"); + m.def( + "generate_insertion_labels", + &GenerateInsertionLabel, + "Generate Insertion Label"); } diff --git a/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/clib/libnat_cuda/edit_dist.cu index 22de16b270..96569d46c8 100644 --- a/fairseq/clib/libnat_cuda/edit_dist.cu +++ b/fairseq/clib/libnat_cuda/edit_dist.cu @@ -1,332 +1,344 @@ /** -* Copyright 2017-present, Facebook, Inc. -* All rights reserved. -* -* This source code is licensed under the license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ #include "edit_dist.h" + #include <THC/THC.h> #include <cuda.h> #include <cuda_runtime.h> #include <device_launch_parameters.h> -#include <utility> // std::pair +#include <utility> // std::pair template <typename scalar_t> __global__ void generate_deletion_label_kernel( - const scalar_t* __restrict__ source, - const size_t source_size, - const size_t operation_size, - int* __restrict__ operations, - int* __restrict__ labels) { - - const int index = blockIdx.x; - const int offset = index * operation_size; - const int offset_label = index * source_size; - - for (int i = 0; i < source_size; i++) { - labels[offset_label + i] = 0; - } - - int k = 0; - for (int i = 0; i < operation_size; i++){ - if (operations[offset + i] == 0){ - break; - } else if (operations[offset + i] == 1){ - continue; - } else { - labels[offset_label + k] = 3 - operations[offset + i]; - k++; - } + const scalar_t* __restrict__ source, + const size_t source_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * source_size; + + for (int i = 0; i < source_size; i++) { + labels[offset_label + i] = 0; + } + + int k = 0; + for (int i = 0; i < operation_size; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 1) { + continue; + } else { + labels[offset_label + k] = 3 - operations[offset + i]; + k++; } + } } template <typename scalar_t> __global__ void generate_insertion_label_kernel( - const scalar_t* __restrict__ target, - const size_t target_size, - const size_t operation_size, - int* __restrict__ operations, - int* __restrict__ labels, - int* __restrict__ masks) { - - const int index = blockIdx.x; - const int offset = index * operation_size; - const int offset_label = index * target_size; - - int k = 0; - int u = 0; - int m = 0; - - for (int i = 0; i < target_size; i++) { - labels[offset_label + i] = 0; - masks[offset_label + i] = 0; - } - - for (int i = 0; i < operation_size-1; i++){ - if (operations[offset + i] == 0){ - break; - } else if (operations[offset + i] == 2){ - continue; - } else if (operations[offset + i] == 1){ - masks[offset_label + m] = 1; - u++; m++; - } else { - labels[offset_label + k] = u; - masks[offset_label + m] = 0; - k++; m++; - u = 0; - } + const scalar_t* __restrict__ target, + const size_t target_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels, + int* __restrict__ masks) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * target_size; + + int k = 0; + int u = 0; + int m = 0; + + for (int i = 0; i < target_size; i++) { + labels[offset_label + i] = 0; + masks[offset_label + i] = 0; + } + + for (int i = 0; i < operation_size - 1; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 2) { + continue; + } else if (operations[offset + i] == 1) { + masks[offset_label + m] = 1; + u++; + m++; + } else { + labels[offset_label + k] = u; + masks[offset_label + m] = 0; + k++; + m++; + u = 0; } + } } template <typename scalar_t> __global__ void levenshtein_distance_kernel( - const scalar_t* __restrict__ source, - const scalar_t* __restrict__ target, - const int* __restrict__ source_length, - const int* __restrict__ target_length, - const size_t source_size, - const size_t target_size, - int* __restrict__ operations, - int* __restrict__ errors_curr) { - - const int index = blockIdx.x; - const int offset = index * (source_size + target_size); - const int d = index * (source_size + 1) * (target_size + 1); - const int t = target_size + 1; - - auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; - auto opt_idx = [offset](int k) { return offset + k; }; - - const int hyp_len = source_length[index]; - const int ref_len = target_length[index]; - const scalar_t* hyp_begin = source + index * source_size; - const scalar_t* ref_begin = target + index * target_size; - - // dynamic programming - for (int i = 0; i <= hyp_len; i++){ - errors_curr[err_idx(i, 0)] = i; - } - for (int j = 0; j <= ref_len; j++){ - errors_curr[err_idx(0, j)] = j; - } - for (int i = 1; i <= hyp_len; i++){ - for (int j = 1; j <= ref_len; j++){ - errors_curr[err_idx(i, j)] = min( - min( - errors_curr[err_idx(i-1, j)], - errors_curr[err_idx(i, j-1)] - ) + 1, - errors_curr[err_idx(i-1, j-1)] + 2 * ( - *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1 - ) - ); - } + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations, + int* __restrict__ errors_curr) { + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int d = index * (source_size + 1) * (target_size + 1); + const int t = target_size + 1; + + auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); } + } - // back-tracing - int i = hyp_len; - int j = ref_len; - int o = hyp_len + ref_len; + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; - for (int k = 0; k < source_size + target_size; k++) { - operations[opt_idx(k)] = 0; - } + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } - while ((i >= 0) && (j >= 0)) { - if ((i == 0) && (j == 0)) { - break; - } - - if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) { - o--; operations[opt_idx(o)] = 1; j--; // insertion - } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) { - o--; operations[opt_idx(o)] = 2; i--; // deletion - } else { - o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing - } + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; } - // moving to the left - for (int k = 0; k < hyp_len + ref_len; k++) { - if (k + o < hyp_len + ref_len){ - operations[opt_idx(k)] = operations[opt_idx(k+o)]; - } else{ - operations[opt_idx(k)] = 0; // padding - } + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing } + } + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } } template <typename scalar_t> __global__ void faster_levenshtein_distance_kernel( - const scalar_t* __restrict__ source, - const scalar_t* __restrict__ target, - const int* __restrict__ source_length, - const int* __restrict__ target_length, - const size_t source_size, - const size_t target_size, - int* __restrict__ operations) { - - extern __shared__ short errors[]; - auto errors_curr = errors; - - const int index = blockIdx.x; - const int offset = index * (source_size + target_size); - const int t = target_size + 1; - - auto err_idx = [t](int i, int j) { return i * t + j; }; - auto opt_idx = [offset](int k) { return offset + k; }; - - const int hyp_len = source_length[index]; - const int ref_len = target_length[index]; - const scalar_t* hyp_begin = source + index * source_size; - const scalar_t* ref_begin = target + index * target_size; - - // dynamic programming - for (int i = 0; i <= hyp_len; i++){ - errors_curr[err_idx(i, 0)] = i; - } - for (int j = 0; j <= ref_len; j++){ - errors_curr[err_idx(0, j)] = j; - } - for (int i = 1; i <= hyp_len; i++){ - for (int j = 1; j <= ref_len; j++){ - errors_curr[err_idx(i, j)] = min( - min( - errors_curr[err_idx(i-1, j)], - errors_curr[err_idx(i, j-1)] - ) + 1, - errors_curr[err_idx(i-1, j-1)] + 2 * ( - *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1 - ) - ); - } + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations) { + extern __shared__ short errors[]; + auto errors_curr = errors; + + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int t = target_size + 1; + + auto err_idx = [t](int i, int j) { return i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); } + } - // back-tracing - int i = hyp_len; - int j = ref_len; - int o = hyp_len + ref_len; + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; - for (int k = 0; k < source_size + target_size; k++) { - operations[opt_idx(k)] = 0; - } + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } - while ((i >= 0) && (j >= 0)) { - if ((i == 0) && (j == 0)) { - break; - } - - if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) { - o--; operations[opt_idx(o)] = 1; j--; // insertion - } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) { - o--; operations[opt_idx(o)] = 2; i--; // deletion - } else { - o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing - } + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; } - // moving to the left - for (int k = 0; k < hyp_len + ref_len; k++) { - if (k + o < hyp_len + ref_len){ - operations[opt_idx(k)] = operations[opt_idx(k+o)]; - } else{ - operations[opt_idx(k)] = 0; // padding - } + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing } + } + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } } - torch::Tensor GenerateDeletionLabelCuda( - torch::Tensor source, - torch::Tensor operations) { - - const auto batch_size = source.size(0); - at::TensorOptions options(source.device()); - options = options.dtype(at::ScalarType::Int); - auto labels = torch::empty({batch_size, source.size(1)}, options); - auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); - - AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { - generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>( - source.data_ptr<scalar_t>(), - source.size(1), - operations.size(1), - operations.data_ptr<int>(), - labels.data_ptr<int>()); - })); - - return labels; + torch::Tensor source, + torch::Tensor operations) { + const auto batch_size = source.size(0); + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, source.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { + generate_deletion_label_kernel<scalar_t> + <<<batch_size, 1, 0, stream>>>( + source.data_ptr<scalar_t>(), + source.size(1), + operations.size(1), + operations.data_ptr<int>(), + labels.data_ptr<int>()); + })); + + return labels; } std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda( torch::Tensor target, torch::Tensor operations) { + const auto batch_size = target.size(0); + at::TensorOptions options(target.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, target.size(1)}, options); + auto masks = torch::empty({batch_size, target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); + + AT_DISPATCH_ALL_TYPES( + target.scalar_type(), "generate_insertion_labels", ([&] { + generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>( + target.data_ptr<scalar_t>(), + target.size(1), + operations.size(1), + operations.data_ptr<int>(), + labels.data_ptr<int>(), + masks.data_ptr<int>()); + })); -const auto batch_size = target.size(0); -at::TensorOptions options(target.device()); -options = options.dtype(at::ScalarType::Int); -auto labels = torch::empty({batch_size, target.size(1)}, options); -auto masks = torch::empty({batch_size, target.size(1)}, options); -auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); - -AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] { - generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>( - target.data_ptr<scalar_t>(), - target.size(1), - operations.size(1), - operations.data_ptr<int>(), - labels.data_ptr<int>(), - masks.data_ptr<int>()); -})); - -return std::make_pair(labels, masks); + return std::make_pair(labels, masks); } - torch::Tensor LevenshteinDistanceCuda( - torch::Tensor source, - torch::Tensor target, - torch::Tensor source_length, - torch::Tensor target_length) { - - const auto batch_size = source.size(0); - const auto shared_size = (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); - - at::TensorOptions options(source.device()); - options = options.dtype(at::ScalarType::Int); - auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options); - auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); - - if (shared_size > 40000) { - auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); - AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { - levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>( - source.data_ptr<scalar_t>(), - target.data_ptr<scalar_t>(), - source_length.data_ptr<int>(), - target_length.data_ptr<int>(), - source.size(1), - target.size(1), - operations.data_ptr<int>(), - distances.data_ptr<int>()); - })); - } else { - AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] { - faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>( - source.data_ptr<scalar_t>(), - target.data_ptr<scalar_t>(), - source_length.data_ptr<int>(), - target_length.data_ptr<int>(), - source.size(1), - target.size(1), - operations.data_ptr<int>()); + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + const auto batch_size = source.size(0); + const auto shared_size = + (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); + + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto operations = + torch::empty({batch_size, source.size(1) + target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + if (shared_size > 40000) { + auto distances = torch::empty( + {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { + levenshtein_distance_kernel<scalar_t> + <<<batch_size, 1, 0, stream>>>( + source.data_ptr<scalar_t>(), + target.data_ptr<scalar_t>(), + source_length.data_ptr<int>(), + target_length.data_ptr<int>(), + source.size(1), + target.size(1), + operations.data_ptr<int>(), + distances.data_ptr<int>()); + })); + } else { + AT_DISPATCH_ALL_TYPES( + source.scalar_type(), "faster_levenshtein_distance", ([&] { + faster_levenshtein_distance_kernel<scalar_t> + <<<batch_size, 1, shared_size, stream>>>( + source.data_ptr<scalar_t>(), + target.data_ptr<scalar_t>(), + source_length.data_ptr<int>(), + target_length.data_ptr<int>(), + source.size(1), + target.size(1), + operations.data_ptr<int>()); })); - } + } - return operations; + return operations; } diff --git a/fairseq/clib/libnat_cuda/edit_dist.h b/fairseq/clib/libnat_cuda/edit_dist.h index e3506cd34d..5220c52fd8 100644 --- a/fairseq/clib/libnat_cuda/edit_dist.h +++ b/fairseq/clib/libnat_cuda/edit_dist.h @@ -11,15 +11,15 @@ #include <torch/extension.h> torch::Tensor LevenshteinDistanceCuda( - torch::Tensor source, - torch::Tensor target, - torch::Tensor source_length, - torch::Tensor target_length); + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length); torch::Tensor GenerateDeletionLabelCuda( - torch::Tensor source, - torch::Tensor operations); + torch::Tensor source, + torch::Tensor operations); std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda( - torch::Tensor source, - torch::Tensor operations); + torch::Tensor source, + torch::Tensor operations); diff --git a/fairseq/modules/cuda_utils.cu b/fairseq/modules/cuda_utils.cu index 516f1d9244..924f852758 100644 --- a/fairseq/modules/cuda_utils.cu +++ b/fairseq/modules/cuda_utils.cu @@ -1,20 +1,17 @@ /** * Copyright (c) Facebook, Inc. and its affiliates. - * + * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ - -template <typename U, typename V> -constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { - return (a + b - 1) / b; +template <typename U, typename V> +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; } - -template<int FS, int SB, int padding_l, typename scalar_t> -__inline__ __device__ -void zeroSharedMem(scalar_t* data) { +template <int FS, int SB, int padding_l, typename scalar_t> +__inline__ __device__ void zeroSharedMem(scalar_t* data) { /* Given an array of length FS + SB, zero out the first padding_l and last (FS - padding_l) values in the array @@ -23,13 +20,11 @@ void zeroSharedMem(scalar_t* data) { int tid = threadIdx.x; if (FS < SB) { - // zero all if we have enough threads in a block to do all of them if (tid < padding_l || tid > SB - FS + padding_l - 1) { data[tid] = scalar_t(0.0); } } else { - // otherwise zero out one block at a time const int numIterations = divUp<int, int>(FS, SB); for (int i = 0; i < numIterations; i++) { @@ -43,9 +38,8 @@ void zeroSharedMem(scalar_t* data) { } } -template<typename scalar_t> -__inline__ __device__ -scalar_t warpReduce(scalar_t data) { +template <typename scalar_t> +__inline__ __device__ scalar_t warpReduce(scalar_t data) { /* Reduce an array within each warp. After processing all values in warp will caontain the sum of all original values in that warp. @@ -60,9 +54,8 @@ scalar_t warpReduce(scalar_t data) { return data; } -template<typename scalar_t> -__inline__ __device__ -scalar_t blockReduce(scalar_t data) { +template <typename scalar_t> +__inline__ __device__ scalar_t blockReduce(scalar_t data) { /* Reduce an entire array on the block level. After processing, the first value in the array will contain the reduced sum. @@ -82,7 +75,7 @@ scalar_t blockReduce(scalar_t data) { if (lane == 0) { warpSum[wid] = sum; } - + __syncthreads(); scalar_t v; @@ -102,21 +95,23 @@ scalar_t blockReduce(scalar_t data) { } void checkCudaStatus(cudaError_t status, int lineNumber = -1) { - if (status != cudaSuccess) { - std::cout << cudaGetErrorString(status) - << " at line " << lineNumber << std::endl; + std::cout << cudaGetErrorString(status) << " at line " << lineNumber + << std::endl; std::cout << "Exiting" << std::endl; exit(1); } } -template<int FS, int SB, int padding_l, typename scalar_t> -__device__ -void load_input_to_shared(const scalar_t* input, // global memory - int inputOffset, int sequenceLength, - int iteration, int numIterations, - bool no_prev, scalar_t* output /* shared memory */) { +template <int FS, int SB, int padding_l, typename scalar_t> +__device__ void load_input_to_shared( + const scalar_t* input, // global memory + int inputOffset, + int sequenceLength, + int iteration, + int numIterations, + bool no_prev, + scalar_t* output /* shared memory */) { /* Load a block size of input into shared memory with right and left overhang of total size FS. If previously @@ -138,19 +133,20 @@ void load_input_to_shared(const scalar_t* input, // global memory // Load the left "overhang" of input if (iteration > 0) { if (padding_l < SB) { - // load all at once if (tid < padding_l) { - output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; + output[tid] = + (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; } } else { - // load in chunks of size SB int numIterations = divUp<int, int>(padding_l, SB); for (int i = 0; i < numIterations; i++) { int offset = i * SB; if ((tid + offset) < padding_l) { - output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB]; + output[tid + offset] = (no_prev) + ? input[inputOffset - padding_l + tid + offset] + : output[tid + offset + SB]; } } } @@ -158,22 +154,25 @@ void load_input_to_shared(const scalar_t* input, // global memory // Load the right "overhang" of input if (iteration < (numIterations - 1)) { - const int elementsLeft = sequenceLength - (iteration+1) * SB; + const int elementsLeft = sequenceLength - (iteration + 1) * SB; if ((FS - padding_l) < SB) { - // load all at once if (tid < (FS - padding_l)) { - output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0); + output[padding_l + SB + tid] = (tid < elementsLeft) + ? input[inputOffset + SB + tid] + : scalar_t(0.0); } } else { - // load in chunks of size SB int numIterations = divUp<int, int>(FS - padding_l, SB); for (int i = 0; i < numIterations; i++) { int offset = i * SB; if ((tid + offset) < (FS - padding_l)) { - output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0); + output[padding_l + SB + tid + offset] = + ((tid + offset) < elementsLeft) + ? input[inputOffset + SB + tid + offset] + : scalar_t(0.0); } } } @@ -182,13 +181,11 @@ void load_input_to_shared(const scalar_t* input, // global memory // We should also clear out the right "overhang" if (iteration == (numIterations - 1)) { if ((FS - padding_l) < SB) { - // clear out all at once if (tid < (FS - padding_l)) { - output[padding_l + SB + tid] = scalar_t(0.0); + output[padding_l + SB + tid] = scalar_t(0.0); } } else { - // clear in chunks of size SB int numIterations = divUp<int, int>(FS - padding_l, SB); for (int i = 0; i < numIterations; i++) { @@ -199,5 +196,7 @@ void load_input_to_shared(const scalar_t* input, // global memory } } } - output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0); + output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) + ? input[inputOffset + tid] + : scalar_t(0.0); } diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp index ebd4df0e96..744c363e55 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp @@ -8,10 +8,8 @@ #include <torch/extension.h> #include <vector> -std::vector<at::Tensor> dynamicconv_cuda_forward( - at::Tensor input, - at::Tensor filters, - int padding_l); +std::vector<at::Tensor> +dynamicconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); std::vector<at::Tensor> dynamicconv_cuda_backward( at::Tensor gradOutput, @@ -19,21 +17,20 @@ std::vector<at::Tensor> dynamicconv_cuda_backward( at::Tensor input, at::Tensor filters); +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector<at::Tensor> dynamicconv_forward( - at::Tensor input, - at::Tensor filters, - int padding_l) { +std::vector<at::Tensor> +dynamicconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { + CHECK_INPUT(input); + CHECK_INPUT(filters); - CHECK_INPUT(input); - CHECK_INPUT(filters); - - return dynamicconv_cuda_forward(input, filters, - padding_l); + return dynamicconv_cuda_forward(input, filters, padding_l); } std::vector<at::Tensor> dynamicconv_backward( @@ -41,16 +38,14 @@ std::vector<at::Tensor> dynamicconv_backward( int padding_l, at::Tensor input, at::Tensor filters) { + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); - CHECK_INPUT(gradOutput); - CHECK_INPUT(input); - CHECK_INPUT(filters); - - return dynamicconv_cuda_backward(gradOutput, padding_l, - input, filters); + return dynamicconv_cuda_backward(gradOutput, padding_l, input, filters); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); - m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); } diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh index 2196259433..44baf21bdd 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh @@ -1,6 +1,6 @@ /** * Copyright (c) Facebook, Inc. and its affiliates. - * + * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ @@ -19,26 +19,25 @@ #include <utility> #include <vector> -#include <stdlib.h> #include <assert.h> #include <math.h> +#include <stdlib.h> #define SHFL_MASK 0xffffffff -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void dynamicconv_forward_kernel(const scalar_t* input, - const scalar_t* weight, - int minibatch, - int sequenceLength, - int numFeatures, - int numFiltersInBlock, - int numHeads, - scalar_t* output); +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void dynamicconv_forward_kernel( + const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output); -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void dynamicconv_backward_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void dynamicconv_backward_kernel( const scalar_t* gradOutput, // B * C * T const scalar_t* input, // B * C * T const scalar_t* weight, diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu index 300d35b647..4630f1e982 100644 --- a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu @@ -1,26 +1,26 @@ /** * Copyright (c) Facebook, Inc. and its affiliates. - * + * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +#include "../cuda_utils.cu" #include "dynamicconv_cuda.cuh" -#include "dynamicconv_cuda_forward.cu" #include "dynamicconv_cuda_backward.cu" -#include "../cuda_utils.cu" +#include "dynamicconv_cuda_forward.cu" // FS is filter size and kernels are specialized for filter sizes -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void dynamicconv_forward_kernel(const scalar_t* input, - const scalar_t* weight, - int minibatch, - int sequenceLength, - int numFeatures, - int numFiltersInBlock, - int numHeads, - scalar_t* output) { +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void dynamicconv_forward_kernel( + const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output) { assert(blockDim.x == SB); const int tid = threadIdx.x; @@ -28,8 +28,8 @@ void dynamicconv_forward_kernel(const scalar_t* input, const int featureIdx = blockIdx.y; const int head = featureIdx / numFiltersInBlock; - const int IOOffset = batchIdx * numFeatures * sequenceLength - + featureIdx * sequenceLength; + const int IOOffset = + batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; scalar_t* outputFeature = &output[IOOffset]; @@ -43,36 +43,36 @@ void dynamicconv_forward_kernel(const scalar_t* input, for (int i = 0; i < numIterations; ++i) { __syncthreads(); const int inputOffset = i * SB; - load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, - sequenceLength, i, - numIterations, false, tempInput); + load_input_to_shared<FS, SB, padding_l>( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); __syncthreads(); if (inputOffset + tid < sequenceLength) { - - #pragma unroll +#pragma unroll for (int k = 0; k < FS; ++k) { - const int filterOffset = batchIdx * numHeads * FS * sequenceLength - + head * FS * sequenceLength - + k * sequenceLength - + i * SB + tid; + const int filterOffset = batchIdx * numHeads * FS * sequenceLength + + head * FS * sequenceLength + k * sequenceLength + i * SB + tid; filter[k] = weight[filterOffset]; } scalar_t out = scalar_t(0.0); - #pragma unroll +#pragma unroll for (int k = 0; k < FS; ++k) { out += filter[k] * tempInput[tid + k]; } outputFeature[inputOffset + tid] = out; - } } } -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void dynamicconv_backward_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void dynamicconv_backward_kernel( const scalar_t* gradOutput, // B * C * T const scalar_t* input, // B * C * T const scalar_t* weight, @@ -111,52 +111,60 @@ void dynamicconv_backward_kernel( int idxOffset = inputOffset + tid + k - padding; if (idxOffset >= 0 && idxOffset < sequenceLength) { - int bfilterOffset = batchIdx * numHeads * FS * sequenceLength - + headIdx * FS * sequenceLength - + (FS - k - 1) * sequenceLength - + idxOffset; + int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength + (FS - k - 1) * sequenceLength + + idxOffset; bfilter[k] = weight[bfilterOffset]; } else { bfilter[k] = scalar_t(0.0); } } - // iterate over filter block for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { __syncthreads(); // load input and output gradient for this channel and chunk - const int IOOffset = batchIdx * numFeatures * sequenceLength - + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; + const int IOOffset = batchIdx * numFeatures * sequenceLength + + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; scalar_t* gradInputFeature = &gradInput[IOOffset]; - load_input_to_shared<FS, SB, padding>(gradOutputFeature, inputOffset, - sequenceLength, chunkIdx, - numChunks, true, tempGradOutput); - load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, - sequenceLength, chunkIdx, - numChunks, true, tempInput); + load_input_to_shared<FS, SB, padding>( + gradOutputFeature, + inputOffset, + sequenceLength, + chunkIdx, + numChunks, + true, + tempGradOutput); + load_input_to_shared<FS, SB, padding_l>( + inputFeature, + inputOffset, + sequenceLength, + chunkIdx, + numChunks, + true, + tempInput); __syncthreads(); - + // sum input and weight gradients scalar_t out = scalar_t(0.0); - #pragma unroll +#pragma unroll for (int k = 0; k < FS; ++k) { tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; out += bfilter[k] * tempGradOutput[tid + k]; } - + if (inputOffset + tid < sequenceLength) { gradInputFeature[inputOffset + tid] = out; } } - const int gradOffset = batchIdx * numHeads * FS * sequenceLength - + headIdx * FS * sequenceLength; - scalar_t *gradWeightFeature = &gradWeight[gradOffset]; + const int gradOffset = + batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength; + scalar_t* gradWeightFeature = &gradWeight[gradOffset]; // write weight gradient if (inputOffset + tid < sequenceLength) { diff --git a/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp index 8a6af4285d..d7e57c8590 100644 --- a/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp +++ b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp @@ -1,10 +1,8 @@ #include <torch/torch.h> #include <vector> -std::vector<float*> dynamicconv_cpu_forward( - float* input, - float* filters, - int padding_l); +std::vector<float*> +dynamicconv_cpu_forward(float* input, float* filters, int padding_l); std::vector<float*> dynamicconv_cpu_backward( float* gradOutput, @@ -12,12 +10,9 @@ std::vector<float*> dynamicconv_cpu_backward( float* input, float* filters); -std::vector<float*> dynamicconv_forward( - float* input, - float* filters, - int padding_l) { - - return dynamicconv_cpu_forward(input, filters, padding_l); +std::vector<float*> +dynamicconv_forward(float* input, float* filters, int padding_l) { + return dynamicconv_cpu_forward(input, filters, padding_l); } std::vector<float*> dynamicconv_backward( @@ -25,11 +20,10 @@ std::vector<float*> dynamicconv_backward( int padding_l, float* input, float* filters) { - - return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); + return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); - m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); } diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cpp b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp index 4bf6b5ad36..ece47a8d90 100644 --- a/fairseq/modules/lightconv_layer/lightconv_cuda.cpp +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp @@ -8,10 +8,8 @@ #include <torch/extension.h> #include <vector> -std::vector<at::Tensor> lightconv_cuda_forward( - at::Tensor input, - at::Tensor filters, - int padding_l); +std::vector<at::Tensor> +lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); std::vector<at::Tensor> lightconv_cuda_backward( at::Tensor gradOutput, @@ -19,20 +17,20 @@ std::vector<at::Tensor> lightconv_cuda_backward( at::Tensor input, at::Tensor filters); +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector<at::Tensor> lightconv_forward( - at::Tensor input, - at::Tensor filters, - int padding_l) { +std::vector<at::Tensor> +lightconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { + CHECK_INPUT(input); + CHECK_INPUT(filters); - CHECK_INPUT(input); - CHECK_INPUT(filters); - - return lightconv_cuda_forward(input, filters, padding_l); + return lightconv_cuda_forward(input, filters, padding_l); } std::vector<at::Tensor> lightconv_backward( @@ -40,15 +38,14 @@ std::vector<at::Tensor> lightconv_backward( int padding_l, at::Tensor input, at::Tensor filters) { + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); - CHECK_INPUT(gradOutput); - CHECK_INPUT(input); - CHECK_INPUT(filters); - - return lightconv_cuda_backward(gradOutput, padding_l, input, filters); + return lightconv_cuda_backward(gradOutput, padding_l, input, filters); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); - m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); + m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); + m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); } diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cuh b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh index 3cae57b68f..610ab399e9 100644 --- a/fairseq/modules/lightconv_layer/lightconv_cuda.cuh +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh @@ -1,6 +1,6 @@ /** * Copyright (c) Facebook, Inc. and its affiliates. - * + * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ @@ -18,23 +18,24 @@ #include <utility> #include <vector> -#include <stdlib.h> #include <assert.h> +#include <stdlib.h> #define SHFL_MASK 0xffffffff -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_forward_kernel(const scalar_t* input, - const scalar_t* filters, - int minibatch, int sequenceLength, - int numFeatures, int numFiltersInBlock, - scalar_t* output); +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_forward_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output); -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_input_kernel( - const scalar_t* input, +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_input_kernel( + const scalar_t* input, const scalar_t* filters, int minibatch, int sequenceLength, @@ -42,9 +43,8 @@ void lightconv_grad_wrt_input_kernel( int numFiltersInBlock, scalar_t* output); -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_firstpass_short_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, @@ -54,17 +54,15 @@ void lightconv_grad_wrt_weights_firstpass_short_kernel( int numHeads, float* output); -template<int FS, int SB, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_secondpass_short_kernel( +template <int FS, int SB, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( const float* input, - const int minibatch, + const int minibatch, const int numFiltersInBlock, scalar_t* output); -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_firstpass_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_firstpass_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, @@ -73,11 +71,9 @@ void lightconv_grad_wrt_weights_firstpass_kernel( int numFiltersInBlock, float* output); -template<int FS, int SB, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_secondpass_kernel( +template <int FS, int SB, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_secondpass_kernel( const float* input, - const int minibatch, + const int minibatch, const int numFiltersInBlock, scalar_t* output); - diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu index 8ee83a56c8..cdf31d5d2d 100644 --- a/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu +++ b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu @@ -1,29 +1,31 @@ /** * Copyright (c) Facebook, Inc. and its affiliates. - * + * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +#include "../cuda_utils.cu" #include "lightconv_cuda.cuh" -#include "lightconv_cuda_forward.cu" #include "lightconv_cuda_backward.cu" -#include "../cuda_utils.cu" - -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_forward_kernel(const scalar_t* input, - const scalar_t* filters, - int minibatch, int sequenceLength, - int numFeatures, int numFiltersInBlock, - scalar_t* output) { +#include "lightconv_cuda_forward.cu" +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_forward_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output) { const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int featureIdx = blockIdx.y; const int filterIdx = featureIdx / numFiltersInBlock; - const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const int IOOffset = + numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; scalar_t* outputFeature = &output[IOOffset]; const scalar_t* inputFilter = &filters[filterIdx * FS]; @@ -31,7 +33,7 @@ void lightconv_forward_kernel(const scalar_t* input, assert(blockDim.x == SB); scalar_t filter[FS]; - #pragma unroll +#pragma unroll for (int i = 0; i < FS; ++i) { filter[i] = inputFilter[i]; } @@ -45,13 +47,19 @@ void lightconv_forward_kernel(const scalar_t* input, // Read input into shared memory const int inputOffset = i * SB; - load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength, - i, numIterations, (numIterations == 1), temp); + load_input_to_shared<FS, SB, padding_l>( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + (numIterations == 1), + temp); __syncthreads(); scalar_t out = 0; - #pragma unroll +#pragma unroll for (int j = 0; j < FS; ++j) { out += filter[j] * temp[tid + j]; } @@ -66,9 +74,8 @@ void lightconv_forward_kernel(const scalar_t* input, } } -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_input_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_input_kernel( const scalar_t* input, const scalar_t* filters, int minibatch, @@ -76,14 +83,14 @@ void lightconv_grad_wrt_input_kernel( int numFeatures, int numFiltersInBlock, scalar_t* output) { - // input grad kernel is similar to forward kernel const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int featureIdx = blockIdx.y; const int filterIdx = featureIdx / numFiltersInBlock; - const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const int IOOffset = + numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; scalar_t* outputFeature = &output[IOOffset]; const scalar_t* inputFilter = &filters[filterIdx * FS]; @@ -92,8 +99,8 @@ void lightconv_grad_wrt_input_kernel( scalar_t filter[FS]; - // The only change is loading the filter in reverse - #pragma unroll +// The only change is loading the filter in reverse +#pragma unroll for (int i = 0; i < FS; ++i) { filter[i] = inputFilter[FS - i - 1]; } @@ -110,13 +117,19 @@ void lightconv_grad_wrt_input_kernel( // Read input into shared memory const int inputOffset = i * SB; - load_input_to_shared<FS, SB, padding>(inputFeature, inputOffset, sequenceLength, - i, numIterations, false, temp); + load_input_to_shared<FS, SB, padding>( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + temp); __syncthreads(); scalar_t out = 0; - #pragma unroll +#pragma unroll for (int j = 0; j < FS; ++j) { out += filter[j] * temp[tid + j]; } @@ -133,9 +146,8 @@ void lightconv_grad_wrt_input_kernel( // This is by far the most expensive kernel in terms of time taken. // Can be 16x slower than the forward or grad_wrt_input when filter size is 31 -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_firstpass_short_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, @@ -144,7 +156,6 @@ void lightconv_grad_wrt_weights_firstpass_short_kernel( int numFiltersInBlock, int numHeads, float* output) { - const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int filterIdx = blockIdx.y; @@ -166,52 +177,60 @@ void lightconv_grad_wrt_weights_firstpass_short_kernel( accumWeights[i] = float(0.0); } - // loop over each sequence within filterblock - for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) { - - const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength; + for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; + ++idxInFilterBlock) { + const int featureOffset = batchIdx * numFeatures * sequenceLength + + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength; const scalar_t* inputFeature = &input[featureOffset]; const scalar_t* gradInputFeature = &gradInput[featureOffset]; zeroSharedMem<FS, SB, padding_l>(tempInput); - zeroSharedMem<FS, SB, (FS/2)>(tempGradInput); + zeroSharedMem<FS, SB, (FS / 2)>(tempGradInput); __syncthreads(); for (int i = 0; i < numIterations; ++i) { - const int inputOffset = i * SB; - load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength, - i, numIterations, false, tempInput); - load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength, - i, numIterations, false, tempGradInput); + load_input_to_shared<FS, SB, padding_l>( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); + load_input_to_shared<FS, SB, (FS / 2)>( + gradInputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempGradInput); __syncthreads(); - const int gradIndex = (FS/2) + tid; + const int gradIndex = (FS / 2) + tid; scalar_t tempGrad = tempGradInput[gradIndex]; - #pragma unroll +#pragma unroll for (int j = 0; j < FS; j++) { const int inputIndex = tid + j; accumWeights[j] += tempInput[inputIndex] * tempGrad; } __syncthreads(); - } - } // Row-major sum for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { - float temp; if (tid < sequenceLength) { - temp = accumWeights[filterWeightIdx]; + temp = accumWeights[filterWeightIdx]; } else { - temp = float(0.0); + temp = float(0.0); } const int outputOffset = filterWeightIdx * minibatch + batchIdx; @@ -224,14 +243,12 @@ void lightconv_grad_wrt_weights_firstpass_short_kernel( } } -template<int FS, int SB, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_secondpass_short_kernel( +template <int FS, int SB, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( const float* input, const int minibatch, const int numFiltersInBlock, scalar_t* output) { - assert(blockDim.x == SB); const int tid = threadIdx.x; @@ -239,8 +256,8 @@ void lightconv_grad_wrt_weights_secondpass_short_kernel( const int filterIdx = blockIdx.x; const int filterWeightIdx = blockIdx.y; - const int inputOffset = filterIdx * FS * minibatch + - filterWeightIdx * minibatch; + const int inputOffset = + filterIdx * FS * minibatch + filterWeightIdx * minibatch; const float* tempInput = &input[inputOffset]; // read into shared memory for reduction @@ -261,9 +278,8 @@ void lightconv_grad_wrt_weights_secondpass_short_kernel( // This is by far the most expensive kernel in terms of time taken. // Can be 16x slower than the forward or grad_wrt_input when filter size is 31 -template<int FS, int SB, int padding_l, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_firstpass_kernel( +template <int FS, int SB, int padding_l, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_firstpass_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, @@ -271,7 +287,6 @@ void lightconv_grad_wrt_weights_firstpass_kernel( int numFeatures, int numFiltersInBlock, float* output) { - assert(blockDim.x == SB); const int tid = threadIdx.x; @@ -287,7 +302,7 @@ void lightconv_grad_wrt_weights_firstpass_kernel( __shared__ scalar_t tempInput[SB + FS]; __shared__ scalar_t tempGradInput[SB + FS]; zeroSharedMem<FS, SB, padding_l>(tempInput); - zeroSharedMem<FS, SB, (FS/2)>(tempGradInput); + zeroSharedMem<FS, SB, (FS / 2)>(tempGradInput); __syncthreads(); float accumWeights[FS]; @@ -296,23 +311,37 @@ void lightconv_grad_wrt_weights_firstpass_kernel( accumWeights[i] = float(0.0); } - const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; + const int IOOffset = + batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; const scalar_t* gradInputFeature = &gradInput[IOOffset]; - float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock]; + float* tempOutputGradWeight = + &output[filterIdx * FS * minibatch * numFiltersInBlock]; for (int i = 0; i < numIterations; ++i) { const int inputOffset = i * SB; - load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength, - i, numIterations, false, tempInput); - load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength, - i, numIterations, false, tempGradInput); + load_input_to_shared<FS, SB, padding_l>( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); + load_input_to_shared<FS, SB, (FS / 2)>( + gradInputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempGradInput); __syncthreads(); - #pragma unroll +#pragma unroll for (int j = 0; j < FS; ++j) { - accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)]; + accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS / 2)]; } __syncthreads(); @@ -320,7 +349,6 @@ void lightconv_grad_wrt_weights_firstpass_kernel( // Row-major sum for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { - // Write to shared memory before reduction if (tid < sequenceLength) { temp = accumWeights[filterWeightIdx]; @@ -331,8 +359,7 @@ void lightconv_grad_wrt_weights_firstpass_kernel( temp = blockReduce(temp); const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock + - batchIdx * numFiltersInBlock + - idxInFilterBlock; + batchIdx * numFiltersInBlock + idxInFilterBlock; if (tid == 0) { tempOutputGradWeight[outputOffset] = temp; @@ -340,14 +367,12 @@ void lightconv_grad_wrt_weights_firstpass_kernel( } } -template<int FS, int SB, typename scalar_t> -__global__ -void lightconv_grad_wrt_weights_secondpass_kernel( +template <int FS, int SB, typename scalar_t> +__global__ void lightconv_grad_wrt_weights_secondpass_kernel( const float* input, const int minibatch, const int numFiltersInBlock, scalar_t* output) { - assert(blockDim.x == SB); const int tid = threadIdx.x; @@ -356,7 +381,7 @@ void lightconv_grad_wrt_weights_secondpass_kernel( const int filterWeightIdx = blockIdx.y; const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock + - filterWeightIdx * minibatch * numFiltersInBlock; + filterWeightIdx * minibatch * numFiltersInBlock; const float* tempInput = &input[inputOffset]; int readIndex = tid; From 53802e781291b63e656c89818c38bfc49ff0f108 Mon Sep 17 00:00:00 2001 From: Omry Yadan <omry@fb.com> Date: Mon, 26 Jul 2021 16:35:40 -0700 Subject: [PATCH 664/707] Compatibility fix with Hydra 1.1 (#3722) Summary: One of the changes in Hydra 1.1 is that the default composition order is changing. This is documented [here](https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order). In Hydra 1.1, a config is overriding values introduced by the defaults list while in Hydra 1.0 - the defaults list was overriding the values in the config. fairseq is currently depending on the previous behavior: The class `FairseqConfig` defines config values, and it's expecting them to be overridden by the defaults list. This result in a different config being created when running `fairseq_cli/hydra_train.py` with Hydra 1.0 and with 1.1. Hydra 1.1 introduced the `_self_` keyword in the defaults list to control the composition order. In order to achieve the behavior of Hydra 1.0, `_self_` should be added as the first item in the defaults list. To allow for a smoother migration, Hydra 1.0 is ignoring `_self_` starting from 1.0.7 (previous versions will issue an error). This diff adds `_self_` as the first item in the defaults list the fairseq config, and introduce a dependency a Hydra 1.0 version that is equal or newer to 1.0.7. ### Testing: I ensured that the following yield the same composed config: Default config with Hydra 1.0.6, 1.0.7 and 1.1.0 `examples/wav2vec/config/finetuning/base_10h.yaml` with Hydra 1.0.6, 1.0.7 and 1.1.0. This can be achieved by outputing the generated config using `--cfg job` and compating the outputs. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3722 Reviewed By: dianaml0 Differential Revision: D29917677 Pulled By: jieru-hu fbshipit-source-id: 7e645b83cccb03fc80a6702e302c4643d2b14a78 --- fairseq/config/config.yaml | 1 + fairseq/optim/fp16_optimizer.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index e20d914b9b..2ed7168cb7 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -5,6 +5,7 @@ hydra: dir: . defaults: + - _self_ - task: null - model: null - criterion: cross_entropy diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 370a910102..b84236e685 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -479,7 +479,7 @@ def __init__( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) - super().__init__(cfg.optimizer) + super().__init__(getattr(cfg, "optimizer", None)) self.wrapped_optimizer = optimizer if getattr(cfg.common, "fp16_scale_window", None) is None: diff --git a/setup.py b/setup.py index 51e555229c..7a19b73c9e 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,7 @@ def do_setup(package_data): "cffi", "cython", 'dataclasses; python_version<"3.7"', - "hydra-core<1.1", + "hydra-core>=1.0.7,<1.1", "omegaconf<2.1", 'numpy<1.20.0; python_version<"3.7"', 'numpy; python_version>="3.7"', From 0769cfe2e9ecd8c2dd15cb2491474ef0b4b3d0e2 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh <sivaibhav@google.com> Date: Tue, 27 Jul 2021 12:59:08 -0700 Subject: [PATCH 665/707] Fixed the reference to mask_channel_prob in task cfg (#3742) Summary: Updated example config file for tpu to include mask parameters. Noted currently cli bug in README ## What does this PR do? Fixes # (3741) B. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3742 Reviewed By: arbabu123 Differential Revision: D29938257 Pulled By: alexeib fbshipit-source-id: 6ab5cd2974949806621fb37cb13d918bea733a73 --- examples/wav2vec/README.md | 5 +++-- .../config/pretraining/wav2vec2_large_librivox_tpu.yaml | 5 +++++ fairseq/tasks/audio_pretraining.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index c543b6b97b..badfda8979 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -196,7 +196,7 @@ target_transcription = "A MAN SAID TO THE UNIVERSE I EXIST" # encode labels with processor.as_target_processor(): - labels = processor(target_transcription, return_tensors="pt").input_ids + labels = processor(target_transcription, return_tensors="pt").input_ids # compute loss by passing labels loss = model(input_values, labels=labels).loss @@ -263,6 +263,7 @@ $ OMP_NUM_THREADS=1 fairseq-hydra-train \ ``` #### Using command line arguments on a v3-8: +Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently. ``` $ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ @@ -285,7 +286,7 @@ $ OMP_NUM_THREADS=1 fairseq-hydra-train \ ``` #### Using command line arguments on a pod slice (v3-N with N > 8): - +Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently. ``` $ python -m torch_xla.distributed.xla_dist \ diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml index 2036e23c6b..ee55bdab72 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -20,6 +20,11 @@ task: num_batch_buckets: 3 precompute_mask_indices: true enable_padding: true + inferred_w2v_config: + mask_prob: 0.65 + mask_selection: 'static' + mask_other: 0 + mask_channel_prob: 0.1 dataset: num_workers: 6 diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 059e2d70c8..b7d0f3da57 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -242,7 +242,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): **self._get_mask_precompute_kwargs(task_cfg), ) - if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0: + if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0: logger.info( "Pretraining on TPUs may suffer convergence " "issues when training with `mask_channel_prob` value of " From 7ca95a66d64411deea49cb8710195ed2e0699f0a Mon Sep 17 00:00:00 2001 From: Yun Tang <yuntang@fb.com> Date: Tue, 27 Jul 2021 13:23:07 -0700 Subject: [PATCH 666/707] Add speech/text joint training for speech to text task (step 1) Summary: 1. adding feature to generate raw audio with target sampling rate 2. fix bugs a) empty sample in transform_eos_lang_pair_dataset.py b) S2T decoding with language model Reviewed By: kahne Differential Revision: D29699692 fbshipit-source-id: cc4b76618ef3b43dbba53a422f24597b9866d17f --- fairseq/data/audio/audio_utils.py | 82 +++-- fairseq/data/audio/speech_to_text_dataset.py | 346 +++++++++++------- .../data/transform_eos_lang_pair_dataset.py | 2 + .../models/speech_to_text/s2t_transformer.py | 16 +- fairseq/tasks/speech_to_text.py | 23 +- 5 files changed, 290 insertions(+), 179 deletions(-) diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index f51cb0cddc..7c2638dc0c 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -9,31 +9,46 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} -def _convert_to_mono( - waveform: torch.FloatTensor, sample_rate: int -) -> torch.FloatTensor: - if waveform.shape[0] > 1: - try: - import torchaudio.sox_effects as ta_sox - except ImportError: - raise ImportError( - "Please install torchaudio to convert multi-channel audios" - ) - effects = [['channels', '1']] - return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] +def update_sample_rate( + waveform: np.ndarray, + sample_rate: int, + tgt_sample_rate: int, +) -> np.ndarray: + if tgt_sample_rate > 0 and tgt_sample_rate != sample_rate: + _waveform = torch.from_numpy(waveform) + effects = [["rate", f"{tgt_sample_rate}"]] + return _sox_convert(_waveform, sample_rate, effects).numpy() return waveform +def _sox_convert( + waveform: torch.FloatTensor, + sample_rate: int, + effects: List[List[str]], +) -> torch.FloatTensor: + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError("Please install torchaudio to convert audios") + return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] + + def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray: if waveform.shape[0] > 1: _waveform = torch.from_numpy(waveform) - return _convert_to_mono(_waveform, sample_rate).numpy() + effects = [["channels", "1"]] + return _sox_convert(_waveform, sample_rate, effects).numpy() return waveform def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization=True, mono=True, - frames=-1, start=0, always_2d=True + path_or_fp: Union[str, BinaryIO], + normalization=True, + mono=True, + frames=-1, + start=0, + always_2d=True, + output_sample_rate=-1, ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. @@ -44,6 +59,7 @@ def get_waveform( frames (int): the number of frames to read. (-1 for reading all) start (int): Where to start reading. A negative value counts from the end. always_2d (bool): always return 2D array even for mono-channel audios + output_sample_rate (int): output sample rate, -1 using default Returns: waveform (numpy.ndarray): 1D or 2D waveform (channels x length) sample_rate (float): sample rate @@ -56,9 +72,7 @@ def get_waveform( try: import soundfile as sf except ImportError: - raise ImportError( - "Please install soundfile to load WAV/FLAC/OGG Vorbis audios" - ) + raise ImportError("Please install soundfile to load WAV/FLAC/OGG Vorbis audios") waveform, sample_rate = sf.read( path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start @@ -66,6 +80,9 @@ def get_waveform( waveform = waveform.T # T x C -> C x T if mono and waveform.shape[0] > 1: waveform = convert_to_mono(waveform, sample_rate) + if output_sample_rate > 0: + waveform = update_sample_rate(waveform, sample_rate, output_sample_rate) + sample_rate = output_sample_rate if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers if not always_2d: @@ -74,12 +91,12 @@ def get_waveform( def _get_kaldi_fbank( - waveform: np.ndarray, sample_rate: int, n_bins=80 + waveform: np.ndarray, sample_rate: int, n_bins=80 ) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: - from kaldi.feat.mel import MelBanksOptions from kaldi.feat.fbank import FbankOptions, Fbank + from kaldi.feat.mel import MelBanksOptions from kaldi.feat.window import FrameExtractionOptions from kaldi.matrix import Vector @@ -98,11 +115,12 @@ def _get_kaldi_fbank( def _get_torchaudio_fbank( - waveform: np.ndarray, sample_rate, n_bins=80 + waveform: np.ndarray, sample_rate, n_bins=80 ) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: import torchaudio.compliance.kaldi as ta_kaldi + waveform = torch.from_numpy(waveform) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate @@ -136,9 +154,9 @@ def is_npy_data(data: bytes) -> bool: def is_sf_audio_data(data: bytes) -> bool: - is_wav = (data[0] == 82 and data[1] == 73 and data[2] == 70) - is_flac = (data[0] == 102 and data[1] == 76 and data[2] == 97) - is_ogg = (data[0] == 79 and data[1] == 103 and data[2] == 103) + is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70 + is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97 + is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103 return is_wav or is_flac or is_ogg @@ -151,16 +169,16 @@ def read_from_stored_zip(zip_path: str, offset: int, file_size: int) -> bytes: def parse_path(path: str) -> Tuple[str, List[int]]: """Parse data path which is either a path to - 1. a .npy/.wav/.flac/.ogg file - 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" + 1. a .npy/.wav/.flac/.ogg file + 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" - Args: - path (str): the data path to parse + Args: + path (str): the data path to parse - Returns: - file_path (str): the file path - slice_ptr (list of int): empty in case 1; - byte offset and length for the slice in case 2 + Returns: + file_path (str): the file path + slice_ptr (list of int): empty in case 1; + byte offset and length for the slice in case 2 """ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index d4b5668d8f..ba6c28632e 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -6,9 +6,10 @@ import csv import io import logging -import os.path as op import re -from typing import Dict, List, Optional, Tuple +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, NamedTuple import numpy as np import torch @@ -20,8 +21,13 @@ data_utils as fairseq_data_utils, ) from fairseq.data.audio.audio_utils import ( - get_fbank, get_waveform, read_from_stored_zip, is_npy_data, - is_sf_audio_data, parse_path, FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS + get_fbank, + get_waveform, + read_from_stored_zip, + is_npy_data, + is_sf_audio_data, + parse_path, + FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS, ) from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform @@ -32,20 +38,22 @@ class S2TDataConfig(object): """Wrapper class for data config YAML""" - def __init__(self, yaml_path): + def __init__(self, yaml_path: Path): try: import yaml except ImportError: - print("Please install PyYAML to load YAML files for " "S2T data config") + print("Please install PyYAML to load YAML files for S2T data config") self.config = {} - if op.isfile(yaml_path): + if yaml_path.is_file(): try: with open(yaml_path) as f: self.config = yaml.load(f, Loader=yaml.FullLoader) except Exception as e: - raise Exception(f"Failed to load config from {yaml_path}: {e}") + raise Exception( + f"Failed to load config from {yaml_path.as_posix()}: {e}" + ) else: - raise FileNotFoundError(f"{yaml_path} not found") + raise FileNotFoundError(f"{yaml_path.as_posix()} not found") @property def vocab_filename(self): @@ -102,6 +110,12 @@ def use_audio_input(self): raw audio as inputs.""" return self.config.get("use_audio_input", False) + @property + def use_sample_rate(self): + """Needed by the dataset loader to see if the model requires + raw audio with specific sample rate as inputs.""" + return self.config.get("use_sample_rate", 16000) + @property def audio_root(self): """Audio paths in the manifest TSV can be relative and this provides @@ -124,14 +138,18 @@ def get_feature_transforms(self, split, is_train): def get_features_from_npy_or_audio(path): - ext = op.splitext(op.basename(path))[1] + ext = Path(path).suffix if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: raise ValueError(f'Unsupported file format for "{path}"') return np.load(path) if ext == ".npy" else get_fbank(path) def get_features_or_waveform_from_stored_zip( - path, byte_offset, byte_size, need_waveform=False + path, + byte_offset, + byte_size, + need_waveform=False, + use_sample_rate=-1, ): assert path.endswith(".zip") data = read_from_stored_zip(path, byte_offset, byte_size) @@ -139,14 +157,17 @@ def get_features_or_waveform_from_stored_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_sf_audio_data(data): - features_or_waveform = \ - get_waveform(f, always_2d=False)[0] if need_waveform else get_fbank(f) + features_or_waveform = ( + get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0] + if need_waveform + else get_fbank(f) + ) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform -def get_features_or_waveform(path: str, need_waveform=False): +def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=-1): """Get speech features from .npy file or waveform from .wav/.flac file. The file may be inside an uncompressed ZIP file and is accessed via byte offset and length. @@ -155,6 +176,7 @@ def get_features_or_waveform(path: str, need_waveform=False): path (str): File path in the format of "<.npy/.wav/.flac path>" or "<zip path>:<byte offset>:<byte length>". need_waveform (bool): return waveform instead of features. + use_sample_rate (int): change sample rate for the input wave file Returns: features_or_waveform (numpy.ndarray): speech features or waveform. @@ -162,11 +184,17 @@ def get_features_or_waveform(path: str, need_waveform=False): _path, slice_ptr = parse_path(path) if len(slice_ptr) == 0: if need_waveform: - return get_waveform(_path, always_2d=False) + return get_waveform( + _path, always_2d=False, output_sample_rate=use_sample_rate + )[0] return get_features_from_npy_or_audio(_path) elif len(slice_ptr) == 2: features_or_waveform = get_features_or_waveform_from_stored_zip( - _path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform + _path, + slice_ptr[0], + slice_ptr[1], + need_waveform=need_waveform, + use_sample_rate=use_sample_rate, ) else: raise ValueError(f"Invalid path: {path}") @@ -195,6 +223,12 @@ def _collate_frames( return out +class SpeechToTextDatasetItem(NamedTuple): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + + class SpeechToTextDataset(FairseqDataset): LANG_TAG_TEMPLATE = "<lang:{}>" @@ -202,7 +236,7 @@ def __init__( self, split: str, is_train_split: bool, - data_cfg: S2TDataConfig, + cfg: S2TDataConfig, audio_paths: List[str], n_frames: List[int], src_texts: Optional[List[str]] = None, @@ -216,7 +250,7 @@ def __init__( bpe_tokenizer=None, ): self.split, self.is_train_split = split, is_train_split - self.data_cfg = data_cfg + self.cfg = cfg self.audio_paths, self.n_frames = audio_paths, n_frames self.n_samples = len(audio_paths) assert len(n_frames) == self.n_samples > 0 @@ -234,22 +268,42 @@ def __init__( self.tgt_dict = tgt_dict self.check_tgt_lang_tag() self.ids = ids - self.shuffle = data_cfg.shuffle if is_train_split else False + self.shuffle = cfg.shuffle if is_train_split else False self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( - self.data_cfg.get_feature_transforms(split, is_train_split) + self.cfg.get_feature_transforms(split, is_train_split) ) self.pre_tokenizer = pre_tokenizer self.bpe_tokenizer = bpe_tokenizer + self.tgt_lens = self.get_tgt_lens_and_check_oov() + logger.info(self.__repr__()) + def get_tgt_lens_and_check_oov(self): + if self.tgt_texts is None: + return [0 for _ in range(self.n_samples)] + tgt_lens = [] + n_tokens, n_oov_tokens = 0, 0 + for i in range(self.n_samples): + tokenized = self.get_tokenized_tgt_text(i).split(" ") + oov_tokens = [ + t + for t in tokenized + if self.tgt_dict.index(t) == self.tgt_dict.unk_index + ] + n_tokens += len(tokenized) + n_oov_tokens += len(oov_tokens) + tgt_lens.append(len(tokenized)) + logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") + return tgt_lens + def __repr__(self): return ( self.__class__.__name__ + f'(split="{self.split}", n_samples={self.n_samples}, ' - f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, " + f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " f"shuffle={self.shuffle}, transforms={self.feature_transforms})" ) @@ -259,55 +313,65 @@ def is_lang_tag(cls, token): return re.match(pattern, token) def check_tgt_lang_tag(self): - if self.data_cfg.prepend_tgt_lang_tag: + if self.cfg.prepend_tgt_lang_tag: assert self.tgt_langs is not None and self.tgt_dict is not None tgt_lang_tags = [ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) ] assert all(t in self.tgt_dict for t in tgt_lang_tags) - def tokenize_text(self, text: str): - if self.pre_tokenizer is not None: - text = self.pre_tokenizer.encode(text) - if self.bpe_tokenizer is not None: - text = self.bpe_tokenizer.encode(text) + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: int): + text = self.tokenize(self.pre_tokenizer, self.tgt_texts[index]) + text = self.tokenize(self.bpe_tokenizer, text) return text - def __getitem__( - self, index: int - ) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]: + @classmethod + def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): + lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) + assert lang_tag_idx != dictionary.unk() + return lang_tag_idx + + def __getitem__(self, index: int) -> SpeechToTextDatasetItem: source = get_features_or_waveform( - self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input + self.audio_paths[index], + need_waveform=self.cfg.use_audio_input, + use_sample_rate=self.cfg.use_sample_rate, ) if self.feature_transforms is not None: - assert not self.data_cfg.use_audio_input + assert not self.cfg.use_audio_input source = self.feature_transforms(source) source = torch.from_numpy(source).float() target = None if self.tgt_texts is not None: - tokenized = self.tokenize_text(self.tgt_texts[index]) + tokenized = self.get_tokenized_tgt_text(index) target = self.tgt_dict.encode_line( tokenized, add_if_not_exist=False, append_eos=True ).long() - if self.data_cfg.prepend_tgt_lang_tag: - lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index]) - lang_tag_idx = self.tgt_dict.index(lang_tag) + if self.cfg.prepend_tgt_lang_tag: + lang_tag_idx = self.get_lang_tag_idx( + self.tgt_langs[index], self.tgt_dict + ) target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) - return index, source, target + + return SpeechToTextDatasetItem(index=index, source=source, target=target) def __len__(self): return self.n_samples - def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict: + def collater( + self, samples: List[SpeechToTextDatasetItem], return_order: bool = False + ) -> Dict: if len(samples) == 0: return {} - indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long) - frames = _collate_frames( - [s for _, s, _ in samples], self.data_cfg.use_audio_input - ) + indices = torch.tensor([x.index for x in samples], dtype=torch.long) + frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input) # sort samples by descending number of frames - n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long) + n_frames = torch.tensor([x.source.size()[0] for x in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) @@ -317,7 +381,7 @@ def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dic ntokens = None if self.tgt_texts is not None: target = fairseq_data_utils.collate_tokens( - [t for _, _, t in samples], + [x.target for x in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, @@ -325,41 +389,40 @@ def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dic ) target = target.index_select(0, order) target_lengths = torch.tensor( - [t.size(0) for _, _, t in samples], dtype=torch.long + [x.target.size()[0] for x in samples], dtype=torch.long ).index_select(0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( - [t for _, _, t in samples], + [x.target for x in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) - ntokens = sum(t.size(0) for _, _, t in samples) + ntokens = sum(x.target.size()[0] for x in samples) + net_input = { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + } out = { "id": indices, - "net_input": { - "src_tokens": frames, - "src_lengths": n_frames, - "prev_output_tokens": prev_output_tokens, - }, + "net_input": net_input, "target": target, "target_lengths": target_lengths, "ntokens": ntokens, "nsentences": len(samples), } + if return_order: + out["order"] = order return out def num_tokens(self, index): return self.n_frames[index] def size(self, index): - t_len = 0 - if self.tgt_texts is not None: - tokenized = self.tokenize_text(self.tgt_texts[index]) - t_len = len(tokenized.split(" ")) - return self.n_frames[index], t_len + return self.n_frames[index], self.tgt_lens[index] @property def sizes(self): @@ -397,67 +460,111 @@ def _from_list( cls, split_name: str, is_train_split, - samples: List[List[Dict]], - data_cfg: S2TDataConfig, + samples: List[Dict], + cfg: S2TDataConfig, tgt_dict, pre_tokenizer, bpe_tokenizer, ) -> SpeechToTextDataset: - audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], [] - speakers, src_langs, tgt_langs = [], [], [] - for s in samples: - ids.extend([ss[cls.KEY_ID] for ss in s]) - audio_paths.extend( - [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s] - ) - n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s]) - tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s]) - src_texts.extend( - [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] - ) - speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]) - src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]) - tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]) + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] return SpeechToTextDataset( split_name, is_train_split, - data_cfg, + cfg, audio_paths, n_frames, - src_texts, - tgt_texts, - speakers, - src_langs, - tgt_langs, - ids, - tgt_dict, - pre_tokenizer, - bpe_tokenizer, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, ) @classmethod - def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0): + def get_size_ratios( + cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 + ) -> List[float]: """Size ratios for temperature-based sampling (https://arxiv.org/abs/1907.05019)""" - _sizes = np.array(sizes) - prob = _sizes / _sizes.sum() - smoothed_prob = prob ** alpha - smoothed_prob = smoothed_prob / smoothed_prob.sum() - size_ratio = (smoothed_prob * _sizes.sum()) / _sizes - - o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)}) - logger.info(f"original sampling probability: {o_str}") - p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)}) - logger.info(f"balanced sampling probability: {p_str}") - sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)}) - logger.info(f"balanced sampling size ratio: {sr_str}") - return size_ratio.tolist() + + id_to_lp, lp_to_sz = {}, defaultdict(int) + for ds in datasets: + lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} + assert len(lang_pairs) == 1 + lang_pair = list(lang_pairs)[0] + id_to_lp[ds.split] = lang_pair + lp_to_sz[lang_pair] += sum(ds.n_frames) + + sz_sum = sum(v for v in lp_to_sz.values()) + lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} + lp_to_tgt_prob = {k: v ** alpha for k, v in lp_to_prob.items()} + prob_sum = sum(v for v in lp_to_tgt_prob.values()) + lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} + lp_to_sz_ratio = { + k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() + } + size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] + + p_formatted = { + k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz + } + logger.info(f"sampling probability balancing: {p_formatted}") + sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} + logger.info(f"balanced sampling size ratio: {sr_formatted}") + return size_ratio + + @classmethod + def _load_samples_from_tsv(cls, root: str, split: str): + tsv_path = Path(root) / f"{split}.tsv" + if not tsv_path.is_file(): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + samples = [dict(e) for e in reader] + if len(samples) == 0: + raise ValueError(f"Empty manifest: {tsv_path}") + return samples + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + split: str, + tgt_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + ) -> SpeechToTextDataset: + samples = cls._load_samples_from_tsv(root, split) + return cls._from_list( + split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, bpe_tokenizer + ) @classmethod def from_tsv( cls, root: str, - data_cfg: S2TDataConfig, + cfg: S2TDataConfig, splits: str, tgt_dict, pre_tokenizer, @@ -466,46 +573,21 @@ def from_tsv( epoch: int, seed: int, ) -> SpeechToTextDataset: - samples = [] - _splits = splits.split(",") - for split in _splits: - tsv_path = op.join(root, f"{split}.tsv") - if not op.isfile(tsv_path): - raise FileNotFoundError(f"Dataset not found: {tsv_path}") - with open(tsv_path) as f: - reader = csv.DictReader( - f, - delimiter="\t", - quotechar=None, - doublequote=False, - lineterminator="\n", - quoting=csv.QUOTE_NONE, - ) - samples.append([dict(e) for e in reader]) - assert len(samples) > 0 - datasets = [ - cls._from_list( - name, - is_train_split, - [s], - data_cfg, - tgt_dict, - pre_tokenizer, - bpe_tokenizer, + cls._from_tsv( + root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer ) - for name, s in zip(_splits, samples) + for split in splits.split(",") ] - if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0: + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling - size_ratios = cls._get_size_ratios( - _splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha - ) + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset( d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) ) - for d, r in zip(datasets, size_ratios) + for r, d in zip(size_ratios, datasets) ] - return ConcatDataset(datasets) + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index 07ebdd5f38..e21144a88e 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -49,6 +49,8 @@ def __len__(self): def collater(self, samples, **extra_args): samples = self.dataset.collater(samples, **extra_args) + if len(samples) == 0: + return samples if 'net_input' not in samples: return samples diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 5c935efaf5..aff9d0ffc7 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -308,7 +308,7 @@ def __init__(self, args): else: self.layer_norm = None - def _forward(self, src_tokens, src_lengths): + def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): x, input_lengths = self.subsample(src_tokens, src_lengths) x = self.embed_scale * x @@ -317,8 +317,12 @@ def _forward(self, src_tokens, src_lengths): x += positions x = self.dropout_module(x) + encoder_states = [] + for layer in self.transformer_layers: x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) @@ -327,17 +331,19 @@ def _forward(self, src_tokens, src_lengths): "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T "encoder_embedding": [], # B x T x C - "encoder_states": [], # List[T x B x C] + "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], } - def forward(self, src_tokens, src_lengths): + def forward(self, src_tokens, src_lengths, return_all_hiddens=False): if self.num_updates < self.encoder_freezing_updates: with torch.no_grad(): - x = self._forward(src_tokens, src_lengths) + x = self._forward(src_tokens, src_lengths, + return_all_hiddens=return_all_hiddens) else: - x = self._forward(src_tokens, src_lengths) + x = self._forward(src_tokens, src_lengths, + return_all_hiddens=return_all_hiddens) return x def reorder_encoder_out(self, encoder_out, new_order): diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 8bdf215643..5795c04bf7 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -import os.path as op +from pathlib import Path from argparse import Namespace from fairseq.data import Dictionary, encoders @@ -22,8 +22,8 @@ @register_task("speech_to_text") class SpeechToTextTask(LegacyFairseqTask): - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): parser.add_argument("data", help="manifest root path") parser.add_argument( "--config-yaml", @@ -49,15 +49,15 @@ def add_args(parser): def __init__(self, args, tgt_dict): super().__init__(args) self.tgt_dict = tgt_dict - self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) + self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) @classmethod def setup_task(cls, args, **kwargs): - data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) - dict_path = op.join(args.data, data_cfg.vocab_filename) - if not op.isfile(dict_path): - raise FileNotFoundError(f"Dict not found: {dict_path}") - tgt_dict = Dictionary.load(dict_path) + data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + dict_path = Path(args.data) / data_cfg.vocab_filename + if not dict_path.is_file(): + raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") + tgt_dict = Dictionary.load(dict_path.as_posix()) logger.info( f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" ) @@ -126,7 +126,10 @@ def build_generator( for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } - extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} + else: + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids return super().build_generator( models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) From 75051ecf26239be1b082101b9d1fe8886e734b45 Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Tue, 27 Jul 2021 13:24:45 -0700 Subject: [PATCH 667/707] wav2vec2 speech translation OSS Summary: wav2vec2 speech translation OSS - Based on https://github.com/fairinternal/fairseq-py/pull/1829 - Updated `Wav2VecEncoder` API to make it consistent for `Wav2VecCTC` (for ASR) and `Wav2Vec2Seq2Seq` (for ST) - Small fixes in `Wav2Vec2Seq2Seq` - Refactored `audio_pretraining` into `audio_pretraining` and `audio_finetuning` Reviewed By: sravyapopuri388, cndn Differential Revision: D29285182 fbshipit-source-id: 89f93b42caa88079940a4b2cac0f8952547d3ff0 --- examples/wav2vec/README.md | 2 +- .../wav2vec/config/finetuning/base_100h.yaml | 3 +- .../wav2vec/config/finetuning/base_10h.yaml | 3 +- .../wav2vec/config/finetuning/base_10m.yaml | 3 +- .../wav2vec/config/finetuning/base_1h.yaml | 3 +- .../wav2vec/config/finetuning/base_960h.yaml | 3 +- .../wav2vec/config/finetuning/vox_100h.yaml | 3 +- .../wav2vec/config/finetuning/vox_10h.yaml | 3 +- .../wav2vec/config/finetuning/vox_10m.yaml | 3 +- .../wav2vec/config/finetuning/vox_1h.yaml | 3 +- .../wav2vec/config/finetuning/vox_960h.yaml | 3 +- .../config/finetuning/w2v_finetune.yaml | 2 +- fairseq/data/add_target_dataset.py | 6 + fairseq/models/wav2vec/wav2vec2_asr.py | 77 ++-- fairseq/tasks/audio_finetuning.py | 346 ++++++++++++++++++ fairseq/tasks/audio_pretraining.py | 212 +---------- fairseq_cli/generate.py | 6 + 17 files changed, 419 insertions(+), 262 deletions(-) create mode 100644 fairseq/tasks/audio_finetuning.py diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index badfda8979..2d6717dc04 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -143,7 +143,7 @@ Next, run the evaluation command: ```shell script $subset=dev_other -python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ +python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_finetuning \ --nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ --lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ --post-process letter diff --git a/examples/wav2vec/config/finetuning/base_100h.yaml b/examples/wav2vec/config/finetuning/base_100h.yaml index 539dabb047..153b5df170 100644 --- a/examples/wav2vec/config/finetuning/base_100h.yaml +++ b/examples/wav2vec/config/finetuning/base_100h.yaml @@ -10,7 +10,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: false labels: ltr @@ -56,4 +56,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 0 - diff --git a/examples/wav2vec/config/finetuning/base_10h.yaml b/examples/wav2vec/config/finetuning/base_10h.yaml index 16a3c4d96c..5044518025 100644 --- a/examples/wav2vec/config/finetuning/base_10h.yaml +++ b/examples/wav2vec/config/finetuning/base_10h.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: false labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/base_10m.yaml b/examples/wav2vec/config/finetuning/base_10m.yaml index 3ceb77a252..14abc013bd 100644 --- a/examples/wav2vec/config/finetuning/base_10m.yaml +++ b/examples/wav2vec/config/finetuning/base_10m.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: false labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/base_1h.yaml b/examples/wav2vec/config/finetuning/base_1h.yaml index 3ceb77a252..14abc013bd 100644 --- a/examples/wav2vec/config/finetuning/base_1h.yaml +++ b/examples/wav2vec/config/finetuning/base_1h.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: false labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml index 2d38211e91..3eadc36b37 100644 --- a/examples/wav2vec/config/finetuning/base_960h.yaml +++ b/examples/wav2vec/config/finetuning/base_960h.yaml @@ -10,7 +10,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: false labels: ltr @@ -55,4 +55,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 0 - diff --git a/examples/wav2vec/config/finetuning/vox_100h.yaml b/examples/wav2vec/config/finetuning/vox_100h.yaml index 2fdb0c568c..b8f81e5e18 100644 --- a/examples/wav2vec/config/finetuning/vox_100h.yaml +++ b/examples/wav2vec/config/finetuning/vox_100h.yaml @@ -10,7 +10,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr @@ -56,4 +56,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/vox_10h.yaml b/examples/wav2vec/config/finetuning/vox_10h.yaml index f1a979e05d..8f1ca71ee2 100644 --- a/examples/wav2vec/config/finetuning/vox_10h.yaml +++ b/examples/wav2vec/config/finetuning/vox_10h.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/vox_10m.yaml b/examples/wav2vec/config/finetuning/vox_10m.yaml index d12439bb28..07e327fe74 100644 --- a/examples/wav2vec/config/finetuning/vox_10m.yaml +++ b/examples/wav2vec/config/finetuning/vox_10m.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/vox_1h.yaml b/examples/wav2vec/config/finetuning/vox_1h.yaml index 7f3b04c034..fac1bbb32f 100644 --- a/examples/wav2vec/config/finetuning/vox_1h.yaml +++ b/examples/wav2vec/config/finetuning/vox_1h.yaml @@ -13,7 +13,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr @@ -61,4 +61,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/config/finetuning/vox_960h.yaml b/examples/wav2vec/config/finetuning/vox_960h.yaml index 0633915bb2..9d72404fa3 100644 --- a/examples/wav2vec/config/finetuning/vox_960h.yaml +++ b/examples/wav2vec/config/finetuning/vox_960h.yaml @@ -10,7 +10,7 @@ checkpoint: best_checkpoint_metric: wer task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr @@ -55,4 +55,3 @@ model: activation_dropout: 0.1 feature_grad_mult: 0.0 freeze_finetune_updates: 10000 - diff --git a/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml b/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml index e94da2ba4e..19a3ef3484 100644 --- a/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml +++ b/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml @@ -11,7 +11,7 @@ checkpoint: save_interval_updates: 20000 task: - _name: audio_pretraining + _name: audio_finetuning data: ??? normalize: true labels: ltr diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py index 673963d0ed..d8a08e746d 100644 --- a/fairseq/data/add_target_dataset.py +++ b/fairseq/data/add_target_dataset.py @@ -71,3 +71,9 @@ def collater(self, samples): ).long() collated["ntokens"] += target.size(0) return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 04307e8771..eb5d819da5 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -65,6 +65,19 @@ class Wav2Vec2AsrConfig(FairseqDataclass): "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" }, ) + conv_feature_layers: Optional[str] = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": ( + "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + ), + }, + ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) # masking apply_mask: bool = field( @@ -92,6 +105,10 @@ class Wav2Vec2AsrConfig(FairseqDataclass): no_mask_overlap: bool = field( default=False, metadata={"help": "whether to allow masks to overlap"} ) + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) # channel masking mask_channel_length: int = field( @@ -123,6 +140,10 @@ class Wav2Vec2AsrConfig(FairseqDataclass): layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} ) + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) mask_channel_before: bool = False normalize: bool = II("task.normalize") data: str = II("task.data") @@ -134,27 +155,6 @@ class Wav2Vec2AsrConfig(FairseqDataclass): class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): blank_weight: float = 0 blank_mode: str = "add" - mask_min_space: Optional[int] = field( - default=1, - metadata={"help": "min space between spans (if no overlap is enabled)"}, - ) - mask_channel_min_space: Optional[int] = field( - default=1, - metadata={"help": "min space between spans (if no overlap is enabled)"}, - ) - conv_feature_layers: Optional[str] = field( - default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", - metadata={ - "help": ( - "string describing convolutional feature extraction " - "layers in form of a python list that contains " - "[(dim, kernel_size, stride), ...]" - ), - }, - ) - encoder_embed_dim: Optional[int] = field( - default=768, metadata={"help": "encoder embedding dimension"} - ) @register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) @@ -299,7 +299,7 @@ def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens): return TransformerDecoder(cfg, tgt_dict, embed_tokens) def forward(self, **kwargs): - encoder_out = self.encoder(tbc=False, **kwargs) + encoder_out = self.encoder(**kwargs) decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) return decoder_out @@ -386,7 +386,8 @@ def set_num_updates(self, num_updates): super().set_num_updates(num_updates) self.num_updates = num_updates - def forward(self, source, padding_mask, tbc=True, **kwargs): + def forward(self, source, padding_mask, **kwargs): + w2v_args = { "source": source, "padding_mask": padding_mask, @@ -401,9 +402,8 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): x = res["x"] padding_mask = res["padding_mask"] - if tbc: - # BTC -> TBC - x = x.transpose(0, 1) + # B x T x C -> T x B x C + x = x.transpose(0, 1) x = self.final_dropout(x) @@ -412,21 +412,24 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): return { "encoder_out": x, # T x B x C - "encoder_padding_mask": padding_mask.transpose(0, 1) - if padding_mask is not None - else None, # T x B - "padding_mask": padding_mask, + "padding_mask": padding_mask, # B x T, "layer_results": res["layer_results"], } + def forward_torchscript(self, net_input): + if torch.jit.is_scripting(): + return self.forward(net_input["source"], net_input["padding_mask"]) + else: + return self.forward_non_torchscript(net_input) + def reorder_encoder_out(self, encoder_out, new_order): if encoder_out["encoder_out"] is not None: encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( 1, new_order ) - if encoder_out["encoder_padding_mask"] is not None: - encoder_out["encoder_padding_mask"] = encoder_out[ - "encoder_padding_mask" + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out[ + "padding_mask" ].index_select(0, new_order) return encoder_out @@ -469,7 +472,7 @@ def __init__( self.layerdrop = cfg.decoder_layerdrop - padding_idx = embed_tokens.padding_idx + self.padding_idx = embed_tokens.padding_idx self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens @@ -485,7 +488,7 @@ def __init__( PositionalEmbedding( cfg.max_target_positions, embed_dim, - padding_idx, + self.padding_idx, learned=cfg.decoder_learned_pos, ) if not cfg.no_token_positional_embeddings @@ -589,6 +592,9 @@ def extract_features( inner_states = [x] # decoder layers + self_attn_padding_mask = None + if prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) for layer in self.layers: dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): @@ -600,6 +606,7 @@ def extract_features( self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, + self_attn_padding_mask=self_attn_padding_mask ) inner_states.append(x) diff --git a/fairseq/tasks/audio_finetuning.py b/fairseq/tasks/audio_finetuning.py new file mode 100644 index 0000000000..4ef87c604f --- /dev/null +++ b/fairseq/tasks/audio_finetuning.py @@ -0,0 +1,346 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class AudioFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + } + ) + eval_bleu_detok_args: str = field( + default="{}", + metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, + metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={"help": "generation args for BLUE scoring, e.g., " + "'{\"beam\": 4, \"lenpen\": 0.6}'"} + ) + eval_bleu_print_samples: bool = field( + default=False, + metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + + +@register_task("audio_finetuning", dataclass=AudioFinetuningConfig) +class AudioFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: AudioFinetuningConfig + + def __init__( + self, + cfg: AudioFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "<s>" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None + + def load_dataset(self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level + ) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output['_bleu_sys_len'] = metrics.sys_len + logging_output['_bleu_ref_len'] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) + + if self.cfg.eval_wer and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + '--eval-bleu-detok is required if using --eval-bleu; ' + 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' + 'to disable detokenization, e.g., when using sentencepiece)' + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is `<unk>`, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=( + "UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP" + ), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False)) + refs.append( + decode( + utils.strip_pad( + sample['target'][i], + self.target_dictionary.pad() + ), + is_ref=True, # don't count <unk> as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info('H-{} {}'.format(sample["id"][0], hyps[0])) + logger.info('T-{} {}'.format(sample["id"][0], refs[0])) + + eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a' + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar( + k, sum(log.get(k, 0) for log in logging_outputs) + ) + + import sacrebleu + metrics.log_derived( + 'bleu', + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters['_bleu_sys_len'].sum, + ref_len=meters['_bleu_ref_len'].sum, + smooth_method="exp" + ).score + ) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index b7d0f3da57..c99c6bf7d1 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -8,46 +8,22 @@ import logging import os import sys -import torch from argparse import Namespace from dataclasses import dataclass, field -from typing import Optional, Any +from typing import Optional from omegaconf import MISSING, II, OmegaConf -from fairseq.data import ( - AddTargetDataset, - BinarizedAudioDataset, - Dictionary, - FileAudioDataset, - encoders, -) +from fairseq.data import BinarizedAudioDataset, FileAudioDataset from fairseq.dataclass import FairseqDataclass, ChoiceEnum -from fairseq.dataclass.configs import GenerationConfig -from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel +from fairseq.data.text_compressor import TextCompressionLevel from . import FairseqTask, register_task -from .. import utils -from ..logging import metrics logger = logging.getLogger(__name__) -class LabelEncoder(object): - def __init__(self, dictionary): - self.dictionary = dictionary - - def __call__(self, label): - return self.dictionary.encode_line( - label, append_eos=False, add_if_not_exist=False - ) - - -def label_len_fn(label): - return len(label.split(" ")) - - @dataclass class InferredW2vConfig: # The following are needed to precompute mask and mask channel indices @@ -74,7 +50,8 @@ class AudioPretrainingConfig(FairseqDataclass): data: str = field(default=MISSING, metadata={"help": "path to data directory"}) labels: Optional[str] = field( default=None, - metadata={"help": "extension of the label file to load, used for fine-tuning"}, + metadata={ + "help": "extension of the label file to load, used for fine-tuning"}, ) binarized_dataset: bool = field( default=False, @@ -102,33 +79,6 @@ class AudioPretrainingConfig(FairseqDataclass): min_sample_size: Optional[int] = field( default=None, metadata={"help": "min sample size to skip small examples"} ) - - # Options for reporting WER metrics during validation. Only applicable to - # Seq2Seq models during fine-tuning - eval_wer: bool = field( - default=False, metadata={"help": "compute WER for Seq2Seq models"} - ) - eval_wer_config: GenerationConfig = field( - default_factory=lambda: GenerationConfig(), - metadata={"help": "beam search config for evaluating wer during training"}, - ) - eval_wer_tokenizer: Any = field( - default=None, - metadata={"help": "tokenizer config for evaluating wer during training"}, - ) - eval_wer_post_process: str = field( - default="letter", - metadata={ - "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" - }, - ) - autoregressive: bool = field( - default=False, - metadata={ - "help": "required for autoregressive decoders (like seq2seq models); " - "adds 'prev_output_tokens' to input and appends eos to target" - }, - ) num_batch_buckets: int = field( default=0, metadata={"help": "number of buckets"}, @@ -163,17 +113,6 @@ class AudioPretrainingTask(FairseqTask): cfg: AudioPretrainingConfig - def __init__( - self, - cfg: AudioPretrainingConfig, - ): - super().__init__(cfg) - if cfg.eval_wer: - assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" - self.blank_symbol = "<s>" - - self.state.add_factory("target_dictionary", self.load_target_dictionary) - @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): """Setup the task (e.g., load dictionaries). @@ -184,12 +123,6 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): return cls(cfg) - def load_target_dictionary(self): - if self.cfg.labels: - dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") - return Dictionary.load(dict_path) - return None - def _get_mask_precompute_kwargs(self, cfg): if self.cfg.precompute_mask_indices or self.cfg.tpu: assert ( @@ -249,155 +182,24 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): "0. You may want to set this to a low value close to 0." ) - if task_cfg.labels: - label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") - skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) - text_compressor = TextCompressor(level=text_compression_level) - with open(label_path, "r") as f: - labels = [ - text_compressor.compress(l) - for i, l in enumerate(f) if i not in skipped_indices - ] - - assert len(labels) == len(self.datasets[split]), ( - f"labels length ({len(labels)}) and dataset length " - f"({len(self.datasets[split])}) do not match" - ) - - process_label = LabelEncoder(self.target_dictionary) - - self.datasets[split] = AddTargetDataset( - self.datasets[split], - labels, - pad=self.target_dictionary.pad(), - eos=self.target_dictionary.eos(), - batch_targets=True, - process_label=process_label, - label_len_fn=label_len_fn, - add_to_input=task_cfg.get("autoregressive", False), - text_compression_level=text_compression_level - ) - @property def source_dictionary(self): return None @property def target_dictionary(self): - """Return the :class:`~fairseq.data.Dictionary` for the language - model.""" - return self.state.target_dictionary + return None def max_positions(self): """Maximum input length supported by the encoder.""" - return (sys.maxsize, sys.maxsize) - - def filter_indices_by_size( - self, - indices, - dataset, - max_positions=None, - ignore_invalid_inputs=False, - ): - # we do not need to filter by size in this task as dataloaders take care of this - return indices - - def valid_step(self, sample, model, criterion): - loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - if self.cfg.eval_wer and self.cfg.autoregressive: - metrics = self._inference_with_wer(self.sequence_generator, sample, model) - logging_output["_num_char_errors"] = metrics["num_char_errors"] - logging_output["_num_chars"] = metrics["num_chars"] - logging_output["_num_word_errors"] = metrics["num_word_errors"] - logging_output["_num_words"] = metrics["num_words"] - return loss, sample_size, logging_output + return sys.maxsize, sys.maxsize def build_model(self, model_cfg: FairseqDataclass): model = super().build_model(model_cfg) - if self.cfg.eval_wer and self.cfg.autoregressive: - self.sequence_generator = self.build_generator( - [model], - self.cfg.eval_wer_config, - ) - if self.cfg.eval_wer_tokenizer: - self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) - else: - self.tokenizer = None - actualized_cfg = getattr(model, "cfg", None) if actualized_cfg is not None: if "w2v_args" in actualized_cfg: model_cfg.w2v_args = actualized_cfg.w2v_args return model - - def _inference_with_wer(self, generator, sample, model): - import editdistance - - def decode(toks): - s = self.target_dictionary.string( - toks.int().cpu(), - self.cfg.eval_wer_post_process, - escape_unk=True, - ) - if self.tokenizer: - s = self.tokenizer.decode(s) - return s - - num_word_errors, num_char_errors = 0, 0 - num_chars, num_words = 0, 0 - gen_out = self.inference_step(generator, [model], sample, None) - for i in range(len(gen_out)): - hyp = decode(gen_out[i][0]["tokens"]) - ref = decode( - utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), - ) - num_char_errors += editdistance.eval(hyp, ref) - num_chars += len(ref) - hyp_words = hyp.split() - ref_words = ref.split() - num_word_errors += editdistance.eval(hyp_words, ref_words) - num_words += len(ref_words) - - return { - "num_char_errors": num_char_errors, - "num_chars": num_chars, - "num_word_errors": num_word_errors, - "num_words": num_words, - } - - def reduce_metrics(self, logging_outputs, criterion): - super().reduce_metrics(logging_outputs, criterion) - - zero = torch.scalar_tensor(0.0) - num_char_errors = sum( - log.get("_num_char_errors", zero) for log in logging_outputs - ) - num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) - num_word_errors = sum( - log.get("_num_word_errors", zero) for log in logging_outputs - ) - num_words = sum(log.get("_num_words", zero) for log in logging_outputs) - metrics.log_scalar("_num_char_errors", num_char_errors) - metrics.log_scalar("_num_chars", num_chars) - metrics.log_scalar("_num_word_errors", num_word_errors) - metrics.log_scalar("_num_words", num_words) - if num_chars > 0: - metrics.log_derived( - "uer", - lambda meters: meters["_num_char_errors"].sum - * 100.0 - / meters["_num_chars"].sum - if meters["_num_chars"].sum > 0 - else float("nan"), - ) - if num_words > 0: - metrics.log_derived( - "wer", - lambda meters: meters["_num_word_errors"].sum - * 100.0 - / meters["_num_words"].sum - if meters["_num_words"].sum > 0 - else float("nan"), - ) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 7bd582b256..c9ea52493d 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -400,6 +400,12 @@ def decode_fn(x): def cli_main(): parser = options.get_generation_parser() + # TODO: replace this workaround with refactoring of `AudioPretraining` + parser.add_argument( + '--arch', '-a', metavar='ARCH', default="transformer", + help='Model architecture. For constructing tasks that rely on ' + 'model args (e.g. `AudioPretraining`)' + ) args = options.parse_args_and_arch(parser) main(args) From 440def26c167af762ae8e8fc64716e1b88f3968b Mon Sep 17 00:00:00 2001 From: Edan Tessel Sneh <edan@fb.com> Date: Wed, 28 Jul 2021 14:58:47 -0700 Subject: [PATCH 668/707] added more descriptive comment to urgent change Summary: Added description of problem and task associated with it in code Reviewed By: dianaml0 Differential Revision: D29942196 fbshipit-source-id: 037d92af720f52bdd51f4efc5b9d49a6465cdd31 --- fairseq/models/transformer/transformer_base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py index 49fd30e502..810c9b98db 100644 --- a/fairseq/models/transformer/transformer_base.py +++ b/fairseq/models/transformer/transformer_base.py @@ -11,8 +11,12 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.distributed import fsdp_wrap from fairseq.models import FairseqEncoderDecoderModel +from fairseq.models.transformer import ( + TransformerEncoderBase, + TransformerDecoderBase, + TransformerConfig, +) from torch import Tensor -from fairseq.models.transformer import (TransformerEncoderBase, TransformerDecoderBase, TransformerConfig) class TransformerModelBase(FairseqEncoderDecoderModel): @@ -49,14 +53,16 @@ def add_args(parser): def build_model(cls, cfg, task): """Build a new model instance.""" - # hacky fixes for issue with II + # -- TODO T96535332 + # bug caused by interaction between OmegaConf II and argparsing cfg.decoder.input_dim = int(cfg.decoder.input_dim) cfg.decoder.output_dim = int(cfg.decoder.output_dim) + # -- if cfg.encoder.layers_to_keep: - cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(',')) + cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) if cfg.decoder.layers_to_keep: - cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(',')) + cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(",")) src_dict, tgt_dict = task.source_dictionary, task.target_dictionary From 1df3e50a2f3d93cd815d2a8730d6a3c8271af891 Mon Sep 17 00:00:00 2001 From: Henry Hu <henryhu6@fb.com> Date: Thu, 29 Jul 2021 12:03:01 -0700 Subject: [PATCH 669/707] Minor performance optimization for sequence generator Summary: Minor performance optimization for sequence generator finalize_hypos function. Reviewed By: myleott Differential Revision: D29784509 fbshipit-source-id: ac42dda75995cea3750c8d4ef07d8950a891c5f8 --- fairseq/sequence_generator.py | 53 +++++++++++++++-------------------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ac04dc7db8..d9c906ceea 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -654,44 +654,38 @@ def finalize_hypos( prev += 1 else: cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) - # The keys here are of the form "{sent}_{unfin_idx}", where + unfin_idx = bbsz_idx // beam_size + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where # "unfin_idx" is the index in the current (possibly reduced) # list of sentences, and "sent" is the index in the original, # unreduced batch - # set() is not supported in script export - sents_seen: Dict[str, Optional[Tensor]] = {} - # For every finished beam item - for i in range(bbsz_idx.size()[0]): - idx = bbsz_idx[i] - score = eos_scores[i] - # sentence index in the current (possibly reduced) batch - unfin_idx = idx // beam_size - # sentence index in the original (unreduced) batch - sent = unfin_idx + cum_unfin[unfin_idx] - # Cannot create dict for key type '(int, int)' in torchscript. - # The workaround is to cast int to string - seen = str(sent.item()) + "_" + str(unfin_idx.item()) - if seen not in sents_seen: - sents_seen[seen] = None - - if self.match_source_len and step > src_lengths[unfin_idx]: - score = torch.tensor(-math.inf).to(score) + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): # An input sentence (among those in a batch) is finished when # beam_size hypotheses have been collected for it - if len(finalized[sent]) < beam_size: + if len(finalized[sent_list[i]]) < beam_size: if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = torch.empty(0) - finalized[sent].append( + finalized[sent_list[i]].append( { "tokens": tokens_clone[i], - "score": score, + "score": eos_scores[i], "attention": hypo_attn, # src_len x tgt_len "alignment": torch.empty(0), "positional_scores": pos_scores[i], @@ -699,17 +693,16 @@ def finalize_hypos( ) newly_finished: List[int] = [] - - for seen in sents_seen.keys(): + for unique_s in unique_seen: # check termination conditions for this sentence - sent: int = int(float(seen.split("_")[0])) - unfin_idx: int = int(float(seen.split("_")[1])) + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) - if not finished[sent] and self.is_finished( - step, unfin_idx, max_len, len(finalized[sent]), beam_size + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size ): - finished[sent] = True - newly_finished.append(unfin_idx) + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) return newly_finished From f1b447075844e7cb912879786c787e30d413ebda Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Thu, 29 Jul 2021 15:59:22 -0700 Subject: [PATCH 670/707] use suffix when saving best checkpoints with metric Summary: This is needed when training with FSDP + sharded state, as the checkpoints should have `-shard0.pt` Reviewed By: sshleifer Differential Revision: D29947728 fbshipit-source-id: d2cb7c23e1e6d027d115cb734e2f2f1c34239eba --- fairseq/checkpoint_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 35cce7fda7..aba1e9f725 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -82,18 +82,21 @@ def is_better(a, b): worst_best = getattr(save_checkpoint, "best", None) chkpts = checkpoint_paths( cfg.save_dir, - pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - cfg.best_checkpoint_metric + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix ), ) if len(chkpts) > 0: p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] - worst_best = float(p.rsplit("_")[-1].replace(".pt", "")) + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) # add random digits to resolve ties rand_sfx = randint(0, cfg.keep_best_checkpoints) checkpoint_conds[ - "checkpoint.best_{}_{:.3f}{}.pt".format( - cfg.best_checkpoint_metric, val_loss, rand_sfx + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, + val_loss, + rand_sfx, + suffix ) ] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds[ From 3cf9053ea260c0de31beaa8948ee4acae2221ba1 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Thu, 29 Jul 2021 15:59:22 -0700 Subject: [PATCH 671/707] use pathmanager to delete old checkpoints Summary: --keep-best-checkpoints doesn't work for Manifold paths because the current code assumes normal paths. Lets use PathManager to properly remove them. Reviewed By: myleott, sshleifer Differential Revision: D29947965 fbshipit-source-id: 237a7b5aaa8293bb203ad05937decdf3b3ae2fc0 --- fairseq/checkpoint_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index aba1e9f725..de2b3eaecd 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -164,6 +164,8 @@ def is_better(a, b): for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric @@ -178,6 +180,8 @@ def is_better(a, b): for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): From 777bba3a41bb92aecc611037e49b1e42fee24836 Mon Sep 17 00:00:00 2001 From: Alex Xiao <axiao@fb.com> Date: Thu, 29 Jul 2021 15:59:22 -0700 Subject: [PATCH 672/707] seed random suffix in checkpoint to be consistent across shards Summary: Currently the random suffix for saving sharded checkpoints can be different for each shard when training with FSDP and use_sharded_state=True. This makes it difficult for downstream applications to load the checkpoints properly, since each shard may have different suffixes. This diff seeds the random suffix to be consistent across shards Reviewed By: zhengwy888 Differential Revision: D29951167 fbshipit-source-id: 65749357e62a28978f3b46b71767204a508e1f61 --- fairseq/checkpoint_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index de2b3eaecd..daabba4574 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -7,15 +7,16 @@ import collections import contextlib import logging +import numpy as np import os import re import time import traceback from collections import OrderedDict from typing import Any, Dict, Optional, Union -from random import randint import torch +from fairseq.data import data_utils from fairseq.dataclass.configs import CheckpointConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -90,7 +91,9 @@ def is_better(a, b): p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) # add random digits to resolve ties - rand_sfx = randint(0, cfg.keep_best_checkpoints) + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + checkpoint_conds[ "checkpoint.best_{}_{:.3f}{}{}.pt".format( cfg.best_checkpoint_metric, From 7a6706f5a3cd9995c9370ec7adcd2da454aecd97 Mon Sep 17 00:00:00 2001 From: Yun Tang <yuntang@fb.com> Date: Thu, 29 Jul 2021 17:43:20 -0700 Subject: [PATCH 673/707] Add speech/text joint training for speech to text task (step 2) Summary: Add scripts for speech/text joint training for the speech to text task. It includes scripts/recipes from the following papers "A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks", ICASSP 2021 "Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task", ACL 2021 "FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task", IWSLT 2021 Reviewed By: kahne Differential Revision: D29820444 fbshipit-source-id: 925eaedb69233e0a6f4c110045db63a6007a2b60 --- examples/speech_text_joint_to_text/README.md | 46 + .../speech_text_joint_to_text/__init__.py | 6 + .../configs/mustc_noise.list | 49 + .../criterions/__init__.py | 15 + .../text_guide_cross_entropy_acc.py | 223 ++++ .../docs/ende-mustc.md | 112 ++ .../docs/iwslt2021.md | 76 ++ .../models/__init__.py | 14 + .../models/s2t_dualinputtransformer.py | 1090 +++++++++++++++++ .../models/s2t_dualinputxmtransformer.py | 584 +++++++++ .../scripts/g2p_encode.py | 191 +++ .../tasks/__init__.py | 12 + .../tasks/speech_text_joint.py | 372 ++++++ fairseq/data/audio/multi_modality_dataset.py | 263 ++++ .../audio/speech_to_text_joint_dataset.py | 288 +++++ fairseq/data/iterators.py | 125 ++ fairseq/models/speech_to_text/__init__.py | 3 +- .../models/speech_to_text/xm_transformer.py | 504 ++++++++ 18 files changed, 3972 insertions(+), 1 deletion(-) create mode 100644 examples/speech_text_joint_to_text/README.md create mode 100644 examples/speech_text_joint_to_text/__init__.py create mode 100644 examples/speech_text_joint_to_text/configs/mustc_noise.list create mode 100644 examples/speech_text_joint_to_text/criterions/__init__.py create mode 100644 examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py create mode 100644 examples/speech_text_joint_to_text/docs/ende-mustc.md create mode 100644 examples/speech_text_joint_to_text/docs/iwslt2021.md create mode 100644 examples/speech_text_joint_to_text/models/__init__.py create mode 100644 examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py create mode 100644 examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py create mode 100644 examples/speech_text_joint_to_text/scripts/g2p_encode.py create mode 100644 examples/speech_text_joint_to_text/tasks/__init__.py create mode 100644 examples/speech_text_joint_to_text/tasks/speech_text_joint.py create mode 100644 fairseq/data/audio/multi_modality_dataset.py create mode 100644 fairseq/data/audio/speech_to_text_joint_dataset.py create mode 100644 fairseq/models/speech_to_text/xm_transformer.py diff --git a/examples/speech_text_joint_to_text/README.md b/examples/speech_text_joint_to_text/README.md new file mode 100644 index 0000000000..e071d241e0 --- /dev/null +++ b/examples/speech_text_joint_to_text/README.md @@ -0,0 +1,46 @@ +# Joint Speech Text training in Fairseq +An extension of Fairseq s2t project with the speech to text task enhanced by the co-trained text to text mapping task. More details about Fairseq s2t can be found [here](../speech_to_text/README.md) + +## Examples +Examples of speech text joint training in fairseq +- [English-to-German MuST-C model](docs/ende-mustc.md) +- [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md) + +## Citation +Please cite as: +``` +@inproceedings{Tang2021AGM, + title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks}, + author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel}, + booktitle={ICASSP}, + year={2021} +} + +@inproceedings{Tang2021IST, + title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task}, + author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel}, + booktitle = {ACL}, + year = {2021}, +} + +@inproceedings{Tang2021FST, + title = {FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task}, + author = {Yun Tang and Hongyu Gong and Xian Li and Changhan Wang and Juan Pino and Holger Schwenk and Naman Goyal}, + booktitle = {IWSLT}, + year = {2021}, +} + +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/examples/speech_text_joint_to_text/__init__.py b/examples/speech_text_joint_to_text/__init__.py new file mode 100644 index 0000000000..239d2e69f9 --- /dev/null +++ b/examples/speech_text_joint_to_text/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import tasks, criterions, models # noqa diff --git a/examples/speech_text_joint_to_text/configs/mustc_noise.list b/examples/speech_text_joint_to_text/configs/mustc_noise.list new file mode 100644 index 0000000000..02eeac4e00 --- /dev/null +++ b/examples/speech_text_joint_to_text/configs/mustc_noise.list @@ -0,0 +1,49 @@ +"(Applause) NOISE +"(Laughter) VOICE +"(Laughter)" VOICE +(Applause) NOISE +(Applause). NOISE +(Audience) VOICE +(Audio) NOISE +(Beat) NOISE +(Beatboxing) VOICE +(Beep) NOISE +(Beeps) NOISE +(Cheering) VOICE +(Cheers) VOICE +(Claps) NOISE +(Clicking) NOISE +(Clunk) NOISE +(Coughs) NOISE +(Drums) NOISE +(Explosion) NOISE +(Gasps) VOICE +(Guitar) NOISE +(Honk) NOISE +(Laugher) VOICE +(Laughing) VOICE +(Laughs) VOICE +(Laughter) VOICE +(Laughter). VOICE +(Laughter)... VOICE +(Mumbling) VOICE +(Music) NOISE +(Noise) NOISE +(Recording) VOICE +(Ringing) NOISE +(Shouts) VOICE +(Sigh) VOICE +(Sighs) VOICE +(Silence) NOISE +(Singing) VOICE +(Sings) VOICE +(Spanish) VOICE +(Static) NOISE +(Tones) NOISE +(Trumpet) NOISE +(Video) NOISE +(Video): NOISE +(Voice-over) NOISE +(Whistle) NOISE +(Whistling) NOISE +(video): NOISE diff --git a/examples/speech_text_joint_to_text/criterions/__init__.py b/examples/speech_text_joint_to_text/criterions/__init__.py new file mode 100644 index 0000000000..7faae73119 --- /dev/null +++ b/examples/speech_text_joint_to_text/criterions/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + criterion_name = file[: file.find(".py")] + importlib.import_module( + "examples.speech_text_joint_to_text.criterions." + criterion_name + ) diff --git a/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py b/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py new file mode 100644 index 0000000000..0d356e5a10 --- /dev/null +++ b/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math + +import torch +import torch.nn.functional as F +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss +from fairseq import metrics, utils + + +@register_criterion("guided_label_smoothed_cross_entropy_with_accuracy") +class GuidedCrossEntAccCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + guide_alpha, + text_input_cost_ratio, + label_smoothing, + disable_text_guide_update_num=0, + attentive_cost_regularization=0, + ): + """ + guide_alpha: alpha to inteplate nll and kd loss + text_input_cost_ratio: loss ratio for text only input data + label_smoothing: label smoothing ratio + disable_text_guide_update_num: only use nll loss for the first N updates + attentive_cost_regularization: ratio fo attentive cost + """ + super().__init__(task) + self.alpha = guide_alpha + self.attn_beta = attentive_cost_regularization + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.text_input_cost_ratio = text_input_cost_ratio + self.disable_update_num = disable_text_guide_update_num + assert self.alpha >= 0 and self.alpha <= 1.0 + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', + help='epsilon for label smoothing, 0 means no label smoothing') + # fmt: off + parser.add_argument('--guide-alpha', default=0., type=float, metavar='D', + help='alpha to merge kd cost from text to speech input with ce loss') + # fmt: off + parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D', + help='disable guided target from text for the first N updates.') + parser.add_argument("--attentive-cost-regularization", default=0.0, type=float, metavar='D', + help="use encoder attentive loss regularization with cost ratio D") + parser.add_argument("--attentive-cost-without-normalize", action='store_true', + help="Don't do normalization during attentive cost computation") + + def forward(self, model, sample, reduce=True): + reduction = 'sum' if reduce else 'none' + net_input = sample["net_input"] + net_output = model(**net_input) + attn_cost = None + lprobs = model.get_normalized_probs(net_output, log_probs=True) + is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False + target = model.get_targets(sample, net_output) + src_token_num = 0 + if is_dual_input: + # lprobs_spch from speech encoder and lprobs_text from text encoder + lprobs_spch, lprobs_text = torch.chunk(lprobs, 2) + lprobs_spch.batch_first = lprobs.batch_first + lprobs_text.batch_first = lprobs.batch_first + + speech_loss, speech_nll_loss, speech_correct, speech_total = \ + self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum')) + text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(model, lprobs_text, target, reduction=reduction) + loss = (speech_loss + text_loss) + nll_loss = (speech_nll_loss + text_nll_loss) + correct = speech_correct + text_correct + total = speech_total + text_total + + attn_cost = net_output[1].get('attn_cost') + if attn_cost is not None: + # attn_cost is batch_first and padding tokens have been masked already + src_token_num = attn_cost.ne(0).sum() + attn_cost = attn_cost.sum() + loss = loss + attn_cost * self.attn_beta + else: + attn_cost = 0 + else: + loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction) + if sample["net_input"]['src_tokens'] is None: # text input only + loss = loss * self.text_input_cost_ratio + speech_loss = None + speech_nll_loss = None + + sample_size, logging_output = self.get_logging_output( + sample, loss, nll_loss, correct, total, src_token_num, speech_loss, speech_nll_loss, attn_cost, is_dual_input + ) + return loss, sample_size, logging_output + + def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'): + if not lprobs.batch_first: + lprobs = lprobs.transpose(0, 1) + lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C + target = target.view(-1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'), + ) + + mask = target.ne(self.padding_idx) + correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))) + total = torch.sum(mask) + return loss, nll_loss, correct, total + + def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True): + """ lprobs_teacher is used as guide for lprobs """ + if self.alpha == 0.0 or model.num_updates < self.disable_update_num: + return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none')) + if not lprobs.batch_first: + lprobs = lprobs.transpose(0, 1) + lprobs_teacher = lprobs_teacher.transpose(0, 1) + + lprobs = lprobs.view(-1, lprobs.size(-1)).float() # -> (B x T) x C + lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float() # -> (B x T) x C + target = target.view(-1) + loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none') + nll_loss = loss + probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0) + probs_teacher = probs_teacher.detach() + guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True) + loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss + + mask = target.ne(self.padding_idx) + correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))) + total = torch.sum(mask) + return loss, nll_loss, correct, total + + def get_logging_output( + self, + sample, + loss, + nll_loss, + correct, + total, + src_token_num=0, + speech_loss=None, + speech_nll_loss=None, + attn_cost=None, + is_dual_input=False, + ): + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + mul_size = 2 if is_dual_input else 1 + + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "nll_loss": utils.item(nll_loss.data), # * sample['ntokens'], + "ntokens": sample["ntokens"]*mul_size, + "nsentences": sample["target"].size(0)*mul_size, + "sample_size": sample_size*mul_size, + "correct": utils.item(correct.data), + "total": utils.item(total.data), + "src_token_num": utils.item(src_token_num.data) if src_token_num > 0 else 0, + "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), + } + + if speech_loss is not None: + logging_output["speech_loss"] = utils.item(speech_loss.data) + logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data) + logging_output["sample_size_speech_cost"] = sample_size + logging_output["speech_attn_loss"] = attn_cost + + return sample_size*mul_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + correct_sum = sum(log.get("correct", 0) for log in logging_outputs) + total_sum = sum(log.get("total", 0) for log in logging_outputs) + src_token_sum = sum(log.get("src_token_num", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + nframes = sum(log.get("nframes", 0) for log in logging_outputs) + speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs) + speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs) + speech_attn_loss_sum = sum(log.get("speech_attn_loss", 0) for log in logging_outputs) + sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs) + + agg_output = { + "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, + "nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, + # if args.sentence_avg, then sample_size is nsentences, and loss + # is per-sentence loss; else sample_size is ntokens, and the loss + # becomes per-output token loss + "speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0, + "speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0, + "speech_attn_loss": speech_attn_loss_sum / src_token_sum / math.log(2) if src_token_sum > 0 else 0.0, + "ntokens": ntokens, + "nsentences": nsentences, + "nframes": nframes, + "sample_size": sample_size, + "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0, + "correct": correct_sum, + "total": total_sum, + "src_token_num": src_token_sum, + # total is the number of validate tokens + } + return agg_output + + @classmethod + def reduce_metrics(cls, logging_outputs): + """Aggregate logging outputs from data parallel training.""" + agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) + for k, v in agg_logging_outputs.items(): + if k in {'nsentences', 'ntokens', 'sample_size'}: + continue + metrics.log_scalar(k, v, round=3) diff --git a/examples/speech_text_joint_to_text/docs/ende-mustc.md b/examples/speech_text_joint_to_text/docs/ende-mustc.md new file mode 100644 index 0000000000..3487af6671 --- /dev/null +++ b/examples/speech_text_joint_to_text/docs/ende-mustc.md @@ -0,0 +1,112 @@ +[[Back]](..) + +# Joint Speech Text Training for the MuST-C English to German Speech Translation task + +Joint Training Baseline: it is based on paper ["A general multi-task learning framework to leverage text data for speech to text tasks"](https://arxiv.org/pdf/2010.11338.pdf) + +Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper ["Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"](https://research.fb.com/publications/improving-speech-translation-by-understanding-and-learning-from-the-auxiliary-text-translation-task) + +## Prepare Data +#### Download files +- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/spm.model) +- Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt) +- config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml) +#### Prepare MuST-C data set +- [Please follow the data preparation in the S2T example](https://github.com/pytorch/fairseq/blob/master/examples/speech_to_text/docs/mustc_example.md) +- Append src_text in the tsv file with phoneme representation. +```bash + python examples/speech_text_joint_to_text/scripts/g2p_encode.py \ + --lower-case --do-filter --use-word-start --no-punc \ + --reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \ + --data-path ${must_c_en_de_src_text} \ + --out-path ${must_c_en_de_src_text_pho} +``` +- Update tsv data with src_text generated above and save to $MANIFEST_ROOT +- Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt) +#### Prepare WMT text data +- [Download wmt data](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2de.sh) +- Convert source text (English) into phoneme representation as above +- Generate binary parallel file for training (as translation example) and save data in $parallel_text_data + +## Training +The model is trained with 8 v100 GPUs. + +#### Download pretrained models +- [pretrain_encoder](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_asr_transformer_m.pt) +- [pretrain_nmt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_mt.pt) + +#### Training scripts +- Jointly trained model from scratch +```bash +python train.py ${MANIFEST_ROOT} \ + --save-dir ${save_dir} \ + --num-workers 8 \ + --task speech_text_joint_to_text \ + --arch dualinputs2ttransformer_s \ + --user-dir examples/speech_text_joint_to_text \ + --max-epoch 100 --update-mix-data \ + --optimizer adam --lr-scheduler inverse_sqrt \ + --lr 0.001 --update-freq 4 --clip-norm 10.0 \ + --criterion guided_label_smoothed_cross_entropy_with_accuracy \ + --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \ + --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \ + --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \ + --dropout 0.1 --warmup-updates 20000 \ + --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \ + --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \ + --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \ + --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \ + --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \ + --keep-last-epochs 10 +``` +- Jointly trained model with good initialization, cross attentive loss and online knowledge distillation +```bash +python train.py ${MANIFEST_ROOT} \ + --save-dir ${save_dir} \ + --num-workers 8 \ + --task speech_text_joint_to_text \ + --arch dualinputs2ttransformer_m \ + --user-dir examples/speech_text_joint_to_text \ + --max-epoch 100 --update-mix-data \ + --optimizer adam --lr-scheduler inverse_sqrt \ + --lr 0.002 --update-freq 4 --clip-norm 10.0 \ + --criterion guided_label_smoothed_cross_entropy_with_accuracy \ + --guide-alpha 0.8 --disable-text-guide-update-num 5000 \ + --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \ + --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \ + --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \ + --dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \ + --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \ + --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \ + --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \ + --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \ + --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \ + --load-pretrain-speech-encoder ${pretrain_encoder} \ + --load-pretrain-decoder ${pretrain_nmt} \ + --load-pretrain-text-encoder-last ${pretrain_nmt} \ + --keep-last-epochs 10 +``` + +## Evaluation +```bash +python ./fairseq_cli/generate.py \ + ${MANIFEST_ROOT} \ + --task speech_text_joint_to_text \ + --max-tokens 25000 \ + --nbest 1 \ + --results-path ${infer_results} \ + --batch-size 512 \ + --path ${model} \ + --gen-subset tst-COMMON \ + --config-yaml config_spm.yaml \ + --scoring sacrebleu \ + --beam 5 --lenpen 1.0 \ + --user-dir examples/speech_text_joint_to_text \ + --load-speech-only +``` + +## Results (Joint training with initialization + CAR + online KD) +|Direction|En-De | En-Es | En-Fr | +|---|---|---|---| +|BLEU|27.4| 31.2 | 37.6 | +|checkpoint | [link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_ave_10.pt) |[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_es/checkpoint_ave_10.pt)|[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_fr/checkpoint_ave_10.pt)| diff --git a/examples/speech_text_joint_to_text/docs/iwslt2021.md b/examples/speech_text_joint_to_text/docs/iwslt2021.md new file mode 100644 index 0000000000..37a07c4a05 --- /dev/null +++ b/examples/speech_text_joint_to_text/docs/iwslt2021.md @@ -0,0 +1,76 @@ +[[Back]](..) + +# Joint Speech Text Training for the 2021 IWSLT multilingual speech translation + +This directory contains the code from paper ["FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task"](https://arxiv.org/pdf/2107.06959.pdf). + +## Prepare Data +#### Download files +- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/spm.model) +- Dictionary [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/dict.txt) +- Config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/config.yaml) + +#### Prepare +- [Please follow the data preparation in speech-to-text](https://github.com/pytorch/fairseq/blob/master/examples/speech_to_text/docs/mtedx_example.md) + + + +## Training + +#### Download pretrained models +- [Pretrained mbart model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/mbart.pt) +- [Pretrained w2v model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/xlsr_53_56k.pt) + + +#### Training scripts + +```bash +python train.py ${MANIFEST_ROOT} \ + --save-dir ${save_dir} \ + --user-dir examples/speech_text_joint_to_text \ + --train-subset train_es_en_tedx,train_es_es_tedx,train_fr_en_tedx,train_fr_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_en_tedx,train_pt_pt_tedx \ + --valid-subset valid_es_en_tedx,valid_es_es_tedx,valid_es_fr_tedx,valid_es_it_tedx,valid_es_pt_tedx,valid_fr_en_tedx,valid_fr_es_tedx,valid_fr_fr_tedx,valid_fr_pt_tedx,valid_it_en_tedx,valid_it_es_tedx,valid_it_it_tedx,valid_pt_en_tedx,valid_pt_es_tedx,valid_pt_pt_tedx \ + --config-yaml config.yaml --ddp-backend no_c10d \ + --num-workers 2 --task speech_text_joint_to_text \ + --criterion guided_label_smoothed_cross_entropy_with_accuracy \ + --label-smoothing 0.3 --guide-alpha 0.8 \ + --disable-text-guide-update-num 5000 --arch dualinputxmtransformer_base \ + --max-tokens 500000 --max-sentences 3 --max-tokens-valid 800000 \ + --max-source-positions 800000 --enc-grad-mult 2.0 \ + --attentive-cost-regularization 0.02 --optimizer adam \ + --clip-norm 1.0 --log-format simple --log-interval 200 \ + --keep-last-epochs 5 --seed 1 \ + --w2v-path ${w2v_path} \ + --load-pretrained-mbart-from ${mbart_path} \ + --max-update 1000000 --update-freq 4 \ + --skip-invalid-size-inputs-valid-test \ + --skip-encoder-projection --save-interval 1 \ + --attention-dropout 0.3 --mbart-dropout 0.3 \ + --finetune-w2v-params all --finetune-mbart-decoder-params all \ + --finetune-mbart-encoder-params all --stack-w2v-mbart-encoder \ + --drop-w2v-layers 12 --normalize \ + --lr 5e-05 --lr-scheduler inverse_sqrt --warmup-updates 5000 +``` + +## Evaluation +```bash +python ./fairseq_cli/generate.py + ${MANIFEST_ROOT} \ + --task speech_text_joint_to_text \ + --user-dir ./examples/speech_text_joint_to_text \ + --load-speech-only --gen-subset test_es_en_tedx \ + --path ${model} \ + --max-source-positions 800000 \ + --skip-invalid-size-inputs-valid-test \ + --config-yaml config.yaml \ + --infer-target-lang en \ + --max-tokens 800000 \ + --beam 5 \ + --results-path ${RESULTS_DIR} \ + --scoring sacrebleu +``` +The trained model can be downloaded [here](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/checkpoint17.pt) + +|direction|es_en|fr_en|pt_en|it_en|fr_es|pt_es|it_es|es_es|fr_fr|pt_pt|it_it| +|---|---|---|---|---|---|---|---|---|---|---|---| +|BLEU|31.62|36.93|35.07|27.12|38.87|35.57|34.13|74.59|74.64|70.84|69.76| diff --git a/examples/speech_text_joint_to_text/models/__init__.py b/examples/speech_text_joint_to_text/models/__init__.py new file mode 100644 index 0000000000..7a394c7e4f --- /dev/null +++ b/examples/speech_text_joint_to_text/models/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module( + "examples.speech_text_joint_to_text.models." + model_name + ) diff --git a/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py b/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py new file mode 100644 index 0000000000..7970a3c714 --- /dev/null +++ b/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py @@ -0,0 +1,1090 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import namedtuple + +import torch +import torch.nn as nn +from fairseq import checkpoint_utils +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.speech_to_text import ( + TransformerDecoder, + S2TTransformerEncoder, +) +from fairseq.models.transformer import TransformerEncoder +from fairseq.modules import ( + TransformerEncoderLayer, + GradMultiply, + LayerNorm, +) + +logger = logging.getLogger(__name__) + + +class SpeechEoSEncoder(FairseqEncoder): + def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0): + super().__init__(None) + self.encoder = encoder + self.eos_num = eos_num # downsampling rate for speech input feature + self.eos_emb = ( + nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True) + if eos_num > 0 + else None + ) + self.adapter = self.add_adapter(adapter_type, adapter_dim) + + def add_adapter(self, adapter_type, adapter_dim): + def _make_identity(linear, eps=1e-5): + assert isinstance(linear, nn.Linear) + linear.weight.data.mul_(eps) + linear.weight.data.fill_diagonal_(1.0) + if linear.bias is not None: + linear.bias.data.mul_(eps) + + adapter = None + if adapter_type == "Linear": + assert adapter_dim > 0 + adapter = nn.Sequential( + nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim) + ) + # initialize the adapter as identity matrix first + _make_identity(adapter[0]) + + elif adapter_type == "MLP": + assert adapter_dim > 0 + # assume the model is pre-norm model + adapter = nn.Sequential( + nn.Linear(adapter_dim, 2 * adapter_dim), + nn.ReLU(), + nn.Linear(2 * adapter_dim, adapter_dim), + LayerNorm(adapter_dim), + ) + _make_identity(adapter[0]) + _make_identity(adapter[2]) + return adapter + + def add_eos(self, src_tokens, src_lengths): + bsz, max_seq_len, fdim = src_tokens.size() + if self.eos_num > 0: + src_token_eos = torch.zeros( + [bsz, max_seq_len + self.eos_num, fdim], + dtype=src_tokens.dtype, + device=src_tokens.device, + ) + src_token_eos[:, :max_seq_len] = src_tokens + for bi in range(bsz): + src_token_eos[bi][ + src_lengths[bi] : src_lengths[bi] + self.eos_num + ] = self.eos_emb.expand(self.eos_num, fdim) + src_lengths = src_lengths + self.eos_num + src_tokens = src_token_eos + return src_tokens, src_lengths + + def apply_adapter(self, enc_out): + if self.adapter is None: + return enc_out + rst = self.adapter(enc_out.encoder_out) + if enc_out.encoder_padding_mask is not None: + rst.masked_fill_( + enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0 + ) + return EncoderOut( + encoder_out=rst, + encoder_padding_mask=enc_out.encoder_padding_mask, + encoder_embedding=enc_out.encoder_embedding, + encoder_states=enc_out.encoder_states, + src_tokens=enc_out.src_tokens, + src_lengths=enc_out.src_lengths, + ) + + def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs): + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths) + enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens) + enc_out = self.apply_adapter(enc_out) + return enc_out + + def reorder_encoder_out(self, encoder_out, new_order): + return self.encoder.reorder_encoder_out(encoder_out, new_order) + + +class DualInputEncoder(FairseqEncoder): + def __init__( + self, + args, + spch_encoder, + text_encoder, + dictionary, + cross_attentive_loss_before_last_layer=-1, + ): + super().__init__(dictionary) + + self.spch_encoder = spch_encoder + self.text_encoder = text_encoder + self.enc_grad_mult = args.enc_grad_mult + self.cross_attentive_loss_before_last_layer = ( + cross_attentive_loss_before_last_layer + ) + self.use_cross_attentive_loss = ( + False if cross_attentive_loss_before_last_layer <= -1 else True + ) + self.enc2_along_grad_mult = args.enc2_along_grad_mult + + @classmethod + def set_shared_layer(cls, share_level, src_layer, tgt_layer): + """ + share parameters from tgt_layer to src_layer + share_level: + 0: share everything + 1: share everything but different model + 2: share weight but not bias, layernorm + """ + if share_level == 0: + return tgt_layer + if isinstance(src_layer, nn.Linear): + return tgt_layer + if isinstance(src_layer, TransformerEncoderLayer): + assert src_layer.embed_dim == tgt_layer.embed_dim + assert src_layer.normalize_before == tgt_layer.normalize_before + if share_level == 1: + src_layer.fc1 = tgt_layer.fc1 + src_layer.fc2 = tgt_layer.fc2 + src_layer.self_attn = tgt_layer.self_attn + src_layer.final_layer_norm = tgt_layer.final_layer_norm + src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm + src_layer.layernorm_embedding = tgt_layer.layernorm_embedding + else: + src_layer.fc1.weight = tgt_layer.fc1.weight + src_layer.fc2.weight = tgt_layer.fc2.weight + src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight + src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight + src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight + src_layer.self_attn.out_proj.weight = ( + tgt_layer.self_attn.out_proj.weight + ) + else: + if share_level == 1: + return tgt_layer + return src_layer + + @classmethod + def build_spch_encoder(cls, args): + cfg = { + "input_feat_per_channel": args.input_feat_per_channel, + "input_channels": args.input_channels, + "conv_kernel_sizes": args.conv_kernel_sizes, + "conv_channels": args.conv_channels, + "encoder_embed_dim": args.encoder_embed_dim, + "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, + "encoder_layers": args.speech_encoder_layers, + "encoder_layerdrop": args.encoder_layerdrop, + "encoder_attention_heads": args.encoder_attention_heads, + "max_source_positions": args.max_source_positions, + "dropout": args.dropout, + "encoder_normalize_before": args.encoder_normalize_before, + "activation_dropout": args.activation_dropout, + "attention_dropout": args.attention_dropout, + "activation_fn": args.activation_fn, + "layernorm_embedding": args.layernorm_embedding, + "no_token_positional_embeddings": args.no_token_positional_embeddings, + "no_scale_embedding": args.no_scale_embedding, + "quant_noise_pq": args.quant_noise_pq, + "encoder_freezing_updates": 0, + } + model_args = namedtuple("args", cfg.keys())(*cfg.values()) + spch_encoder = S2TTransformerEncoder(model_args) + if args.add_speech_eos: + spch_encoder = SpeechEoSEncoder( + spch_encoder, + 2 * len(args.conv_kernel_sizes.split(",")), + args.input_feat_per_channel, + adapter_type=getattr(args, "speech_encoder_adapter_type", "None"), + adapter_dim=args.encoder_embed_dim, + ) + return spch_encoder + + @classmethod + def build_text_encoder(cls, args, src_dictionary, spch_encoder): + if args.encoder_shared_layers > 0: + mx_shared_layers = ( + args.speech_encoder_layers + if args.speech_encoder_layers < args.text_encoder_layers + else args.text_encoder_layers + ) + args.encoder_shared_layers = ( + args.encoder_shared_layers + if args.encoder_shared_layers <= mx_shared_layers + else mx_shared_layers + ) + cfg = { + "encoder_embed_dim": args.encoder_text_embed_dim, + "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, + "encoder_layers": args.text_encoder_layers, + "encoder_layerdrop": args.encoder_layerdrop, + "encoder_attention_heads": args.encoder_attention_heads, + "encoder_learned_pos": args.encoder_learned_pos, + "max_source_positions": args.max_source_positions, + "dropout": args.dropout, + "encoder_normalize_before": args.encoder_normalize_before, + "activation_dropout": args.activation_dropout, + "attention_dropout": args.attention_dropout, + "activation_fn": args.activation_fn, + "adaptive_input": args.adaptive_input, + "no_token_positional_embeddings": args.no_token_positional_embeddings, + "no_scale_embedding": args.no_scale_embedding, + "quant_noise_pq": args.quant_noise_pq, + } + model_args = namedtuple("args", cfg.keys())(*cfg.values()) + enc_emb = nn.Embedding( + len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad() + ) + text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) + if args.add_speech_eos: + spch_encoder = spch_encoder.encoder + if args.encoder_shared_layers > 0: + text_encoder.layer_norm = cls.set_shared_layer( + args.encoder_shared_layer_level, + text_encoder.layer_norm, + spch_encoder.layer_norm, + ) + for i, ly in enumerate( + spch_encoder.transformer_layers[-args.encoder_shared_layers :] + ): + ly_id = i + args.text_encoder_layers - args.encoder_shared_layers + assert isinstance(text_encoder.layers[ly_id], type(ly)) + text_encoder.layers[ly_id] = cls.set_shared_layer( + args.encoder_shared_layer_level, + text_encoder.layers[ly_id], + ly, + ) + return text_encoder + + def mult_rst_grad(self, rst, ratio): + assert isinstance(rst, dict) # instead of EncoderOut + assert len(rst["encoder_out"]) == 1 + rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio) + return rst + + def process_attentive_loss_states(self, rst, interstates): + assert isinstance(rst, dict) # instead of EncoderOut + rst["encoder_states"] = interstates + return rst + + def forward( + self, + src_tokens, + src_lengths=None, + src_txt_tokens=None, + src_txt_lengths=None, + **kwargs + ): + """ + Args: + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (speech) (B,) + src_txt_tokens: padded tensor (B, T) + src_txt_lengths: tensor of original lengths of input utterances (text) (B,) + """ + # src_tokens only: inference + # src_tokens, src_lengths: speech only training + # src_txt_tokens, src_txt_lengths: text only training + # all valid: speech + text training + + if src_tokens is None and src_txt_tokens is None: + raise ValueError( + "src_tokens and src_txt_tokens cannot be None at the same time" + ) + ret1 = None + ret2 = None + return_all_hiddens = False + if src_tokens is not None: + if ( + self.use_cross_attentive_loss and src_txt_tokens is not None + ): # remove self.training so we can get attn score during validation step + return_all_hiddens = True + ret1 = self.spch_encoder( + src_tokens, src_lengths, return_all_hiddens=return_all_hiddens + ) + + if self.use_cross_attentive_loss and src_txt_tokens is not None: + assert self.cross_attentive_loss_before_last_layer < len( + ret1["encoder_states"] + ) + ret1 = self.process_attentive_loss_states( + ret1, + ret1["encoder_states"][ + -self.cross_attentive_loss_before_last_layer - 1 + ], + ) + + if src_txt_tokens is not None: + ret2 = self.text_encoder( + src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens + ) + if return_all_hiddens: + if self.cross_attentive_loss_before_last_layer == len( + self.text_encoder.layers + ): + text_embedding, _ = self.text_encoder.forward_embedding( + src_txt_tokens + ) + text_embedding = text_embedding.transpose(0, 1) + ret2 = self.process_attentive_loss_states(ret2, text_embedding) + else: + assert self.cross_attentive_loss_before_last_layer < len( + self.text_encoder.layers + ) + ret2 = self.process_attentive_loss_states( + ret2, + ret2["encoder_states"][ + -self.cross_attentive_loss_before_last_layer - 1 + ], + ) + + def merge_output(rst1, rst2): + if rst1 is None: + if not (self.enc2_along_grad_mult == 1.0 or self.training): + rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult) + return rst2 + if rst2 is None: + return rst1 + if self.enc_grad_mult != 1.0 and self.training: + rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult) + rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult) + rst = (rst1, rst2) + return rst + + return merge_output(ret1, ret2) + + def reorder_encoder_out(self, encoder_out, new_order): + assert self.training is False # used for inference only + return self.spch_encoder.reorder_encoder_out(encoder_out, new_order) + + +# TransformerMultiInputDecoder: take one or two encoder inputs +class TransformerMultiInputDecoder(FairseqDecoder): + def __init__( + self, + dictionary, + spch_decoder, + text_decoder, + compute_cross_attentive_loss=False, + cross_attentive_loss_with_norm=True, + cross_attentive_loss_reverse=False, + ): + + super().__init__(dictionary) + self.spch_decoder = spch_decoder + self.text_decoder = text_decoder + self.compute_cross_attentive_loss = compute_cross_attentive_loss + self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm + self.cross_attentive_loss_reverse = cross_attentive_loss_reverse + + @classmethod + def share_spchdecoder(cls, task_args, text_decoder, spch_decoder): + if task_args.decoder_shared_layer_level == 0: + return text_decoder + assert text_decoder.embed_tokens == spch_decoder.embed_tokens + spch_decoder.project_in_dim = text_decoder.project_in_dim + spch_decoder.embed_positions = text_decoder.embed_positions + spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding + spch_decoder.project_out_dim = text_decoder.project_out_dim + spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax + if task_args.decoder_shared_layer_level == 1: + spch_decoder.output_projection = text_decoder.output_projection + spch_decoder.layer_norm = text_decoder.layer_norm + else: # 2 + spch_decoder.output_projection.weight = ( + text_decoder.output_projection.weight + ) + for i, ly in enumerate(text_decoder.layers): + sly = spch_decoder.layers[i] + sly.self_attn = ly.self_attn + sly.self_attn_layer_norm = ly.self_attn_layer_norm + # sly.encoder_attn = ly.encoder_attn + if ( + task_args.decoder_shared_layer_level == 1 + ): # share everything, but under different models + sly.encoder_attn = ly.encoder_attn + sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm + sly.fc1 = ly.fc1 + sly.fc2 = ly.fc2 + sly.final_layer_norm = ly.final_layer_norm + else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias + sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight + sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight + sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight + sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight + sly.fc1.weight = ly.fc1.weight + sly.fc2.weight = ly.fc2.weight + + return spch_decoder + + def cross_attentive_loss( + self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6 + ): + x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D + y = student_states.transpose(0, 1) + if self.cross_attentive_loss_with_norm: + x = x / (x.norm(dim=2, keepdim=True) + eps) + y = y / (y.norm(dim=2, keepdim=True) + eps) + dim = x.size(-1) + # lengths: batch X seqLen + sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ] + if y.dtype == torch.float16: + sim_scores_xy = sim_scores_xy.float() + y = y.float() + x = x.float() + if teacher_masking != []: + assert len(teacher_masking) == 1 + sim_scores_xy = sim_scores_xy.masked_fill( + teacher_masking[0].unsqueeze(-1), float("-inf") + ) + if student_masking != []: + sim_scores_xy = sim_scores_xy.masked_fill( + student_masking[0].unsqueeze(1), float("-inf") + ) + # do masking + y_weights = utils.softmax(sim_scores_xy, dim=-1) + if teacher_masking != []: + y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) + x_reconstruct_from_y = torch.bmm(y_weights, y) + + sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ] + x_weights = utils.softmax(sim_scores_xx, dim=-1) + if teacher_masking != []: + x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) + + # no gradient for teacher state + x_reconstruct_from_x = torch.bmm(x_weights, x).detach() + cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2) + if teacher_masking != []: + cost = cost.masked_fill(teacher_masking[0], 0) + + if not self.cross_attentive_loss_with_norm: + cost = cost / dim + return cost + + def forward( + self, + prev_output_tokens, + encoder_out, + incremental_state=None, + has_txt_input=False, + **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing. If there are + two or more input during training, they will share the same prev_output_tokens + encoder_out (tuple[Tensor]): output from the encoder, used for + encoder-side attention. It will be tuple if there are more inputs, but a tensor + if only one input + incremental_state ([dict]): dictionary used for storing state during + :ref:`Incremental decoding`. It is only valid for inference, only from single + input + Returns: + tuple: + - the last decoder layer's output of shape `(batch, tgt_len, + vocab)`. If there are N inputs, batch will be N bigger than a single input + - the last decoder layer's attention weights of shape `(batch, + tgt_len, src_len)` + """ + assert not isinstance(encoder_out, EncoderOut) + if isinstance(encoder_out, tuple): # training with mulitple input + rst = [] + assert len(encoder_out) == 2 + for i, eo in enumerate(encoder_out): + assert incremental_state is None + if i == 0: + rst.append( + self.spch_decoder(prev_output_tokens, eo, incremental_state) + ) + else: + rst.append( + self.text_decoder(prev_output_tokens, eo, incremental_state) + ) + dec_out = torch.cat([r[0] for r in rst], dim=0) + attn_cost = None + if self.compute_cross_attentive_loss: + assert isinstance(encoder_out[0], dict) + if self.cross_attentive_loss_reverse: + attn_cost = self.cross_attentive_loss( + teacher_states=encoder_out[1]["encoder_states"], # text_states + student_states=encoder_out[0]["encoder_states"], # spch_states + teacher_masking=encoder_out[1]["encoder_padding_mask"], + student_masking=encoder_out[0]["encoder_padding_mask"], + ) + else: + attn_cost = self.cross_attentive_loss( + teacher_states=encoder_out[0]["encoder_states"], # spch_states + student_states=encoder_out[1]["encoder_states"], # text_states + teacher_masking=encoder_out[0]["encoder_padding_mask"], + student_masking=encoder_out[1]["encoder_padding_mask"], + ) + + return (dec_out, {"attn_cost": attn_cost}) + else: # inference or training with one input + if has_txt_input: + return self.text_decoder( + prev_output_tokens, encoder_out, incremental_state + ) + return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state) + + +# Note: +# dual input transformer: +# encoder: S2TTransformerEncoder for speech + TransformerEncoder for text +# decoder: TransformerDecoder for text +@register_model("dual_input_s2t_transformer") +class DualInputS2TTransformerModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.num_updates = 0 + + def max_positions(self): + return None # it is provided in task + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # encoder 1: S2TTransformerEncoder for speech + parser.add_argument( + "--conv-kernel-sizes", + type=str, + metavar="N", + help="kernel sizes of Conv1d subsampling layers", + ) + parser.add_argument( + "--conv-channels", + type=int, + metavar="N", + help="# of channels in Conv1d subsampling layers", + ) + parser.add_argument( + "--enc-output-dim", + type=int, + metavar="N", + help=""" + encoder output dimension, can be None. If specified, projecting the + transformer output to the specified dimension""", + ) + # standard Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-text-embed-dim", + type=int, + metavar="N", + help="encoder text embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + # non-standard transformer parameters + parser.add_argument( + "--speech-encoder-layers", + type=int, + metavar="N", + help="num speech encoder layers", + ) + parser.add_argument( + "--text-encoder-layers", + type=int, + metavar="N", + help="num text encoder layers", + ) + parser.add_argument( + "--encoder-shared-layers", + type=int, + metavar="N", + help="num shared encoder layers", + ) + parser.add_argument( + "--encoder-shared-layer-level", + type=int, + metavar="N", + default=0, + choices=[0, 1, 2], + help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm", + ) + + parser.add_argument( + "--decoder-shared-layer-level", + default=0, + choices=[0, 1, 2], + type=int, + metavar="N", + help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias", + ) + ### + parser.add_argument( + "--text-input-cost-ratio", + type=float, + default=1.0, + metavar="V", + help="text input cost ratio relative to speech input cost", + ) + parser.add_argument( + "--init-scale", + type=float, + default=1.0, + metavar="V", + help="scale the initial weight by given factor", + ) + parser.add_argument( + "--enc-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc1 and enc2 gradient by V", + ) + parser.add_argument( + "--enc2-along-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc2 gradient by V if only enc2 is used", + ) + parser.add_argument( + "--load-pretrain-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained encoder """, + ) + parser.add_argument( + "--load-pretrain-speech-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained speech encoder """, + ) + parser.add_argument( + "--load-pretrain-text-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained text encoder """, + ) + parser.add_argument( + "--load-pretrain-text-encoder-last", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained text encoder """, + ) + parser.add_argument( + "--load-pretrain-decoder", + type=str, + metavar="EXPR", + default="", + help=""" path to the pretrained encoder """, + ) + parser.add_argument( + "--add-speech-eos", + action="store_true", + help="add eos token at the end of input feature", + ) + parser.add_argument( + "--speech-encoder-adapter-type", + type=str, + metavar="EXPR", + default="None", + choices=["None", "Linear", "MLP"], + help="add speech encoder adapter", + ) + + @classmethod + def build_encoder(cls, args, task): + spch_encoder = DualInputEncoder.build_spch_encoder(args) + text_encoder = DualInputEncoder.build_text_encoder( + args, task.src_dict, spch_encoder + ) + cross_attentive_loss_before_last_layer = ( + 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 + ) + encoder = DualInputEncoder( + args, + spch_encoder, + text_encoder, + task.src_dict, + cross_attentive_loss_before_last_layer, + ) + if args.init_scale != 1.0: + with torch.no_grad(): + for param in encoder.parameters(): + param.data.mul_(args.init_scale) + if args.load_pretrain_text_encoder != "": + checkpoint_utils.load_pretrained_component_from_model( + text_encoder, args.load_pretrain_text_encoder + ) + if args.load_pretrain_speech_encoder != "": + if hasattr(spch_encoder, "encoder"): + checkpoint_utils.load_pretrained_component_from_model( + spch_encoder.encoder, args.load_pretrain_speech_encoder + ) + else: + checkpoint_utils.load_pretrained_component_from_model( + spch_encoder, args.load_pretrain_speech_encoder + ) + if ( + args.load_pretrain_text_encoder_last != "" + ): # if share encoder, speech encoder parameters will be used. + # It provides a chance to use pre-trained mt encoder instead + checkpoint_utils.load_pretrained_component_from_model( + text_encoder, args.load_pretrain_text_encoder_last + ) + + if args.load_pretrain_encoder != "": + checkpoint_utils.load_pretrained_component_from_model( + encoder, args.load_pretrain_encoder + ) + return encoder + + @classmethod + def build_decoder(cls, args, task): + dec_cfg = { + "decoder_layerdrop": args.decoder_layerdrop, + "share_decoder_input_output_embed": args.share_decoder_input_output_embed, + "decoder_embed_dim": args.decoder_embed_dim, + "max_target_positions": args.max_target_positions, + "dropout": args.dropout, + "encoder_learned_pos": args.encoder_learned_pos, + "decoder_learned_pos": args.decoder_learned_pos, + "layernorm_embedding": args.layernorm_embedding, + "decoder_normalize_before": args.decoder_normalize_before, + "activation_dropout": args.activation_dropout, + "attention_dropout": args.attention_dropout, + "decoder_ffn_embed_dim": args.decoder_ffn_embed_dim, + "decoder_layers": args.decoder_layers, + "decoder_attention_heads": args.decoder_attention_heads, + "decoder_output_dim": args.decoder_embed_dim, + "no_scale_embedding": args.no_scale_embedding, + "adaptive_input": args.adaptive_input, + "quant_noise_pq": args.quant_noise_pq, + "adaptive_softmax_cutoff": args.adaptive_softmax_cutoff, + "tie_adaptive_weights": args.tie_adaptive_weights, + "no_token_positional_embeddings": args.no_token_positional_embeddings, + } + dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values()) + dec_emb = nn.Embedding( + len(task.target_dictionary), + args.decoder_embed_dim, + task.target_dictionary.pad(), + ) + compute_cross_attentive_loss = ( + True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False + ) + cross_attentive_loss_without_norm = getattr( + args, "attentive_cost_without_normalize", False + ) + cross_attentive_loss_reverse = ( + False # getattr(args, "attentive_cost_reverse", False) + ) + + text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) + spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) + spch_decoder = TransformerMultiInputDecoder.share_spchdecoder( + args, text_decoder, spch_decoder + ) + decoder = TransformerMultiInputDecoder( + dictionary=task.target_dictionary, + spch_decoder=spch_decoder, + text_decoder=text_decoder, + compute_cross_attentive_loss=compute_cross_attentive_loss, + cross_attentive_loss_with_norm=True + if not cross_attentive_loss_without_norm + else False, + cross_attentive_loss_reverse=cross_attentive_loss_reverse, + ) + if args.init_scale != 1.0: + with torch.no_grad(): + for param in decoder.parameters(): + param.data.mul_(args.init_scale) + if args.load_pretrain_decoder != "": + try: + checkpoint_utils.load_pretrained_component_from_model( + decoder, args.load_pretrain_decoder + ) + except RuntimeError: + checkpoint_utils.load_pretrained_component_from_model( + decoder.text_decoder, args.load_pretrain_decoder + ) + if args.decoder_shared_layer_level > 0: + checkpoint_utils.load_pretrained_component_from_model( + decoder.spch_decoder, args.load_pretrain_decoder + ) + + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted + # (in case there are any new ones) + dualinputs2ttransformer_base(args) + + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + return cls(encoder, decoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + use_encoder_outputs=False, + src_txt_tokens=None, + src_txt_lengths=None, + mode="sup_speech", + **kwargs + ): + """ + Run the forward pass for an encoder-decoder model. + + First feed a batch of source tokens through the encoder. Then, feed the + encoder output and previous decoder outputs (i.e., teacher forcing) to + the decoder to produce the next outputs:: + + encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder(prev_output_tokens, encoder_out) + + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + mode = 'sup_speech' or 'text' + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + if mode == "text": + assert src_txt_tokens is None + src_txt_tokens = src_tokens + src_txt_lengths = src_lengths + src_tokens = None + src_lengths = None + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + src_txt_tokens=src_txt_tokens, + src_txt_lengths=src_txt_lengths, + **kwargs + ) + has_txt_input = True if src_txt_tokens is not None else False + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + has_txt_input=has_txt_input, + **kwargs + ) + if use_encoder_outputs: + return decoder_out, encoder_out + return decoder_out + + +@register_model_architecture( + "dual_input_s2t_transformer", "dualinputs2ttransformer_base" +) +def dualinputs2ttransformer_base(args): + args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) + # Convolutional subsampler + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") + args.conv_channels = getattr(args, "conv_channels", 1024) + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_text_embed_dim = getattr( + args, "encoder_text_embed_dim", args.encoder_embed_dim + ) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0) + args.decoder_layers = getattr(args, "decoder_layers", 6) + + args.add_speech_eos = getattr(args, "add_speech_eos", False) + + +@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_s") +def dualinputs2ttransformer_s(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 7) + args.decoder_layers = getattr(args, "decoder_layers", 7) + dualinputs2ttransformer_base(args) + + +@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_m") +def dualinputs2ttransformer_m(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.dropout = getattr(args, "dropout", 0.15) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 6) + dualinputs2ttransformer_base(args) + + +@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_b") +def dualinputs2ttransformer_b(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + args.dropout = getattr(args, "dropout", 0.15) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 6) + dualinputs2ttransformer_base(args) + + +@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_l") +def dualinputs2ttransformer_l(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.2) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 6) + dualinputs2ttransformer_base(args) diff --git a/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py b/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py new file mode 100644 index 0000000000..6c853b96ed --- /dev/null +++ b/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py @@ -0,0 +1,584 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import torch.nn as nn +from fairseq import checkpoint_utils +from fairseq import utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + register_model, + register_model_architecture, + FairseqEncoder, +) +from fairseq.models.speech_to_text import XMTransformerModel, Wav2VecEncoderWithAdaptor +from fairseq.models.speech_to_text.xm_transformer import ( + set_default_adaptor_args, + set_default_w2v_encoder_args, +) +from fairseq.models.transformer import TransformerEncoder, TransformerDecoder +from fairseq.models.wav2vec import TransformerSentenceEncoderLayer + +from .s2t_dualinputtransformer import ( + DualInputS2TTransformerModel, + TransformerMultiInputDecoder, + DualInputEncoder, +) + + +class TransformerSentenceEncoderLayerStd(TransformerSentenceEncoderLayer): + def __init__(self, sent_enc_layer): + super(TransformerSentenceEncoderLayer, self).__init__() + self.embedding_dim = sent_enc_layer.embedding_dim + self.dropout = sent_enc_layer.dropout + self.activation_dropout = sent_enc_layer.activation_dropout + + # Initialize blocks + self.activation_fn = sent_enc_layer.activation_fn + self.self_attn = sent_enc_layer.self_attn + + self.dropout1 = sent_enc_layer.dropout1 + self.dropout2 = sent_enc_layer.dropout2 + self.dropout3 = sent_enc_layer.dropout3 + + self.layer_norm_first = sent_enc_layer.layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = sent_enc_layer.self_attn_layer_norm + self.fc1 = sent_enc_layer.fc1 + self.fc2 = sent_enc_layer.fc2 + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = sent_enc_layer.final_layer_norm + + def forward( + self, + x, + self_attn_mask=None, + self_attn_padding_mask=None, + need_weights=None, + att_args=None, + ): + x, attn = super().forward( + x, self_attn_mask, self_attn_padding_mask, need_weights, att_args + ) + return x + + +# TODO retire SharedEncoder +class SharedEncoder(FairseqEncoder): + def __init__(self, wav2vec_enc, mbart_enc, adaptor, shared_layers): + super().__init__(None) + self.w2v_encoder = wav2vec_enc + self.shared_layers = self.w2v_encoder.w2v_model.encoder.layers[-shared_layers:] + self.w2v_encoder.w2v_model.encoder.layers = ( + self.w2v_encoder.w2v_model.encoder.layers[:-shared_layers] + ) + self.adaptor = adaptor + if self.shared_layers[-1].layer_norm_first: + self.final_layer_norm = mbart_enc.layer_norm + else: + mbart_enc.layer_norm = None + self.final_layer_norm = None + shared_layer_from = len(mbart_enc.layers) - shared_layers + if shared_layer_from < 0: + shared_layer_from = 0 + for layer_id, layer in enumerate(self.shared_layers): + mbart_enc.layers[ + shared_layer_from + layer_id + ] = TransformerSentenceEncoderLayerStd(layer) + + def forward(self, src_tokens, src_lengths=None, **kwargs): + padding_mask = lengths_to_padding_mask(src_lengths) + if not padding_mask.any(): + padding_mask = None + + out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) + x = out["encoder_out"] + enc_padding_mask = None + if out["encoder_padding_mask"] is not None: + enc_padding_mask = out["encoder_padding_mask"].transpose( + 0, 1 + ) # T X B --> B X T + + x, enc_padding_mask = self.adaptor(x, enc_padding_mask) + for layer in self.shared_layers: + x, _ = layer(x, enc_padding_mask) + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [enc_padding_mask] + if enc_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + +class StackedWav2VecEncoderWithAdaptor(FairseqEncoder): + def __init__( + self, + wav2vec_enc, + mbart_enc_layers, + mbart_layer_norm, + adaptor, + drop_w2v_layers=0, + ): + super().__init__(None) + self.w2v_encoder = wav2vec_enc + self.adaptor = adaptor + self.mbart_encoder_layers = mbart_enc_layers + self.final_layer_norm = mbart_layer_norm + if drop_w2v_layers > 0: + self.w2v_encoder.w2v_model.encoder.layers = ( + self.w2v_encoder.w2v_model.encoder.layers[:-drop_w2v_layers] + ) + + def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs): + padding_mask = lengths_to_padding_mask(src_lengths) + if not padding_mask.any(): + padding_mask = None + + out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) + x = out["encoder_out"] + enc_padding_mask = None + if out["encoder_padding_mask"] is not None: + enc_padding_mask = out["encoder_padding_mask"].transpose( + 0, 1 + ) # T X B --> B X T + + x, enc_padding_mask = self.adaptor(x, enc_padding_mask) + encoder_states = [] + for layer in self.mbart_encoder_layers: + x = layer(x, enc_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [enc_padding_mask] + if enc_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +# Note: +# dual input transformer: +# encoder: wav2vec for speech + mbart encoder for text +# decoder: mbart decoder for text +@register_model("dual_input_xm_transformer") +class DualInputXMTransformerModel(DualInputS2TTransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # wav2vec encoder + Wav2VecEncoderWithAdaptor.add_args(parser) + # add_decoder_args(parser) + # mbart Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + + parser.add_argument( + "--mbart-dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--mbart-attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--mbart-activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-mbart-from", + type=str, + metavar="STR", + help="model to take text encoder decoder weights from (for initialization)", + ) + # parser.add_argument("--finetune-w2v-params", type=str, metavar="STR", + # help="comma-separated param strings to finetune.") + parser.add_argument( + "--finetune-mbart-decoder-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", + ) + parser.add_argument( + "--finetune-mbart-encoder-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", + ) + parser.add_argument( + "--skip-encoder-projection", + action="store_true", + help="skip the projection layer in encoder", + ) + + parser.add_argument( + "--enc-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc1 and enc2 gradient by V", + ) + parser.add_argument( + "--enc2-along-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc2 gradient by V if only enc2 is used", + ) + parser.add_argument( + "--text-input-cost-ratio", + type=float, + default=1.0, + metavar="V", + help="text input cost ratio relative to speech input cost", + ) + parser.add_argument( + "--stack-w2v-mbart-encoder", + action="store_true", + help="stack w2v and mbart encoder", + ) + parser.add_argument( + "--stack-w2v-mbart-nonorm-encoder", + action="store_true", + help="stack w2v and mbart encoder", + ) + parser.add_argument( + "--no-final-norm-decoder", action="store_true", help="no layer norm" + ) + parser.add_argument( + "--drop-w2v-layers", + type=int, + default=0, + metavar="N", + help="drop w2v encoder layers", + ) + + parser.add_argument( + "--share-w2v-text-encoder", + action="store_true", + help="share w2v encoder layers with text encoder", + ) + parser.add_argument( + "--shared-w2v-layers", + type=int, + default=0, + metavar="N", + help="shared encoder layers from w2v encoder", + ) + + @classmethod + def build_encoder(cls, args, task): + _args = copy.deepcopy(args) + _args.dropout = args.mbart_dropout + _args.attention_dropout = args.mbart_attention_dropout + _args.activation_dropout = args.mbart_activation_dropout + _args.max_source_positions = 1024 + enc_emb = nn.Embedding( + len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() + ) + text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) + spch_encoder = Wav2VecEncoderWithAdaptor(args) + if getattr(args, "load_pretrained_mbart_from", None): + text_encoder = checkpoint_utils.load_pretrained_component_from_model( + component=text_encoder, checkpoint=args.load_pretrained_mbart_from + ) + if getattr(args, "stack_w2v_mbart_encoder", False): + assert getattr(args, "share_w2v_text_encoder", False) is False + spch_encoder = StackedWav2VecEncoderWithAdaptor( + spch_encoder.w2v_encoder, + text_encoder.layers, + text_encoder.layer_norm, + spch_encoder.adaptor, + args.drop_w2v_layers, + ) + elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): + text_encoder.layer_norm = None + spch_encoder = StackedWav2VecEncoderWithAdaptor( + spch_encoder.w2v_encoder, + text_encoder.layers, + text_encoder.layer_norm, + spch_encoder.adaptor, + args.drop_w2v_layers, + ) + elif getattr(args, "share_w2v_text_encoder", False): + spch_encoder = SharedEncoder( + spch_encoder.w2v_encoder, + text_encoder, + spch_encoder.adaptor, + args.shared_w2v_layers, + ) + + for k, p in spch_encoder.named_parameters(): + # Freeze pretrained models by default + if hasattr( + args, "finetune_w2v_params" + ) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k): + p.requires_grad = True + else: + p.requires_grad = False + for k, p in text_encoder.named_parameters(): + # Freeze pretrained models by default + if hasattr( + args, "finetune_mbart_encoder_params" + ) and XMTransformerModel.finetune_params( + args.finetune_mbart_encoder_params, k + ): + p.requires_grad = True + else: + p.requires_grad = False + cross_attentive_loss_before_last_layer = ( + 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 + ) + encoder = DualInputEncoder( + args, + spch_encoder, + text_encoder, + task.src_dict, + cross_attentive_loss_before_last_layer, + ) + return encoder + + @classmethod + def build_decoder(cls, args, task): + _args = copy.deepcopy(args) + _args.dropout = args.mbart_dropout + _args.attention_dropout = args.mbart_attention_dropout + _args.activation_dropout = args.mbart_activation_dropout + _args.max_target_positions = 1024 + dec_emb = nn.Embedding( + len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad() + ) + decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb) + if getattr(args, "load_pretrained_mbart_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_mbart_from + ) + if getattr(args, "no_final_norm_decoder", False): + decoder.layer_norm = None + for k, p in decoder.named_parameters(): + # Freeze pretrained models by default + if hasattr( + args, "finetune_mbart_decoder_params" + ) and XMTransformerModel.finetune_params( + args.finetune_mbart_decoder_params, k + ): + p.requires_grad = True + else: + p.requires_grad = False + + compute_cross_attentive_loss = ( + True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False + ) + cross_attentive_loss_without_norm = getattr( + args, "attentive_cost_without_normalize", False + ) + cross_attentive_loss_reverse = ( + False # getattr(args, "attentive_cost_reverse", False) + ) + decoder = TransformerMultiInputDecoder( + dictionary=task.target_dictionary, + spch_decoder=decoder, + text_decoder=decoder, + compute_cross_attentive_loss=compute_cross_attentive_loss, + cross_attentive_loss_with_norm=True + if not cross_attentive_loss_without_norm + else False, + cross_attentive_loss_reverse=cross_attentive_loss_reverse, + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted + # (in case there are any new ones) + dualinputxmtransformer_base(args) + + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + return cls(encoder, decoder) + + +@register_model_architecture("dual_input_xm_transformer", "dualinputxmtransformer_base") +def dualinputxmtransformer_base(args): + # wav2vec encoder + set_default_w2v_encoder_args(args) + set_default_adaptor_args(args) + + # mbart model + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr( + args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + + args.adaptive_input = getattr(args, "adaptive_input", False) + + args.mbart_attention_dropout = getattr(args, "mbart_attention_dropout", 0.0) + args.mbart_activation_dropout = getattr(args, "mbart_activation_dropout", 0.0) + args.mbart_dropout = getattr(args, "mbart_dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", True + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) diff --git a/examples/speech_text_joint_to_text/scripts/g2p_encode.py b/examples/speech_text_joint_to_text/scripts/g2p_encode.py new file mode 100644 index 0000000000..9db779396f --- /dev/null +++ b/examples/speech_text_joint_to_text/scripts/g2p_encode.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import itertools +import logging +import re +import time + +from g2p_en import G2p + +logger = logging.getLogger(__name__) + +FAIL_SENT = "FAILED_SENTENCE" + + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--out-path", type=str, required=True) + parser.add_argument("--lower-case", action="store_true") + parser.add_argument("--do-filter", action="store_true") + parser.add_argument("--use-word-start", action="store_true") + parser.add_argument("--dup-vowel", default=1, type=int) + parser.add_argument("--dup-consonant", default=1, type=int) + parser.add_argument("--no-punc", action="store_true") + parser.add_argument("--reserve-word", type=str, default="") + parser.add_argument( + "--reserve-first-column", + action="store_true", + help="first column is sentence id", + ) + ### + parser.add_argument("--parallel-process-num", default=1, type=int) + parser.add_argument("--logdir", default="") + args = parser.parse_args() + return args + + +def process_sent(sent, g2p, res_wrds, args): + sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds) + pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)] + pho_seq = ( + [FAIL_SENT] + if [FAIL_SENT] in pho_seqs + else list(itertools.chain.from_iterable(pho_seqs)) + ) + if args.no_punc: + pho_seq = remove_punc(pho_seq) + if args.dup_vowel > 1 or args.dup_consonant > 1: + pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant) + if args.use_word_start: + pho_seq = add_word_start(pho_seq) + return " ".join(pho_seq) + + +def remove_punc(sent): + ns = [] + regex = re.compile("[^a-zA-Z0-9 ]") + for p in sent: + if (not regex.search(p)) or p == FAIL_SENT: + if p == " " and (len(ns) == 0 or ns[-1] == " "): + continue + ns.append(p) + return ns + + +def do_g2p(g2p, sent, res_wrds, is_first_sent): + if sent in res_wrds: + pho_seq = [res_wrds[sent]] + else: + pho_seq = g2p(sent) + if not is_first_sent: + pho_seq = [" "] + pho_seq # add space to separate + return pho_seq + + +def pre_process_sent(sent, do_filter, lower_case, res_wrds): + if do_filter: + sent = re.sub("-", " ", sent) + sent = re.sub("—", " ", sent) + if len(res_wrds) > 0: + wrds = sent.split() + wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds] + sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""] + else: + sents = [sent] + if lower_case: + sents = [s.lower() if s not in res_wrds else s for s in sents] + return sents + + +def dup_pho(sent, dup_v_num, dup_c_num): + """ + duplicate phoneme defined as cmudict + http://www.speech.cs.cmu.edu/cgi-bin/cmudict + """ + if dup_v_num == 1 and dup_c_num == 1: + return sent + ns = [] + for p in sent: + ns.append(p) + if re.search(r"\d$", p): + for i in range(1, dup_v_num): + ns.append(f"{p}-{i}P") + elif re.search(r"\w", p): + for i in range(1, dup_c_num): + ns.append(f"{p}-{i}P") + return ns + + +def add_word_start(sent): + ns = [] + do_add = True + ws = "▁" + for p in sent: + if do_add: + p = ws + p + do_add = False + if p == " ": + do_add = True + else: + ns.append(p) + return ns + + +def load_reserve_word(reserve_word): + if reserve_word == "": + return [] + with open(reserve_word, "r") as fp: + res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""] + assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0 + res_wrds = dict(res_wrds) + return res_wrds + + +def process_sents(sents, args): + g2p = G2p() + out_sents = [] + res_wrds = load_reserve_word(args.reserve_word) + for sent in sents: + col1 = "" + if args.reserve_first_column: + col1, sent = sent.split(None, 1) + sent = process_sent(sent, g2p, res_wrds, args) + if args.reserve_first_column and col1 != "": + sent = f"{col1} {sent}" + out_sents.append(sent) + return out_sents + + +def main(): + args = parse() + out_sents = [] + with open(args.data_path, "r") as fp: + sent_list = [x.strip() for x in fp.readlines()] + if args.parallel_process_num > 1: + try: + import submitit + except ImportError: + logger.warn( + "submitit is not found and only one job is used to process the data" + ) + submitit = None + + if args.parallel_process_num == 1 or submitit is None: + out_sents = process_sents(sent_list, args) + else: + # process sentences with parallel computation + lsize = len(sent_list) // args.parallel_process_num + 1 + executor = submitit.AutoExecutor(folder=args.logdir) + executor.update_parameters(timeout_min=1000, cpus_per_task=4) + jobs = [] + for i in range(args.parallel_process_num): + job = executor.submit( + process_sents, sent_list[lsize * i : lsize * (i + 1)], args + ) + jobs.append(job) + is_running = True + while is_running: + time.sleep(5) + is_running = sum([job.done() for job in jobs]) < len(jobs) + out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs])) + with open(args.out_path, "w") as fp: + fp.write("\n".join(out_sents) + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/speech_text_joint_to_text/tasks/__init__.py b/examples/speech_text_joint_to_text/tasks/__init__.py new file mode 100644 index 0000000000..d878278475 --- /dev/null +++ b/examples/speech_text_joint_to_text/tasks/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + task_name = file[: file.find(".py")] + importlib.import_module("examples.speech_text_joint_to_text.tasks." + task_name) diff --git a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py new file mode 100644 index 0000000000..f2b3966d2d --- /dev/null +++ b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py @@ -0,0 +1,372 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +from argparse import Namespace +from pathlib import Path + +import torch +from fairseq.data import ( + encoders, + Dictionary, + ResamplingDataset, + TransformEosLangPairDataset, + ConcatDataset, +) +from fairseq.data.iterators import GroupedEpochBatchIterator +from fairseq.data.audio.multi_modality_dataset import ( + MultiModalityDataset, + LangPairMaskDataset, + ModalityDatasetItem, +) +from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator +from fairseq.data.audio.speech_to_text_joint_dataset import ( + S2TJointDataConfig, + SpeechToTextJointDatasetCreator, +) +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.tasks.translation import load_langpair_dataset + +logger = logging.getLogger(__name__) +LANG_TAG_TEMPLATE = "<lang:{}>" + + +@register_task("speech_text_joint_to_text") +class SpeechTextJointToTextTask(SpeechToTextTask): + """ + Task for joint training speech and text to text. + """ + + @classmethod + def add_args(cls, parser): + """Add task-specific arguments to the parser.""" + super(SpeechTextJointToTextTask, cls).add_args(parser) + ### + parser.add_argument( + "--parallel-text-data", + default="", + help="path to parallel text data directory", + ) + parser.add_argument( + "--max-tokens-text", + type=int, + metavar="N", + help="maximum tokens for encoder text input ", + ) + parser.add_argument( + "--max-positions-text", + type=int, + metavar="N", + default=400, + help="maximum tokens for per encoder text input ", + ) + parser.add_argument( + "--langpairs", + default=None, + metavar="S", + help='language pairs for text training, separated with ","', + ) + parser.add_argument( + "--speech-sample-ratio", + default=1, + type=float, + metavar="N", + help="Multiple Ratio for speech dataset with transcripts ", + ) + parser.add_argument( + "--text-sample-ratio", + default=1, + type=float, + metavar="N", + help="Multiple Ratio for text set ", + ) + parser.add_argument( + "--update-mix-data", + action="store_true", + help="use mixed data in one update when update-freq > 1", + ) + parser.add_argument( + "--load-speech-only", + action="store_true", + help="load speech data only", + ) + parser.add_argument( + "--mask-text-ratio", + type=float, + metavar="V", + default=0.0, + help="mask V source tokens for text only mode", + ) + parser.add_argument( + "--mask-text-type", + default="random", + choices=["random", "tail"], + help="mask text typed", + ) + parser.add_argument( + "--noise-token", + default="", + help="noise token for masking src text tokens if mask-text-ratio > 0", + ) + parser.add_argument( + "--infer-target-lang", + default="", + metavar="S", + help="target language for inference", + ) + + def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None): + super().__init__(args, tgt_dict) + self.src_dict = src_dict + self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) + assert self.tgt_dict.pad() == self.src_dict.pad() + assert self.tgt_dict.eos() == self.src_dict.eos() + self.speech_only = args.load_speech_only + self._infer_tgt_lang_id = infer_tgt_lang_id + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries).""" + data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) + tgt_dict_path = Path(args.data) / data_cfg.vocab_filename + src_dict_path = Path(args.data) / data_cfg.src_vocab_filename + if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)): + raise FileNotFoundError("Dict not found: {}".format(args.data)) + src_dict = Dictionary.load(src_dict_path.as_posix()) + tgt_dict = Dictionary.load(tgt_dict_path.as_posix()) + + print("| src dictionary: {} types".format(len(src_dict))) + print("| tgt dictionary: {} types".format(len(tgt_dict))) + + if args.parallel_text_data != "": + if not os.path.isabs(args.parallel_text_data): + args.parallel_text_data = os.path.join( + args.data, args.parallel_text_data + ) + + if args.langpairs is None: + raise Exception( + "Could not infer language pair, please provide it explicitly" + ) + infer_tgt_lang_id = None + if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change: + tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( + args.infer_target_lang + ) + infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) + assert infer_tgt_lang_id != tgt_dict.unk() + return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) + + def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): + lang_pairs = [] + text_dataset = None + split = "train" + for lp in self.args.langpairs.split(","): + src, tgt = lp.split("-") + text_dataset = load_langpair_dataset( + self.args.parallel_text_data, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=True, + dataset_impl=None, + upsample_primary=1, + left_pad_source=False, + left_pad_target=False, + max_source_positions=self.args.max_positions_text, + max_target_positions=self.args.max_target_positions, + load_alignments=False, + truncate_source=False, + ) + if prepend_tgt_lang_tag: + # TODO + text_dataset = TransformEosLangPairDataset( + text_dataset, + src_eos=self.src_dict.eos(), + tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos + new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)), + ) + lang_pairs.append(text_dataset) + if len(lang_pairs) > 1: + if sampling_alpha != 1.0: + size_ratios = SpeechToTextDatasetCreator.get_size_ratios( + self.args.langpairs.split(","), + [len(s) for s in lang_pairs], + alpha=sampling_alpha, + ) + lang_pairs = [ + ResamplingDataset( + d, size_ratio=r, epoch=epoch, replace=(r >= 1.0) + ) + for d, r in zip(lang_pairs, size_ratios) + ] + return ConcatDataset(lang_pairs) + return text_dataset + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + bos_token=self._infer_tgt_lang_id, + ) + + def build_src_tokenizer(self, args): + logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer)) + + def build_src_bpe(self, args): + logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}") + return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer)) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + src_pre_tokenizer = self.build_src_tokenizer(self.args) + src_bpe_tokenizer = self.build_src_bpe(self.args) + ast_dataset = SpeechToTextJointDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.tgt_dict, + src_dict=None if self.speech_only else self.src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=src_pre_tokenizer, + src_bpe_tokenizer=src_bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + ) + noise_token_id = -1 + text_dataset = None + if self.args.parallel_text_data != "" and is_train_split: + text_dataset = self.load_langpair_dataset( + self.data_cfg.prepend_tgt_lang_tag_no_change, + 1.0, + epoch=epoch, + ) + if self.args.mask_text_ratio > 0: + # add mask + noise_token_id = ( + self.src_dict.unk() + if self.args.noise_token == "" + else self.src_dict.index(self.args.noise_token) + ) + text_dataset = LangPairMaskDataset( + text_dataset, + src_bos=self.src_dict.bos(), + src_eos=self.src_dict.eos(), + noise_id=noise_token_id, + mask_ratio=self.args.mask_text_ratio, + mask_type=self.args.mask_text_type, + ) + + if text_dataset is not None: + mdsets = [ + ModalityDatasetItem( + "sup_speech", + ast_dataset, + (self.args.max_source_positions, self.args.max_target_positions), + self.args.max_tokens, + self.args.batch_size, + ), + ModalityDatasetItem( + "text", + text_dataset, + (self.args.max_positions_text, self.args.max_target_positions), + self.args.max_tokens_text + if self.args.max_tokens_text is not None + else self.args.max_tokens, + self.args.batch_size, + ), + ] + ast_dataset = MultiModalityDataset(mdsets) + self.datasets[split] = ast_dataset + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.tgt_dict + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None if self.speech_only else self.src_dict + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + data_buffer_size=0, + disable_iterator_cache=False, + ): + + if not isinstance(dataset, MultiModalityDataset): + return super(SpeechTextJointToTextTask, self).get_batch_iterator( + dataset, + max_tokens, + max_sentences, + max_positions, + ignore_invalid_inputs, + required_batch_size_multiple, + seed, + num_shards, + shard_id, + num_workers, + epoch, + data_buffer_size, + disable_iterator_cache, + ) + + mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio] + assert len(dataset.datasets) == 2 + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + batch_samplers = dataset.get_batch_samplers( + mult_ratio, required_batch_size_multiple, seed + ) + + # return a reusable, sharded iterator + epoch_iter = GroupedEpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_samplers=batch_samplers, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq), + buffer_size=data_buffer_size, + ) + self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch + return epoch_iter diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py new file mode 100644 index 0000000000..69d23d31c1 --- /dev/null +++ b/fairseq/data/audio/multi_modality_dataset.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import math +from typing import List, Optional, NamedTuple + +import numpy as np +import torch +from fairseq.data import ( + ConcatDataset, + LanguagePairDataset, + FileAudioDataset, + data_utils, +) +from fairseq.data import FairseqDataset + +logger = logging.getLogger(__name__) + + +class ModalityDatasetItem(NamedTuple): + datasetname: str + dataset: any + max_positions: List[int] + max_tokens: Optional[int] = None + max_sentences: Optional[int] = None + +# MultiModalityDataset: it concate multiple datasets with different modalities. +# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets +# 2) it adds mode to indicate what type of the data samples come from. +# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples +# from the same type of dataset +# If only one dataset is used, it will perform like the original dataset with mode added +class MultiModalityDataset(ConcatDataset): + def __init__(self, datasets: List[ModalityDatasetItem]): + id_to_mode = [] + dsets = [] + max_tokens = [] + max_sentences = [] + max_positions = [] + for dset in datasets: + id_to_mode.append(dset.datasetname) + dsets.append(dset.dataset) + max_tokens.append(dset.max_tokens) + max_positions.append(dset.max_positions) + max_sentences.append(dset.max_sentences) + weights = [1.0 for s in dsets] + super().__init__(dsets, weights) + self.max_tokens = max_tokens + self.max_positions = max_positions + self.max_sentences = max_sentences + self.id_to_mode = id_to_mode + self.raw_sub_batch_samplers = [] + self._cur_epoch = 0 + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self._cur_epoch = epoch + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + sample = self.datasets[dataset_idx][sample_idx] + return (dataset_idx, sample) + + def collater(self, samples): + if len(samples) == 0: + return {} + dataset_idx = samples[0][0] + # make sure all samples in samples are from same dataset + assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0 + samples = self.datasets[dataset_idx].collater([x[1] for x in samples]) + # add mode + samples["net_input"]["mode"] = self.id_to_mode[dataset_idx] + + return samples + + def size(self, index: int): + if len(self.datasets) == 1: + return self.datasets[0].size(index) + return super().size(index) + + @property + def sizes(self): + if len(self.datasets) == 1: + return self.datasets[0].sizes + super().sizes + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if len(self.datasets) == 1: + return self.datasets[0].ordered_indices() + indices_group = [] + for d_idx, ds in enumerate(self.datasets): + sample_num = self.cumulative_sizes[d_idx] + if d_idx > 0: + sample_num = sample_num - self.cumulative_sizes[d_idx - 1] + assert sample_num == len(ds) + indices_group.append(ds.ordered_indices()) + return indices_group + + def get_raw_batch_samplers(self, required_batch_size_multiple, seed): + if len(self.raw_sub_batch_samplers) > 0: + logger.info(" raw_sub_batch_samplers exists. No action is taken") + return + with data_utils.numpy_seed(seed): + indices = self.ordered_indices() + for i, ds in enumerate(self.datasets): + indices[i] = ds.filter_indices_by_size( + indices[i], + self.max_positions[i], + )[0] + sub_batch_sampler = ds.batch_by_size( + indices[i], + max_tokens=self.max_tokens[i], + max_sentences=self.max_sentences[i], + required_batch_size_multiple=required_batch_size_multiple, + ) + self.raw_sub_batch_samplers.append(sub_batch_sampler) + + def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed): + self.get_raw_batch_samplers(required_batch_size_multiple, seed) + batch_samplers = [] + for i, _ in enumerate(self.datasets): + if i > 0: + sub_batch_sampler = [ + [y + self.cumulative_sizes[i - 1] for y in x] + for x in self.raw_sub_batch_samplers[i] + ] + else: + sub_batch_sampler = list(self.raw_sub_batch_samplers[i]) + smp_r = mult_ratios[i] + if smp_r != 1: + is_increase = "increased" if smp_r > 1 else "decreased" + logger.info( + "number of batch for the dataset {} is {} from {} to {}".format( + self.id_to_mode[i], + is_increase, + len(sub_batch_sampler), + int(len(sub_batch_sampler) * smp_r), + ) + ) + mul_samplers = [] + for _ in range(math.floor(smp_r)): + mul_samplers = mul_samplers + sub_batch_sampler + if math.floor(smp_r) != smp_r: + with data_utils.numpy_seed(seed + self._cur_epoch): + np.random.shuffle(sub_batch_sampler) + smp_num = int( + (smp_r - math.floor(smp_r)) * len(sub_batch_sampler) + ) + mul_samplers = mul_samplers + sub_batch_sampler[:smp_num] + sub_batch_sampler = mul_samplers + else: + logger.info( + "dataset {} batch number is {} ".format( + self.id_to_mode[i], len(sub_batch_sampler) + ) + ) + batch_samplers.append(sub_batch_sampler) + + return batch_samplers + + +class LangPairMaskDataset(FairseqDataset): + def __init__( + self, + dataset: LanguagePairDataset, + src_eos: int, + src_bos: Optional[int] = None, + noise_id: Optional[int] = -1, + mask_ratio: Optional[float] = 0, + mask_type: Optional[str] = "random", + ): + self.dataset = dataset + self.src_eos = src_eos + self.src_bos = src_bos + self.noise_id = noise_id + self.mask_ratio = mask_ratio + self.mask_type = mask_type + assert mask_type in ("random", "tail") + + @property + def src_sizes(self): + return self.dataset.src_sizes + + @property + def tgt_sizes(self): + return self.dataset.tgt_sizes + + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + + def get_batch_shapes(self): + return self.dataset.buckets + + def num_tokens_vec(self, indices): + return self.dataset.num_tokens_vec(indices) + + def __len__(self): + return len(self.dataset) + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) + + def mask_src_tokens(self, sample): + src_item = sample["source"] + mask = None + if self.mask_type == "random": + mask = torch.rand(len(src_item)).le(self.mask_ratio) + else: + mask = torch.ones(len(src_item)) + mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0 + mask = mask.eq(1) + if src_item[0] == self.src_bos: + mask[0] = False + if src_item[-1] == self.src_eos: + mask[-1] = False + mask_src_item = src_item.masked_fill(mask, self.noise_id) + smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]} + return smp + + def __getitem__(self, index): + sample = self.dataset[index] + if self.mask_ratio > 0: + sample = self.mask_src_tokens(sample) + return sample + + def collater(self, samples, pad_to_length=None): + return self.dataset.collater(samples, pad_to_length) + + +class FileAudioDatasetWrapper(FileAudioDataset): + def collater(self, samples): + samples = super().collater(samples) + if len(samples) == 0: + return {} + samples["net_input"]["src_tokens"] = samples["net_input"]["source"] + samples["net_input"]["prev_output_tokens"] = None + del samples["net_input"]["source"] + samples["net_input"]["src_lengths"] = None + samples["net_input"]["alignment"] = None + return samples diff --git a/fairseq/data/audio/speech_to_text_joint_dataset.py b/fairseq/data/audio/speech_to_text_joint_dataset.py new file mode 100644 index 0000000000..885ee7e0a3 --- /dev/null +++ b/fairseq/data/audio/speech_to_text_joint_dataset.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Dict, List, Optional, NamedTuple + +import torch +from fairseq.data import ( + ConcatDataset, + Dictionary, + ResamplingDataset, + data_utils as fairseq_data_utils, +) +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + S2TDataConfig, + SpeechToTextDatasetCreator, +) + + +logger = logging.getLogger(__name__) + + +class S2TJointDataConfig(S2TDataConfig): + """Wrapper class for data config YAML""" + + @property + def src_vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("src_vocab_filename", "src_dict.txt") + + @property + def src_pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("src_pre_tokenizer", {"tokenizer": None}) + + @property + def src_bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply on source text after pre-tokenization. + Returning a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("src_bpe_tokenizer", {"bpe": None}) + + @property + def prepend_tgt_lang_tag_no_change(self) -> bool: + """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for + to-many multilingual setting). No change needed during inference. + """ + return self.config.get("prepend_tgt_lang_tag_no_change", False) + + +class SpeechToTextJointDatasetItem(NamedTuple): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + src_txt_tokens: Optional[torch.Tensor] = None + tgt_lang_tag: Optional[int] = None + + +class SpeechToTextJointDataset(SpeechToTextDataset): + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TJointDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + src_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + ): + super().__init__( + split, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + ) + + self.src_dict = src_dict + self.src_pre_tokenizer = src_pre_tokenizer + self.src_bpe_tokenizer = src_bpe_tokenizer + + def get_tokenized_src_text(self, index: int): + text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index]) + text = self.tokenize(self.src_bpe_tokenizer, text) + return text + + def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem: + s2t_dataset_item = super().__getitem__(index) + src_tokens = None + if self.src_texts is not None and self.src_dict is not None: + src_tokens = self.get_tokenized_src_text(index) + src_tokens = self.src_dict.encode_line( + src_tokens, add_if_not_exist=False, append_eos=True + ).long() + tgt_lang_tag = None + if self.cfg.prepend_tgt_lang_tag_no_change: + # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead + tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) + + return SpeechToTextJointDatasetItem( + index=index, + source=s2t_dataset_item.source, + target=s2t_dataset_item.target, + src_txt_tokens=src_tokens, + tgt_lang_tag=tgt_lang_tag, + ) + + def __len__(self): + return self.n_samples + + def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict: + s2t_out = super().collater(samples, return_order=True) + if s2t_out == {}: + return s2t_out + net_input, order = s2t_out["net_input"], s2t_out["order"] + + if self.src_texts is not None and self.src_dict is not None: + src_txt_tokens = fairseq_data_utils.collate_tokens( + [x.src_txt_tokens for x in samples], + self.src_dict.pad(), + self.src_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + src_txt_tokens = src_txt_tokens.index_select(0, order) + src_txt_lengths = torch.tensor( + [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long + ).index_select(0, order) + net_input["src_txt_tokens"] = src_txt_tokens + net_input["src_txt_lengths"] = src_txt_lengths + + if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: + for i in range(len(samples)): + net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag + + out = { + "id": s2t_out["id"], + "net_input": net_input, + "target": s2t_out["target"], + "target_lengths": s2t_out["target_lengths"], + "ntokens": s2t_out["ntokens"], + "nsentences": len(samples), + } + return out + + +class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator): + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TJointDataConfig, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + ) -> SpeechToTextJointDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + return SpeechToTextJointDataset( + split_name, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=src_pre_tokenizer, + src_bpe_tokenizer=src_bpe_tokenizer, + ) + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TJointDataConfig, + split: str, + tgt_dict, + src_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + ) -> SpeechToTextJointDataset: + samples = cls._load_samples_from_tsv(root, split) + return cls._from_list( + split, + is_train_split, + samples, + cfg, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + ) + + @classmethod + def from_tsv( + cls, + root: str, + cfg: S2TJointDataConfig, + splits: str, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + ) -> SpeechToTextJointDataset: + datasets = [ + cls._from_tsv( + root, + cfg, + split, + tgt_dict, + src_dict, + is_train_split, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + ) + for split in splits.split(",") + ] + + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for r, d in zip(size_ratios, datasets) + ] + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 86f6d05533..1ce26e57e5 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -638,3 +638,128 @@ def __next__(self): if item is _sentinel: raise StopIteration() return item + +class GroupedEpochBatchIterator(EpochBatchIterator): + """Grouped version of EpochBatchIterator + It takes several samplers from different datasets. + Each epoch shuffle the dataset wise sampler individually with different + random seed. The those sub samplers are combined with into + one big samplers with deterministic permutation to mix batches from + different datasets. It will act like EpochBatchIterator but make sure + 1) data from one data set each time + 2) for different workers, they use the same order to fetch the data + so they will use data from the same dataset everytime + mult_rate is used for update_freq > 1 case where we want to make sure update_freq + mini-batches come from same source + """ + + def __init__( + self, + dataset, + collate_fn, + batch_samplers, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + mult_rate=1, + buffer_size=0, + ): + super().__init__( + dataset, + collate_fn, + batch_samplers, + seed, + num_shards, + shard_id, + num_workers, + epoch, + buffer_size, + ) + # level 0: sub-samplers 1: batch_idx 2: batches + self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) + self.step_size = mult_rate * num_shards + + self.lengths = [ + (len(x) // self.step_size) * self.step_size for x in self.frozen_batches + ] + + def __len__(self): + return sum(self.lengths) + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]]) + else: + return "DUMMY" + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + def return_full_batches(batch_sets, seed, shuffle): + if shuffle: + batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets] + + batch_sets = [ + batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets)) + ] + batches = list(itertools.chain.from_iterable(batch_sets)) + + if shuffle: + with data_utils.numpy_seed(seed): + idx = np.random.permutation(len(batches) // self.step_size) + if len(idx) * self.step_size != len(batches): + raise ValueError( + "ERROR: %d %d %d %d" + % (len(idx), self.step_size, len(batches), self.shard_id), + ":".join(["%d" % x for x in self.lengths]), + ) + mini_shards = [ + batches[i * self.step_size : (i + 1) * self.step_size] + for i in idx + ] + batches = list(itertools.chain.from_iterable(mini_shards)) + + return batches + + if self._supports_prefetch: + raise NotImplementedError("To be implemented") + else: + batches = return_full_batches( + self.frozen_batches, self.seed + epoch, shuffle + ) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, + ) + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + return CountingIterator(itr, start=offset) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index c6ae9b17ba..cac365cbb8 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -5,4 +5,5 @@ from .berard import * # noqa from .convtransformer import * # noqa -from .s2t_transformer import * # noqa +from .s2t_transformer import * # noqa +from .xm_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py new file mode 100644 index 0000000000..03c434b5ea --- /dev/null +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import copy +from typing import Dict, List, Optional, Tuple + +from fairseq import utils, checkpoint_utils +from fairseq.models import (FairseqEncoderDecoderModel, FairseqEncoder, + register_model, register_model_architecture) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.wav2vec import Wav2VecEncoder +from fairseq.modules.layer_norm import LayerNorm +from fairseq.data.data_utils import lengths_to_padding_mask +from torch import Tensor +import torch.nn as nn + + +logger = logging.getLogger(__name__) + + +class Conv1dAdaptor(nn.Module): + def __init__(self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, + add_layernorm=False): + super().__init__() + self.layers = nn.ModuleList( + nn.Conv1d(in_dim if i == 0 else out_dim, out_dim * 2, kernel_size, + stride=stride, padding=kernel_size // 2) + for i in range(n_layers) + ) + self.layernorms = None + if add_layernorm: + self.layernorms = nn.ModuleList(LayerNorm(out_dim) + for _ in range(n_layers)) + self.stride = stride + + @classmethod + def add_args(cls, parser): + parser.add_argument("--adaptor-n-layers", type=int) + parser.add_argument("--adaptor-kernel-size", type=int) + parser.add_argument("--adaptor-stride", type=int) + parser.add_argument("--adaptor-layernorm", action='store_true') + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in self.layers: + out = ((out.float() - 1) / self.stride + 1).floor().long() + return out + + def forward(self, x, padding_mask): + # T x B x C -> B x C x T + x = x.transpose(0, 1).transpose(1, 2) + for i, layer in enumerate(self.layers): + x = nn.functional.glu(layer(x), dim=1) + if self.layernorms is not None: + x = self.layernorms[i](x.transpose(1, 2)).transpose(1, 2) + # B x C x T -> T x B x C + x = x.transpose(1, 2).transpose(0, 1) + + if padding_mask is None: + out_padding_mask = None + else: + out_lengths = self.get_out_seq_lens_tensor((~padding_mask).sum(1)) + out_padding_mask = lengths_to_padding_mask(out_lengths) + return x, out_padding_mask + + +def add_wav2vec_asr_args(parser): + parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") + parser.add_argument( + "--no-pretrained-weights", + action="store_true", + help="if true, does not load pretrained weights", + ) + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--final-dropout", + type=float, + metavar="D", + help="dropout after transformer and before final projection", + ) + parser.add_argument( + "--apply-mask", action="store_true", help="apply masking during fine-tuning" + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability inside wav2vec 2.0 model", + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside wav2vec 2.0 model", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside wav2vec 2.0 model", + ) + + parser.add_argument( + "--mask-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--freeze-finetune-updates", + default=0, + type=int, + help="dont finetune wav2vec for this many updates", + ) + + parser.add_argument( + "--feature-grad-mult", + default=None, + type=float, + help="reset feature grad mult in wav2vec 2.0 to this", + ) + + parser.add_argument( + "--layerdrop", + default=0.0, + type=float, + help="probability of dropping a layer in wav2vec 2.0", + ) + parser.add_argument("--w2v-args", default=None) + + +class Wav2VecEncoderWithAdaptor(FairseqEncoder): + def __init__(self, args): + super().__init__(None) + self.w2v_encoder = Wav2VecEncoder(args) + encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim + # Projection + 8x shrinking + self.adaptor = Conv1dAdaptor( + encoder_out_dim, args.decoder_embed_dim, + n_layers=args.adaptor_n_layers, + kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride, + add_layernorm=args.adaptor_layernorm + ) + for k, p in self.w2v_encoder.w2v_model.named_parameters(): + # Freeze pretrained models by default + if hasattr(args, 'finetune_w2v_params') and XMTransformerModel.finetune_params( + args.finetune_w2v_params, k): + p.requires_grad = True + else: + p.requires_grad = False + + @classmethod + def add_args(cls, parser): + add_wav2vec_asr_args(parser) + parser.add_argument( + "--normalize", action="store_true", + help="if set, normalizes input to have 0 mean and unit variance", + ) + parser.add_argument("--finetune-w2v-params", type=str, metavar="STR", + help="comma-separated param strings to finetune.") + Conv1dAdaptor.add_args(parser) + + def forward(self, src_tokens, src_lengths=None, **kwargs): + padding_mask = lengths_to_padding_mask(src_lengths) + out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) + x = out["encoder_out"] + enc_padding_mask = None + if out["encoder_padding_mask"] is not None: + enc_padding_mask = out["encoder_padding_mask"].transpose(0, 1) # T X B --> B X T + + x, enc_padding_mask = self.adaptor(x, enc_padding_mask) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [enc_padding_mask] if enc_padding_mask.any() else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] if len(encoder_out["encoder_padding_mask"]) == 0 + else [x.index_select(0, new_order) for x in + encoder_out["encoder_padding_mask"]] + ) + + new_encoder_embedding = ( + [] if len(encoder_out["encoder_embedding"]) == 0 + else [x.index_select(0, new_order) for x in + encoder_out["encoder_embedding"]] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +def add_decoder_args(parser): + parser.add_argument("--activation-fn", type=str, default='relu', + choices=utils.get_available_activation_fns(), + help="activation function to use") + parser.add_argument("--decoder-dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--decoder-attention-dropout", type=float, + metavar="D", + help="dropout probability for attention weights") + parser.add_argument("--decoder-activation-dropout", type=float, + metavar="D", + help="dropout probability after activation in FFN.") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N", + help="decoder embedding dimension for FFN") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="num decoder layers") + parser.add_argument("--decoder-attention-heads", type=int, metavar="N", + help="num decoder attention heads") + parser.add_argument("--decoder-normalize-before", action="store_true", + help="apply layernorm before each decoder block") + parser.add_argument("--layernorm-embedding", action="store_true", + help="add layernorm to embedding") + parser.add_argument("--no-scale-embedding", action="store_true", + help="if True, dont scale embeddings") + parser.add_argument( + "--load-pretrained-decoder-from", type=str, metavar="STR", + help="model to take decoder weights from (for initialization)" + ) + parser.add_argument("--finetune-decoder-params", type=str, + metavar="STR", + help="comma-separated param strings to finetune.") + parser.add_argument("--checkpoint-activations", action="store_true") + + +@register_model("xm_transformer") +class XMTransformerModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + Wav2VecEncoderWithAdaptor.add_args(parser) + add_decoder_args(parser) + + @classmethod + def build_encoder(cls, args): + _args = copy.deepcopy(args) + state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) + if state.get("cfg") is not None: + encoder_embed_dim = state["cfg"]._content["model"]["encoder_embed_dim"] + elif state.get("args") is not None: + encoder_embed_dim = state["args"].encoder_embed_dim + else: + raise ValueError(f"Invalid config in {args.w2v_path}") + _args.decoder_embed_dim = encoder_embed_dim + encoder = Wav2VecEncoderWithAdaptor(_args) + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + _args = copy.deepcopy(args) + _args.dropout = args.decoder_dropout + _args.attention_dropout = args.decoder_attention_dropout + _args.activation_dropout = args.decoder_activation_dropout + _args.max_target_positions = 1024 + + decoder = TransformerDecoder(_args, task.target_dictionary, + embed_tokens) + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + for k, p in decoder.named_parameters(): + # Freeze pretrained models by default + if hasattr(args, 'finetune_decoder_params') and XMTransformerModel.finetune_params( + args.finetune_decoder_params, k): + p.requires_grad = True + else: + p.requires_grad = False + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding(task.target_dictionary, + args.decoder_embed_dim) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + lprobs.batch_first = True + return lprobs + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder(src_tokens=src_tokens, + src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out) + return decoder_out + + def upgrade_state_dict(self, state_dict): + for k, _ in state_dict.items(): + if 'adaptor.layers' in state_dict: + print(k) + new = k.replace('adaptor.layers', 'adaptor_layers') + state_dict[new] = state_dict[k] + del state_dict[k] + + @staticmethod + def finetune_params(finetune_params, param_name): + if finetune_params == "all": + return True + finetune_params_list = finetune_params.split(",") + for finetune_param in finetune_params_list: + if finetune_param in param_name: + return True + return False + + +def set_default_w2v_encoder_args(args): + args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) + args.dropout_input = getattr(args, "dropout_input", 0) + args.final_dropout = getattr(args, "final_dropout", 0) + args.apply_mask = getattr(args, "apply_mask", False) + args.dropout = getattr(args, "dropout", 0) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.5) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_before = getattr(args, "mask_channel_before", False) + args.mask_channel_selection = getattr(args, "mask_channel_selection", + "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", + False) + + args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) + args.feature_grad_mult = 0.1 + args.layerdrop = getattr(args, "layerdrop", 0.0) + + args.normalize = getattr(args, "normalize", False) + + +def set_default_adaptor_args(args): + args.adaptor_n_layers = getattr(args, "adaptor_n_layers", 3) + args.adaptor_kernel_size = getattr(args, "adaptor_kernel_size", 3) + args.adaptor_stride = getattr(args, "adaptor_stride", 2) + args.adaptor_layernorm = getattr(args, "adaptor_layernorm", False) + + +def set_default_mbart_decoder_args(args): + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', + 4 * 1024) + args.decoder_layers = getattr(args, 'decoder_layers', 12) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', + True) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_attention_dropout = getattr(args, 'decoder_attention_dropout', + 0.) + args.decoder_activation_dropout = getattr(args, + 'decoder_activation_dropout', 0.) + args.decoder_dropout = getattr(args, 'decoder_dropout', 0.1) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', + None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.share_decoder_input_output_embed = getattr( + args, 'share_decoder_input_output_embed', True + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_output_dim = getattr(args, 'decoder_output_dim', + args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', + args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) + + args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + +@register_model_architecture(model_name="xm_transformer", + arch_name="xm_transformer") +def base_architecture(args): + set_default_w2v_encoder_args(args) + set_default_adaptor_args(args) + set_default_mbart_decoder_args(args) From 20fbc348215e558d23da9461a5daaec85b97d114 Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <31931787+wnhsu@users.noreply.github.com> Date: Fri, 30 Jul 2021 10:15:09 -0700 Subject: [PATCH 674/707] update hubert decode config (#2106) Summary: Update HuBERT decode config yaml to make compatible with the new decoder config Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2106 Reviewed By: alexeib Differential Revision: D29967631 Pulled By: wnhsu fbshipit-source-id: fe39c5126f50c3024022f8333e2f3aa97065cbfc --- examples/hubert/config/decode/infer_fsqlm.yaml | 8 ++++---- examples/hubert/config/decode/infer_kenlm.yaml | 6 +++--- examples/hubert/config/decode/infer_viterbi.yaml | 8 +++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/hubert/config/decode/infer_fsqlm.yaml b/examples/hubert/config/decode/infer_fsqlm.yaml index bc77cab32e..026ad8db89 100644 --- a/examples/hubert/config/decode/infer_fsqlm.yaml +++ b/examples/hubert/config/decode/infer_fsqlm.yaml @@ -5,14 +5,15 @@ defaults: hydra: run: - dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} sweep: dir: ${common_eval.results_path} - subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} task: _name: hubert_pretraining single_target: true + fine_tuning: true data: ??? normalize: ??? @@ -20,13 +21,12 @@ decoding: type: fairseqlm lexicon: ??? lmpath: ??? - beamthreshold: 25 # 100 + beamthreshold: 25 beam: 500 lmweight: 2 wordscore: -1 silweight: 0 unique_wer_file: true - beam: 500 common_eval: results_path: ??? path: ??? diff --git a/examples/hubert/config/decode/infer_kenlm.yaml b/examples/hubert/config/decode/infer_kenlm.yaml index 26f5c48928..04642aeb65 100644 --- a/examples/hubert/config/decode/infer_kenlm.yaml +++ b/examples/hubert/config/decode/infer_kenlm.yaml @@ -5,14 +5,15 @@ defaults: hydra: run: - dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} sweep: dir: ${common_eval.results_path} - subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} task: _name: hubert_pretraining single_target: true + fine_tuning: true data: ??? normalize: ??? @@ -26,7 +27,6 @@ decoding: wordscore: -1 silweight: 0 unique_wer_file: true - beam: 500 common_eval: results_path: ??? path: ??? diff --git a/examples/hubert/config/decode/infer_viterbi.yaml b/examples/hubert/config/decode/infer_viterbi.yaml index 935d7d1d01..4afc74c18c 100644 --- a/examples/hubert/config/decode/infer_viterbi.yaml +++ b/examples/hubert/config/decode/infer_viterbi.yaml @@ -5,14 +5,15 @@ defaults: hydra: run: - dir: ${common_eval.results_path}/beam${decoding.decoder.beam}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + dir: ${common_eval.results_path}/viterbi sweep: dir: ${common_eval.results_path} - subdir: beam${decoding.decoder.beam}_th${decoding.decoder.beamthreshold}_lmw${decoding.decoder.lmweight}_wrd${decoding.decoder.wordscore}_sil${decoding.decoder.silweight} + subdir: viterbi task: _name: hubert_pretraining single_target: true + fine_tuning: true data: ??? normalize: ??? @@ -23,9 +24,6 @@ common_eval: results_path: ??? path: ??? post_process: letter -generation: - nbest: 1 - beam: 500 dataset: max_tokens: 1100000 gen_subset: ??? From 972401937b9aa44e45ee3380fa497c8eb30005c4 Mon Sep 17 00:00:00 2001 From: Ann Lee <annl@fb.com> Date: Fri, 30 Jul 2021 14:33:53 -0700 Subject: [PATCH 675/707] add paper link (#2116) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2116 Reviewed By: michaelauli Differential Revision: D30019908 Pulled By: an918tw fbshipit-source-id: ca8d7a6e97ed81e7df9a15e778c68fad8fb0a308 --- examples/discriminative_reranking_nmt/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/discriminative_reranking_nmt/README.md b/examples/discriminative_reranking_nmt/README.md index aba0090370..e6f42b1278 100644 --- a/examples/discriminative_reranking_nmt/README.md +++ b/examples/discriminative_reranking_nmt/README.md @@ -1,4 +1,6 @@ # Discriminative Reranking for Neural Machine Translation +https://aclanthology.org/2021.acl-long.563/ + This folder contains source code for training DrNMT, a discriminatively trained reranker for neural machine translation. ## Data preparation From 9d70f9ca6eb0dc49065a2691f302df2e68c1cac4 Mon Sep 17 00:00:00 2001 From: Ishani Karmarkar <ikarmarkar@fb.com> Date: Fri, 30 Jul 2021 17:41:47 -0700 Subject: [PATCH 676/707] iPQ Summary: Implemented iterative product quantization (iPQ trainer) and unit tests Reviewed By: AkshatSh, AdithyaSagar007 Differential Revision: D29662949 fbshipit-source-id: fdc1f124decc722b54225a7fe0031695823e1c69 --- fairseq/modules/quantization/pq/__init__.py | 2 +- fairseq/modules/quantization/pq/utils.py | 37 +++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/fairseq/modules/quantization/pq/__init__.py b/fairseq/modules/quantization/pq/__init__.py index 5b10b51b1b..c142a802e0 100644 --- a/fairseq/modules/quantization/pq/__init__.py +++ b/fairseq/modules/quantization/pq/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .utils import SizeTracker, quantize_model_ # NOQA +from .utils import SizeTracker, get_param, attrsetter, quantize_model_ # NOQA diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py index 3c5ea4155d..14c015b7c1 100644 --- a/fairseq/modules/quantization/pq/utils.py +++ b/fairseq/modules/quantization/pq/utils.py @@ -6,7 +6,7 @@ import logging import re from operator import attrgetter, itemgetter - +import torch import numpy as np import torch.distributed as dist import torch.nn as nn @@ -25,7 +25,9 @@ def quantize_model_( n_iter=15, eps=1e-6, max_tentatives=100, + remove_weights=False, verbose=True, + state_dict=None, ): """ Quantize a model in-place by stages. All the targeted @@ -58,7 +60,7 @@ def quantize_model_( to layers_to_quantize[step] """ - quantized_layers = get_layers(model, layers_to_quantize[step]) + quantized_layers = get_layers(model, layers_to_quantize[step], remove_weights=remove_weights) for layer in quantized_layers: @@ -96,6 +98,37 @@ def quantize_model_( centroids = quantizer.centroids.contiguous() assignments = quantizer.assignments.contiguous() + # If n_iter = 0 and state_dict is provided, then + # we initialize random assignments and centroids to + # random values of the appropriate dimensions + # because the quantized model parameters will + # overwritten by the state_dict later on. + if n_iter == 0 and state_dict: + # Initialize random centroids of the correct size + centroids = torch.rand(centroids.size()) + centroids.cuda() + # Get counts and assignment keys from layer in loaded checkpoint. + counts_key = layer+"."+"counts" + assignment_key = layer+"."+"assignments" + # Get number of different bins to include. + counts = list(state_dict[counts_key].shape)[0] + print(layer) + print(state_dict[counts_key]) + print(counts) + # Initialize random assignments of the correct size + # with an appropriate number of bins. + num_assignments = list(state_dict[assignment_key].shape)[0] + num_extra = num_assignments - counts + print(num_assignments) + print(num_extra) + assignments_bins = torch.arange(counts) + assignments_rand = torch.randint(0, counts-1, (num_extra, )) + assignments = torch.cat((assignments_bins, assignments_rand), 0) + # assignments = assignments.type(torch.IntTensor) + assignments.cuda() + print("assignments") + print(assignments) + # broadcast results to make sure weights are up-to-date if dist.is_initialized(): dist.broadcast(centroids, 0) From 82e8a24bddc1cb5aecb267bb529f76f5b3901432 Mon Sep 17 00:00:00 2001 From: Wei-Ning Hsu <wnhsu@csail.mit.edu> Date: Mon, 2 Aug 2021 12:39:30 -0700 Subject: [PATCH 677/707] add max_keep_size (#2124) Summary: Set max_keep_size to filter long utterances. Needed when trained on labeled datasets with long utterances. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2124 Reviewed By: Abdel-rahmanMohamed Differential Revision: D30046509 Pulled By: wnhsu fbshipit-source-id: ec52ae0997284a05295dff35626927a71c78cf52 --- fairseq/tasks/hubert_pretraining.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py index ee3fedce3f..f756080dd1 100644 --- a/fairseq/tasks/hubert_pretraining.py +++ b/fairseq/tasks/hubert_pretraining.py @@ -76,6 +76,10 @@ class HubertPretrainingConfig(FairseqDataclass): default=False, metadata={"help": "pad shorter samples instead of cropping"}, ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) max_sample_size: Optional[int] = field( default=None, metadata={"help": "max sample size to crop to for batching"}, @@ -123,13 +127,11 @@ def __init__( else: self.state.add_factory("dictionaries", self.load_dictionaries) - self._source_dictionary = None - self.blank_symbol = "<s>" @property def source_dictionary(self) -> Optional[Dictionary]: - return self._source_dictionary + return None @property def target_dictionary(self) -> Optional[Dictionary]: @@ -174,7 +176,7 @@ def load_dataset(self, split: str, **kwargs) -> None: pad_list=pad_list, eos_list=eos_list, label_processors=procs, - max_keep_sample_size=None, + max_keep_sample_size=self.cfg.max_keep_size, min_keep_sample_size=self.cfg.min_sample_size, max_sample_size=self.cfg.max_sample_size, pad_audio=self.cfg.pad_audio, From db4f96b09295ffae2e534adbc6eefcd9f2d2089f Mon Sep 17 00:00:00 2001 From: Jingfei Du <jingfeidu@fb.com> Date: Mon, 2 Aug 2021 14:35:22 -0700 Subject: [PATCH 678/707] fixing checkpoint config upgrade for generation print_alignment (#2125) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes config upgrade conditions for upgrading generation. print_alignment ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2125 Reviewed By: myleott Differential Revision: D30049140 Pulled By: jingfeidu fbshipit-source-id: e613821e94d0cdb876c35bc6e3fede7affbf4628 --- fairseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index daabba4574..b8c46f8253 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -653,7 +653,7 @@ def _upgrade_state_dict(state): ): cfg.task.eval_wer_config.print_alignment = "hard" if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): - cfg.generation.print_alignment = "hard" + cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None if ( "model" in cfg and "w2v_args" in cfg.model From fe15926d48167c5028d2d7e2ed7d0a66642d0700 Mon Sep 17 00:00:00 2001 From: Edan Tessel Sneh <edan@fb.com> Date: Mon, 2 Aug 2021 18:39:08 -0700 Subject: [PATCH 679/707] Adding Hydra based trainer target to fairseq in fbcode Summary: Adding fairseq entrypoint section of e2e pipeline so FairseqConfig to hydra_main, runs smoothly Reviewed By: jieru-hu Differential Revision: D29714729 fbshipit-source-id: e3694e0037bb4c4f69208c1d6ec7df91d42fb588 --- fairseq/dataclass/initialize.py | 2 +- fairseq_cli/hydra_train.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index e43b31790e..8f6cbafb80 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -47,7 +47,7 @@ def add_defaults(cfg: DictConfig) -> None: field_cfg = DictConfig({"_name": field_cfg}) field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] - name = field_cfg.get("_name") + name = getattr(field_cfg, "_name", None) if k == "task": dc = TASK_DATACLASS_REGISTRY.get(name) diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 9de01084ba..6555ab415e 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -11,6 +11,7 @@ from fairseq_cli.train import main as pre_main from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import omegaconf_no_object_check from fairseq.utils import reset_logging import hydra @@ -24,25 +25,32 @@ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: + _hydra_main(cfg) + + +def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: add_defaults(cfg) if cfg.common.reset_logging: reset_logging() # Hydra hijacks logging, fix that else: - with open_dict(cfg): - # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) - cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) - - cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) + # check if directly called or called through hydra_main + if HydraConfig.initialized(): + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) + + with omegaconf_no_object_check(): + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) try: if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(cfg, pre_main) + distributed_utils.call_main(cfg, pre_main, **kwargs) else: - distributed_utils.call_main(cfg, pre_main) + distributed_utils.call_main(cfg, pre_main, **kwargs) except BaseException as e: if not cfg.common.suppress_crashes: raise From 3d90df4a1fd3e7734622f7bd6cfda5dd6b4d3aef Mon Sep 17 00:00:00 2001 From: Ishani Karmarkar <ikarmarkar@fb.com> Date: Mon, 2 Aug 2021 21:57:40 -0700 Subject: [PATCH 680/707] Quant Noise Summary: Implemented fix bit scalar quantization with quant noise for pytext models Reviewed By: AkshatSh Differential Revision: D29662977 fbshipit-source-id: ebab68a4a5ff1583a0c6dfadcf2671663e232c18 --- fairseq/modules/quantization/scalar/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py index 76db40fec0..2ec6af3fcb 100644 --- a/fairseq/modules/quantization/scalar/utils.py +++ b/fairseq/modules/quantization/scalar/utils.py @@ -27,7 +27,6 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", - bits: number of bits - update_step: update quantization parameters every update_step steps """ - # quantize all layers # remove weights indicates whether the weights extension should be removed, in addition to # weight_orig and weight extension on names From 9825786fbe8a32053f21bec988d953d175f7262a Mon Sep 17 00:00:00 2001 From: Sam Shleifer <sshleifer@gmail.com> Date: Wed, 4 Aug 2021 16:30:23 -0700 Subject: [PATCH 681/707] --fp16-adam-stats (#2139) Summary: - stores exp_avg and exp_sq_avg in fp16, with `scale` variables to avoid overflow. - myleott added this to gshard, following github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2139 Reviewed By: myleott Differential Revision: D30113175 Pulled By: sshleifer fbshipit-source-id: 03995c8eb096629675eadec4e7b8e7f18fc2730e --- fairseq/optim/adam.py | 13 ++++++++- fairseq/optim/cpu_adam.py | 4 +++ fairseq/optim/fp16_optimizer.py | 2 ++ fairseq/optim/fused_adam.py | 52 ++++++++++++++++++++++++++++----- 4 files changed, 62 insertions(+), 9 deletions(-) diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 6a31e53a62..d3ae9e64a7 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -33,6 +33,9 @@ class FairseqAdamConfig(FairseqDataclass): use_old_adam: bool = field( default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} ) + fp16_adam_stats: bool = field( + default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} + ) # TODO common vars below in parent tpu: bool = II("common.tpu") lr: List[float] = II("optimization.lr") @@ -56,13 +59,21 @@ def __init__(self, cfg: FairseqAdamConfig, params): and torch.cuda.is_available() ) if getattr(cfg, "tpu", False): + if self.cfg.fp16_adam_stats: + raise NotImplementedError("--fp16-adam-stats is only supported on GPU") # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) elif use_fused_adam: logger.info("using FusedAdam") - self._optimizer = fused_adam_cls(params, **self.optimizer_config) + self._optimizer = fused_adam_cls( + params, + use_fp16_stats=self.cfg.fp16_adam_stats, + **self.optimizer_config + ) else: + if self.cfg.fp16_adam_stats: + raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") self._optimizer = Adam(params, **self.optimizer_config) @property diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index 211c376756..b2f893aeda 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -116,6 +116,10 @@ def __init__( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_memory_efficient_fp16(self): + return True + @property def supports_flat_params(self): return True diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b84236e685..c59b21cf6b 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -66,6 +66,8 @@ def build_fp32_params(cls, args, params, flatten=True): p32 = torch.nn.Parameter(p.data.float()) if hasattr(p, 'expert'): p32.expert = True + elif hasattr(p, 'base_expert'): + p32.base_expert = True p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): p32.param_group = p.param_group diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index e2b8e1bcd1..7a6d1f73d5 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -80,6 +80,7 @@ def __init__( weight_decay=0.0, max_grad_norm=0.0, amsgrad=False, + use_fp16_stats=False, ): global fused_adam_cuda import importlib @@ -99,6 +100,9 @@ def __init__( super().__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 + self.use_fp16_stats = use_fp16_stats + self.FLOAT16_MAX = 65504.0 + @property def supports_memory_efficient_fp16(self): return True @@ -173,29 +177,42 @@ def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): "please consider SparseAdam instead" ) - p_data_fp32 = p.data.float() + if p.device.type == "cpu": + p_data_fp32 = p.data.cuda(non_blocking=True).float() + out_p = torch.tensor([], dtype = torch.float) + else: + p_data_fp32 = p.data.float() + out_p = p.data state = self.state[p] # State initialization + dtype = torch.float16 if self.use_fp16_stats else p_data_fp32.dtype if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg"] = torch.zeros_like(p_data_fp32, dtype=dtype) # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32, dtype=dtype) + if self.use_fp16_stats: + state["exp_avg_scale"] = 1.0 + state["exp_avg_sq_scale"] = 1.0 else: - state["exp_avg"] = state["exp_avg"].to(p_data_fp32) - state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + device = p_data_fp32.device + state["exp_avg"] = state["exp_avg"].to(device, dtype) + state["exp_avg_sq"] = state["exp_avg_sq"].to(device, dtype) exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] + if self.use_fp16_stats: + assert exp_avg.dtype == torch.float16 + exp_avg = exp_avg.float() * state["exp_avg_scale"] + exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] beta1, beta2 = group["betas"] state["step"] += 1 - out_p = p.data - with torch.cuda.device(p.device): + with torch.cuda.device(p_data_fp32.device): fused_adam_cuda.adam( p_data_fp32, out_p, @@ -213,6 +230,23 @@ def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): group["weight_decay"], ) + if p.device.type == "cpu": + p.data.copy_(p_data_fp32, non_blocking=True) + + if self.use_fp16_stats: + def inf_norm(t): + return torch.norm(t, float("inf")) + + # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( + 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, + 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, + ) + state["exp_avg"], state["exp_avg_sq"] = ( + (exp_avg / state["exp_avg_scale"]).half(), + (exp_avg_sq / state["exp_avg_sq_scale"]).half(), + ) + return loss @@ -226,7 +260,9 @@ class FusedAdamV2(FusedAdam): and params to FP32 internally to support ``--memory-efficient-fp16``. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, use_fp16_stats=False, **kwargs): + if use_fp16_stats: + raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") super().__init__(*args, **kwargs) if not hasattr(self, "multi_tensor_adam"): raise Exception( From 741fd138c69eebc1fac09b590047da85cdeafe2d Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Wed, 11 Aug 2021 16:07:06 -0700 Subject: [PATCH 682/707] Commit README for GSLM (#2151) Summary: ## What does this PR do? Adds GSLM directory with README. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2151 Reviewed By: wnhsu Differential Revision: D30147672 Pulled By: hikushalhere fbshipit-source-id: bcc7cbbde3626ea3d91917707a91aff85d715baa --- examples/textless_nlp/gslm/README.md | 49 ++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 examples/textless_nlp/gslm/README.md diff --git a/examples/textless_nlp/gslm/README.md b/examples/textless_nlp/gslm/README.md new file mode 100644 index 0000000000..7fdb337335 --- /dev/null +++ b/examples/textless_nlp/gslm/README.md @@ -0,0 +1,49 @@ +# Generative Spoken Language Modeling + +## Speech to Unit Model (S2U) +### Acoustic Model +For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below. +* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +* [CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc_big_ll6kh_top_ctc.pt) + +### Quantization +For quantizing speech with a given acoustic representation, please follow the steps below. +1. Learn K-means clustering model +``` +N_CLUSTERS=<num_cluster> +TYPE=<logmel/hubert/w2v2/cpc> +CKPT_PATH=<path_of_pretrained_acoustic_model> +LAYER=<layer_of_acoustic_model_to_extract_features_from> +MANIFEST=<path_manifest_of_input_audio_files_to_train_with> +KM_MODEL_PATH=<path_of_trained_kmeans_model> + +PYTHONPATH=. python examples/textless_nlp/gslm/u2s/clustering/cluster_kmeans.py \ + --num_clusters $N_CLUSTERS \ + --feature_type $TYPE \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_kmeans_model_path $KM_MODEL_PATH +``` +2. Quantize using the learned clusters +``` +MANIFEST=<path_manifest_of_input_audio_files_to_quantize> +OUT_QUANT_FILE=<path_quzntized_audio_file> + +python examples/textless_nlp/gslm/u2s/clustering/del/quantize_with_kmeans.py \ + --feature_type $TYPE \ + --kmeans_model_path $KM_MODEL_PATH \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_quantized_file_path $OUT_QUANT_FILE \ + --extension .flac +``` + +## Unit Language Model (ULM) +Unit Language Model is a generative LM trained on quantized speech. We use it to generate novel quantized spoken language with and without prompt. + +## Unit to Speech Model (U2S) +Unit to speech model is modified Tacotron2 model that learns to syntehsize speech from discrete speech units. We use to synthesize quantized spoken language. + From 2513524a1604dbafcc4ea9cc5a99ae0aa4f19694 Mon Sep 17 00:00:00 2001 From: alexeib <alexei.b@gmail.com> Date: Tue, 17 Aug 2021 06:55:26 -0700 Subject: [PATCH 683/707] add finetuned robust w2v models and update readme (#2196) Summary: adds finetuned robust w2v models and updates readme fixes #3721 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2196 Reviewed By: wnhsu Differential Revision: D30367999 Pulled By: alexeib fbshipit-source-id: 616b373bf31265c89f694fba7dccce2961d394f3 --- README.md | 2 ++ examples/wav2vec/README.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 460f3439fb..cd9654cf31 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ We provide reference implementations of various sequence modeling papers: + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) * **Non-autoregressive Transformers** + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -66,6 +67,7 @@ We provide reference implementations of various sequence modeling papers: * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) +* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md) * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) * February 2021 [Added LASER training code](examples/laser/README.md) * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 2d6717dc04..253c8af251 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -28,6 +28,8 @@ Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https:/ Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | 960 hours Librispeech | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv_ftls960.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | 300 hours Switchboard | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv_ftsb300.pt) \* updated (Oct. 24, 2020)\ ** updated (Jul. 8, 2021) From cb747010c47a017e71285556afa9acce0ec62786 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh <sivaibhav@google.com> Date: Tue, 17 Aug 2021 07:02:41 -0700 Subject: [PATCH 684/707] Set batch size to 4 to prevent OOM due dynamic batch sizing (#3781) Summary: ## What does this PR do? Fixes OOM which happens from TPUs due to dynamic batching exceed the max a single core can work with. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3781 Reviewed By: wnhsu Differential Revision: D30327091 Pulled By: alexeib fbshipit-source-id: 0ebe6b18329fa05d359083fa8ac54aba7b48bc53 --- examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml index bee41157a9..3192ce4cba 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml @@ -18,6 +18,7 @@ task: normalize: true dataset: + batch_size: 4 num_workers: 6 max_tokens: 1200000 skip_invalid_size_inputs_valid_test: true From 1f7ef9ed1e1061f8c7f88f8b94c7186834398690 Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Thu, 19 Aug 2021 12:49:06 -0700 Subject: [PATCH 685/707] (fix #2177) Erase the encoder_embed_dim default (#2213) Summary: Fix https://github.com/fairinternal/fairseq-py/issues/2177 for the transformer conversion to Hydra. The way the defaults are dealt with now is different so when you use the legacy Namespace configuration, you end up with a default encoder_embed_dim, which in the VGG case sets up a encoder attention in the TransformerDecoderLayer with the wrong dimentions. The easiest solution is to erase the default value for encoder_embed_dim (by forcing it to None) when converting the VGG config to the raw Namespace for the decoder layer. Tested with: `pytest tests/speech_recognition/test_vggtransformer.py -k Transformer` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2213 Test Plan: pytest tests/speech_recognition/test_vggtransformer.py -k Transformer Reviewed By: sshleifer Differential Revision: D30425143 Pulled By: Mortimerp9 fbshipit-source-id: 92f6dea2ffbb68e441700bcc55274b3167a587b3 --- examples/speech_recognition/models/vggtransformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py index 97974360a4..bca0ae59a8 100644 --- a/examples/speech_recognition/models/vggtransformer.py +++ b/examples/speech_recognition/models/vggtransformer.py @@ -203,6 +203,7 @@ def prepare_transformer_decoder_params( relu_dropout, ): args = argparse.Namespace() + args.encoder_embed_dim = None args.decoder_embed_dim = input_dim args.decoder_attention_heads = num_heads args.attention_dropout = attention_dropout From 6f847c8654d56b4d1b1fbacec027f47419426ddb Mon Sep 17 00:00:00 2001 From: Kushal Lakhotia <kushall@fb.com> Date: Thu, 26 Aug 2021 05:46:38 -0700 Subject: [PATCH 686/707] Release GSLM (#2201) Summary: ## What does this PR do? Open sourcing code for Generative Spoken Language Modeling Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2201 Reviewed By: wnhsu, eugene-kharitonov Differential Revision: D30563114 Pulled By: hikushalhere fbshipit-source-id: 6c1ee3b29038fd2c9fb5939bddcc70af0794dab4 --- examples/textless_nlp/gslm/README.md | 54 +- examples/textless_nlp/gslm/metrics/README.md | 10 + .../gslm/metrics/abx_metrics/README.md | 77 ++ .../metrics/abx_metrics/dump_abx_feats.py | 107 +++ .../gslm/metrics/asr_metrics/README.md | 87 +++ .../metrics/asr_metrics/continuation_eval.py | 99 +++ .../metrics/asr_metrics/misc/bleu_utils.py | 166 +++++ .../gslm/metrics/asr_metrics/misc/cut_as.py | 69 ++ .../metrics/asr_metrics/misc/dict.ltr.txt | 28 + .../gslm/metrics/asr_metrics/ppx.py | 122 ++++ .../metrics/asr_metrics/self_auto_bleu.py | 201 ++++++ .../textless_nlp/gslm/speech2unit/README.md | 71 ++ .../textless_nlp/gslm/speech2unit/__init__.py | 0 .../gslm/speech2unit/clustering/__init__.py | 0 .../speech2unit/clustering/cluster_kmeans.py | 212 ++++++ .../gslm/speech2unit/clustering/dump_feats.py | 91 +++ .../clustering/quantize_with_kmeans.py | 125 ++++ .../gslm/speech2unit/clustering/utils.py | 20 + .../pretrained/cpc_feature_reader.py | 192 +++++ .../pretrained/hubert_feature_reader.py | 59 ++ .../pretrained/logmel_feature_reader.py | 30 + .../gslm/speech2unit/pretrained/utils.py | 126 ++++ .../pretrained/w2v2_feature_reader.py | 46 ++ examples/textless_nlp/gslm/tools/README.md | 22 + .../gslm/tools/resynthesize_speech.py | 138 ++++ examples/textless_nlp/gslm/ulm/README.md | 72 ++ examples/textless_nlp/gslm/ulm/sample.py | 174 +++++ .../textless_nlp/gslm/unit2speech/README.md | 42 ++ .../gslm/unit2speech/convert_to_16k.py | 56 ++ .../textless_nlp/gslm/unit2speech/glow.py | 311 ++++++++ .../gslm/unit2speech/multiproc.py | 27 + .../synthesize_audio_from_units.py | 97 +++ .../gslm/unit2speech/tacotron2/__init__.py | 0 .../unit2speech/tacotron2/audio_processing.py | 93 +++ .../gslm/unit2speech/tacotron2/cleaners.py | 90 +++ .../gslm/unit2speech/tacotron2/cmudict.py | 65 ++ .../gslm/unit2speech/tacotron2/layers.py | 103 +++ .../gslm/unit2speech/tacotron2/model.py | 669 ++++++++++++++++++ .../gslm/unit2speech/tacotron2/numbers.py | 71 ++ .../gslm/unit2speech/tacotron2/stft.py | 141 ++++ .../gslm/unit2speech/tacotron2/symbols.py | 18 + .../gslm/unit2speech/tacotron2/text.py | 107 +++ .../gslm/unit2speech/tacotron2/utils.py | 167 +++++ .../tacotron2/waveglow_denoiser.py | 40 ++ .../textless_nlp/gslm/unit2speech/tts_data.py | 52 ++ .../textless_nlp/gslm/unit2speech/utils.py | 55 ++ fairseq/tasks/audio_pretraining.py | 3 +- 47 files changed, 4563 insertions(+), 42 deletions(-) create mode 100644 examples/textless_nlp/gslm/metrics/README.md create mode 100644 examples/textless_nlp/gslm/metrics/abx_metrics/README.md create mode 100644 examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/README.md create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py create mode 100644 examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py create mode 100644 examples/textless_nlp/gslm/speech2unit/README.md create mode 100644 examples/textless_nlp/gslm/speech2unit/__init__.py create mode 100644 examples/textless_nlp/gslm/speech2unit/clustering/__init__.py create mode 100644 examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py create mode 100644 examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py create mode 100644 examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py create mode 100644 examples/textless_nlp/gslm/speech2unit/clustering/utils.py create mode 100644 examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py create mode 100644 examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py create mode 100644 examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py create mode 100644 examples/textless_nlp/gslm/speech2unit/pretrained/utils.py create mode 100644 examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py create mode 100644 examples/textless_nlp/gslm/tools/README.md create mode 100644 examples/textless_nlp/gslm/tools/resynthesize_speech.py create mode 100644 examples/textless_nlp/gslm/ulm/README.md create mode 100644 examples/textless_nlp/gslm/ulm/sample.py create mode 100644 examples/textless_nlp/gslm/unit2speech/README.md create mode 100644 examples/textless_nlp/gslm/unit2speech/convert_to_16k.py create mode 100644 examples/textless_nlp/gslm/unit2speech/glow.py create mode 100644 examples/textless_nlp/gslm/unit2speech/multiproc.py create mode 100644 examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/model.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/text.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py create mode 100644 examples/textless_nlp/gslm/unit2speech/tts_data.py create mode 100644 examples/textless_nlp/gslm/unit2speech/utils.py diff --git a/examples/textless_nlp/gslm/README.md b/examples/textless_nlp/gslm/README.md index 7fdb337335..79de55d96e 100644 --- a/examples/textless_nlp/gslm/README.md +++ b/examples/textless_nlp/gslm/README.md @@ -1,49 +1,21 @@ # Generative Spoken Language Modeling -## Speech to Unit Model (S2U) -### Acoustic Model -For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below. -* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) -* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) -* [CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc_big_ll6kh_top_ctc.pt) +* [Paper](https://arxiv.org/abs/2102.01192) +* [Demo](https://speechbot.github.io/gslm/index.html) -### Quantization -For quantizing speech with a given acoustic representation, please follow the steps below. -1. Learn K-means clustering model -``` -N_CLUSTERS=<num_cluster> -TYPE=<logmel/hubert/w2v2/cpc> -CKPT_PATH=<path_of_pretrained_acoustic_model> -LAYER=<layer_of_acoustic_model_to_extract_features_from> -MANIFEST=<path_manifest_of_input_audio_files_to_train_with> -KM_MODEL_PATH=<path_of_trained_kmeans_model> +We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/master/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below. -PYTHONPATH=. python examples/textless_nlp/gslm/u2s/clustering/cluster_kmeans.py \ - --num_clusters $N_CLUSTERS \ - --feature_type $TYPE \ - --checkpoint_path $CKPT_PATH \ - --layer $LAYER \ - --manifest_path $MANIFEST \ - --out_kmeans_model_path $KM_MODEL_PATH -``` -2. Quantize using the learned clusters -``` -MANIFEST=<path_manifest_of_input_audio_files_to_quantize> -OUT_QUANT_FILE=<path_quzntized_audio_file> +## Speech to Unit Model (speech2unit) +Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit) -python examples/textless_nlp/gslm/u2s/clustering/del/quantize_with_kmeans.py \ - --feature_type $TYPE \ - --kmeans_model_path $KM_MODEL_PATH \ - --checkpoint_path $CKPT_PATH \ - --layer $LAYER \ - --manifest_path $MANIFEST \ - --out_quantized_file_path $OUT_QUANT_FILE \ - --extension .flac -``` +## Unit Language Model (ulm) +Unit Language Model is a generative language model trained on discrete speech units. [More details](ulm) -## Unit Language Model (ULM) -Unit Language Model is a generative LM trained on quantized speech. We use it to generate novel quantized spoken language with and without prompt. +## Unit to Speech Model (unit2speech) +Unit to speech model is used for synthesizing speech from discrete speech units. [More details](unit2speech) -## Unit to Speech Model (U2S) -Unit to speech model is modified Tacotron2 model that learns to syntehsize speech from discrete speech units. We use to synthesize quantized spoken language. +## Metrics +We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics). +## Tools +We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools) \ No newline at end of file diff --git a/examples/textless_nlp/gslm/metrics/README.md b/examples/textless_nlp/gslm/metrics/README.md new file mode 100644 index 0000000000..0a63e2f0d8 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/README.md @@ -0,0 +1,10 @@ +# GSLM Metrics + +## ASR Metrics +The suite of metrics here uses an ASR model to transcribe the synthesized speech into text, and then uses text-based metrics. We also use word error rate from ASR transcription itself as one of the metrics. [More details](asr_metrics) + +## ABX Metrics +We use [ABX](https://www.semanticscholar.org/paper/ABX-Discriminability-Measures-and-Applications-Schatz/13d3537228f728c1063cc83743cb118bba3367a0) to evaluate how well-separated phonetic categories are with quantized representations. [More details](abx_metrics) + +## sWUGGY and sBLIMP +We refer to [ZeroSpeech challenge](https://www.zerospeech.com/2021/track_s.html#scoring-based-metrics) for details on the sWUGGY and sBLIMP metrics. diff --git a/examples/textless_nlp/gslm/metrics/abx_metrics/README.md b/examples/textless_nlp/gslm/metrics/abx_metrics/README.md new file mode 100644 index 0000000000..aa2560f045 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/abx_metrics/README.md @@ -0,0 +1,77 @@ +# ABX-based evaluation + +ABX is used to evaluate the quality of the obtained discrete units. + +The life cycle of the ABX-based evaluation for the Speech-to-Unit contains the following steps: +1. Training an acoustic model (or use an existing acoustic model) ([description](./../..)) +2. Perform quantization of speech by learning a K-means clustering model ([description](./../..)) +3. Compute discrete features for ABX computation using the learned clusters +4. Compute the ABX score over the discrete features taking advantage of [libri-light's ABX evaluation script][ll-abx] + +Here we assume that you already went throught the first two steps and focus solely on extracting features and computing ABX scores. + +## Libri-light setup + +Follow [libri-light's instructions][ll-instructions] for installation and [ABX evaluation setup][ll-abx] (including the download of the data items required for ABX computation). + +## Computing ABX + +### Dumping quantized features + +The first step for the ABX computation is to dump the quantized representations corresponding to the test files. + +```shell +TYPE="hubert" +LAYER=6 +CKPT_PATH="<PATH_TO_HUBERT_MODEL_CHECKPOINT_FILE>" +KM_MODEL_PATH="<PATH_TO_PRETRAINED_KM_MODEL_FILE>" + +SUBSET="dev-clean" +MANIFEST="<PATH_TO_MANIFEST_FOR_LS_DEV-CLEAN>" +DATA_DIR="<PATH_TO_DIR_TO_STORE_FEATURES>/$SUBSET" + +PYTHONPATH=. python examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py \ + --feature_type $TYPE \ + --kmeans_model_path $KM_MODEL_PATH \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_dir_path $DATA_DIR \ + --extension ".flac" +``` + +Again the manifest file follows the same structure than elsewhere in the codebase. + +### Compute ABX with Libri-light + +Use libri-light's `eval_ABX.py` script (within the appropriate environment set up) as followed: + +```shell +LIBRILIGHT_ROOT="<PATH_TO_LIBRILIGHT>" + +SUBSET="dev-clean" +DATA_DIR="<PATH_TO_DIR_TO_STORE_FEATURES>/$SUBSET" +ITEM_FILE_PATH="$LIBRILIGHT_ROOT/eval/ABX_data/$SUBSET.item" +OUT_DIR="<PATH_TO_DIR_TO_STORE_ABX_SCORES>/$SUBSET" + +FILE_EXTENSION=".npy" +FEATURE_SIZE=0.02 # depends on the model used + +PYTHONPATH=$LIBRILIGHT_ROOT \ + python $LIBRILIGHT_ROOT/eval/eval_ABX.py \ + $DATA_DIR \ + $ITEM_FILE_PATH \ + --file_extension $FILE_EXTENSION \ + --feature_size $FEATURE_SIZE \ + --out $OUT_DIR \ + --mode "all" +``` + +Note that `FEATURE_SIZE` will depend on the model type you are using to extract the acoustic features: +* For HuBERT and Wav2Vec2.0, use `FEATURE_SIZE=0.02` +* For CPC and Log Mel, use `FEATURE_SIZE=0.01` + +If you have a gpu available, make sure you add the `--cuda` flag for faster computation. + +[ll-instructions]: https://github.com/facebookresearch/libri-light +[ll-abx]: https://github.com/facebookresearch/libri-light/tree/master/eval#abx diff --git a/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py b/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py new file mode 100644 index 0000000000..41cf558970 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import joblib +import numpy as np + +from examples.textless_nlp.gslm.speech2unit.clustering.utils import get_audio_files +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_features + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + +def get_parser(): + parser = argparse.ArgumentParser( + description="Quantize using K-means clustering over acoustic features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Pretrained model checkpoint", + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--out_dir_path", + required=True, + type=str, + help="File path of quantized output.", + ) + parser.add_argument( + "--extension", type=str, default=".flac", help="Features file path" + ) + return parser + + +def one_hot(feat, n_clusters): + return np.eye(n_clusters)[feat] + +def main(args, logger): + # Feature extraction + logger.info(f"Extracting {args.feature_type} acoustic features...") + features_batch = get_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=1.0, + flatten=False, + ) + logger.info(f"Features extracted for {len(features_batch)} utterances.\n") + logger.info(f"Dimensionality of representation = {features_batch[0].shape[1]}") + + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + _, fnames, _ = get_audio_files(args.manifest_path) + + os.makedirs(args.out_dir_path, exist_ok=True) + logger.info(f"Writing quantized features to {args.out_dir_path}") + for i, feats in enumerate(features_batch): + pred = kmeans_model.predict(feats) + emb = one_hot(pred, kmeans_model.n_clusters) + base_fname = os.path.basename(fnames[i]).rstrip(args.extension) + output_path = os.path.join(args.out_dir_path, f"{base_fname}.npy") + with open(output_path, "wb") as f: + np.save(f, emb) + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/README.md b/examples/textless_nlp/gslm/metrics/asr_metrics/README.md new file mode 100644 index 0000000000..d05bc73d0d --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/README.md @@ -0,0 +1,87 @@ +# ASR-based evaluation + +Overall, the life cycle of the ASR-based evaluation for an ULM contains the following steps: + 1. Training an ULM and sampling from it [[description]](./../../ulm) + 2. Running UTS on the sampled unit sequences [[description]](./../../unit2speech) + 3. Pre-processing for the ASR (down-sampling to 16 KHz, aligning length of the generated audio with ground-truth utterances) + 4. Running ASR + 5. Calculation of the post-ASR evaluation metrics + +Here we assume that you have already went throught the first two steps and focus on the rest. + +## Preprocessing +### Down-sampling to 16KHz +The bulk conversion can be done by running +```bash + python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py $UTS_OUTPUT $UTS_OUTPUT_DOWNSAMPLE + ``` + where `$UTS_OUTPUT` specifies the directory with the generated audio and `$UTS_OUTPUT_DOWNSAMPLE` is the directory where downsampled audio would be saved. + + ### Matching by length +This step is somewhat optional. However, if you want to compare the fluency and diversity of a generated speech utterance to that of the ground-truth speech with the same prefix, it is a good idea to force them to be of the same length. +```bash +python $FAIRSEQ_ROOT/examples/textless_nlp/asr_metrics/cut_as.py \ + --samples_dir=$UTS_OUTPUT_DOWNSAMPLE --out_dir=$UTS_OUTPUT_DOWNSAMPLE_CUT \ + --prompts_description=data/ground_truth_continuation_dev.json +``` + +Here `ground_truth_continuation_dev.json` is a json file with ground-truth text from LibriSpeech dev-clean, associated with some meta-data (assuming the evaluation is done on dev-clean). This file can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_dev.json). A similar file for the test-clean is [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_test.json). These files are used for the evaluation and contain texts for audio sequences that are at least 6s long. + +## Running ASR +We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint +[[here]](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md)). To run ASR, you would also need to +install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model. + +```bash + python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py \ + $UTS_OUTPUT_DOWNSAMPLE_CUT --valid-percent 0.0 --dest $MANIFEST_DIR --ext wav +``` +where `$UTS_OUTPUT_DOWNSAMPLE_CUT` speficies the directory with the preprocessed UTS outputs and `$MANIFEST_DIR` is the output directory. + +We will be running an out-of-the-box evaluation script which requires ground-truth transcripts to measure quality metrics. We are only +interested in the transcripts (and we don't have ground-truth outputs for when our ULM generated!), hence we will just generate +some dummy transcripts instead: +```bash +cp $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dict.ltr.txt $MANIFEST_DIR +python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dummy_asr_data.py --tsv=$MANIFEST_DIR/train.tsv \ + --output-dir=$MANIFEST_DIR +``` + +Now we are ready for running ASR: +``` +mkdir -p asr +python $FAIRSEQ_ROOT/examples/speech_recognition/infer.py \ + $MANIFEST_DIR \ + --task audio_pretraining --nbest 1 --path 960h_scratch.pt \ + --gen-subset=train --results-path $PATH_TO_ASR_OUTPUT \ + --w2l-decoder kenlm --lm-model 4-gram.bin \ + --lexicon librispeech/lexicon_ltr.lst --word-score -1 \ + --sil-weight 0 --lm-weight 2 --criterion ctc --labels ltr --max-tokens 300000 --remove-bpe letter +``` +where `lexicon_ltr.lst` is the LibriSpeech lexicon and `$PATH_TO_ASR_OUTPUT` is the output directory (can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/lexicon_ltr.lst)). + +## Evaluation metrics +We run evaluation on the 1_000 shortest sequences that are at least 6s long. To filter those from the ASR transcript, we additionally provide each metric script with the paths to the manifest and `ground_truth_continuation_*` files. + +### Perplexity (PPX) +To get a PPX metric estimate on an ASR transcript, you need to run the following command: +```bash +python ppx.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail\ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` +where `--cut-tail` tells the script to ignore the last token on each line (ASR puts the sequence ID there). + +### Self- and Auto-BLEU +```bash +python self_bleu.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail \ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` + +### Continuation-BLEU +```bash +python continuation_eval.py --asr-transcript $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt \ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` + +### AUC +Based on the metrics calculated above, we can estimate the AUC of the perplexity/diversity trade-off. We provide an illustration in a [Colab notebook](https://colab.research.google.com/drive/1pVPfOVax_PU3MkYdHRSsa-SI8GBUldNt?usp=sharing). diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py b/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py new file mode 100644 index 0000000000..72b92a341d --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import defaultdict +import numpy as np +from misc.bleu_utils import sentence_bleu +import json +import warnings + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser("Tool to calculate Continuation-BLEU2") + parser.add_argument('--asr-transcript', type=str, + help='Path to the transcript file.') + parser.add_argument('--prompts-description', type=str, + help='Path to the ground-truth continuation') + parser.add_argument('--manifest', type=str, required=True) + parser.add_argument('--take-shortest', type=int, default=1000) + + args = parser.parse_args() + + return args + + +def main(): + # NLTK produces warnings + warnings.filterwarnings("ignore") + + args = get_args() + + with open(args.prompts_description, 'r') as fin: + original_continuations = json.loads(fin.read()) + + sequence2length = [(k, v[0]) for k, v in original_continuations.items()] + assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds + + sequence2length.sort(key=lambda x: x[1]) + to_take = set(v[0] for v in sequence2length[:args.take_shortest]) + + with open(args.manifest, 'r') as fin: + fin.readline() + + linenum2file = dict([ + (i, l.split("__")[0]) for (i, l) in enumerate(fin) + ]) + + max_files = max(linenum2file.keys()) + continuations = defaultdict(list) + + mean_length_after = 0 + n_examples = 0 + + with open(args.asr_transcript, 'r') as fin: + for line in fin: + n_examples += 1 + line = line.split() + sequence_id = int(line[-1].split('-')[1][:-1]) + + assert sequence_id <= max_files + + sequence_name = linenum2file[sequence_id] + + continuations[sequence_name].append(line[:-1]) + mean_length_after += len(line) + + mean_length_after /= n_examples + print(f'Mean length of continuations, in words: {mean_length_after}') + metric_values = [] + + mean_ground_truth_words = 0 + n_examples = 0 + n_candidates = 0 + + for k, candidates in continuations.items(): + if k not in to_take: + continue + + n_examples += 1 + + ground_truth = original_continuations[k][1].split() + n_candidates += len(candidates) + bleu = sentence_bleu(candidates, ground_truth, weights=( + 0.5, 0.5), no_length_penalty=True, averaging_mode="geometric") + mean_ground_truth_words += len(ground_truth) + + metric_values.append(bleu) + + n = len(metric_values) + print( + f'Median BLEU over {n} examples: {np.median(metric_values)} +- {np.std(metric_values) / np.sqrt(n)}') + + +if __name__ == '__main__': + main() diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py new file mode 100644 index 0000000000..75cc5272d3 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py @@ -0,0 +1,166 @@ +""" + +TODO: the code is take from Apache-2 Licensed NLTK: make sure we do this properly! + + +Copied over from nltk.tranlate.bleu_score. This code has two major changes: + - allows to turn off length/brevity penalty --- it has no sense for self-bleu, + - allows to use arithmetic instead of geometric mean +""" + +import math +import sys +from fractions import Fraction +import warnings +from collections import Counter +from nltk.translate.bleu_score import modified_precision, closest_ref_length, brevity_penalty, SmoothingFunction + + +def corpus_bleu( + list_of_references, + hypotheses, + weights=(0.25, 0.25, 0.25, 0.25), + smoothing_function=None, + auto_reweigh=False, + averaging_mode="geometric", + no_length_penalty=False +): + """ + Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all + the hypotheses and their respective references. + + Instead of averaging the sentence level BLEU scores (i.e. marco-average + precision), the original BLEU metric (Papineni et al. 2002) accounts for + the micro-average precision (i.e. summing the numerators and denominators + for each hypothesis-reference(s) pairs before the division). + + >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', + ... 'ensures', 'that', 'the', 'military', 'always', + ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] + >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', + ... 'ensures', 'that', 'the', 'military', 'will', 'forever', + ... 'heed', 'Party', 'commands'] + >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which', + ... 'guarantees', 'the', 'military', 'forces', 'always', + ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party'] + >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', + ... 'army', 'always', 'to', 'heed', 'the', 'directions', + ... 'of', 'the', 'party'] + + >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', + ... 'interested', 'in', 'world', 'history'] + >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history', + ... 'because', 'he', 'read', 'the', 'book'] + + >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]] + >>> hypotheses = [hyp1, hyp2] + >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS + 0.5920... + + The example below show that corpus_bleu() is different from averaging + sentence_bleu() for hypotheses + + >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1) + >>> score2 = sentence_bleu([ref2a], hyp2) + >>> (score1 + score2) / 2 # doctest: +ELLIPSIS + 0.6223... + + :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses + :type list_of_references: list(list(list(str))) + :param hypotheses: a list of hypothesis sentences + :type hypotheses: list(list(str)) + :param weights: weights for unigrams, bigrams, trigrams and so on + :type weights: list(float) + :param smoothing_function: + :type smoothing_function: SmoothingFunction + :param auto_reweigh: Option to re-normalize the weights uniformly. + :type auto_reweigh: bool + :return: The corpus-level BLEU score. + :rtype: float + """ + # Before proceeding to compute BLEU, perform sanity checks. + + p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. + p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. + hyp_lengths, ref_lengths = 0, 0 + + assert len(list_of_references) == len(hypotheses), ( + "The number of hypotheses and their reference(s) should be the " "same " + ) + + # Iterate through each hypothesis and their corresponding references. + for references, hypothesis in zip(list_of_references, hypotheses): + # For each order of ngram, calculate the numerator and + # denominator for the corpus-level modified precision. + for i, _ in enumerate(weights, start=1): + p_i = modified_precision(references, hypothesis, i) + p_numerators[i] += p_i.numerator + p_denominators[i] += p_i.denominator + + # Calculate the hypothesis length and the closest reference length. + # Adds them to the corpus-level hypothesis and reference counts. + hyp_len = len(hypothesis) + hyp_lengths += hyp_len + ref_lengths += closest_ref_length(references, hyp_len) + + # Calculate corpus-level brevity penalty. + if no_length_penalty and averaging_mode == 'geometric': + bp = 1.0 + elif no_length_penalty and averaging_mode == 'arithmetic': + bp = 0.0 + else: + assert not no_length_penalty + assert averaging_mode != 'arithmetic', 'Not sure how to apply length penalty when aurithmetic mode' + bp = brevity_penalty(ref_lengths, hyp_lengths) + + # Uniformly re-weighting based on maximum hypothesis lengths if largest + # order of n-grams < 4 and weights is set at default. + if auto_reweigh: + if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): + weights = (1 / hyp_lengths,) * hyp_lengths + + # Collects the various precision values for the different ngram orders. + p_n = [ + Fraction(p_numerators[i], p_denominators[i], _normalize=False) + for i, _ in enumerate(weights, start=1) + ] + + # Returns 0 if there's no matching n-grams + # We only need to check for p_numerators[1] == 0, since if there's + # no unigrams, there won't be any higher order ngrams. + if p_numerators[1] == 0: + return 0 + + # If there's no smoothing, set use method0 from SmoothinFunction class. + if not smoothing_function: + smoothing_function = SmoothingFunction().method0 + # Smoothen the modified precision. + # Note: smoothing_function() may convert values into floats; + # it tries to retain the Fraction object as much as the + # smoothing method allows. + p_n = smoothing_function( + p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths + ) + + if averaging_mode == "geometric": + s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n)) + s = bp * math.exp(math.fsum(s)) + elif averaging_mode == "arithmetic": + s = (w_i * p_i for w_i, p_i in zip(weights, p_n)) + s = math.fsum(s) + + return s + + +def sentence_bleu( + references, + hypothesis, + weights=(0.25, 0.25, 0.25, 0.25), + smoothing_function=None, + auto_reweigh=False, + averaging_mode="geometric", + no_length_penalty=False +): + return corpus_bleu( + [references], [hypothesis], weights, smoothing_function, auto_reweigh, averaging_mode, no_length_penalty + ) \ No newline at end of file diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py new file mode 100644 index 0000000000..5b7e1e9685 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torchaudio +import argparse +import json +import pathlib + + +def get_args(): + parser = argparse.ArgumentParser( + "Assuring generated audio have the same length as ground-truth audio") + parser.add_argument('--samples_dir', required=True, type=str) + parser.add_argument('--out_dir', required=True, type=str) + parser.add_argument('--prompts_description', required=True, type=str) + return parser.parse_args() + + +def cut(src, tgt, l): + x, sr = torchaudio.load(str(src)) + assert sr == 16_000 + + x = x.squeeze() + target_frames = int(l * sr) + + flag = 0 + if target_frames <= x.size(0): + x = x[:target_frames] + flag = 1 + else: + flag = 0 + torchaudio.save(str(tgt), x.unsqueeze(0), sr) + return flag + + +def main(): + args = get_args() + tgt_dir = pathlib.Path(args.out_dir) + tgt_dir.mkdir(exist_ok=True, parents=True) + + total_files, sufficiently_long = 0, 0 + + with open(args.prompts_description, 'r') as f: + description = json.loads(f.read()) + + for src_f in pathlib.Path(args.samples_dir).glob('*.wav'): + name_prompt = src_f.with_suffix('').name.split('__')[0] + + assert name_prompt in description, f'Cannot find {name_prompt}!' + + target_length = description[name_prompt][0] + tgt_f = tgt_dir / (src_f.name) + + is_long_enough = cut(src_f, tgt_f, target_length) + sufficiently_long += is_long_enough + if not is_long_enough: + print(f'{src_f} is not long enough') + + total_files += 1 + + print( + f'Total files: {total_files}; sufficiently long: {sufficiently_long}') + + +if __name__ == '__main__': + main() diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt new file mode 100644 index 0000000000..69929e1666 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt @@ -0,0 +1,28 @@ +| 94802 +E 51860 +T 38431 +A 33152 +O 31495 +N 28855 +I 28794 +H 27187 +S 26071 +R 23546 +D 18289 +L 16308 +U 12400 +M 10685 +W 10317 +C 9844 +F 9062 +G 8924 +Y 8226 +P 6890 +B 6339 +V 3936 +K 3456 +' 1023 +X 636 +J 598 +Q 437 +Z 213 diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py b/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py new file mode 100644 index 0000000000..d6a40e4d35 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import numpy as np +import warnings + + +def get_target_sequences(manifest, ground_truth, to_take=1000): + import json + import pathlib + + with open(ground_truth, 'r') as fin: + original_continuations = json.loads(fin.read()) + + sequence2length = [(k, v[0]) for k, v in original_continuations.items()] + assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds + + sequence2length.sort(key=lambda x: x[1]) + to_take_sequences = set(v[0] for v in sequence2length[:to_take]) + to_take_ids = [] + + with open(manifest, 'r') as f: + f.readline() + + for i, line in enumerate(f.readlines()): + seq_id = line.split()[0] + seq_id = pathlib.Path(seq_id).name.split('__')[0] + + if seq_id in to_take_sequences: + to_take_ids.append(i) + + print(f'Took {len(to_take_ids)} ids') + return set(to_take_ids) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser("Evaluate PPX metric of a transcript.") + parser.add_argument('--asr-transcript', type=str, + help='Path to the transcript file.') + parser.add_argument('--cut-id', action='store_true', + help='Whether cut the first token (typically a seq id)') + parser.add_argument('--cut-tail', action='store_true', + help='Whether cut the last token (typically a speaker id)') + + parser.add_argument('--manifest', type=str, default=None) + parser.add_argument('--prompts-description', type=str, default=None) + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + lm = torch.hub.load( + 'pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe') + + lm.eval().cuda() # disable dropout + + if args.manifest is None and args.prompts_description is None: + target_ids = None + else: + target_ids = get_target_sequences( + args.manifest, args.prompts_description) + + with open(args.asr_transcript, 'r') as fin: + lines = fin.readlines() + + if target_ids is not None: + filtered = [] + for line in lines: + line_id = line.split()[-1] + line_id = int(line_id.split('-')[1][:-1]) + if line_id in target_ids: + filtered.append(line) + lines = filtered + else: + pass + + if args.cut_id: + lines = [' '.join(x.split()[1:]) for x in lines] + if args.cut_tail: + lines = [' '.join(x.split()[:-1]) for x in lines] + lines = [x.strip().lower() for x in lines] + + def get_logprob(sent): return \ + lm.score(sent)['positional_scores'].mean().neg().item() + + logprobs = [get_logprob(l) for l in lines] + + filtered = [x for x in logprobs if not np.isnan(x)] + if len(filtered) != len(logprobs): + warnings.warn("NaNs detected!") + logprobs = filtered + + perplexities = [np.exp(l) for l in logprobs] + + for name, stats in [('logprob', logprobs), ('perplexity', perplexities)]: + mean = np.mean(stats) + sem = np.std(stats) / np.sqrt(len(stats)) + + median = np.median(stats) + interval = list(np.percentile(stats, [10, 90])) + + mean, sem, median, percentile10, percentile90 = [ + round(x, 2) for x in [mean, sem, median] + interval] + + print(name) + print(f"\tMean {mean} +- {sem}") + print( + f"\tMedian {median}, 90% confidence interval {percentile10}...{percentile90}") + + +if __name__ == '__main__': + main() diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py b/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py new file mode 100644 index 0000000000..062bb82f66 --- /dev/null +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import nltk +from misc.bleu_utils import sentence_bleu +import warnings + + +def get_target_sequences(manifest, ground_truth, to_take=1000): + import json + import pathlib + + with open(ground_truth, 'r') as fin: + original_continuations = json.loads(fin.read()) + + sequence2length = [(k, v[0]) for k, v in original_continuations.items()] + assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds + + sequence2length.sort(key=lambda x: x[1]) + to_take_sequences = set(v[0] for v in sequence2length[:to_take]) + to_take_ids = [] + + with open(manifest, 'r') as f: + f.readline() + + for i, line in enumerate(f.readlines()): + seq_id = line.split()[0] + seq_id = pathlib.Path(seq_id).name.split('__')[0] + + if seq_id in to_take_sequences: + to_take_ids.append(i) + + print(f'Took {len(to_take_ids)} ids') + return set(to_take_ids) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--asr-transcript', type=str, + help='Path to the transcript file.') + + parser.add_argument('--manifest', required=True) + parser.add_argument('--prompts-description', required=True) + + parser.add_argument('--cut-id', action='store_true', + help='Whether cut the first token (typically a seq id)') + parser.add_argument('--cut-tail', action='store_true', + help='Whether cut the last token (typically a speaker id)') + parser.add_argument('--debug', action='store_true') + + args = parser.parse_args() + + return args + + +def get_self_bleu(utterances, averaging_mode, weights): + self_bleu = [] + + for i in range(len(utterances)): + hypo = utterances[i] + rest = utterances[:i] + utterances[i+1:] + + self_bleu.append(sentence_bleu(rest, hypo, weights, + no_length_penalty=True, averaging_mode=averaging_mode)) + + return self_bleu + + +def get_self_bleu2_arithmetic(utterances): + weights = (0.5, 0.5) # equal weight for unigrams and bigrams + return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights) + + +def get_self_bleu2_geometric(utterances): + weights = (0.5, 0.5) + return get_self_bleu(utterances, averaging_mode='geometric', weights=weights) + + +def get_auto_bleu2_arithmetic(utterances): + weights = (0.5, 0.5) + return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances] + + +def get_auto_bleu2_geometric(utterances): + weights = (0.5, 0.5) + return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances] + + +def get_auto_bleu3_geometric(utterances): + weights = (1./3, 1./3, 1./3) + return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances] + + +def get_auto_bleu3_arithmetic(utterances): + weights = (1./3, 1./3, 1./3) + return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances] + + +def get_self_bleu3_arithmetic(utterances): + weights = (1./3, 1./3, 1./3) + return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights) + + +def get_self_bleu3_geometric(utterances): + weights = (1./3, 1./3, 1./3) + return get_self_bleu(utterances, averaging_mode='geometric', weights=weights) + + +def auto_bleu(sentence, weights, mean_mode='arithmetic'): + if len(sentence) <= 1: + return 0 + + N = len(weights) + + bleu_n = np.zeros([N]) + for n in range(N): + targ_ngrams = list(nltk.ngrams(sentence, n+1)) + for p in range(len(targ_ngrams)): + left = sentence[:p] + right = sentence[(p+n+1):] + rest_ngrams = list(nltk.ngrams(left, n+1)) + \ + list(nltk.ngrams(right, n+1)) + # compute the nb of matching ngrams + bleu_n[n] += targ_ngrams[p] in rest_ngrams + bleu_n[n] /= len(targ_ngrams) # average them to get a proportion + + weights = np.array(weights) + if mean_mode == 'arithmetic': + return (bleu_n * weights).sum() + elif mean_mode == 'geometric': + return (bleu_n ** weights).prod() + else: + raise ValueError(f'Unknown agggregation mode {mean_mode}') + + +def main(): + from multiprocessing import Pool + + args = get_args() + target_ids = get_target_sequences(args.manifest, args.prompts_description) + + with open(args.asr_transcript, 'r') as fin: + lines = fin.readlines() + + terms = [x.strip().split() for x in lines] + filtered = [] + for term in terms: + line_id = int(term[-1].split('-')[1][:-1]) + if line_id in target_ids: + filtered.append(term) + terms = filtered + + if args.cut_id: + terms = [x[1:] for x in terms] + if args.cut_tail: + terms = [x[:-1] for x in terms] + + if args.debug: + terms = terms[:10] + + tasks = [ + ('Self-BLEU2-arithmetic', get_self_bleu2_arithmetic), + ('Self-BLEU2-geometric', get_self_bleu2_geometric), + ('Auto-BLEU2-arithmetic', get_auto_bleu2_arithmetic), + ('Auto-BLEU2-geometric', get_auto_bleu2_geometric), + + ('Self-BLEU3-arithmetic', get_self_bleu3_arithmetic), + ('Self-BLEU3-geometric', get_self_bleu3_geometric), + ('Auto-BLEU3-arithmetic', get_auto_bleu3_arithmetic), + ('Auto-BLEU3-geometric', get_auto_bleu3_geometric), + ] + + n_processes = min(16, len(tasks)) + with Pool(n_processes) as pool: + metrics = pool.map(run_f, [(t[1], terms) for t in tasks]) + + for (metric_name, _), metric in zip(tasks, metrics): + metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric)) + + metric, sem = [ + round(100 * x, 2) for x in [metric, sem] + ] + + print(f'{metric_name} {metric} +- {sem}') + + +def run_f(task_params): + f, terms = task_params + return f(terms) + + +if __name__ == '__main__': + # NLTK produces warnings + warnings.filterwarnings("ignore") + + main() diff --git a/examples/textless_nlp/gslm/speech2unit/README.md b/examples/textless_nlp/gslm/speech2unit/README.md new file mode 100644 index 0000000000..1a3d131ec1 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/README.md @@ -0,0 +1,71 @@ +# Speech to Unit Model (speech2unit) + +## Acoustic Model +For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below. +* [Modified CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/cpc_big_ll6kh_top_ctc.pt) +* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) + +## Quantization Model +You can download pretrained quantized model from the list below. + +K-Means Model | Download Link +|-|- +Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km50/km.bin) +Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km100/km.bin) +Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km200/km.bin) +Log Mel Filterbank + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km500/km.bin) +Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km50/km.bin) +Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km100/km.bin) +Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km200/km.bin) +Modified CPC + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km500/km.bin) +HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km50/km.bin) +HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin) +HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin) +HuBERT Base + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km500/km.bin) +wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km50/km.bin) +wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km100/km.bin) +wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km200/km.bin) +wav2vec 2.0 Large + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km500/km.bin) + +### Quantization +For quantizing speech with a given acoustic representation, please follow the steps below. +1. Learn K-means clustering model +``` +N_CLUSTERS=<number_of_clusters_used_for_kmeans> +TYPE=<one_of_logmel/cpc/hubert/w2v2> +CKPT_PATH=<path_of_pretrained_acoustic_model> +LAYER=<layer_of_acoustic_model_to_extract_features_from> +MANIFEST=<tab_separated_manifest_of_audio_files_for_training_kmeans> +KM_MODEL_PATH=<output_path_of_the_kmeans_model> + +PYTHONPATH=. python examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py \ + --num_clusters $N_CLUSTERS \ + --feature_type $TYPE \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_kmeans_model_path $KM_MODEL_PATH +``` +2. Quantize using the learned clusters +``` +MANIFEST=<tab_separated_manifest_of_audio_files_to_quantize> +OUT_QUANTIZED_FILE=<output_quantized_audio_file_path> + +python examples/textless_nlp/gslm/speech2unit/clustering/del/quantize_with_kmeans.py \ + --feature_type $TYPE \ + --kmeans_model_path $KM_MODEL_PATH \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_quantized_file_path $OUT_QUANTIZED_FILE \ + --extension ".flac" +``` + +Note about the manifest file is a file with paths and length of input audio files. The format of the file is as follows: +``` +<path_of_root_directory_containing_audio_files> +<relative_path_of_audio_file_1>\t<number_of_frames_1> +<relative_path_of_audio_file_2>\t<number_of_frames_1> +... +``` \ No newline at end of file diff --git a/examples/textless_nlp/gslm/speech2unit/__init__.py b/examples/textless_nlp/gslm/speech2unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py b/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py b/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py new file mode 100644 index 0000000000..7cf844a95a --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py @@ -0,0 +1,212 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import time + +import numpy as np +from sklearn.cluster import MiniBatchKMeans + +import joblib +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_and_dump_features, + get_features, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Learn K-means clustering over acoustic features." + ) + + # Features arguments + parser.add_argument( + "--in_features_path", type=str, default=None, help="Features file path" + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + help="Acoustic feature type", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--out_features_path", + type=str, + default=None, + help="Features file path to write to", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Pretrained acoustic model checkpoint", + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--sample_pct", + type=float, + help="Percent data to use for K-means training", + default=0.1, + ) + + # K-means arguments + parser.add_argument( + "--num_clusters", type=int, help="Nubmer of clusters", default=50 + ) + parser.add_argument("--init", default="k-means++") + parser.add_argument( + "--max_iter", + type=int, + help="Maximum number of iterations for K-means training", + default=150, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size for K-means training", + default=10000, + ) + parser.add_argument("--tol", default=0.0, type=float) + parser.add_argument("--max_no_improvement", default=100, type=int) + parser.add_argument("--n_init", default=20, type=int) + parser.add_argument("--reassignment_ratio", default=0.5, type=float) + parser.add_argument( + "--out_kmeans_model_path", + type=str, + required=True, + help="Path to save K-means model", + ) + + # Leftovers + parser.add_argument( + "--seed", + type=int, + help="Random seed to use for K-means training", + default=1369, + ) + + return parser + + +def get_kmeans_model( + n_clusters, + init, + max_iter, + batch_size, + tol, + max_no_improvement, + n_init, + reassignment_ratio, + random_state, +): + return MiniBatchKMeans( + n_clusters=n_clusters, + init=init, + max_iter=max_iter, + batch_size=batch_size, + tol=tol, + max_no_improvement=max_no_improvement, + n_init=n_init, + reassignment_ratio=reassignment_ratio, + random_state=random_state, + verbose=1, + compute_labels=True, + init_size=None, + ) + + +def train_kmeans(kmeans_model, features_batch): + start_time = time.time() + kmeans_model.fit(features_batch) + time_taken = round((time.time() - start_time) // 60, 2) + return kmeans_model, time_taken + + +def main(args, logger): + # Features loading/extraction for K-means + if args.in_features_path: + # Feature loading + logger.info(f"Loading features from {args.in_features_path}...") + features_batch = np.load(args.in_features_path, allow_pickle=True) + else: + # Feature extraction + logger.info(f"Extracting {args.feature_type} acoustic features...") + features_batch = ( + get_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=args.sample_pct, + flatten=True, + ) + if not args.out_features_path + else get_and_dump_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=args.sample_pct, + flatten=True, + out_features_path=args.out_features_path, + ) + ) + if args.out_features_path: + logger.info( + f"Saved extracted features at {args.out_features_path}" + ) + logger.info(f"Features shape = {features_batch.shape}\n") + + # Learn and save K-means model + kmeans_model = get_kmeans_model( + n_clusters=args.num_clusters, + init=args.init, + max_iter=args.max_iter, + batch_size=args.batch_size, + tol=args.tol, + max_no_improvement=args.max_no_improvement, + n_init=args.n_init, + reassignment_ratio=args.reassignment_ratio, + random_state=args.seed, + ) + logger.info("Starting k-means training...") + kmeans_model, time_taken = train_kmeans( + kmeans_model=kmeans_model, features_batch=features_batch + ) + logger.info(f"...done k-means training in {time_taken} minutes") + inertia = -kmeans_model.score(features_batch) / len(features_batch) + logger.info(f"Total intertia: {round(inertia, 2)}\n") + + logger.info(f"Saving k-means model to {args.out_kmeans_model_path}") + os.makedirs(os.path.dirname(args.out_kmeans_model_path), exist_ok=True) + joblib.dump(kmeans_model, open(args.out_kmeans_model_path, "wb")) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py b/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py new file mode 100644 index 0000000000..031567c6d8 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging + +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_and_dump_features, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Compute and dump log mel fbank features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + help="Acoustic feature type", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--out_features_path", + type=str, + default=None, + help="Features file path to write to", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Pretrained acoustic model checkpoint", + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--sample_pct", + type=float, + help="Percent data to use for K-means training", + default=0.1, + ) + parser.add_argument( + "--out_features_path", + type=str, + help="Path to save log mel fbank features", + ) + return parser + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +if __name__ == "__main__": + """ + Example command: + python ~/speechbot/clustering/dump_logmelfank_feats.py \ + --manifest_path /checkpoint/kushall/data/LJSpeech-1.1/asr_input_wavs_16k/train.tsv + --out_features_path /checkpoint/kushall/experiments/speechbot/logmelfbank/features/ljspeech/train.npy + """ + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + + logger.info(f"Extracting {args.feature_type} acoustic features...") + get_and_dump_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=args.sample_pct, + flatten=True, + out_features_path=args.out_features_path, + ) + logger.info(f"Saved extracted features at {args.out_features_path}") diff --git a/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py b/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py new file mode 100644 index 0000000000..2c87445d81 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import numpy as np + +import joblib +from examples.textless_nlp.gslm.speech2unit.clustering.utils import ( + get_audio_files, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_features, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Quantize using K-means clustering over acoustic features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--acoustic_model_path", + type=str, + help="Pretrained acoustic model checkpoint" + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--features_path", + type=str, + default=None, + help="Features file path. You don't need to enter acoustic model details if you have dumped features", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--out_quantized_file_path", + required=True, + type=str, + help="File path of quantized output.", + ) + parser.add_argument( + "--extension", type=str, default=".flac", help="Features file path" + ) + return parser + + +def main(args, logger): + # Feature extraction + if args.features_path is not None: + logger.info(f"Loading acoustic features from {args.features_path}...") + features_batch = np.load(args.features_path) + else: + logger.info(f"Extracting {args.feature_type} acoustic features...") + features_batch = get_features( + feature_type=args.feature_type, + checkpoint_path=args.acoustic_model_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=1.0, + flatten=False, + ) + logger.info( + f"Features extracted for {len(features_batch)} utterances.\n" + ) + logger.info( + f"Dimensionality of representation = {features_batch[0].shape[1]}" + ) + + # K-means model + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + _, fnames, _ = get_audio_files(args.manifest_path) + + os.makedirs(os.path.dirname(args.out_quantized_file_path), exist_ok=True) + print(f"Writing quantized predictions to {args.out_quantized_file_path}") + with open(args.out_quantized_file_path, "w") as fout: + for i, feats in enumerate(features_batch): + pred = kmeans_model.predict(feats) + pred_str = " ".join(str(p) for p in pred) + base_fname = os.path.basename(fnames[i]).rstrip(args.extension) + fout.write(f"{base_fname}|{pred_str}\n") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/examples/textless_nlp/gslm/speech2unit/clustering/utils.py b/examples/textless_nlp/gslm/speech2unit/clustering/utils.py new file mode 100644 index 0000000000..cf08d1fe4b --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/clustering/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + + +def get_audio_files(manifest_path: str) -> Tuple[str, List[str], List[int]]: + fnames, sizes = [], [] + with open(manifest_path, "r") as f: + root_dir = f.readline().strip() + for line in f: + items = line.strip().split("\t") + assert ( + len(items) == 2 + ), f"File must have two columns separated by tab. Got {line}" + fnames.append(items[0]) + sizes.append(int(items[1])) + return root_dir, fnames, sizes diff --git a/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py b/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py new file mode 100644 index 0000000000..c613f52d3c --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py @@ -0,0 +1,192 @@ +import soundfile as sf +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CpcFeatureReader: + """ + Wrapper class to run inference on CPC model. + Helps extract features for a given audio file. + """ + + def __init__( + self, + checkpoint_path, + layer, + use_encoder_layer=False, + norm_features=False, + sample_rate=16000, + max_chunk=64000, + ): + self.model = load_cpc_model(checkpoint_path, layer).eval().cuda() + self.sample_rate = sample_rate + self.max_chunk = max_chunk + self.norm_features = norm_features + self.use_encoder_layer = use_encoder_layer + + def read_audio(self, path, ref_len=None): + wav, sr = sf.read(path) + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.sample_rate, sr + if ref_len is not None and abs(ref_len - len(wav)) > 160: + print(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, file_path, ref_len=None): + x = self.read_audio(file_path, ref_len) + # Inspired from CPC_audio feature_loader.py + with torch.no_grad(): + x = torch.from_numpy(x).float().cuda() + x = x.view(1, 1, -1) + size = x.size(2) + feat = [] + start = 0 + while start < size: + if start + self.max_chunk > size: + break + x_chunk = x[..., start : start + self.max_chunk] + feat_chunk = self.model.extract_features( + source=x_chunk, + get_encoded=self.use_encoder_layer, + norm_output=self.norm_features, + ) + feat.append(feat_chunk) + start += self.max_chunk + + if start < size: + x_chunk = x[:, -self.max_chunk :] + feat_chunk = self.model.extract_features( + source=x_chunk, + get_encoded=self.use_encoder_layer, + norm_output=self.norm_features, + ) + df = x_chunk.size(2) // feat_chunk.size(1) + delta = (size - start) // df + feat.append(feat_chunk[:, -delta:]) + return torch.cat(feat, 1).squeeze(0) + + +def load_cpc_model(checkpoint_path, layer=None): + state_dict = torch.load(checkpoint_path) + weights = state_dict["weights"] + config = state_dict["config"] + if layer is not None: + config["nLevelsGRU"] = layer + + encoder = CPCEncoder(config["hiddenEncoder"]) + ar_net = CPCAR( + config["hiddenEncoder"], config["hiddenGar"], False, config["nLevelsGRU"] + ) + + model = CPCModel(encoder, ar_net) + model.load_state_dict(weights, strict=False) + model.config = config + + return model + + +class ChannelNorm(nn.Module): + def __init__(self, num_features, epsilon=1e-05, affine=True): + super(ChannelNorm, self).__init__() + if affine: + self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1)) + self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1)) + else: + self.weight = None + self.bias = None + self.epsilon = epsilon + self.p = 0 + self.affine = affine + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + cum_mean = x.mean(dim=1, keepdim=True) + cum_var = x.var(dim=1, keepdim=True) + x = (x - cum_mean) * torch.rsqrt(cum_var + self.epsilon) + if self.weight is not None: + x = x * self.weight + self.bias + return x + + +class CPCEncoder(nn.Module): + def __init__(self, hidden_dim=512): + super(CPCEncoder, self).__init__() + self.conv0 = nn.Conv1d(1, hidden_dim, 10, stride=5, padding=3) + self.batchNorm0 = ChannelNorm(hidden_dim) + self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 8, stride=4, padding=2) + self.batchNorm1 = ChannelNorm(hidden_dim) + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm2 = ChannelNorm(hidden_dim) + self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm3 = ChannelNorm(hidden_dim) + self.conv4 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm4 = ChannelNorm(hidden_dim) + self.DOWNSAMPLING = 160 + + def get_output_dim(self): + return self.conv4.out_channels + + def forward(self, x): + x = F.relu(self.batchNorm0(self.conv0(x))) + x = F.relu(self.batchNorm1(self.conv1(x))) + x = F.relu(self.batchNorm2(self.conv2(x))) + x = F.relu(self.batchNorm3(self.conv3(x))) + x = F.relu(self.batchNorm4(self.conv4(x))) + return x + + +class CPCAR(nn.Module): + def __init__(self, dim_encoded, dim_output, keep_hidden, num_layers): + super(CPCAR, self).__init__() + self.baseNet = nn.LSTM( + dim_encoded, dim_output, num_layers=num_layers, batch_first=True + ) + self.hidden = None + self.keep_hidden = keep_hidden + + def get_output_dim(self): + return self.baseNet.hidden_size + + def forward(self, x): + try: + self.baseNet.flatten_parameters() + except RuntimeError: + pass + x, h = self.baseNet(x, self.hidden) + if self.keep_hidden: + if isinstance(h, tuple): + self.hidden = tuple(x.detach() for x in h) + else: + self.hidden = h.detach() + return x + + +class CPCModel(nn.Module): + def __init__(self, encoder, ar_net): + super(CPCModel, self).__init__() + self.gEncoder = encoder + self.gAR = ar_net + self.config = None + + def forward(self, x, label): + encoded = self.gEncoder(x).permute(0, 2, 1) + cpc_feature = self.gAR(encoded) + return cpc_feature, encoded, label + + def extract_features(self, source, get_encoded=False, norm_output=False): + cpc_feature, encoded, _ = self.forward(source, None) + if get_encoded: + cpc_feature = encoded + if norm_output: + mean = cpc_feature.mean(dim=1, keepdim=True) + var = cpc_feature.var(dim=1, keepdim=True) + cpc_feature = (cpc_feature - mean) / torch.sqrt(var + 1e-08) + return cpc_feature diff --git a/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py b/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py new file mode 100644 index 0000000000..09442206e1 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import fairseq +import soundfile as sf +import torch.nn.functional as F + + +class HubertFeatureReader: + """ + Wrapper class to run inference on HuBERT model. + Helps extract features for a given audio file. + """ + + def __init__(self, checkpoint_path, layer, max_chunk=1600000): + ( + model, + cfg, + task, + ) = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path] + ) + self.model = model[0].eval().cuda() + self.task = task + self.layer = layer + self.max_chunk = max_chunk + + def read_audio(self, path, ref_len=None): + wav, sr = sf.read(path) + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.task.cfg.sample_rate, sr + if ref_len is not None and abs(ref_len - len(wav)) > 160: + print(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, file_path, ref_len=None): + x = self.read_audio(file_path, ref_len) + with torch.no_grad(): + x = torch.from_numpy(x).float().cuda() + if self.task.cfg.normalize: + x = F.layer_norm(x, x.shape) + x = x.view(1, -1) + + feat = [] + for start in range(0, x.size(1), self.max_chunk): + x_chunk = x[:, start: start + self.max_chunk] + feat_chunk, _ = self.model.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + output_layer=self.layer, + ) + feat.append(feat_chunk) + return torch.cat(feat, 1).squeeze(0) diff --git a/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py b/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py new file mode 100644 index 0000000000..106f502476 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import soundfile as sf +import torch +import torchaudio.compliance.kaldi as kaldi + + +class LogMelFeatureReader: + """ + Wrapper class to run inference on HuBERT model. + Helps extract features for a given audio file. + """ + + def __init__(self, *args, **kwargs): + self.num_mel_bins = kwargs.get("num_mel_bins", 80) + self.frame_length = kwargs.get("frame_length", 25.0) + + def get_feats(self, file_path): + wav, sr = sf.read(file_path) + feats = torch.from_numpy(wav).float() + feats = kaldi.fbank( + feats.unsqueeze(0), + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + sample_frequency=sr, + ) + return feats diff --git a/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py b/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py new file mode 100644 index 0000000000..5aaddf6421 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import os +import random +import shutil +import numpy as np + +import torch +import tqdm +from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import ( + CpcFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import ( + HubertFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import ( + LogMelFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import ( + Wav2VecFeatureReader, +) + + +def get_feature_reader(feature_type): + if feature_type == "logmel": + return LogMelFeatureReader + elif feature_type == "hubert": + return HubertFeatureReader + elif feature_type == "w2v2": + return Wav2VecFeatureReader + elif feature_type == "cpc": + return CpcFeatureReader + else: + raise NotImplementedError(f"{feature_type} is not supported.") + + +def get_feature_iterator( + feature_type, checkpoint_path, layer, manifest_path, sample_pct +): + feature_reader_cls = get_feature_reader(feature_type) + with open(manifest_path, "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + file_path_list = [ + os.path.join(root, line.split("\t")[0]) + for line in lines + if len(line) > 0 + ] + if sample_pct < 1.0: + file_path_list = random.sample( + file_path_list, int(sample_pct * len(file_path_list)) + ) + num_files = len(file_path_list) + reader = feature_reader_cls( + checkpoint_path=checkpoint_path, layer=layer + ) + + def iterate(): + for file_path in file_path_list: + feats = reader.get_feats(file_path) + yield feats.cpu().numpy() + + return iterate, num_files + + +def get_features( + feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten +): + generator, num_files = get_feature_iterator( + feature_type=feature_type, + checkpoint_path=checkpoint_path, + layer=layer, + manifest_path=manifest_path, + sample_pct=sample_pct, + ) + iterator = generator() + + features_list = [] + for features in tqdm.tqdm(iterator, total=num_files): + features_list.append(features) + + # Explicit clean up + del iterator + del generator + gc.collect() + torch.cuda.empty_cache() + + if flatten: + return np.concatenate(features_list) + + return features_list + + +def get_and_dump_features( + feature_type, + checkpoint_path, + layer, + manifest_path, + sample_pct, + flatten, + out_features_path, +): + # Feature extraction + features_batch = get_features( + feature_type=feature_type, + checkpoint_path=checkpoint_path, + layer=layer, + manifest_path=manifest_path, + sample_pct=sample_pct, + flatten=flatten, + ) + + # Save features + out_dir_path = os.path.dirname(out_features_path) + os.makedirs(out_dir_path, exist_ok=True) + shutil.copyfile( + manifest_path, + os.path.join(out_dir_path, os.path.basename(manifest_path)), + ) + np.save(out_features_path, features_batch) + + return features_batch diff --git a/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py b/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py new file mode 100644 index 0000000000..b878321e44 --- /dev/null +++ b/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import fairseq +import soundfile as sf + + +class Wav2VecFeatureReader: + """ + Wrapper class to run inference on Wav2Vec 2.0 model. + Helps extract features for a given audio file. + """ + + def __init__(self, checkpoint_path, layer): + state = fairseq.checkpoint_utils.load_checkpoint_to_cpu( + checkpoint_path + ) + + w2v_args = state["args"] + self.task = fairseq.tasks.setup_task(w2v_args) + model = self.task.build_model(w2v_args) + model.load_state_dict(state["model"], strict=True) + model.eval() + model.cuda() + self.model = model + self.layer = layer + + def read_audio(self, fname): + wav, sr = sf.read(fname) + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.task.cfg.sample_rate, sr + return wav + + def get_feats(self, file_path): + x = self.read_audio(file_path) + with torch.no_grad(): + source = torch.from_numpy(x).view(1, -1).float().cuda() + res = self.model( + source=source, mask=False, features_only=True, layer=self.layer + ) + return res["layer_results"][self.layer][0].squeeze(1) diff --git a/examples/textless_nlp/gslm/tools/README.md b/examples/textless_nlp/gslm/tools/README.md new file mode 100644 index 0000000000..61fcbbded8 --- /dev/null +++ b/examples/textless_nlp/gslm/tools/README.md @@ -0,0 +1,22 @@ +# GSLM Tools + +## Resynthesis +You can use the command line tool below to input an audio file and get the resynthesized audio. This tool implements the unsupervised method for resynthesis described in the paper. The way to invoke the command line tool is shown below. +``` +FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root> +TYPE=<one_of_logmel/cpc/hubert/w2v2> +ACOUSTIC_MODEL_PATH=<path_of_pretrained_acoustic_model> +LAYER=<layer_of_acoustic_model_to_extract_features_from> +KM_MODEL_PATH=<output_path_of_the_kmeans_model> +TTS_MODEL_PATH=<unit2speech_model_file_path> +WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint> + +PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/tools/gen_speech.py \ + --feature_type $TYPE \ + --acoustic_model_path $ACOUSTIC_MODEL_PATH \ + --layer $LAYER \ + --kmeans_model_path $KM_MODEL_PATH \ + --tts_model_path $TTS_MODEL_PATH \ + --waveglow_path $WAVEGLOW_PATH \ + --max_decoder_steps 2000 +``` \ No newline at end of file diff --git a/examples/textless_nlp/gslm/tools/resynthesize_speech.py b/examples/textless_nlp/gslm/tools/resynthesize_speech.py new file mode 100644 index 0000000000..2b6215d372 --- /dev/null +++ b/examples/textless_nlp/gslm/tools/resynthesize_speech.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging + +import joblib +import soundfile as sf +import torch +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_feature_reader, +) +from examples.textless_nlp.gslm.unit2speech.tts_data import ( + TacotronInputDataset, +) +from examples.textless_nlp.gslm.unit2speech.utils import ( + load_tacotron, + load_waveglow, + synthesize_audio, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="GSLM speech resynthesis tool." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--acoustic_model_path", + type=str, + help="Pretrained acoustic model checkpoint", + ) + parser.add_argument( + "--layer", type=int, help="Layer of acoustic model" + ) + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--tts_model_path", + type=str, + help="TTS model file path to use for inference", + ) + parser.add_argument( + "--waveglow_path", + type=str, + help="Waveglow (vocoder) model file path to use for inference", + ) + parser.add_argument("--max_decoder_steps", type=int, default=2000) + parser.add_argument("--denoiser_strength", type=float, default=0.1) + return parser + + +################################################ +def main(args, logger): + # Acoustic Model + logger.info(f"Loading acoustic model from {args.tts_model_path}...") + feature_reader_cls = get_feature_reader(args.feature_type) + reader = feature_reader_cls( + checkpoint_path=args.acoustic_model_path, layer=args.layer + ) + + # K-means Model + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + # TTS Model + logger.info(f"Loading TTS model from {args.tts_model_path}...") + tacotron_model, sample_rate, hparams = load_tacotron( + tacotron_model_path=args.tts_model_path, + max_decoder_steps=args.max_decoder_steps, + ) + + # Waveglow Model + logger.info(f"Loading Waveglow model from {args.waveglow_path}...") + waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path) + + # Dataset + tts_dataset = TacotronInputDataset(hparams) + + iters = 0 + while True: + in_file_path = input( + "Input: Enter the full file path of audio file...\n" + ) + out_file_path = input( + "Output: Enter the full file path of audio file...\n" + ) + feats = reader.get_feats(in_file_path).cpu().numpy() + iters += 1 + if iters == 1000: + gc.collect() + torch.cuda.empty_cache() + + quantized_units = kmeans_model.predict(feats) + quantized_units_str = " ".join(map(str, quantized_units)) + + tts_input = tts_dataset.get_tensor(quantized_units_str) + mel, aud, aud_dn, has_eos = synthesize_audio( + tacotron_model, + waveglow, + denoiser, + tts_input.unsqueeze(0), + strength=args.denoiser_strength, + ) + sf.write( + f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate + ) + logger.info("Resynthesis done!\n") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/examples/textless_nlp/gslm/ulm/README.md b/examples/textless_nlp/gslm/ulm/README.md new file mode 100644 index 0000000000..01459121ce --- /dev/null +++ b/examples/textless_nlp/gslm/ulm/README.md @@ -0,0 +1,72 @@ +# Unit Language Model (ULM) + +Here you can find links to the pre-trained ULMs and instructions on training new models using fairseq. At the end of the page, we also share how to run sampling for those models and provide pointers to the transcribed prompts we used. + +## Pre-trained models + +Using the links below, you can download pre-trained models for various unit types and vocabulary sizes: + +| | 50 | 100 | 200 +|-|-|-|- +| LogMel Filterbank | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km50/logmel50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km100/logmel100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km200/logmel200_lm.tgz) +| Modified CPC | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km50/cpc50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km100/cpc100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km200/cpc200_lm.tgz) +| HuBERT | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km50/hubert50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km100/hubert100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km200/hubert200_lm.tgz) +| Wav2Vec 2.0 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km50/w2v2_50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km100/w2v2_100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km200/w2v2_200_lm.tgz) + + +## Preprocessing data +Assuming that unit-transcribed train, valid, and test sets are located in `data/train.txt`, `data/valid.txt`, and `data/test.txt`, respectively, +we run the following command to get a preprocessed version of the datast in `data-bin`: + +```bash +fairseq-preprocess --only-source \ + --trainpref data/train.txt --validpref data/valid.txt --testpref data/test.txt \ + --destdir data-bin/ --workers 40 +``` +As a result, the `data-bin` directory should appear. + +## Fitting a Unit Language Model (ULM) +As an ULM, we train a standard fairseq Transformer LM. Assuming 8 GPUs used for training, a good starting point for an ULM training would be: +```bash + fairseq-train data-bin/ \ + --task=language_modeling \ + --arch=transformer_lm_big \ + --share-decoder-input-output-embed \ + --dropout=0.1 \ + --attention-dropout=0.1 \ + --optimizer=adam \ + --adam-betas='(0.9, 0.98)' \ + --clip-norm=1.0 \ + --lr=0.0005 \ + --lr-scheduler=inverse_sqrt \ + --warmup-updates=4000 \ + --warmup-init-lr=1e-07 \ + --tokens-per-sample=3072 \ + --update-freq=16 \ + --max-tokens=4096 \ + --num-workers=4 \ + --skip-invalid-size-inputs-valid-test \ + --max-update=500000 \ + --log-interval=10 \ + --seed=100501 \ + --fp16 \ + --sample-break-mode=eos +``` +This command will train a Transformer-large model (12 layers). You can train other standard LM models provided by fairseq, e.g. specify `--arch=transformer_lm` to train a smaller (6-layer) Transformer model. When training with a different number of GPUs, it might be a good idea to adjust the `update-freq` parameter. To save the GPU memory at an expense of additional computation, it can be useful to enable activation checkpointing with `--checkpoint-activations`. + +## Sampling from an ULM +Once an ULM was trained, we can use it for generating new utterances. Suppose, that the prompts are given in a file named `prompts.txt`. Then we can sample continuations by running the following command: + +```bash + python sample.py data-bin/ \ + --path=checkpoints/checkpoint_best.pt --task=language_modeling --sampling --temperature=0.7 \ + --seed=1 --prompts=prompts.txt --output=samples.txt --max-len-a=0 --max-len-b=500 \ + --prefix-size=-1 --batch-size=16 --fp16 --samples-per-prompt=10 +``` +Here, `--prefix-size` controls the number of tokens that are used to prime the ULM. When set to a positive value, the sampling script will take first `prefix-size` tokens to prompt the ULM; with `0` it runs unconditional sampling and with `-1` the entire prompt is used. +`--samples-per-prompt` specifies how many utterances are generated with every prompt which can be useful when generating multiple prompt continuations. In this command, `--max-len-a` and `--max-len-b` control the number of generated tokens. + +When using a pretrained model from above, `data-bin` should point to the unpacked directory (with `dict.txt` file). + +Evaluation-time, to generate prompts, we used utterances from LibriSpeech dev-clean and test-clean that are longer than 6s. We took first 3s from an utterance as a prompt. Unit transcripts of those prompts can be downloaded here: [[dev]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/dev_prompts.tgz) [[test]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/test_prompts.tgz) + diff --git a/examples/textless_nlp/gslm/ulm/sample.py b/examples/textless_nlp/gslm/ulm/sample.py new file mode 100644 index 0000000000..77302a6894 --- /dev/null +++ b/examples/textless_nlp/gslm/ulm/sample.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Sample from a trained LM; hacked fairseq-interactive +""" +from collections import namedtuple +import os +import ast +import numpy as np + +from fairseq import checkpoint_utils, options, tasks, utils + +import tqdm + +Batch = namedtuple('Batch', 'ids src_tokens src_lengths') +Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') + + +def make_batches(lines, args, task, max_positions): + tokens = [ + task.source_dictionary.encode_line( + src_str, add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + itr = task.get_batch_iterator( + dataset=task.build_dataset_for_inference(tokens, lengths), + max_tokens=args.dataset.max_tokens, + max_sentences=args.dataset.batch_size, + max_positions=max_positions, + ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test + ).next_epoch_itr(shuffle=False) + for batch in itr: + yield Batch( + ids=batch['id'], + src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], + ) + + +def main(args): + arg_prompts = args.prompts + arg_output = args.output + arg_debug = args.debug + arg_sample_size = args.samples_per_prompt + + try: + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + args = convert_namespace_to_omegaconf(args) + except: + pass + + # if args.max_tokens is None and args.max_sentences is None: + if args.common.seed is not None: + np.random.seed(args.common.seed) + utils.set_torch_seed(args.common.seed) + + if args.generation.sampling: + args.generation.nbest = args.generation.beam = arg_sample_size + + task = tasks.setup_task(args.task) + + overrides = ast.literal_eval(args.common_eval.model_overrides) + + models, _model_args = checkpoint_utils.load_model_ensemble( + args.common_eval.path.split(os.pathsep), + arg_overrides=overrides, + task=task, + suffix=getattr(args, "checkpoint_suffix", ""), + ) + + # Set dictionaries + src_dict = task.source_dictionary + tgt_dict = task.target_dictionary + + # Optimize ensemble for generation + for model in models: + model.prepare_for_inference_(args) + model.cuda() + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(args.generation.replace_unk) + + max_positions = utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() for model in models] + ) + + output_file = open(arg_output, 'w') + + with open(arg_prompts, 'r') as fin: + lines = fin.readlines() + + split = [x.split('|', 1) for x in lines] + seq_id = [x[0] for x in split] + prompts = [x[1] for x in split] + + if args.generation.prefix_size >= 0: + prompts = [' '.join(l.split()[:args.generation.prefix_size]) + for l in prompts] + + if arg_debug: + prompts = prompts[:10] + + generator = task.build_generator(models, args.generation) + + start_id = 0 + pbar = tqdm.tqdm(total=len(prompts)) + for batch in make_batches(prompts, args, task, max_positions): + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + + sample = { + 'net_input': { + 'src_tokens': src_tokens, + 'src_lengths': src_lengths, + }, + } + + results = [] + translations = task.inference_step(generator, models, sample) + for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): + src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) + results.append((i + start_id, src_tokens_i, hypos)) + + # sort output to match input order + for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): + if src_dict is not None: + src_str = src_dict.string( + src_tokens, args.common_eval.post_process) + + # Process top predictions + for hypo_id, hypo in enumerate(hypos): + _hypo_tokens, hypo_str, _alignment = utils.post_process_prediction( + hypo_tokens=hypo['tokens'].int().cpu(), + src_str=src_str, + alignment=hypo['alignment'], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=args.common_eval.post_process, + ) + + detok_hypo_str = hypo_str + utterance = detok_hypo_str + print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file) + pbar.update(1) + start_id += len(results) + + # output_file.close() + + +def cli_main(): + parser = options.get_interactive_generation_parser() + parser.add_argument('--prompts', type=str, default=None, required=True) + parser.add_argument('--output', type=str, default=None, required=True) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--samples-per-prompt', type=int, default=1) + + args = options.parse_args_and_arch(parser) + + np.random.seed(args.seed) + utils.set_torch_seed(args.seed) + + main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/textless_nlp/gslm/unit2speech/README.md b/examples/textless_nlp/gslm/unit2speech/README.md new file mode 100644 index 0000000000..5710423065 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/README.md @@ -0,0 +1,42 @@ +# Unit to Speech Model (unit2speech) + +Unit to speech model is modified Tacotron2 model that learns to synthesize speech from discrete speech units. All models are trained on quantized [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +Upstream Units | Download Link +|-|- +Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/tts_checkpoint_best.pt) +Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/tts_checkpoint_best.pt) +Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/tts_checkpoint_best.pt) +Log Mel Filterbank + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km500/tts_checkpoint_best.pt) +Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/tts_checkpoint_best.pt) +Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/tts_checkpoint_best.pt) +Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/tts_checkpoint_best.pt) +Modified CPC + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km500/tts_checkpoint_best.pt) +HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/tts_checkpoint_best.pt) +HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/tts_checkpoint_best.pt) +HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt) +HuBERT Base + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km500/tts_checkpoint_best.pt) +wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/tts_checkpoint_best.pt) +wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/tts_checkpoint_best.pt) +wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/tts_checkpoint_best.pt) +wav2vec 2.0 Large + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km500/tts_checkpoint_best.pt) + +## Run inference using a unit2speech model +* Install librosa, unidecode and inflect using `pip install librosa, unidecode, inflect` +* Download [Waveglow checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt). This is the vocoder. + +Sample commnd to run inference using trained unit2speech models. Please note that the quantized audio to synthesized should be using the same units as the unit2speech model was trained with. +``` +FAIRSEQ_ROOT=<path_to_your_fairseq_repo_root> +TTS_MODEL_PATH=<unit2speech_model_file_path> +QUANTIZED_UNIT_PATH=<quantized_audio_file_path> +OUT_DIR=<dir_to_dump_synthesized_audio_files> +WAVEGLOW_PATH=<path_where_you_have_downloaded_waveglow_checkpoint> + +PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py \ + --tts_model_path $TTS_MODEL_PATH \ + --quantized_unit_path $QUANTIZED_UNIT_PATH \ + --out_audio_dir $OUT_DIR \ + --waveglow_path $WAVEGLOW_PATH \ + --max_decoder_steps 2000 +``` \ No newline at end of file diff --git a/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py b/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py new file mode 100644 index 0000000000..2be848fcea --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py @@ -0,0 +1,56 @@ +import os +import shlex +import subprocess +import progressbar +from time import time +from pathlib import Path + +def find_all_files(path_dir, extension): + out = [] + for root, dirs, filenames in os.walk(path_dir): + for f in filenames: + if f.endswith(extension): + out.append(((str(Path(f).stem)), os.path.join(root, f))) + return out + +def convert16k(inputfile, outputfile16k): + command = ('sox -c 1 -b 16 {} -t wav {} rate 16k'.format(inputfile, outputfile16k)) + subprocess.call(shlex.split(command)) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Convert to wav 16k audio using sox.') + parser.add_argument('input_dir', type=str, + help='Path to the input dir.') + parser.add_argument('output_dir', type=str, + help='Path to the output dir.') + parser.add_argument('--extension', type=str, default='wav', + help='Audio file extension in the input. Default: mp3') + args = parser.parse_args() + + # Find all sequences + print(f"Finding all audio files with extension '{args.extension}' from {args.input_dir}...") + audio_files = find_all_files(args.input_dir, args.extension) + print(f"Done! Found {len(audio_files)} files.") + + # Convert to relative path + audio_files = [os.path.relpath(file[-1], start=args.input_dir) for file in audio_files] + + # Create all the directories needed + rel_dirs_set = set([os.path.dirname(file) for file in audio_files]) + for rel_dir in rel_dirs_set: + Path(os.path.join(args.output_dir, rel_dir)).mkdir(parents=True, exist_ok=True) + + # Converting wavs files + print("Converting the audio to wav files...") + bar = progressbar.ProgressBar(maxval=len(audio_files)) + bar.start() + start_time = time() + for index, file in enumerate(audio_files): + bar.update(index) + input_file = os.path.join(args.input_dir, file) + output_file = os.path.join(args.output_dir, os.path.splitext(file)[0]+".wav") + convert16k(input_file, output_file) + bar.finish() + print(f"...done {len(audio_files)} files in {time()-start_time} seconds.") \ No newline at end of file diff --git a/examples/textless_nlp/gslm/unit2speech/glow.py b/examples/textless_nlp/gslm/unit2speech/glow.py new file mode 100644 index 0000000000..7a7696403d --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/glow.py @@ -0,0 +1,311 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** +import copy +import torch +from torch.autograd import Variable +import torch.nn.functional as F + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a+input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WaveGlowLoss(torch.nn.Module): + def __init__(self, sigma=1.0): + super(WaveGlowLoss, self).__init__() + self.sigma = sigma + + def forward(self, model_output): + z, log_s_list, log_det_W_list = model_output + for i, log_s in enumerate(log_s_list): + if i == 0: + log_s_total = torch.sum(log_s) + log_det_W_total = log_det_W_list[i] + else: + log_s_total = log_s_total + torch.sum(log_s) + log_det_W_total += log_det_W_list[i] + + loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total + return loss/(z.size(0)*z.size(1)*z.size(2)) + + +class Invertible1x1Conv(torch.nn.Module): + """ + The layer outputs both the convolution, and the log determinant + of its weight matrix. If reverse=True it does convolution with + inverse + """ + def __init__(self, c): + super(Invertible1x1Conv, self).__init__() + self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, + bias=False) + + # Sample a random orthonormal matrix to initialize weights + W = torch.qr(torch.FloatTensor(c, c).normal_())[0] + + # Ensure determinant is 1.0 not -1.0 + if torch.det(W) < 0: + W[:,0] = -1*W[:,0] + W = W.view(c, c, 1) + self.conv.weight.data = W + + def forward(self, z, reverse=False): + # shape + batch_size, group_size, n_of_groups = z.size() + + W = self.conv.weight.squeeze() + + if reverse: + if not hasattr(self, 'W_inverse'): + # Reverse computation + W_inverse = W.float().inverse() + W_inverse = Variable(W_inverse[..., None]) + if z.type() == 'torch.cuda.HalfTensor': + W_inverse = W_inverse.half() + self.W_inverse = W_inverse + z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) + return z + else: + # Forward computation + log_det_W = batch_size * n_of_groups * torch.logdet(W) + z = self.conv(z) + return z, log_det_W + + +class WN(torch.nn.Module): + """ + This is the WaveNet like layer for the affine coupling. The primary difference + from WaveNet is the convolutions need not be causal. There is also no dilation + size reset. The dilation only doubles on each layer + """ + def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, + kernel_size): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + assert(n_channels % 2 == 0) + self.n_layers = n_layers + self.n_channels = n_channels + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + start = torch.nn.Conv1d(n_in_channels, n_channels, 1) + start = torch.nn.utils.weight_norm(start, name='weight') + self.start = start + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = 2 ** i + padding = int((kernel_size*dilation - dilation)/2) + in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2*n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, forward_input): + audio, spect = forward_input + audio = self.start(audio) + output = torch.zeros_like(audio) + n_channels_tensor = torch.IntTensor([self.n_channels]) + + spect = self.cond_layer(spect) + + for i in range(self.n_layers): + spect_offset = i*2*self.n_channels + acts = fused_add_tanh_sigmoid_multiply( + self.in_layers[i](audio), + spect[:,spect_offset:spect_offset+2*self.n_channels,:], + n_channels_tensor) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + audio = audio + res_skip_acts[:,:self.n_channels,:] + output = output + res_skip_acts[:,self.n_channels:,:] + else: + output = output + res_skip_acts + + return self.end(output) + + +class WaveGlow(torch.nn.Module): + def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, + n_early_size, WN_config): + super(WaveGlow, self).__init__() + + self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, + n_mel_channels, + 1024, stride=256) + assert(n_group % 2 == 0) + self.n_flows = n_flows + self.n_group = n_group + self.n_early_every = n_early_every + self.n_early_size = n_early_size + self.WN = torch.nn.ModuleList() + self.convinv = torch.nn.ModuleList() + + n_half = int(n_group/2) + + # Set up layers with the right sizes based on how many dimensions + # have been output already + n_remaining_channels = n_group + for k in range(n_flows): + if k % self.n_early_every == 0 and k > 0: + n_half = n_half - int(self.n_early_size/2) + n_remaining_channels = n_remaining_channels - self.n_early_size + self.convinv.append(Invertible1x1Conv(n_remaining_channels)) + self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) + self.n_remaining_channels = n_remaining_channels # Useful during inference + + def forward(self, forward_input): + """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time + """ + spect, audio = forward_input + + # Upsample spectrogram to size of audio + spect = self.upsample(spect) + assert(spect.size(2) >= audio.size(1)) + if spect.size(2) > audio.size(1): + spect = spect[:, :, :audio.size(1)] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + + audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) + output_audio = [] + log_s_list = [] + log_det_W_list = [] + + for k in range(self.n_flows): + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:,:self.n_early_size,:]) + audio = audio[:,self.n_early_size:,:] + + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) + + n_half = int(audio.size(1)/2) + audio_0 = audio[:,:n_half,:] + audio_1 = audio[:,n_half:,:] + + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = torch.exp(log_s)*audio_1 + b + log_s_list.append(log_s) + + audio = torch.cat([audio_0, audio_1],1) + + output_audio.append(audio) + return torch.cat(output_audio,1), log_s_list, log_det_W_list + + def infer(self, spect, sigma=1.0): + spect = self.upsample(spect) + # trim conv artifacts. maybe pad spec to kernel multiple + time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] + spect = spect[:, :, :-time_cutoff] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + + if spect.type() == 'torch.cuda.HalfTensor': + audio = torch.cuda.HalfTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + else: + audio = torch.cuda.FloatTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + + audio = torch.autograd.Variable(sigma*audio) + + for k in reversed(range(self.n_flows)): + n_half = int(audio.size(1)/2) + audio_0 = audio[:,:n_half,:] + audio_1 = audio[:,n_half:,:] + + output = self.WN[k]((audio_0, spect)) + + s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = (audio_1 - b)/torch.exp(s) + audio = torch.cat([audio_0, audio_1],1) + + audio = self.convinv[k](audio, reverse=True) + + if k % self.n_early_every == 0 and k > 0: + if spect.type() == 'torch.cuda.HalfTensor': + z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + else: + z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + audio = torch.cat((sigma*z, audio),1) + + audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data + return audio + + @staticmethod + def remove_weightnorm(model): + waveglow = model + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) + WN.res_skip_layers = remove(WN.res_skip_layers) + return waveglow + + +def remove(conv_list): + new_conv_list = torch.nn.ModuleList() + for old_conv in conv_list: + old_conv = torch.nn.utils.remove_weight_norm(old_conv) + new_conv_list.append(old_conv) + return new_conv_list diff --git a/examples/textless_nlp/gslm/unit2speech/multiproc.py b/examples/textless_nlp/gslm/unit2speech/multiproc.py new file mode 100644 index 0000000000..2a287a4e97 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/multiproc.py @@ -0,0 +1,27 @@ +import os +import time +import torch +import sys +import subprocess + +argslist = list(sys.argv)[1:] +log_dir = argslist[-1] +num_gpus = torch.cuda.device_count() +argslist.append('--n_gpus={}'.format(num_gpus)) +workers = [] +job_id = time.strftime("%Y_%m_%d-%H%M%S") +argslist.append("--group_name=group_{}".format(job_id)) + +print("GPU log directory is {}".format(log_dir)) +os.makedirs(log_dir, exist_ok=True) +for i in range(num_gpus): + argslist.append('--rank={}'.format(i)) + stdout = None if i == 0 else open("{}/{}_GPU_{}.log".format(log_dir, job_id, i), + "w") + print(argslist) + p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) + workers.append(p) + argslist = argslist[:-1] + +for p in workers: + p.wait() diff --git a/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py b/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py new file mode 100644 index 0000000000..f226d5f505 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import soundfile as sf +from examples.textless_nlp.gslm.unit2speech.tts_data import ( + TacotronInputDataset, +) +from examples.textless_nlp.gslm.unit2speech.utils import ( + load_quantized_audio_from_file, + load_tacotron, + load_waveglow, + synthesize_audio, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Wav2Vec 2.0 speech generator." + ) + parser.add_argument( + "--quantized_unit_path", + type=str, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--tts_model_path", + type=str, + help="TTS model file path to use for inference", + ) + parser.add_argument( + "--waveglow_path", + type=str, + help="Path to the waveglow checkpoint (vocoder).", + ) + parser.add_argument("--max_decoder_steps", type=int, default=2000) + parser.add_argument("--denoiser_strength", type=float, default=0.1) + parser.add_argument( + "--out_audio_dir", + type=str, + help="Output directory to dump audio files", + ) + + return parser + + +def main(args, logger): + # Load quantized audio + logger.info(f"Loading quantized audio from {args.quantized_unit_path}...") + names_batch, quantized_units_batch = load_quantized_audio_from_file( + file_path=args.quantized_unit_path + ) + + logger.info(f"Loading TTS model from {args.tts_model_path}...") + tacotron_model, sample_rate, hparams = load_tacotron( + tacotron_model_path=args.tts_model_path, + max_decoder_steps=args.max_decoder_steps, + ) + + logger.info(f"Loading Waveglow model from {args.waveglow_path}...") + waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path) + + tts_dataset = TacotronInputDataset(hparams) + for name, quantized_units in zip(names_batch, quantized_units_batch): + quantized_units_str = " ".join(map(str, quantized_units)) + tts_input = tts_dataset.get_tensor(quantized_units_str) + mel, aud, aud_dn, has_eos = synthesize_audio( + tacotron_model, + waveglow, + denoiser, + tts_input.unsqueeze(0), + strength=args.denoiser_strength, + ) + out_file_path = os.path.join(args.out_audio_dir, f"{name}.wav") + sf.write( + f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate + ) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py new file mode 100644 index 0000000000..b5af7f723e --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py @@ -0,0 +1,93 @@ +import torch +import numpy as np +from scipy.signal import get_window +import librosa.util as librosa_util + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py new file mode 100644 index 0000000000..e2e35c1a8c --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py @@ -0,0 +1,90 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py new file mode 100644 index 0000000000..62bfef745c --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py @@ -0,0 +1,65 @@ +""" from https://github.com/keithito/tacotron """ + +import re + + +valid_symbols = [ + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', + 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', + 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', + 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', + 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + + def __len__(self): + return len(self._entries) + + + def lookup(self, word): + '''Returns list of ARPAbet pronunciations of the given word.''' + return self._entries.get(word.upper()) + + + +_alt_re = re.compile(r'\([0-9]+\)') + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py new file mode 100644 index 0000000000..f10d557ff5 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py @@ -0,0 +1,103 @@ +import torch +from librosa.filters import mel as librosa_mel_fn +from .audio_processing import dynamic_range_compression +from .audio_processing import dynamic_range_decompression +from .stft import STFT +from .utils import get_mask_from_lengths + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +class GlobalAvgPool(torch.nn.Module): + def __init__(self): + super(GlobalAvgPool, self).__init__() + + def forward(self, x, lengths=None): + """Average pooling across time steps (dim=1) with optionally lengths. + Args: + x: torch.Tensor of shape (N, T, ...) + lengths: None or torch.Tensor of shape (N,) + dim: dimension to pool + """ + if lengths is None: + return x.mean(dim=1, keepdim=False) + else: + mask = get_mask_from_lengths(lengths).type(x.type()).to(x.device) + mask_shape = list(mask.size()) + [1 for _ in range(x.ndimension()-2)] + mask = mask.reshape(*mask_shape) + numer = (x * mask).sum(dim=1, keepdim=False) + denom = mask.sum(dim=1, keepdim=False) + return numer / denom + + +class TacotronSTFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=256, win_length=1024, + n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, + mel_fmax=8000.0): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer('mel_basis', mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert(torch.min(y.data) >= -1) + assert(torch.max(y.data) <= 1) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + return mel_output diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py new file mode 100644 index 0000000000..ccf132b150 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py @@ -0,0 +1,669 @@ +from math import sqrt +import torch +import torch.distributions as distr +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F +from .layers import ConvNorm, LinearNorm, GlobalAvgPool +from .utils import to_gpu, get_mask_from_lengths + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm(2, attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, bias=False, stride=1, + dilation=1) + self.location_dense = LinearNorm(attention_n_filters, attention_dim, + bias=False, w_init_gain='tanh') + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(Attention, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class Prenet(nn.Module): + def __init__(self, in_dim, sizes): + super(Prenet, self).__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [LinearNorm(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes)]) + + def forward(self, x): + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class Postnet(nn.Module): + """Postnet + - Five 1-d convolution with 512 channels and kernel size 5 + """ + + def __init__(self, hparams): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + for i in range(1, hparams.postnet_n_convolutions - 1): + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, + hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.BatchNorm1d(hparams.n_mel_channels)) + ) + + def forward(self, x): + for i in range(len(self.convolutions) - 1): + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) + + return x + + +class Encoder(nn.Module): + """Encoder module: + - Three 1-d convolution banks + - Bidirectional LSTM + """ + def __init__(self, hparams): + super(Encoder, self).__init__() + + convolutions = [] + for _ in range(hparams.encoder_n_convolutions): + conv_layer = nn.Sequential( + ConvNorm(hparams.encoder_embedding_dim, + hparams.encoder_embedding_dim, + kernel_size=hparams.encoder_kernel_size, stride=1, + padding=int((hparams.encoder_kernel_size - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.BatchNorm1d(hparams.encoder_embedding_dim)) + convolutions.append(conv_layer) + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM(hparams.encoder_embedding_dim, + int(hparams.encoder_embedding_dim / 2), 1, + batch_first=True, bidirectional=True) + + def forward(self, x, input_lengths): + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + # pytorch tensor are not reversible, hence the conversion + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + outputs, _ = nn.utils.rnn.pad_packed_sequence( + outputs, batch_first=True) + + return outputs + + def inference(self, x): + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + return outputs + + +class AudioEncoder(nn.Module): + def __init__(self, hparams): + super(AudioEncoder, self).__init__() + + assert hparams.lat_dim > 0 + + convolutions = [] + inp_dim = hparams.n_mel_channels + for _ in range(hparams.lat_n_convolutions): + conv_layer = nn.Sequential( + ConvNorm(inp_dim, hparams.lat_n_filters, + kernel_size=hparams.lat_kernel_size, stride=1, + padding=int((hparams.lat_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.lat_n_filters)) + inp_dim = hparams.lat_n_filters + convolutions.append(conv_layer) + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM(hparams.lat_n_filters, + int(hparams.lat_n_filters / 2), + hparams.lat_n_blstms, batch_first=True, + bidirectional=True) + self.pool = GlobalAvgPool() + + self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) + self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) + self.lat_dim = hparams.lat_dim + + def forward(self, x, lengths): + """ + Args: + x (torch.Tensor): (B, F, T) + """ + + for conv in self.convolutions: + x = F.dropout(F.tanh(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) # (B, T, D) + + # x may not be sorted by length. Sort->process->unsort + max_len = x.size(1) + assert max_len == torch.max(lengths).item() + + lengths, perm_idx = lengths.sort(0, descending=True) + x = x[perm_idx] + x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + _, unperm_idx = perm_idx.sort(0) + outputs = outputs[unperm_idx] # (B, T, D) + lengths = lengths[unperm_idx] # (B, T, D) + + outputs = self.pool(outputs, lengths) # (B, D) + + mu = self.mu_proj(outputs) + logvar = self.logvar_proj(outputs) + z = distr.Normal(mu, logvar).rsample() + return z, mu, logvar + + +class Decoder(nn.Module): + def __init__(self, hparams): + super(Decoder, self).__init__() + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + self.encoder_embedding_dim = hparams.encoder_embedding_dim + self.obs_dim = hparams.obs_dim + self.lat_dim = hparams.lat_dim + self.attention_rnn_dim = hparams.attention_rnn_dim + self.decoder_rnn_dim = hparams.decoder_rnn_dim + self.prenet_dim = hparams.prenet_dim + self.max_decoder_steps = hparams.max_decoder_steps + self.gate_threshold = hparams.gate_threshold + self.p_attention_dropout = hparams.p_attention_dropout + self.p_decoder_dropout = hparams.p_decoder_dropout + + self.prenet = Prenet( + hparams.n_mel_channels * hparams.n_frames_per_step, + [hparams.prenet_dim, hparams.prenet_dim]) + + self.attention_rnn = nn.LSTMCell( + hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.attention_rnn_dim) + + self.attention_layer = Attention( + hparams.attention_rnn_dim, hparams.encoder_embedding_dim, + hparams.attention_dim, hparams.attention_location_n_filters, + hparams.attention_location_kernel_size) + + encoder_tot_dim = (hparams.encoder_embedding_dim + \ + hparams.lat_dim + hparams.obs_dim) + self.decoder_rnn = nn.LSTMCell( + hparams.attention_rnn_dim + encoder_tot_dim, + hparams.decoder_rnn_dim, 1) + + self.linear_projection = LinearNorm( + hparams.decoder_rnn_dim + encoder_tot_dim, + hparams.n_mel_channels * hparams.n_frames_per_step) + + self.gate_layer = LinearNorm( + hparams.decoder_rnn_dim + encoder_tot_dim, 1, + bias=True, w_init_gain='sigmoid') + + def get_go_frame(self, memory): + """ Gets all zeros frames to use as first decoder input + PARAMS + ------ + memory: decoder outputs + + RETURNS + ------- + decoder_input: all zeros frames + """ + B = memory.size(0) + decoder_input = Variable(memory.data.new( + B, self.n_mel_channels * self.n_frames_per_step).zero_()) + return decoder_input + + def initialize_decoder_states(self, memory, obs_and_lat, mask): + """ Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + mask: Mask for padded data if training, expects None for inference + """ + B = memory.size(0) + MAX_TIME = memory.size(1) + + self.attention_hidden = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + self.attention_cell = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + + self.decoder_hidden = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + self.decoder_cell = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + + self.attention_weights = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_context = Variable(memory.data.new( + B, self.encoder_embedding_dim).zero_()) + + self.memory = memory + self.processed_memory = self.attention_layer.memory_layer(memory) + self.obs_and_lat = obs_and_lat + self.mask = mask + + def parse_decoder_inputs(self, decoder_inputs): + """ Prepares decoder inputs, i.e. mel outputs + PARAMS + ------ + decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs + + RETURNS + ------- + inputs: processed decoder inputs + + """ + # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1)/self.n_frames_per_step), -1) + # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): + """ Prepares decoder outputs for output + PARAMS + ------ + mel_outputs: + gate_outputs: gate output energies + alignments: + + RETURNS + ------- + mel_outputs: + gate_outpust: gate output energies + alignments: + """ + # (T_out, B) -> (B, T_out) + alignments = torch.stack(alignments).transpose(0, 1) + # (T_out, B) -> (B, T_out) + gate_outputs = torch.stack(gate_outputs).transpose(0, 1) + gate_outputs = gate_outputs.contiguous() + # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) + mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() + # decouple frames per step + mel_outputs = mel_outputs.view( + mel_outputs.size(0), -1, self.n_mel_channels) + # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) + mel_outputs = mel_outputs.transpose(1, 2) + + return mel_outputs, gate_outputs, alignments + + def decode(self, decoder_input): + """ Decoder step using stored states, attention and memory + PARAMS + ------ + decoder_input: previous mel output + + RETURNS + ------- + mel_output: + gate_output: gate output energies + attention_weights: + """ + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.p_attention_dropout, self.training) + + attention_weights_cat = torch.cat( + (self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), dim=1) + self.attention_context, self.attention_weights = self.attention_layer( + self.attention_hidden, self.memory, self.processed_memory, + attention_weights_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + if self.obs_and_lat is not None: + decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, self.p_decoder_dropout, self.training) + + decoder_hidden_attention_context = torch.cat( + (self.decoder_hidden, self.attention_context), dim=1) + if self.obs_and_lat is not None: + decoder_hidden_attention_context = torch.cat( + (decoder_hidden_attention_context, self.obs_and_lat), dim=1) + decoder_output = self.linear_projection( + decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + return decoder_output, gate_prediction, self.attention_weights + + def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths): + """ Decoder forward pass for training + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs + memory_lengths: Encoder output lengths for attention masking. + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + + decoder_input = self.get_go_frame(memory).unsqueeze(0) + decoder_inputs = self.parse_decoder_inputs(decoder_inputs) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + + self.initialize_decoder_states( + memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths)) + + mel_outputs, gate_outputs, alignments = [], [], [] + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + mel_output, gate_output, attention_weights = self.decode( + decoder_input) + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [attention_weights] + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + return mel_outputs, gate_outputs, alignments + + def inference(self, memory, obs_and_lat, ret_has_eos=False): + """ Decoder inference + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, obs_and_lat, mask=None) + + mel_outputs, gate_outputs, alignments = [], [], [] + has_eos = False + while True: + decoder_input = self.prenet(decoder_input) + mel_output, gate_output, alignment = self.decode(decoder_input) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] + alignments += [alignment] + + if torch.sigmoid(gate_output.data) > self.gate_threshold: + has_eos = True + break + elif len(mel_outputs) == self.max_decoder_steps: + # print("Warning! Reached max decoder steps") + break + + decoder_input = mel_output + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + if ret_has_eos: + return mel_outputs, gate_outputs, alignments, has_eos + else: + return mel_outputs, gate_outputs, alignments + + +class Tacotron2(nn.Module): + def __init__(self, hparams): + super(Tacotron2, self).__init__() + self.mask_padding = hparams.mask_padding + self.fp16_run = hparams.fp16_run + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + + # initialize text encoder embedding + self.embedding = nn.Embedding( + hparams.n_symbols, hparams.symbols_embedding_dim) + std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) + + # initialize observed attribute embedding + self.obs_embedding = None + if hparams.obs_dim > 0: + self.obs_embedding = nn.Embedding( + hparams.obs_n_class, hparams.obs_dim) + std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.obs_embedding.weight.data.uniform_(-val, val) + + self.encoder = Encoder(hparams) + self.decoder = Decoder(hparams) + self.postnet = Postnet(hparams) + + self.lat_encoder = None + if hparams.lat_dim > 0: + self.lat_encoder = AudioEncoder(hparams) + + def parse_batch(self, batch): + (text_padded, input_lengths, obs_labels, + mel_padded, gate_padded, output_lengths) = batch + text_padded = to_gpu(text_padded).long() + input_lengths = to_gpu(input_lengths).long() + obs_labels = to_gpu(obs_labels).long() + max_len = torch.max(input_lengths.data).item() + mel_padded = to_gpu(mel_padded).float() + gate_padded = to_gpu(gate_padded).float() + output_lengths = to_gpu(output_lengths).long() + + return ( + (text_padded, input_lengths, obs_labels, + mel_padded, max_len, output_lengths), + (mel_padded, gate_padded)) + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask = ~get_mask_from_lengths(output_lengths) + mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + outputs[0].data.masked_fill_(mask, 0.0) + outputs[1].data.masked_fill_(mask, 0.0) + outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies + + return outputs + + def forward(self, inputs): + (text_inputs, text_lengths, obs_labels, + mels, max_len, output_lengths) = inputs + text_lengths, output_lengths = text_lengths.data, output_lengths.data + + embedded_inputs = self.embedding(text_inputs).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + + obs = None + if self.obs_embedding is not None: + obs = self.obs_embedding(obs_labels) + + lat, lat_mu, lat_logvar = None, None, None + if self.lat_encoder is not None: + (lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths) + + obs_and_lat = [x for x in [obs, lat] if x is not None] + if bool(obs_and_lat): + obs_and_lat = torch.cat(obs_and_lat, dim=-1) + else: + obs_and_lat = None + + mel_outputs, gate_outputs, alignments = self.decoder( + encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + return self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, + lat_mu, lat_logvar], + output_lengths) + + def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False): + embedded_inputs = self.embedding(inputs).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if obs_labels is None: + obs_labels = torch.LongTensor(len(inputs)) + obs_labels = obs_labels.to(inputs.device).zero_() + + obs = None + if self.obs_embedding is not None: + obs = self.obs_embedding(obs_labels) + + if self.lat_encoder is not None: + if lat is None: + lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim) + lat = lat.to(inputs.device).zero_().type(encoder_outputs.type()) + + obs_and_lat = [x for x in [obs, lat] if x is not None] + if bool(obs_and_lat): + obs_and_lat = torch.cat(obs_and_lat, dim=-1) + else: + obs_and_lat = None + + mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference( + encoder_outputs, obs_and_lat, ret_has_eos=True) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + outputs = self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) + + if ret_has_eos: + return outputs + [has_eos] + else: + return outputs diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py new file mode 100644 index 0000000000..0d5f7fa818 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import inflect +import re + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py new file mode 100644 index 0000000000..63fcd431e2 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py @@ -0,0 +1,141 @@ +""" +BSD 3-Clause License + +Copyright (c) 2017, Prem Seetharaman +All rights reserved. + +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from .audio_processing import window_sumsquare + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py new file mode 100644 index 0000000000..5f0d70fdad --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py @@ -0,0 +1,18 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +from . import cmudict + +_pad = '_' +_punctuation = '!\'(),.:;? ' +_special = '-' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +_arpabet = ['@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py new file mode 100644 index 0000000000..49e2ca498b --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py @@ -0,0 +1,107 @@ +""" from https://github.com/keithito/tacotron """ +import numpy as np +import re +from . import cleaners +from .symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + +# Special symbols +SOS_TOK = '<s>' +EOS_TOK = '</s>' + +def text_to_sequence(text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + return sequence + + +def sample_code_chunk(code, size): + assert(size > 0 and size <= len(code)) + start = np.random.randint(len(code) - size + 1) + end = start + size + return code[start:end], start, end + + +def code_to_sequence(code, code_dict, collapse_code): + if collapse_code: + prev_c = None + sequence = [] + for c in code: + if c in code_dict and c != prev_c: + sequence.append(code_dict[c]) + prev_c = c + else: + sequence = [code_dict[c] for c in code if c in code_dict] + if len(sequence) < 0.95 * len(code): + print('WARNING : over 5%% codes are OOV') + + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') + + +def sequence_to_code(sequence, code_dict): + '''Analogous to sequence_to_text''' + id_to_code = {i: c for c, i in code_dict.items()} + return ' '.join([id_to_code[i] for i in sequence]) + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(['@' + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s != '_' and s != '~' diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py new file mode 100644 index 0000000000..66a426d222 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py @@ -0,0 +1,167 @@ +import collections +import io +import json +import librosa +import numpy as np +import soundfile as sf +import time +import torch +from scipy.io.wavfile import read +from .text import SOS_TOK, EOS_TOK + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) + mask = (ids < lengths.unsqueeze(1)) + return mask + + +def load_wav_to_torch(full_path, sr=None): + data, sr = librosa.load(full_path, sr=sr) + data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling + data = data * 32768.0 # match values loaded by scipy + return torch.FloatTensor(data.astype(np.float32)), sr + + +def read_binary_audio(bin_data, tar_sr=None): + """ + read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32` + `numpy.ndarray` + + RETURNS: + data (np.ndarray) : audio of shape (n,) or (2, n) + tar_sr (int) : sample rate + """ + data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32') + data = data.T + if (tar_sr is not None) and (ori_sr != tar_sr): + data = librosa.resample(data, ori_sr, tar_sr) + else: + tar_sr = ori_sr + data = np.clip(data, -1, 1) + data = data * 32768.0 + return torch.FloatTensor(data.astype(np.float32)), tar_sr + + +def load_filepaths_and_text(filename): + with open(filename, encoding='utf-8') as f: + data = [json.loads(line.rstrip()) for line in f] + return data + + +def to_gpu(x): + x = x.contiguous() + + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return torch.autograd.Variable(x) + + +def load_code_dict(path, add_sos=False, add_eos=False): + if not path: + return {} + + with open(path, 'r') as f: + codes = ['_'] + [line.rstrip() for line in f] # '_' for pad + code_dict = {c: i for i, c in enumerate(codes)} + + if add_sos: + code_dict[SOS_TOK] = len(code_dict) + if add_eos: + code_dict[EOS_TOK] = len(code_dict) + assert(set(code_dict.values()) == set(range(len(code_dict)))) + + return code_dict + + +def load_obs_label_dict(path): + if not path: + return {} + with open(path, 'r') as f: + obs_labels = [line.rstrip() for line in f] + return {c: i for i, c in enumerate(obs_labels)} + + +# A simple timer class inspired from `tnt.TimeMeter` +class CudaTimer: + def __init__(self, keys): + self.keys = keys + self.reset() + + def start(self, key): + s = torch.cuda.Event(enable_timing=True) + s.record() + self.start_events[key].append(s) + return self + + def stop(self, key): + e = torch.cuda.Event(enable_timing=True) + e.record() + self.end_events[key].append(e) + return self + + def reset(self): + self.start_events = collections.defaultdict(list) + self.end_events = collections.defaultdict(list) + self.running_times = collections.defaultdict(float) + self.n = collections.defaultdict(int) + return self + + def value(self): + self._synchronize() + return {k: self.running_times[k] / self.n[k] for k in self.keys} + + def _synchronize(self): + torch.cuda.synchronize() + for k in self.keys: + starts = self.start_events[k] + ends = self.end_events[k] + if len(starts) == 0: + raise ValueError("Trying to divide by zero in TimeMeter") + if len(ends) != len(starts): + raise ValueError("Call stop before checking value!") + time = 0 + for start, end in zip(starts, ends): + time += start.elapsed_time(end) + self.running_times[k] += time * 1e-3 + self.n[k] += len(starts) + self.start_events = collections.defaultdict(list) + self.end_events = collections.defaultdict(list) + + +# Used to measure the time taken for multiple events +class Timer: + def __init__(self, keys): + self.keys = keys + self.n = {} + self.running_time = {} + self.total_time = {} + self.reset() + + def start(self, key): + self.running_time[key] = time.time() + return self + + def stop(self, key): + self.total_time[key] = time.time() - self.running_time[key] + self.n[key] += 1 + self.running_time[key] = None + return self + + def reset(self): + for k in self.keys: + self.total_time[k] = 0 + self.running_time[k] = None + self.n[k] = 0 + return self + + def value(self): + vals = {} + for k in self.keys: + if self.n[k] == 0: + raise ValueError("Trying to divide by zero in TimeMeter") + else: + vals[k] = self.total_time[k] / self.n[k] + return vals + diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py new file mode 100644 index 0000000000..6a6585e8b6 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py @@ -0,0 +1,40 @@ +# import sys +# sys.path.append('tacotron2') +import torch +from .layers import STFT + + +class Denoiser(torch.nn.Module): + """ Removes model bias from audio produced with waveglow """ + + def __init__(self, waveglow, filter_length=1024, n_overlap=4, + win_length=1024, mode='zeros'): + super(Denoiser, self).__init__() + self.stft = STFT(filter_length=filter_length, + hop_length=int(filter_length/n_overlap), + win_length=win_length).cuda() + if mode == 'zeros': + mel_input = torch.zeros( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + elif mode == 'normal': + mel_input = torch.randn( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + else: + raise Exception("Mode {} if not supported".format(mode)) + + with torch.no_grad(): + bias_audio = waveglow.infer(mel_input, sigma=0.0).float() + bias_spec, _ = self.stft.transform(bias_audio) + + self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) + + def forward(self, audio, strength=0.1): + audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) + audio_spec_denoised = audio_spec - self.bias_spec * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/examples/textless_nlp/gslm/unit2speech/tts_data.py b/examples/textless_nlp/gslm/unit2speech/tts_data.py new file mode 100644 index 0000000000..eb0f7c360d --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/tts_data.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import numpy as np +from examples.textless_nlp.gslm.unit2speech.tacotron2.text import ( + EOS_TOK, + SOS_TOK, + code_to_sequence, + text_to_sequence, +) +from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import ( + load_code_dict, +) + + +class TacotronInputDataset: + def __init__(self, hparams, append_str=""): + self.is_text = getattr(hparams, "text_or_code", "text") == "text" + if not self.is_text: + self.code_dict = load_code_dict(hparams.code_dict) + self.code_key = hparams.code_key + self.add_sos = hparams.add_sos + self.add_eos = hparams.add_eos + self.collapse_code = hparams.collapse_code + self.append_str = append_str + + def process_code(self, inp_str): + inp_toks = inp_str.split() + if self.add_sos: + inp_toks = [SOS_TOK] + inp_toks + if self.add_eos: + inp_toks = inp_toks + [EOS_TOK] + return code_to_sequence(inp_toks, self.code_dict, self.collapse_code) + + def process_text(self, inp_str): + return text_to_sequence(inp_str, ["english_cleaners"]) + + def get_tensor(self, inp_str): + # uid, txt, inp_str = self._get_data(idx) + inp_str = inp_str + self.append_str + if self.is_text: + inp_toks = self.process_text(inp_str) + else: + inp_toks = self.process_code(inp_str) + return torch.from_numpy(np.array(inp_toks)).long() + + def __len__(self): + return len(self.data) diff --git a/examples/textless_nlp/gslm/unit2speech/utils.py b/examples/textless_nlp/gslm/unit2speech/utils.py new file mode 100644 index 0000000000..7aced08d38 --- /dev/null +++ b/examples/textless_nlp/gslm/unit2speech/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2 +from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import ( + Denoiser, +) + + +def load_quantized_audio_from_file(file_path): + base_fname_batch, quantized_units_batch = [], [] + with open(file_path) as f: + for line in f: + base_fname, quantized_units_str = line.rstrip().split("|") + quantized_units = [int(q) for q in quantized_units_str.split(" ")] + base_fname_batch.append(base_fname) + quantized_units_batch.append(quantized_units) + return base_fname_batch, quantized_units_batch + + +def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0): + assert inp.size(0) == 1 + inp = inp.cuda() + if lab is not None: + lab = torch.LongTensor(1).cuda().fill_(lab) + + with torch.no_grad(): + _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True) + aud = waveglow.infer(mel, sigma=0.666) + aud_dn = denoiser(aud, strength=strength).squeeze(1) + return mel, aud, aud_dn, has_eos + + +def load_tacotron(tacotron_model_path, max_decoder_steps): + ckpt_dict = torch.load(tacotron_model_path) + hparams = ckpt_dict["hparams"] + hparams.max_decoder_steps = max_decoder_steps + sr = hparams.sampling_rate + model = Tacotron2(hparams) + model.load_state_dict(ckpt_dict["model_dict"]) + model = model.cuda().eval().half() + return model, sr, hparams + + +def load_waveglow(waveglow_path): + waveglow = torch.load(waveglow_path)["model"] + waveglow = waveglow.cuda().eval().half() + for k in waveglow.convinv: + k.float() + denoiser = Denoiser(waveglow) + return waveglow, denoiser diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index c99c6bf7d1..cc310088db 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -199,7 +199,8 @@ def build_model(self, model_cfg: FairseqDataclass): actualized_cfg = getattr(model, "cfg", None) if actualized_cfg is not None: - if "w2v_args" in actualized_cfg: + # if "w2v_args" in actualized_cfg: + if hasattr(actualized_cfg, "w2v_args"): model_cfg.w2v_args = actualized_cfg.w2v_args return model From db0175a882e8ae0f30d89b5a610373dbe032d528 Mon Sep 17 00:00:00 2001 From: Myle Ott <myleott@fb.com> Date: Fri, 27 Aug 2021 05:54:21 -0700 Subject: [PATCH 687/707] Require fairscale >= 0.4.0 to combine FSDP and --update-freq (#2239) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2239 Reviewed By: sshleifer, ngoyal2707 Differential Revision: D30574791 Pulled By: myleott fbshipit-source-id: 0f83e6ffe53d608292545884df269a604a57448d --- fairseq/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d53e650b0a..c86b1a51ec 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -62,6 +62,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): self.device = torch.device("cpu") if self.is_fsdp: + import fairscale if self.cfg.common.bf16: raise ValueError( "FullyShardedDataParallel is not compatible with --bf16 or " @@ -72,6 +73,11 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "FullyShardedDataParallel is not compatible with --zero-sharding " "option (it's already built in)" ) + if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0": + raise RuntimeError( + "Please update to fairscale 0.4.0 or newer when combining " + "--update-freq with FullyShardedDataParallel" + ) else: if ( hasattr(self.cfg.distributed_training, "cpu_offload") From 932a3d4aad6cae3ef05aad59e257eba1c765a36c Mon Sep 17 00:00:00 2001 From: Jingfei Du <jingfeidu@fb.com> Date: Mon, 30 Aug 2021 18:05:50 -0700 Subject: [PATCH 688/707] fix beam search with prefix tokens (#2227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: 1. added test for genereting pad tokens during beam search with prefix tokens 2. modified lprobs for pad token and prefix tokens to avoid generating pad # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2227 Reviewed By: xianxl Differential Revision: D30649356 Pulled By: jingfeidu fbshipit-source-id: d94903a912e767391c8fca61f98f65b5cea3b56e --- fairseq/sequence_generator.py | 22 ++++++------- tests/test_sequence_generator.py | 54 ++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index d9c906ceea..740c32d648 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -344,16 +344,6 @@ def _generate( probs = probs[:, -1, :] * self.lm_weight lprobs += probs - lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) - - lprobs[:, self.pad] = -math.inf # never select pad - lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty - - # handle max length constraint - if step >= max_len: - lprobs[:, : self.eos] = -math.inf - lprobs[:, self.eos + 1 :] = -math.inf - # handle prefix tokens (possibly with different lengths) if ( prefix_tokens is not None @@ -367,6 +357,16 @@ def _generate( # minimum length constraint (does not apply if using prefix_tokens) lprobs[:, self.eos] = -math.inf + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + # Record attention scores, only support avg_attn_scores is a Tensor if avg_attn_scores is not None: if attn is None: @@ -568,7 +568,7 @@ def _prefix_tokens( prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) - lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = torch.min(prefix_lprobs) lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index afbdfb6c2c..9273191962 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -563,6 +563,60 @@ def test_diverse_beam_search(self): self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5) +class TestPrefixBeamSearch(TestSequenceGeneratorBase): + def setUp(self): + # construct dummy dictionary + vocab_size = 10 + d = test_utils.dummy_dictionary(vocab_size=vocab_size) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) + self.eos = d.eos() + self.w1 = 4 + self.w2 = 5 + self.beam_size = 3 + + # construct prefix data + self.tokens = torch.LongTensor( + [ + [self.w1, self.w2, self.eos], + ] + ) + self.token_lengths = torch.LongTensor([2]) + + args = argparse.Namespace() + unk = 0.0 + args.beam_probs = [ + # prefix step 0: + torch.FloatTensor( + [ + # eos + [0.0, unk] + [1.0 / vocab_size] * vocab_size # beam 1 + ] * self.beam_size + ), + ] * vocab_size + + task = test_utils.TestTranslationTask.setup_task(args, d, d) + self.model = task.build_model(args) + self.tgt_dict = task.target_dictionary + + def test_prefix_beam_search(self): + search_strategy = search.BeamSearch(self.tgt_dict) + generator = SequenceGenerator( + [self.model], + self.tgt_dict, + beam_size=self.beam_size, + search_strategy=search_strategy, + ) + sample = { + "net_input": { + "src_tokens": self.tokens, + "src_lengths": self.token_lengths, + } + } + # make sure test sample doesn't break any assertion + generator.forward(sample, prefix_tokens=self.tokens[:, :-1]) + class TestTopPSamplingSearch(TestSequenceGeneratorBase): def setUp(self): # construct dummy dictionary From 5277ec47bdb51165592596e2a2c4c7ee650d9958 Mon Sep 17 00:00:00 2001 From: Rengan Xu <renganxu@fb.com> Date: Mon, 30 Aug 2021 21:47:52 -0700 Subject: [PATCH 689/707] Fix test_eval_bleu unittest (#2236) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2236 The test_eval_bleu unittest in TestTranslation in tests/test_binaries.py failed after the scarebleu version is updated to 2.0.0 in OSS testing tool. Added the fix so that the test can pass when scarebleu version is both 1.x and 2.0.0. Reviewed By: myleott, sravyapopuri388 Differential Revision: D30525920 fbshipit-source-id: 8ef27509cec45422a8d22003c87c2a7acb55225d --- fairseq/tasks/translation.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index ea80fa2e73..8647360867 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -418,14 +418,20 @@ def sum_logs(key): def compute_bleu(meters): import inspect - import sacrebleu - - fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] + try: + from sacrebleu.metrics import BLEU + comp_bleu = BLEU.compute_bleu + except ImportError: + # compatibility API for sacrebleu 1.x + import sacrebleu + comp_bleu = sacrebleu.compute_bleu + + fn_sig = inspect.getfullargspec(comp_bleu)[0] if "smooth_method" in fn_sig: smooth = {"smooth_method": "exp"} else: smooth = {"smooth": "exp"} - bleu = sacrebleu.compute_bleu( + bleu = comp_bleu( correct=meters["_bleu_counts"].sum, total=meters["_bleu_totals"].sum, sys_len=meters["_bleu_sys_len"].sum, From 68a81202a371574b3acb5f8a8c36bfac7ab255ed Mon Sep 17 00:00:00 2001 From: Pierre Andrews <mortimer@fb.com> Date: Tue, 31 Aug 2021 01:11:34 -0700 Subject: [PATCH 690/707] Indexed Huffman Coded dataset (#2029) Summary: ## What does this PR do? Currently, binarized dataset are stored as a bin representation of int tensors. At best, each int is coded as uint16 on disk. When coding a fixed size vocabulary dataset where we know the frequency of each symbol and where some symbols are more common than other, we can do better. This happens in particular when binarizing a dataset split in subword units as the most common "tokenizers" like bpe and spm will choose subwords with high frequencies over subwords with low frequencies. In practice, if we know the frequency of all symbols (or a good estimate), we can use entropy encoding methods to compress the data. The idea is to assign a compressed representation where frequent symbols will have shorter representations than unfrequent symbols. In this PR, we build a Huffman code from a frequency table and use this code to encode a dataset. The PR provides the huffman coder implementation (using the single queue approach as we usually start with a sorted set of symbols) as well as a memory map implementation of a dataset that stores the data compressed with a huffman code and can return indexed tensors from it. Over a whole dataset, depending on how many symbols we sample to evaluate the frequency, we can save between 25% and 30% of storage space. ## Follow Ups currently the binarizer/preprocess script make too many assumptions about the dataset writers so the huffman dataset writer cannot be used straight out of the box with it. I will make follow ups PRs to provide easy to use scripts to build such datasets. But it's as simple as doing: ``` code_builder = HuffmanCodeBuilder() with open(sample_file, 'r', encoding="utf-8") as input: for line in input: code_builder.add(*line.strip().split(" ")) coder = code_builder.build_code() with HuffmanMMapIndexedDatasetBuilder('/tmp/testing_huffman', coder) as builder: with open(dataset_file, 'r', encoding="utf-8") as input: for line in input: builder.add_item(line.strip().split(' ')) ``` a lot of the `HuffmanMMapIndexedDataset` code comes from the normal `MMapIndexedDataset` and we could probably extract commonalities in a base class the `HuffmanCoder` is also really a special kind of `Dictionary` and again, a common base class could be abstracted out of them. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2029 Reviewed By: dianaml0 Differential Revision: D29557468 Pulled By: Mortimerp9 fbshipit-source-id: a01b6d98f38f937934cadebb3786133e257adefe --- fairseq/data/dictionary.py | 2 +- fairseq/data/huffman/__init__.py | 21 ++ fairseq/data/huffman/huffman_coder.py | 265 ++++++++++++++++ .../huffman/huffman_mmap_indexed_dataset.py | 287 ++++++++++++++++++ fairseq/data/indexed_dataset.py | 9 + fairseq/dataclass/constants.py | 2 +- fairseq_cli/preprocess.py | 2 + setup.py | 1 + tests/test_huffman.py | 201 ++++++++++++ 9 files changed, 788 insertions(+), 2 deletions(-) create mode 100644 fairseq/data/huffman/__init__.py create mode 100644 fairseq/data/huffman/huffman_coder.py create mode 100644 fairseq/data/huffman/huffman_mmap_indexed_dataset.py create mode 100644 tests/test_huffman.py diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 6876b461d7..d3ef0f9896 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -267,7 +267,7 @@ def add_from_file(self, f): self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( - "Incorrect dictionary format, expected '<token> <cnt> [flags]'" + f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\"" ) def _save(self, f, kv_iterator): diff --git a/fairseq/data/huffman/__init__.py b/fairseq/data/huffman/__init__.py new file mode 100644 index 0000000000..9b61fafadb --- /dev/null +++ b/fairseq/data/huffman/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder +from .huffman_mmap_indexed_dataset import ( + HuffmanMMapIndex, + HuffmanMMapIndexedDataset, + HuffmanMMapIndexedDatasetBuilder, + vocab_file_path, +) + +__all__ = [ + "HuffmanCoder", + "HuffmanCodeBuilder", + "HuffmanMMapIndexedDatasetBuilder", + "HuffmanMMapIndexedDataset", + "HuffmanMMapIndex", + "vocab_file_path", +] diff --git a/fairseq/data/huffman/huffman_coder.py b/fairseq/data/huffman/huffman_coder.py new file mode 100644 index 0000000000..6531f1547c --- /dev/null +++ b/fairseq/data/huffman/huffman_coder.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +import typing as tp +from collections import Counter, deque +from dataclasses import dataclass + +from bitarray import bitarray, util +from fairseq.data import Dictionary + +# basically we have to write to addressable bytes for the memory mapped +# dataset loader. Sentences that get encoded to a length that is not a +# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder) +BLOCKSIZE = 8 + + +class HuffmanCoder: + def __init__( + self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>" + ): + self.root = root + self.table = root.code_table() + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + + def _pad(self, a: bitarray) -> bitarray: + """ + bitpadding, 1 then 0. + + If the array is already a multiple of blocksize, we add a full block. + """ + pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1 + padding = bitarray("1" + "0" * pad_len) + return a + padding + + def _unpad(self, a: bitarray) -> bitarray: + """ + remove the bitpadding. + + There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that + """ + # count the 0 padding at the end until we find the first 1 + # we want to remove the one too + remove_cnt = util.rindex(a, 1) + return a[:remove_cnt] + + def encode(self, iter: tp.List[str]) -> bytes: + """ + encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes. + """ + a = bitarray() + for token in iter: + code = self.get_code(token) + if code is None: + if self.unk_word is None: + raise Exception(f"unknown token {token} cannot be encoded.") + else: + token = self.unk_word + a = a + self.get_code(token) + return self._pad(a).tobytes() + + def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]: + """ + take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id + """ + a = bitarray() + a.frombytes(bits) + return self.root.decode(self._unpad(a)) + + def get_code(self, symbol: str) -> tp.Optional[bitarray]: + node = self.get_node(symbol) + return None if node is None else node.code + + def get_node(self, symbol: str) -> "HuffmanNode": + return self.table.get(symbol) + + @classmethod + def from_file( + cls, + filename: str, + bos="<s>", + pad="<pad>", + eos="</s>", + unk="<unk>", + ) -> "HuffmanCoder": + builder = HuffmanCodeBuilder.from_file(filename) + return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk) + + def to_file(self, filename, sep="\t"): + nodes = list(self.table.values()) + nodes.sort(key=lambda n: n.id) + with open(filename, "w", encoding="utf-8") as output: + for n in nodes: + output.write(f"{n.symbol}{sep}{n.count}\n") + + def __iter__(self): + for n in self.table.values(): + yield n + + def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder": + builder = HuffmanCodeBuilder() + for n in self: + builder.increment(n.symbol, n.count) + for n in other_coder: + builder.increment(n.symbol, n.count) + return builder.build_code() + + def __eq__(self, other: "HuffmanCoder") -> bool: + return self.table == other.table + + def __len__(self) -> int: + return len(self.table) + + def __contains__(self, sym: str) -> bool: + return sym in self.table + + def to_dictionary(self) -> Dictionary: + dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos) + for n in self: + dictionary.add_symbol(n.symbol, n=n.count) + dictionary.finalize() + return dictionary + + +@dataclass +class HuffmanNode: + """ + a node in a Huffman tree + """ + + id: int + count: int + symbol: tp.Optional[str] = None + left: tp.Optional["HuffmanNode"] = None + right: tp.Optional["HuffmanNode"] = None + code: tp.Optional[bitarray] = None + + def is_leaf(self) -> bool: + return self.left is None and self.right is None + + def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]: + defaulted_prefix = prefix if prefix is not None else bitarray() + if self.is_leaf(): + self.code = ( + defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0") + ) # leaf could be the root if there is only one symbol + return {self.symbol: self} + + codes_right = self.right.code_table(defaulted_prefix + bitarray([0])) + codes_left = self.left.code_table(defaulted_prefix + bitarray([1])) + return {**codes_left, **codes_right} + + def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]: + current_node = self + for bit in bits: + if bit == 0: # go right + current_node = current_node.right + else: # go left + current_node = current_node.left + if current_node is None: + # we shouldn't be on a leaf here + raise Exception("fell off a leaf") + if current_node.is_leaf(): + yield current_node + current_node = self + if current_node != self: + raise Exception("couldn't decode all the bits") + + +class HuffmanCodeBuilder: + """ + build a dictionary with occurence count and then build the Huffman code for it. + """ + + def __init__(self): + self.symbols = Counter() + + def add_symbols(self, *syms) -> None: + self.symbols.update(syms) + + def increment(self, symbol: str, cnt: int) -> None: + self.symbols[symbol] += cnt + + @classmethod + def from_file(cls, filename): + c = cls() + with open(filename, "r", encoding="utf-8") as input: + for line in input: + split = re.split(r"[\s]+", line) + c.increment(split[0], int(split[1])) + return c + + def to_file(self, filename, sep="\t"): + with open(filename, "w", encoding="utf-8") as output: + for (tok, cnt) in self.symbols.most_common(): + output.write(f"{tok}{sep}{cnt}\n") + + def _smallest(self, q1: deque, q2: deque) -> HuffmanNode: + if len(q1) == 0: + return q2.pop() + + if len(q2) == 0: + return q1.pop() + + if q1[-1].count < q2[-1].count: + return q1.pop() + + return q2.pop() + + def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder": + new_c = self.symbols + c.symbols + new_b = HuffmanCodeBuilder() + new_b.symbols = new_c + return new_b + + def build_code( + self, + bos="<s>", + pad="<pad>", + eos="</s>", + unk="<unk>", + ) -> HuffmanCoder: + assert len(self.symbols) > 0, "cannot build code from empty list of symbols" + + if self.symbols[bos] == 0: + self.add_symbols(bos) + if self.symbols[pad] == 0: + self.add_symbols(pad) + if self.symbols[eos] == 0: + self.add_symbols(eos) + if self.symbols[unk] == 0: + self.add_symbols(unk) + + node_id = 0 + leaves_queue = deque( + [ + HuffmanNode(symbol=symbol, count=count, id=idx) + for idx, (symbol, count) in enumerate(self.symbols.most_common()) + ] + ) # left are the most common, right are the least common + + if len(leaves_queue) == 1: + root = leaves_queue.pop() + root.id = 0 + return HuffmanCoder(root) + + nodes_queue = deque() + + while len(leaves_queue) > 0 or len(nodes_queue) != 1: + # get the lowest two nodes at the head of each queue + node1 = self._smallest(leaves_queue, nodes_queue) + node2 = self._smallest(leaves_queue, nodes_queue) + + # add new node + nodes_queue.appendleft( + HuffmanNode( + count=node1.count + node2.count, left=node1, right=node2, id=node_id + ) + ) + node_id += 1 + + # we are left with the root + return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk) diff --git a/fairseq/data/huffman/huffman_mmap_indexed_dataset.py b/fairseq/data/huffman/huffman_mmap_indexed_dataset.py new file mode 100644 index 0000000000..3279dae89a --- /dev/null +++ b/fairseq/data/huffman/huffman_mmap_indexed_dataset.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import mmap +import os +import shutil +import struct +import typing as tp +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import indexed_dataset +from fairseq.data.huffman import HuffmanCoder +from fairseq.file_io import PathManager + + +class HuffmanMMapIndex: + """ + keep an index of the offsets in the huffman binary file. + First a header, then the list of sizes (num tokens) for each instance and finally + the addresses of each instance. + """ + + _HDR_MAGIC = b"HUFFIDX\x00\x00" + _VERSION = 1 + + @classmethod + def writer(cls, path: str, data_len: int): + class _Writer: + def __enter__(self): + self._file = open(path, "wb") + + # write header (magic + version) + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack("<Q", cls._VERSION)) + self._file.write(struct.pack("<Q", data_len)) + + return self + + def write(self, sizes, pointers): + # add number of items in the index to the header + self._file.write(struct.pack("<Q", len(sizes))) + + # write sizes + sizes = np.array(sizes, dtype=np.int32) + self._file.write(sizes.tobytes(order="C")) + del sizes + + # write address pointers + pointers = np.array(pointers, dtype=np.int64) + self._file.write(pointers.tobytes(order="C")) + del pointers + + def __exit__(self, exc_type, exc_val, exc_tb): + self._file.close() + + return _Writer() + + def __init__(self, path): + with open(path, "rb") as stream: + # read headers + magic_test = stream.read(9) + assert self._HDR_MAGIC == magic_test, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + (version,) = struct.unpack("<Q", stream.read(8)) + assert ( + self._VERSION == version + ), "Unexpected file version f{version} != code version f{self._VERSION}" + + # read length of data file + (self._data_len,) = struct.unpack("<Q", stream.read(8)) + # read number of items in data file/index + (self._len,) = struct.unpack("<Q", stream.read(8)) + offset = stream.tell() + + indexed_dataset._warmup_mmap_file(path) + + self._bin_buffer_mmap = np.memmap(path, mode="r", order="C") + self._bin_buffer = memoryview(self._bin_buffer_mmap) + self._sizes = np.frombuffer( + self._bin_buffer, dtype=np.int32, count=self._len, offset=offset + ) + self._pointers = np.frombuffer( + self._bin_buffer, + dtype=np.int64, + count=self._len, + offset=offset + self._sizes.nbytes, + ) + + def __del__(self): + self._bin_buffer_mmap._mmap.close() + del self._bin_buffer_mmap + + def __iter__(self): + for i in range(self._len): + yield self[i] + + @property + def data_len(self): + return self._data_len + + @property + def sizes(self): + return self._sizes + + @lru_cache(maxsize=8) + def __getitem__(self, i): + return self._pointers[i], self._sizes[i] + + def __len__(self): + return self._len + + +def vocab_file_path(prefix_path): + return prefix_path + ".vocab" + + +class HuffmanMMapIndexedDataset(torch.utils.data.Dataset): + """ + an indexed dataset that use mmap and memoryview to access data from disk + that was compressed with a HuffmanCoder. + """ + + def __init__(self, prefix_path): + super().__init__() + + self._prefix_path = None + self._index = None + self._bin_buffer = None + self._coder = None + self._file = None + + self._bin_buffer_mmap = None + + self._do_init(prefix_path) + + def __getstate__(self): + return self._prefix_path + + def __setstate__(self, state): + self._do_init(state) + + def _do_init(self, prefix_path): + self._prefix_path = prefix_path + self._index = HuffmanMMapIndex( + indexed_dataset.index_file_path(self._prefix_path) + ) + self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path)) + + indexed_dataset._warmup_mmap_file( + indexed_dataset.data_file_path(self._prefix_path) + ) + self._file = os.open( + indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY + ) + self._bin_buffer_mmap = mmap.mmap( + self._file, + self._index.data_len, + access=mmap.ACCESS_READ, + ) + self._bin_buffer = memoryview(self._bin_buffer_mmap) + + def __del__(self): + del self._bin_buffer + if self._file: + os.close(self._file) + del self._index + + def __len__(self): + return len(self._index) + + def _decode(self, i): + ptr, _ = self._index[i] + if i == 0: + raw_bytes = self._bin_buffer[:ptr] + else: + (prev_ptr, _) = self._index[i - 1] + raw_bytes = self._bin_buffer[prev_ptr:ptr] + + return self._coder.decode(raw_bytes.tobytes()) + + @lru_cache(maxsize=8) + def __getitem__(self, i): + nodes = self._decode(i) + return torch.tensor([n.id for n in nodes], dtype=torch.int64) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + def get_symbols(self, i): + nodes = self._decode(i) + for n in nodes: + yield n.symbol + + @property + def sizes(self): + return self._index.sizes + + @property + def supports_prefetch(self): + return False + + @property + def coder(self): + return self._coder + + @staticmethod + def exists(prefix_path): + return ( + PathManager.exists(indexed_dataset.index_file_path(prefix_path)) + and PathManager.exists(indexed_dataset.data_file_path(prefix_path)) + and PathManager.exists(vocab_file_path(prefix_path)) + ) + + +class HuffmanMMapIndexedDatasetBuilder: + """ + Helper to build a memory mapped datasets with a huffman encoder. + You can either open/close this manually or use it as a ContextManager. + Provide your own coder, it will then be stored alongside the dataset. + The builder will first write the vocab file, then open the binary file so you can stream + into it, finally the index will be written when the builder is closed (your index should fit in memory). + """ + + def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None: + self._path_prefix = path_prefix + self._coder = coder + self._sizes = [] + self._ptrs = [] + self._data_len = 0 + + def open(self): + self._coder.to_file(vocab_file_path(self._path_prefix)) + self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb") + + def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder": + self.open() + return self + + def add_item(self, tokens: tp.List[str]) -> None: + """ + add a list of tokens to the dataset, they will compressed with the + provided coder before being written to file. + """ + encoded = self._coder.encode(tokens) + code_len = len(encoded) + last_ptr = 0 + if len(self._ptrs) > 0: + last_ptr = self._ptrs[-1] + self._sizes.append(len(tokens)) + self._ptrs.append(last_ptr + code_len) + self._data_len += code_len + self._data_file.write(encoded) + + def append(self, other_dataset_path_prefix: str) -> None: + """ + append an existing dataset. + Beware, if it wasn't built with the same coder, you are in trouble. + """ + other_index = HuffmanMMapIndex( + indexed_dataset.index_file_path(other_dataset_path_prefix) + ) + for (ptr, size) in other_index: + self._ptrs.append(ptr + self._data_len) + self._sizes.append(size) + + # Concatenate data + with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + self._data_len += other_index.data_len + + def close(self): + self._data_file.close() + with HuffmanMMapIndex.writer( + indexed_dataset.index_file_path(self._path_prefix), self._data_len + ) as index: + index.write(self._sizes, self._ptrs) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 802e37a7ff..23afb43356 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -12,6 +12,7 @@ from fairseq.dataclass.constants import DATASET_IMPL_CHOICES from fairseq.data.fasta_dataset import FastaDataset from fairseq.file_io import PathManager +from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex from . import FairseqDataset @@ -48,6 +49,8 @@ def infer_dataset_impl(path): return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return "mmap" + elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]: + return "huffman" else: return None elif FastaDataset.exists(path): @@ -63,6 +66,8 @@ def make_builder(out_file, impl, vocab_size=None): ) elif impl == "fasta": raise NotImplementedError + elif impl == "huffman": + raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.") else: return IndexedDatasetBuilder(out_file) @@ -81,6 +86,8 @@ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): from fairseq.data.fasta_dataset import EncodedFastaDataset return EncodedFastaDataset(path, dictionary) + elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path): + return HuffmanMMapIndexedDataset(path) return None @@ -89,6 +96,8 @@ def dataset_exists(path, impl): return IndexedRawTextDataset.exists(path) elif impl == "mmap": return MMapIndexedDataset.exists(path) + elif impl == "huffman": + return HuffmanMMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 442c25982b..4f159cfe9a 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -44,7 +44,7 @@ def ChoiceEnum(choices: List[str]): "slow_mo", ]) DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) -DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) +DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( ["unigram", "ensemble", "vote", "dp", "bs"] diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index f7170eb00f..4ee9a1e3ba 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -41,6 +41,8 @@ def main(args): ) logger.info(args) + assert args.dataset_impl != "huffman", "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." + task = tasks.get_task(args.task) def train_path(lang): diff --git a/setup.py b/setup.py index 7a19b73c9e..c699936a99 100644 --- a/setup.py +++ b/setup.py @@ -209,6 +209,7 @@ def do_setup(package_data): "sacrebleu>=1.4.12", "torch", "tqdm", + "bitarray", ], dependency_links=dependency_links, packages=find_packages( diff --git a/tests/test_huffman.py b/tests/test_huffman.py new file mode 100644 index 0000000000..a8cd5222b4 --- /dev/null +++ b/tests/test_huffman.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import string +import typing as tp +import unittest +from collections import Counter +from tempfile import NamedTemporaryFile, TemporaryDirectory + +from fairseq.data import Dictionary, indexed_dataset +from fairseq.data.huffman import ( + HuffmanCodeBuilder, + HuffmanCoder, + HuffmanMMapIndexedDataset, + HuffmanMMapIndexedDatasetBuilder, +) + +POPULATION = string.ascii_letters + string.digits + + +def make_sentence() -> tp.List[str]: + length = random.randint(10, 50) + return random.choices( + population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1) + ) + + +def make_data(length=1000) -> tp.List[tp.List[str]]: + return ( + [make_sentence() for _ in range(0, length)] + # add all the symbols at least once + + [list(string.ascii_letters), list(string.digits)] + ) + + +def make_counts(data: tp.List[tp.List[str]]) -> Counter: + return Counter([symbol for sentence in data for symbol in sentence]) + + +def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder: + builder = HuffmanCodeBuilder() + for sentence in data: + builder.add_symbols(*sentence) + return builder + + +class TestCodeBuilder(unittest.TestCase): + def test_code_builder_can_count(self): + data = make_data() + counts = make_counts(data) + builder = make_code_builder(data) + + self.assertEqual(builder.symbols, counts) + + def test_code_builder_can_add(self): + data = make_data() + counts = make_counts(data) + builder = make_code_builder(data) + + new_builder = builder + builder + + self.assertEqual(new_builder.symbols, counts + counts) + + def test_code_builder_can_io(self): + data = make_data() + builder = make_code_builder(data) + + with NamedTemporaryFile() as tmp_fp: + builder.to_file(tmp_fp.name) + other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name) + + self.assertEqual(builder.symbols, other_builder.symbols) + + +class TestCoder(unittest.TestCase): + def test_coder_can_io(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with NamedTemporaryFile() as tmp_fp: + coder.to_file(tmp_fp.name) + other_coder = HuffmanCoder.from_file(tmp_fp.name) + + self.assertEqual(coder, other_coder) + + def test_coder_can_encode_decode(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + encoded = [coder.encode(sentence) for sentence in data] + decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded] + + self.assertEqual(decoded, data) + + unseen_data = make_data() + unseen_encoded = [coder.encode(sentence) for sentence in unseen_data] + unseen_decoded = [ + [n.symbol for n in coder.decode(enc)] for enc in unseen_encoded + ] + self.assertEqual(unseen_decoded, unseen_data) + + +def build_dataset(prefix, data, coder): + with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder: + for sentence in data: + builder.add_item(sentence) + + +def sizes(data): + return [len(sentence) for sentence in data] + + +class TestHuffmanDataset(unittest.TestCase): + def test_huffman_can_encode_decode(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix = os.path.join(dirname, "test1") + build_dataset(prefix, data, coder) + dataset = HuffmanMMapIndexedDataset(prefix) + + self.assertEqual(len(dataset), len(data)) + decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))] + + self.assertEqual(decoded, data) + data_sizes = [i.item() for i in dataset.sizes] + self.assertEqual(data_sizes, sizes(data)) + + def test_huffman_compresses(self): + data = make_data() + builder = make_code_builder(data) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix = os.path.join(dirname, "huffman") + build_dataset(prefix, data, coder) + + prefix_mmap = os.path.join(dirname, "mmap") + mmap_builder = indexed_dataset.make_builder( + indexed_dataset.data_file_path(prefix_mmap), + "mmap", + vocab_size=len(POPULATION), + ) + dictionary = Dictionary() + for c in POPULATION: + dictionary.add_symbol(c) + dictionary.finalize() + for sentence in data: + mmap_builder.add_item(dictionary.encode_line(" ".join(sentence))) + mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap)) + + huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size + mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size + self.assertLess(huff_size, mmap_size) + + def test_huffman_can_append(self): + data1 = make_data() + builder = make_code_builder(data1) + coder = builder.build_code() + + with TemporaryDirectory() as dirname: + prefix1 = os.path.join(dirname, "test1") + build_dataset(prefix1, data1, coder) + + data2 = make_data() + prefix2 = os.path.join(dirname, "test2") + build_dataset(prefix2, data2, coder) + + prefix3 = os.path.join(dirname, "test3") + + with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder: + builder.append(prefix1) + builder.append(prefix2) + + dataset = HuffmanMMapIndexedDataset(prefix3) + + self.assertEqual(len(dataset), len(data1) + len(data2)) + + decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))] + self.assertEqual(decoded1, data1) + + decoded2 = [ + list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset)) + ] + self.assertEqual(decoded2, data2) + + data_sizes = [i.item() for i in dataset.sizes] + self.assertEqual(data_sizes[: len(data1)], sizes(data1)) + self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2)) + + +if __name__ == "__main__": + unittest.main() From 8feccf94412424a4683b01090de36fa77cb4951d Mon Sep 17 00:00:00 2001 From: Vimal Manohar <vimalmanohar@fb.com> Date: Wed, 1 Sep 2021 11:43:25 -0700 Subject: [PATCH 691/707] EMA Summary: Adds Exponential moving average (EMA) model for Kaizen semi-supervised training https://arxiv.org/abs/2106.07759 1. Add `ema.store_ema` to enable storing EMA. EMA will be written to extra_state in the state dict while saving checkpoint. 2. `ema.ema_start_update` to control when the EMA starts accumulating 3. Tasks can use `uses_ema` property to decide if the EMA should be passed to the task. (Default is False) 4. `load_ema_from_checkpoint` can be used to load EMA model in place of the model to be used for evalutation. Pyspeech has eval-ema option for this. ``` This module has the EMA class used to store a copy of the exponentially decayed model params. Typical usage of EMA class involves initializing an object using an existing model (random or from a seed model) and setting the config like ema_decay, ema_start_update which determine how the EMA model is updated. After every update of the model i.e. at the end of the train_step, the EMA should be updated by passing the new model to the EMA.step function. The EMA model state dict can be stored in the extra state under the key of "ema" and dumped into a checkpoint and loaded. The EMA object can be passed to tasks by setting task.uses_ema property. EMA is a smoothed/ensemble model which might have better performance when used for inference or further fine-tuning. EMA class has a reverse function to load the EMA params into a model and use it like a regular model. ``` Reviewed By: cruvadom Differential Revision: D24238379 fbshipit-source-id: 879d3ba5070a614b7d365f9503af357001e875b2 --- fairseq/checkpoint_utils.py | 46 ++++++++ fairseq/dataclass/configs.py | 31 +++++ fairseq/models/ema/__init__.py | 20 ++++ fairseq/models/ema/ema.py | 189 +++++++++++++++++++++++++++++++ fairseq/options.py | 7 ++ fairseq/trainer.py | 78 ++++++++++++- tests/gpu/test_ema_gpu.py | 200 +++++++++++++++++++++++++++++++++ tests/test_ema.py | 199 ++++++++++++++++++++++++++++++++ 8 files changed, 769 insertions(+), 1 deletion(-) create mode 100644 fairseq/models/ema/__init__.py create mode 100644 fairseq/models/ema/ema.py create mode 100644 tests/gpu/test_ema_gpu.py create mode 100644 tests/test_ema.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index b8c46f8253..ef5d4c9022 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -810,3 +810,49 @@ def verify_checkpoint_directory(save_dir: str) -> None: raise e else: os.remove(temp_file_path) + + +def load_ema_from_checkpoint(fpath): + """Loads exponential moving averaged (EMA) checkpoint from input and + returns a model with ema weights. + + Args: + fpath: A string path of checkpoint to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + new_state = None + + with PathManager.open(fpath, 'rb') as f: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, 'cpu') + ), + ) + + # EMA model is stored in a separate "extra state" + model_params = new_state['extra_state']['ema'] + + for key in list(model_params.keys()): + p = model_params[key] + if isinstance(p, torch.HalfTensor): + p = p.float() + if key not in params_dict: + params_dict[key] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + raise ValueError("Key {} is repeated in EMA model params.".format(key)) + + if len(params_dict) == 0: + raise ValueError( + f"Input checkpoint path '{fpath}' does not contain " + "ema model weights, is this model trained with EMA?" + ) + + new_state['model'] = params_dict + return new_state diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 6a86ea0192..952f1ec4d1 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -984,6 +984,36 @@ class InteractiveConfig(FairseqDataclass): ) +@dataclass +class EMAConfig(FairseqDataclass): + store_ema: bool = field( + default=False, metadata={ + help: "store exponential moving average shadow model" + } + ) + ema_decay: float = field( + default=0.9999, metadata={ + "help": 'decay for exponential moving average model' + } + ) + ema_start_update : int = field( + default=0, metadata={"help": "start EMA update after this many model updates"} + ) + ema_seed_model : Optional[str] = field( + default=None, metadata={ + "help": "Seed to load EMA model from. " + "Used to load EMA model separately from the actual model." + } + ) + ema_update_freq : int = field( + default=1, metadata={"help": "Do EMA update every this many model updates"} + ) + ema_fp32: bool = field( + default=False, + metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, + ) + + @dataclass class FairseqConfig(FairseqDataclass): common: CommonConfig = CommonConfig() @@ -1004,3 +1034,4 @@ class FairseqConfig(FairseqDataclass): scoring: Any = None bpe: Any = None tokenizer: Any = None + ema: EMAConfig = EMAConfig() diff --git a/fairseq/models/ema/__init__.py b/fairseq/models/ema/__init__.py new file mode 100644 index 0000000000..503ceaa609 --- /dev/null +++ b/fairseq/models/ema/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +from .ema import EMA + + +def build_ema(model, cfg, device): + return EMA(model, cfg, device) + + +# automatically import any Python files in the models/ema/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.models.ema." + file_name) diff --git a/fairseq/models/ema/ema.py b/fairseq/models/ema/ema.py new file mode 100644 index 0000000000..6c0af69325 --- /dev/null +++ b/fairseq/models/ema/ema.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +""" +This module has the EMA class used to store a copy of the exponentially decayed +model params. + +Typical usage of EMA class involves initializing an object using an existing +model (random or from a seed model) and setting the config like ema_decay, +ema_start_update which determine how the EMA model is updated. After every +update of the model i.e. at the end of the train_step, the EMA should be updated +by passing the new model to the EMA.step function. The EMA model state dict +can be stored in the extra state under the key of "ema" and dumped +into a checkpoint and loaded. The EMA object can be passed to tasks +by setting task.uses_ema property. +EMA is a smoothed/ensemble model which might have better performance +when used for inference or further fine-tuning. EMA class has a +reverse function to load the EMA params into a model and use it +like a regular model. +""" + +import copy +import logging + +import torch +from fairseq import checkpoint_utils + + +class EMA(object): + """Exponential Moving Average of Fairseq Models + EMA keeps a copy of the exponentially decayed model params. + The set of params should include both gradient-descent and + non-gradient descent params, such as batch mean/var and buffers. + This is a modified implementation of + the open source code in https://github.com/zhawe01/fairseq-gec.git, + and internal source code in + fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. + + Similar to TF EMA. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. + EMA provides a averaged and smoothed set of model weights, and has been shown to + improve vision models. EMA class does all necessary functions to update, reload, + or init EMA methods. + + EMA object is initialized from an arbitrary model. By default, it is stored in + the same device (unless device specified at initialization) and with the + same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. + This stores the EMA parameters in fp32 only for the EMA update step, and + is used at the default precision otherwise. + EMA is usually enabled using EMAConfig with store_ema=True. Some important + parameters to configure EMA are + 1) ema_decay - The decay of EMA + 2) ema_update_freq - EMA is updated every this many model updates. + 3) ema_start_update - Start EMA update after this many model updates [default 0] + + Key methods: + 1) step - One update of EMA using new model + 2) restore - Update EMA from a state dict + 3) reverse - Load EMA into a model + 4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is + called from step. + 5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. + Note this is enabled only when ema_fp32=True + """ + + def __init__(self, model, config, device=None): + """ + @param model model to initialize the EMA with + @param config EMAConfig object with configuration like + ema_decay, ema_update_freq, ema_fp32 + @param device If provided, copy EMA to this device (e.g. gpu). + Otherwise EMA is in the same device as the model. + """ + + self.decay = config.ema_decay + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + self.config = config + self.fp32_params = {} + + if self.config.ema_seed_model is not None: + state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model) + self.model.load_state_dict(state["model"], strict=True) + + if device is not None: + logging.info(f"Copying EMA model to device {device}") + self.model = self.model.to(device=device) + + if self.config.ema_fp32: + self.build_fp32_params() + + self.update_freq_counter = 0 + + def get_model(self): + return self.model + + def build_fp32_params(self, state_dict=None): + """ + Store a copy of the EMA params in fp32. + If state dict is passed, the EMA params is copied from + the provided state dict. Otherwise, it is copied from the + current EMA model parameters. + """ + if not self.config.ema_fp32: + raise RuntimeError( + "build_fp32_params should not be called if ema_fp32=False. " + "Use ema_fp32=True if this is really intended." + ) + + if state_dict is None: + state_dict = self.model.state_dict() + + def _to_float(t): + return t.float() if torch.is_floating_point(t) else t + + for param_key in state_dict: + if param_key in self.fp32_params: + self.fp32_params[param_key].copy_(state_dict[param_key]) + else: + self.fp32_params[param_key] = _to_float(state_dict[param_key]) + + def restore(self, state_dict, build_fp32_params=False): + """ Load data from a model spec into EMA model """ + self.model.load_state_dict(state_dict, strict=False) + if build_fp32_params: + self.build_fp32_params(state_dict) + + def _set_decay(self, decay): + self.decay = decay + + def get_decay(self): + return self.decay + + def _step_internal(self, new_model, updates=None): + """ One update of the EMA model based on new model weights """ + decay = self.decay + + ema_state_dict = {} + ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + for key, param in new_model.state_dict().items(): + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + + if param.shape != ema_param.shape: + raise ValueError( + "incompatible tensor shapes between model param and ema param" + + "{} vs. {}".format(param.shape, ema_param.shape) + ) + if "version" in key: + # Do not decay a model.version pytorch param + continue + ema_param.mul_(decay) + ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay) + ema_state_dict[key] = ema_param + self.restore(ema_state_dict, build_fp32_params=False) + + def step(self, new_model, updates=None): + """ + One update of EMA which is done every self.config.ema_update_freq + updates of the model. + + @param updates The current number of model updates done. + Decay is set of 0 if model updates < ema_start_update, which means + the model will be simply copied over to the EMA. + When model updates >= ema_start_updates, then EMA is updated with + a decay of self.config.ema_decay. + """ + self._set_decay( + 0 + if updates is not None + and updates < self.config.ema_start_update + else self.config.ema_decay + ) + if updates is not None and self.config.ema_update_freq > 1: + self.update_freq_counter += 1 + if self.update_freq_counter >= self.config.ema_update_freq: + self._step_internal(new_model, updates) + self.update_freq_counter = 0 + else: + self._step_internal(new_model, updates) + + def reverse(self, model): + """ + Load the model parameters from EMA model. + Useful for inference or fine-tuning from the EMA model. + """ + model.load_state_dict(self.model.state_dict(), strict=False) + return model diff --git a/fairseq/options.py b/fairseq/options.py index 2d9f8381a7..03883fc561 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -20,6 +20,7 @@ GenerationConfig, InteractiveConfig, OptimizationConfig, + EMAConfig, ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -40,6 +41,7 @@ def get_training_parser(default_task="translation"): add_model_args(parser) add_optimization_args(parser) add_checkpoint_args(parser) + add_ema_args(parser) return parser @@ -379,3 +381,8 @@ def get_args( setattr(args, k, v) return args + + +def add_ema_args(parser): + group = parser.add_argument_group("EMA configuration") + gen_parser_from_dataclass(group, EMAConfig()) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index c86b1a51ec..e46ccfe0b8 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -22,6 +22,7 @@ from fairseq.distributed import utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics +from fairseq.models.ema import build_ema from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler from omegaconf import OmegaConf @@ -131,6 +132,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None + self._ema = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: @@ -256,6 +258,19 @@ def model(self): self._wrapped_model = self._model return self._wrapped_model + @property + def ema(self): + if self._ema is None: + self._build_ema() + return self._ema + + def _build_ema(self): + if self.cfg.ema.store_ema: + self._ema = build_ema(self._model, self.cfg.ema, self.device) + logger.info( + "Exponential Moving Average Shadow Model is initialized." + ) + @property def optimizer(self): if self._optimizer is None: @@ -392,6 +407,12 @@ def state_dict(self): "previous_training_time": self.cumulative_training_time(), }, } + if self.cfg.ema.store_ema: + # Save EMA model state as extra state + state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict() + if self.cfg.ema.ema_fp32: + # Save EMA params in fp32 + state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params if not self.cfg.checkpoint.no_save_optimizer_state: if self._gathered_optim_state is not None: state_dict["last_optimizer_state"] = self._gathered_optim_state @@ -552,6 +573,31 @@ def load_checkpoint( if isinstance(meter, meters.TimeMeter): meter.reset() + if self.cfg.ema.store_ema: + if "ema" not in extra_state: + logger.warn( + "EMA not found in checkpoint. But store_ema is True. " + "EMA is re-initialized from checkpoint." + ) + self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32) + else: + logger.info( + "Loading EMA from checkpoint" + ) + self.ema.restore(extra_state["ema"], build_fp32_params=False) + + if self.cfg.ema.ema_fp32: + if "ema_fp32_params" in extra_state: + logger.info( + "Loading EMA fp32 params from checkpoint" + ) + self.ema.build_fp32_params(extra_state["ema_fp32_params"]) + else: + logger.info( + "Building EMA fp32 params from EMA model in checkpoint" + ) + self.ema.build_fp32_params() + logger.info( "Loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates() @@ -670,6 +716,13 @@ def train_step(self, samples, raise_oom=False): metrics.log_start_time("train_wall", priority=800, round=0) + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): # delayed update loop @@ -705,6 +758,7 @@ def maybe_no_sync(): optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, + **extra_kwargs, ) del loss @@ -840,6 +894,7 @@ def maybe_no_sync(): self.optimizer, self.get_num_updates(), ignore_grad=False, + **extra_kwargs, ) raise except OverflowError as e: @@ -871,6 +926,20 @@ def maybe_no_sync(): if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo": self.set_num_updates(self.get_num_updates() + 1) + if self.cfg.ema.store_ema: + # Step EMA forward with new model. + self.ema.step( + self.get_model(), + self.get_num_updates(), + ) + metrics.log_scalar( + "ema_decay", + self.ema.get_decay(), + priority=10000, + round=5, + weight=0, + ) + if self.tpu: import torch_xla.core.xla_model as xm @@ -953,6 +1022,13 @@ def valid_step(self, sample, raise_oom=False): xm.rendezvous("valid_step") # wait for all workers + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + with torch.no_grad(): self.model.eval() self.criterion.eval() @@ -961,7 +1037,7 @@ def valid_step(self, sample, raise_oom=False): try: _loss, sample_size, logging_output = self.task.valid_step( - sample, self.model, self.criterion + sample, self.model, self.criterion, **extra_kwargs ) except RuntimeError as e: if "out of memory" in str(e): diff --git a/tests/gpu/test_ema_gpu.py b/tests/gpu/test_ema_gpu.py new file mode 100644 index 0000000000..337107d69a --- /dev/null +++ b/tests/gpu/test_ema_gpu.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq.models.ema import EMA + + +class DummyModule(torch.nn.Module): + def __init__(self) -> None: + """LightningModule for testing purposes + + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(in_features=32, out_features=2) + self.another_layer = torch.nn.Linear(in_features=2, out_features=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer(x) + return self.another_layer(x) + + +@dataclass +class EMAConfig(object): + ema_decay: float = 0.99 + ema_start_update: int = 0 + ema_fp32: bool = False + ema_seed_model: Optional[str] = None + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestEMAGPU(unittest.TestCase): + def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): + diff = x.float() - y.float() + diff_norm = torch.norm(diff) + other_norm = torch.norm(y.float()) + + if msg is None: + msg = "|input - other| > {} + {} * |other|".format( + atol, rtol + ) + + self.assertLessEqual( + diff_norm, + atol + rtol * other_norm, + msg=msg, + ) + + def test_ema(self): + model = DummyModule().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig() + ema = EMA(model, config) + + # set decay + ema._set_decay(config.ema_decay) + self.assertEqual(ema.get_decay(), config.ema_decay) + + # get model + self.assertEqual(ema.get_model(), ema.model) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # EMA step + x = torch.randn(32).cuda() + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + ema_state_dict = ema.get_model().state_dict() + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema_state_dict[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # Load EMA into model + model2 = DummyModule().cuda() + ema.reverse(model2) + + for key, param in model2.state_dict().items(): + ema_param = ema_state_dict[key] + self.assertTrue( + torch.allclose(ema_param, param) + ) + + def test_ema_fp32(self): + model = DummyModule().cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=True) + ema = EMA(model, config) + + x = torch.randn(32).cuda() + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertIn(key, ema.fp32_params) + + # EMA update is done in fp32, and hence the EMA param must be + # closer to the EMA update done in fp32 than in fp16. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + ) + self.assertTorchAllClose( + ema_param, + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ) + + def test_ema_fp16(self): + model = DummyModule().cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=False) + ema = EMA(model, config) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + x = torch.randn(32).cuda() + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + # EMA update is done in fp16, and hence the EMA param must be + # closer to the EMA update done in fp16 than in fp32. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + ) + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 0000000000..88ea65a434 --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq.models.ema import EMA + + +class DummyModule(torch.nn.Module): + def __init__(self) -> None: + """LightningModule for testing purposes + + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(in_features=32, out_features=2) + self.another_layer = torch.nn.Linear(in_features=2, out_features=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer(x) + return self.another_layer(x) + + +@dataclass +class EMAConfig(object): + ema_decay: float = 0.99 + ema_start_update: int = 0 + ema_fp32: bool = False + ema_seed_model: Optional[str] = None + + +class TestEMAGPU(unittest.TestCase): + def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): + diff = x.float() - y.float() + diff_norm = torch.norm(diff) + other_norm = torch.norm(y.float()) + + if msg is None: + msg = "|input - other| > {} + {} * |other|".format( + atol, rtol + ) + + self.assertLessEqual( + diff_norm, + atol + rtol * other_norm, + msg=msg, + ) + + def test_ema(self): + model = DummyModule() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig() + ema = EMA(model, config) + + # set decay + ema._set_decay(config.ema_decay) + self.assertEqual(ema.get_decay(), config.ema_decay) + + # get model + self.assertEqual(ema.get_model(), ema.model) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # EMA step + x = torch.randn(32) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + ema_state_dict = ema.get_model().state_dict() + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema_state_dict[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # Load EMA into model + model2 = DummyModule() + ema.reverse(model2) + + for key, param in model2.state_dict().items(): + ema_param = ema_state_dict[key] + self.assertTrue( + torch.allclose(ema_param, param) + ) + + def test_ema_fp32(self): + model = DummyModule().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=True) + ema = EMA(model, config) + + x = torch.randn(32) + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertIn(key, ema.fp32_params) + + # EMA update is done in fp32, and hence the EMA param must be + # closer to the EMA update done in fp32 than in fp16. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + ) + self.assertTorchAllClose( + ema_param, + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ) + + def test_ema_fp16(self): + model = DummyModule().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=False) + ema = EMA(model, config) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + x = torch.randn(32) + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + # EMA update is done in fp16, and hence the EMA param must be + # closer to the EMA update done in fp16 than in fp32. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + ) + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + +if __name__ == "__main__": + unittest.main() From 14c5bd027f04aae9dbb32f1bd7b34591b61af97f Mon Sep 17 00:00:00 2001 From: Koustuv Sinha <koustuvsinha@hotmail.com> Date: Wed, 1 Sep 2021 13:13:08 -0700 Subject: [PATCH 692/707] Releasing models for our paper "Masked Language Modeling and the Distributional Hypothesis" (#1930) Summary: Paper submitted to EMNLP: https://arxiv.org/abs/2104.06644 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1930 Reviewed By: lematt1991 Differential Revision: D28885634 Pulled By: shruti-bh fbshipit-source-id: d433c87cff3603b3e676a129029a827c510a72c7 --- .../shuffled_word_order/README.finetuning.md | 135 ++++++++++++++++++ examples/shuffled_word_order/README.md | 84 +++++++++++ 2 files changed, 219 insertions(+) create mode 100644 examples/shuffled_word_order/README.finetuning.md create mode 100644 examples/shuffled_word_order/README.md diff --git a/examples/shuffled_word_order/README.finetuning.md b/examples/shuffled_word_order/README.finetuning.md new file mode 100644 index 0000000000..ecbcb65884 --- /dev/null +++ b/examples/shuffled_word_order/README.finetuning.md @@ -0,0 +1,135 @@ +# Fine-tuning details + +For each task (GLUE and PAWS), we perform hyperparam search for each model, and report the mean and standard deviation across 5 seeds of the best model. First, get the datasets following the instructions in [RoBERTa fine-tuning README](../roberta/README.glue.md). Alternatively, you can use [huggingface datasets](https://huggingface.co/docs/datasets/) to get the task data: + +```python +from datasets import load_dataset +import pandas as pd +from pathlib import Path + +key2file = { +"paws": { + "loc": "paws_data", + "columns": ["id", "sentence1", "sentence2", "label"], + "train": "train.tsv", + "validation": "dev.tsv", + "test": "test.tsv" + } +} + +task_data = load_dataset("paws", "labeled_final") +task_config = key2file["paws"] +save_path = Path(task_config["loc"]) +save_path.mkdir(exist_ok=True, parents=True) +for key, fl in task_config.items(): + if key in ["loc", "columns"]: + continue + print(f"Reading {key}") + columns = task_config["columns"] + df = pd.DataFrame(task_data[key]) + print(df.columns) + df = df[columns] + print(f"Got {len(df)} records") + save_loc = save_path / fl + print(f"Saving to : {save_loc}") + df.to_csv(save_loc, sep="\t", header=None, index=None) + +``` + +- Preprocess using RoBERTa GLUE preprocessing script, while keeping in mind the column numbers for `sentence1`, `sentence2` and `label` (which is 0,1,2 if you save the data according to the above example.) +- Then, fine-tuning is performed similarly to RoBERTa (for example, in case of RTE): + +```bash +TOTAL_NUM_UPDATES=30875 # 10 epochs through RTE for bsz 16 +WARMUP_UPDATES=1852 # 6 percent of the number of updates +LR=2e-05 # Peak LR for polynomial LR scheduler. +NUM_CLASSES=2 +MAX_SENTENCES=16 # Batch size. +SHUFFLED_ROBERTA_PATH=/path/to/shuffled_roberta/model.pt + +CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \ + --restore-file $SHUFFLED_ROBERTA_PATH \ + --max-positions 512 \ + --batch-size $MAX_SENTENCES \ + --max-tokens 4400 \ + --task sentence_prediction \ + --reset-optimizer --reset-dataloader --reset-meters \ + --required-batch-size-multiple 1 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --max-epoch 10 \ + --find-unused-parameters \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; +``` + +- `TOTAL_NUM_UPDATES` is computed based on the `--batch_size` value and the dataset size. +- `WARMUP_UPDATES` is computed as 6% of `TOTAL_NUM_UPDATES` +- Best hyperparam of `--lr` and `--batch_size` is reported below: + +## `--lr` + +| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS | +| --: | :----------- | ----: | ----: | ----: | ----: | ----: | ----: | ----: | ----: | +| 0 | original | 2e-05 | 2e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 | +| 1 | n_1 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | +| 2 | n_2 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 | 1e-05 | 3e-05 | +| 3 | n_3 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 3e-05 | 1e-05 | 1e-05 | 2e-05 | +| 4 | n_4 | 3e-05 | 1e-05 | 2e-05 | 2e-05 | 2e-05 | 1e-05 | 1e-05 | 2e-05 | +| 5 | r512 | 1e-05 | 3e-05 | 2e-05 | 2e-05 | 3e-05 | 2e-05 | 3e-05 | 2e-05 | +| 6 | rand_corpus | 2e-05 | 1e-05 | 3e-05 | 1e-05 | 3e-05 | 3e-05 | 3e-05 | 2e-05 | +| 7 | rand_uniform | 2e-05 | 1e-05 | 3e-05 | 2e-05 | 3e-05 | 3e-05 | 3e-05 | 1e-05 | +| 8 | rand_init | 1e-05 | 1e-05 | 3e-05 | 1e-05 | 1e-05 | 1e-05 | 2e-05 | 1e-05 | +| 9 | no_pos | 1e-05 | 3e-05 | 2e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 | 1e-05 | + +## `--batch_size` + +| | name | RTE | MRPC | SST-2 | CoLA | QQP | QNLI | MNLI | PAWS | +| --: | :----------- | --: | ---: | ----: | ---: | --: | ---: | ---: | ---: | +| 0 | orig | 16 | 16 | 32 | 16 | 16 | 32 | 32 | 16 | +| 1 | n_1 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 16 | +| 2 | n_2 | 32 | 16 | 32 | 16 | 32 | 32 | 16 | 32 | +| 3 | n_3 | 32 | 32 | 16 | 32 | 32 | 16 | 32 | 32 | +| 4 | n_4 | 32 | 16 | 32 | 16 | 32 | 32 | 32 | 32 | +| 5 | r512 | 32 | 16 | 16 | 32 | 32 | 16 | 16 | 16 | +| 6 | rand_corpus | 16 | 16 | 16 | 16 | 32 | 16 | 16 | 32 | +| 7 | rand_uniform | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 | +| 8 | rand_init | 16 | 16 | 32 | 16 | 16 | 16 | 32 | 16 | +| 9 | no_pos | 16 | 32 | 16 | 16 | 32 | 16 | 16 | 16 | + +- Perform inference similar to RoBERTa as well: + +```python +from fairseq.models.roberta import RobertaModel + +roberta = RobertaModel.from_pretrained( + 'checkpoints/', + checkpoint_file='checkpoint_best.pt', + data_name_or_path='PAWS-bin' +) + +label_fn = lambda label: roberta.task.label_dictionary.string( + [label + roberta.task.label_dictionary.nspecial] +) +ncorrect, nsamples = 0, 0 +roberta.cuda() +roberta.eval() +with open('paws_data/dev.tsv') as fin: + fin.readline() + for index, line in enumerate(fin): + tokens = line.strip().split('\t') + sent1, sent2, target = tokens[0], tokens[1], tokens[2] + tokens = roberta.encode(sent1, sent2) + prediction = roberta.predict('sentence_classification_head', tokens).argmax().item() + prediction_label = label_fn(prediction) + ncorrect += int(prediction_label == target) + nsamples += 1 +print('| Accuracy: ', float(ncorrect)/float(nsamples)) + +``` diff --git a/examples/shuffled_word_order/README.md b/examples/shuffled_word_order/README.md new file mode 100644 index 0000000000..14c240cb56 --- /dev/null +++ b/examples/shuffled_word_order/README.md @@ -0,0 +1,84 @@ +# Masked Language Modeling and the Distributional Hypothesis: Order Word Matters Pre-training for Little + +[https://arxiv.org/abs/2104.06644](https://arxiv.org/abs/2104.06644) + +## Introduction + +In this work, we pre-train [RoBERTa](../roberta) base on various word shuffled variants of BookWiki corpus (16GB). We observe that a word shuffled pre-trained model achieves surprisingly good scores on GLUE, PAWS and several parametric probing tasks. Please read our paper for more details on the experiments. + +## Pre-trained models + +| Model | Description | Download | +| ------------------------------------- | -------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | +| `roberta.base.orig` | RoBERTa (base) trained on natural corpus | [roberta.base.orig.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.orig.tar.gz) | +| `roberta.base.shuffle.n1` | RoBERTa (base) trained on n=1 gram sentence word shuffled data | [roberta.base.shuffle.n1.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.tar.gz) | +| `roberta.base.shuffle.n2` | RoBERTa (base) trained on n=2 gram sentence word shuffled data | [roberta.base.shuffle.n2.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n2.tar.gz) | +| `roberta.base.shuffle.n3` | RoBERTa (base) trained on n=3 gram sentence word shuffled data | [roberta.base.shuffle.n3.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n3.tar.gz) | +| `roberta.base.shuffle.n4` | RoBERTa (base) trained on n=4 gram sentence word shuffled data | [roberta.base.shuffle.n4.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n4.tar.gz) | +| `roberta.base.shuffle.512` | RoBERTa (base) trained on unigram 512 word block shuffled data | [roberta.base.shuffle.512.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.512.tar.gz) | +| `roberta.base.shuffle.corpus` | RoBERTa (base) trained on unigram corpus word shuffled data | [roberta.base.shuffle.corpus.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus.tar.gz) | +| `roberta.base.shuffle.corpus_uniform` | RoBERTa (base) trained on unigram corpus word shuffled data, where all words are uniformly sampled | [roberta.base.shuffle.corpus_uniform.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus_uniform.tar.gz) | +| `roberta.base.nopos` | RoBERTa (base) without positional embeddings, trained on natural corpus | [roberta.base.nopos.tar.gz](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.nopos.tar.gz) | + +## Results + +[GLUE (Wang et al, 2019)](https://gluebenchmark.com/) & [PAWS (Zhang et al, 2019)](https://github.com/google-research-datasets/paws) _(dev set, single model, single-task fine-tuning, median of 5 seeds)_ + +| name | CoLA | MNLI | MRPC | PAWS | QNLI | QQP | RTE | SST-2 | +| :----------------------------------- | ----: | ----: | ----: | ----: | ----: | ----: | ----: | ----: | +| `roberta.base.orig` | 61.4 | 86.11 | 89.19 | 94.46 | 92.53 | 91.26 | 74.64 | 93.92 | +| `roberta.base.shuffle.n1` | 35.15 | 82.64 | 86 | 89.97 | 89.02 | 91.01 | 69.02 | 90.47 | +| `roberta.base.shuffle.n2` | 54.37 | 83.43 | 86.24 | 93.46 | 90.44 | 91.36 | 70.83 | 91.79 | +| `roberta.base.shuffle.n3` | 48.72 | 83.85 | 86.36 | 94.05 | 91.69 | 91.24 | 70.65 | 92.02 | +| `roberta.base.shuffle.n4` | 58.64 | 83.77 | 86.98 | 94.32 | 91.69 | 91.4 | 70.83 | 92.48 | +| `roberta.base.shuffle.512` | 12.76 | 77.52 | 79.61 | 84.77 | 85.19 | 90.2 | 56.52 | 86.34 | +| `roberta.base.shuffle.corpus` | 0 | 71.9 | 70.52 | 58.52 | 71.11 | 85.52 | 53.99 | 83.35 | +| `roberta.base.shuffle.corpus_random` | 9.19 | 72.33 | 70.76 | 58.42 | 77.76 | 85.93 | 53.99 | 84.04 | +| `roberta.base.nopos` | 0 | 63.5 | 72.73 | 57.08 | 77.72 | 87.87 | 54.35 | 83.24 | + +For more results on probing tasks, please refer to [our paper](https://arxiv.org/abs/2104.06644). + +## Example Usage + +Follow the same usage as in [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) to load and test your models: + +```python +# Download roberta.base.shuffle.n1 model +wget https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.tar.gz +tar -xzvf roberta.base.shuffle.n1.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import RoBERTaModel +roberta = RoBERTaModel.from_pretrained('/path/to/roberta.base.shuffle.n1', checkpoint_file='model.pt') +roberta.eval() # disable dropout (or leave in train mode to finetune) +``` + +**Note**: The model trained without positional embeddings (`roberta.base.nopos`) is a modified `RoBERTa` model, where the positional embeddings are not used. Thus, the typical `from_pretrained` method on fairseq version of RoBERTa will not be able to load the above model weights. To do so, construct a new `RoBERTaModel` object by setting the flag `use_positional_embeddings` to `False` (or [in the latest code](https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/model.py#L543), set `no_token_positional_embeddings` to `True`), and then load the individual weights. + +## Fine-tuning Evaluation + +We provide the trained fine-tuned models on MNLI here for each model above for quick evaluation (1 seed for each model). Please refer to [finetuning details](README.finetuning.md) for the parameters of these models. Follow [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) instructions to evaluate these models. + +| Model | MNLI M Dev Accuracy | Link | +| :----------------------------------------- | :------------------ | :--------------------------------------------------------------------------------------------------------------- | +| `roberta.base.orig.mnli` | 86.14 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.orig.mnli.tar.gz) | +| `roberta.base.shuffle.n1.mnli` | 82.55 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n1.mnli.tar.gz) | +| `roberta.base.shuffle.n2.mnli` | 83.21 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n2.mnli.tar.gz) | +| `roberta.base.shuffle.n3.mnli` | 83.89 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n3.mnli.tar.gz) | +| `roberta.base.shuffle.n4.mnli` | 84.00 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.n4.mnli.tar.gz) | +| `roberta.base.shuffle.512.mnli` | 77.22 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.512.mnli.tar.gz) | +| `roberta.base.shuffle.corpus.mnli` | 71.88 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus.mnli.tar.gz) | +| `roberta.base.shuffle.corpus_uniform.mnli` | 72.46 | [Download](https://dl.fbaipublicfiles.com/unnatural_pretraining/roberta.base.shuffle.corpus_uniform.mnli.tar.gz) | + +## Citation + +```bibtex +@misc{sinha2021masked, + title={Masked Language Modeling and the Distributional Hypothesis: Order Word Matters Pre-training for Little}, + author={Koustuv Sinha and Robin Jia and Dieuwke Hupkes and Joelle Pineau and Adina Williams and Douwe Kiela}, + year={2021}, + eprint={2104.06644}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` From 5cfd373876ad374139b2de15735a870d4797c606 Mon Sep 17 00:00:00 2001 From: Jingfei Du <jingfeidu@fb.com> Date: Tue, 7 Sep 2021 13:18:00 -0700 Subject: [PATCH 693/707] fix default lprob score of beam search with prefix tokens (#2267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting the default score was set as min score of all lprobs, which would let us select tokens other than prefix tokens during beam search. having a pretty hacky way to make it smaller than any lprobs. - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2267 Reviewed By: myleott Differential Revision: D30730475 Pulled By: jingfeidu fbshipit-source-id: 7dab4e9ed2fc094910467bad776155230987e21a --- fairseq/sequence_generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 740c32d648..2e61140dd8 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -343,7 +343,6 @@ def _generate( ) probs = probs[:, -1, :] * self.lm_weight lprobs += probs - # handle prefix tokens (possibly with different lengths) if ( prefix_tokens is not None @@ -568,7 +567,7 @@ def _prefix_tokens( prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) - lprobs[prefix_mask] = torch.min(prefix_lprobs) + lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) From 9549e7f76994095c92441b81c615a169dc21f478 Mon Sep 17 00:00:00 2001 From: Yunhong Xu <yunhong@fb.com> Date: Tue, 7 Sep 2021 15:08:44 -0700 Subject: [PATCH 694/707] try using gradient_as_bucket_view in DDP Summary: As title Reviewed By: zhengwy888, xiaoxiao26 Differential Revision: D30621478 fbshipit-source-id: d79aba3f98d39a5c46a53bf206522c5f7d05e02a --- fairseq/dataclass/configs.py | 7 +++++++ fairseq/models/distributed_fairseq_model.py | 1 + 2 files changed, 8 insertions(+) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 952f1ec4d1..80caa0f2da 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -306,6 +306,13 @@ class DistributedTrainingConfig(FairseqDataclass): "--ddp-backend=legacy_ddp)" }, ) + gradient_as_bucket_view: bool = field( + default=False, + metadata={ + "help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. " + "--gradient-as-bucket-view=gradient_as_bucket_view)" + }, + ) fast_stat_sync: bool = field( default=False, metadata={"help": "[deprecated] this is now defined per Criterion"}, diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 06905455fd..5eda227640 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -63,6 +63,7 @@ def DistributedFairseqModel(args, model, process_group, device): bucket_cap_mb=args.bucket_cap_mb, process_group=process_group, find_unused_parameters=args.find_unused_parameters, + gradient_as_bucket_view=args.gradient_as_bucket_view, ) if args.ddp_comm_hook == "fp16": logger.info("enable fp16 communication hook in DDP") From 50b65368639ac25663764e1d4b3cf46b821975ec Mon Sep 17 00:00:00 2001 From: "Yuan Shangguan (June)" <yuansg@fb.com> Date: Wed, 8 Sep 2021 18:16:54 -0700 Subject: [PATCH 695/707] Fairseq needs to store and load metadata from model state_dict Summary: ## TL;DR Fairseq checkpoint saving and loading should mirror torch's checkpoint by saving and loading "state_dict()._metadata". ## Long Story: #### What happened: During model loading and saving, Quantization-aware-training models in Pytorch encounters a weird bug that says state_dict "fake_weight_quant.weight.min_val" is mismatched to "min_vals". #### What was the reason: - We found the issue in that torch uses state_dict()._metadata to store module._version, but the metadata was never store in checkpoint, nor are they loaded during checkpoint loading in fairseq. Reviewed By: frankseide Differential Revision: D30649933 fbshipit-source-id: ce262486b9b95fbcece463fa05c4e1903d4232d7 --- fairseq/checkpoint_utils.py | 37 ++++++++++++++++++++++++++++++--- fairseq/models/fairseq_model.py | 32 ++++++++++++++++++++++++++++ fairseq/trainer.py | 9 +++++++- 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ef5d4c9022..7a494356ac 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss, save_metadata=False): from fairseq import meters # only one worker should attempt to create the required dir @@ -114,7 +114,7 @@ def is_better(a, b): os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: - trainer.save_checkpoint(checkpoints[0], extra_state) + trainer.save_checkpoint(checkpoints[0], extra_state, save_metadata) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous @@ -455,8 +455,35 @@ def load_model_ensemble_and_task( else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) + new_state_model = state["model"] + + '''=====The following if-else statement is a work-around ===== + # the current metadata loading/saving of pytorch. + # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True) + # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be + # in orderedDict format, which has new_state_model._metadata attribute but not as key. + # TODO yuansg@ This issue should be fixed in pytorch ideally. + ''' + if new_state_model.get("_metadata", None) is not None: + new_metadata = new_state_model.get("_metadata", None) + del state["model"]["_metadata"] + else: + new_metadata = None + # Construct state dict content. + contents = OrderedDict(new_state_model) + # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models. + # calling state["model"] in fairseq will not invoke metadata storage. + if new_metadata is None: + logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====") + logger.warning("===Jit: we will be filling in with current model's meta-data instead.") + # For models trained before this diff, we do the following to be backward compatible. + contents.__setattr__("_metadata", model.state_dict()._metadata) + else: + contents.__setattr__("_metadata", new_metadata) + '''====End of work-around logic=====''' + model.load_state_dict( - state["model"], strict=strict, model_cfg=cfg.model + contents, strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble @@ -683,6 +710,7 @@ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): It's called by functions that load models from checkpoints and does not need to be called directly. """ + state_meta_data = state_dict.get("_metadata", None) arch = None if model_cfg is not None: arch = ( @@ -762,6 +790,9 @@ def create_pruning_pass(layers_to_keep, layer_name): if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None + # Ensure metadata is stored. + if state_meta_data is not None: + new_state_dict["_metadata"] = state_meta_data return new_state_dict diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index e55c7ba1ad..0645208efe 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -22,6 +22,7 @@ from fairseq.models import FairseqDecoder, FairseqEncoder from omegaconf import DictConfig from torch import Tensor +from collections import OrderedDict logger = logging.getLogger(__name__) @@ -122,6 +123,15 @@ def load_state_dict( from fairseq.checkpoint_utils import prune_state_dict new_state_dict = prune_state_dict(state_dict, model_cfg) + # The pytorch assumption of module is that it is an OrderedDict. + # Pytorch also assumes module._metadata exists in the state_dict, + # not as dictionary keys, rather as an attribute of the state dict. + new_state_dict = OrderedDict(new_state_dict) + metadata = new_state_dict.get("_metadata", None) + + if metadata: + del new_state_dict["_metadata"] + new_state_dict.__setattr__("_metadata", metadata) return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): @@ -151,6 +161,28 @@ def do_upgrade(m, prefix): do_upgrade(self, name) + def update_metadata(self, model_meta): + """ The model.state_dict()._metadata is stored in a collective location in + state_dict["model"]["_metadata"]. + A pytorch module's _metadata contains the torch modules' versions, which is important + for versionsetting functions. + + During model loading time, we load the model state_dict, but we don't load the state_dict metadata. + This function helps to update the model according to the state_dict["model"]["_metadata"] dump. + InputArgs: + update_metadata: Dict; key is module names, value is {"version", 1} or other metadata. + """ + # Do nothing if the model level metadata is empty. + if model_meta is None: + return + assert isinstance(model_meta, Dict), \ + "Input model_meta from state_dict should be a dictionary. Check state dict." + for key, val in model_meta.items(): + if key is None: # First level set up + self._metadata = val + else: # Subsequent levels of the model + self.get_submodule(key)._metadata = val + def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" for m in self.modules(): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e46ccfe0b8..a10ec97aa7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -424,12 +424,18 @@ def state_dict(self): state_dict["fsdp_metadata"] = self.model.local_metadata_dict() return state_dict - def save_checkpoint(self, filename, extra_state): + def save_checkpoint(self, filename, extra_state, save_metadata=False): """Save all training state in a checkpoint file.""" logger.info(f"Saving checkpoint to {filename}") # call state_dict on all ranks in case it needs internal communication state_dict = utils.move_to_cpu(self.state_dict()) state_dict["extra_state"].update(extra_state) + # This should be added because model versions are stored as metadata. + if save_metadata and getattr(self.model.state_dict(), "_metadata", None) is not None: + logger.warning("Trainer: _metadata is inside model.state_dict(). ") + state_dict["model"]["_metadata"] = self.model.state_dict()._metadata + else: + logger.warning("Trainer: _metadata is not saved inside model.state_dict(). ") if self.should_save_checkpoint_on_current_rank: checkpoint_utils.torch_persistent_save( state_dict, @@ -502,6 +508,7 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) + self.model.update_metadata(getattr(state["model"], "_metadata", None)) # save memory for later steps del state["model"] if utils.has_parameters(self.get_criterion()): From e3fafbdfc9a39ba8339ebde98aa01c5349ab060d Mon Sep 17 00:00:00 2001 From: Xianfeng Rui <xfrui@fb.com> Date: Thu, 9 Sep 2021 10:09:34 -0700 Subject: [PATCH 696/707] annotation added for jitable Summary: 1) add annotation for encoder_out 2) force dropout to be float for jitable purpose. Reviewed By: cndn Differential Revision: D30826657 fbshipit-source-id: aca79845d7ae48d450b602a7be8f56404f4c7bab --- fairseq/models/lstm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index 12e3aff85d..e1e66a7d50 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -225,10 +225,10 @@ def __init__( super().__init__(dictionary) self.num_layers = num_layers self.dropout_in_module = FairseqDropout( - dropout_in, module_name=self.__class__.__name__ + dropout_in*1.0, module_name=self.__class__.__name__ ) self.dropout_out_module = FairseqDropout( - dropout_out, module_name=self.__class__.__name__ + dropout_out*1.0, module_name=self.__class__.__name__ ) self.bidirectional = bidirectional self.hidden_size = hidden_size @@ -329,7 +329,7 @@ def combine_bidir(self, outs, bsz: int): out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() return out.view(self.num_layers, bsz, -1) - def reorder_encoder_out(self, encoder_out, new_order): + def reorder_encoder_out(self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order): return tuple( ( encoder_out[0].index_select(1, new_order), @@ -402,10 +402,10 @@ def __init__( ): super().__init__(dictionary) self.dropout_in_module = FairseqDropout( - dropout_in, module_name=self.__class__.__name__ + dropout_in*1.0, module_name=self.__class__.__name__ ) self.dropout_out_module = FairseqDropout( - dropout_out, module_name=self.__class__.__name__ + dropout_out*1.0, module_name=self.__class__.__name__ ) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed From dbe1f82fc8c76d7578081b453398f594e2ae671a Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Sun, 12 Sep 2021 20:19:11 -0700 Subject: [PATCH 697/707] add speech synthesis preprocessing and evaluation scripts Summary: [fairseq-py] add speech synthesis preprocessing and evaluation scripts Reviewed By: wnhsu Differential Revision: D30720282 fbshipit-source-id: 6e4b098b6f56fff41b82af4347518d7f7905c801 --- examples/speech_synthesis/README.md | 16 + examples/speech_synthesis/__init__.py | 4 + .../docs/common_voice_example.md | 56 +++ .../speech_synthesis/docs/ljspeech_example.md | 138 +++++ .../speech_synthesis/docs/vctk_example.md | 51 ++ .../speech_synthesis/evaluation/__init__.py | 4 + .../speech_synthesis/evaluation/eval_asr.py | 128 +++++ .../speech_synthesis/evaluation/eval_f0.py | 266 ++++++++++ .../speech_synthesis/evaluation/eval_sp.py | 131 +++++ .../evaluation/get_eval_manifest.py | 58 +++ .../preprocessing/__init__.py | 4 + .../preprocessing/denoise_and_vad_audio.py | 204 ++++++++ .../preprocessing/denoiser/__init__.py | 4 + .../preprocessing/denoiser/demucs.py | 473 ++++++++++++++++++ .../preprocessing/denoiser/pretrained.py | 81 +++ .../preprocessing/denoiser/resample.py | 79 +++ .../preprocessing/denoiser/utils.py | 176 +++++++ .../get_common_voice_audio_manifest.py | 140 ++++++ .../preprocessing/get_feature_manifest.py | 233 +++++++++ .../get_ljspeech_audio_manifest.py | 70 +++ .../preprocessing/get_speaker_embedding.py | 89 ++++ .../preprocessing/get_vctk_audio_manifest.py | 79 +++ .../speaker_embedder/__init__.py | 135 +++++ .../preprocessing/vad/__init__.py | 192 +++++++ 24 files changed, 2811 insertions(+) create mode 100644 examples/speech_synthesis/README.md create mode 100644 examples/speech_synthesis/__init__.py create mode 100644 examples/speech_synthesis/docs/common_voice_example.md create mode 100644 examples/speech_synthesis/docs/ljspeech_example.md create mode 100644 examples/speech_synthesis/docs/vctk_example.md create mode 100644 examples/speech_synthesis/evaluation/__init__.py create mode 100644 examples/speech_synthesis/evaluation/eval_asr.py create mode 100644 examples/speech_synthesis/evaluation/eval_f0.py create mode 100644 examples/speech_synthesis/evaluation/eval_sp.py create mode 100644 examples/speech_synthesis/evaluation/get_eval_manifest.py create mode 100644 examples/speech_synthesis/preprocessing/__init__.py create mode 100644 examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py create mode 100644 examples/speech_synthesis/preprocessing/denoiser/__init__.py create mode 100644 examples/speech_synthesis/preprocessing/denoiser/demucs.py create mode 100644 examples/speech_synthesis/preprocessing/denoiser/pretrained.py create mode 100644 examples/speech_synthesis/preprocessing/denoiser/resample.py create mode 100644 examples/speech_synthesis/preprocessing/denoiser/utils.py create mode 100644 examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py create mode 100644 examples/speech_synthesis/preprocessing/get_feature_manifest.py create mode 100644 examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py create mode 100644 examples/speech_synthesis/preprocessing/get_speaker_embedding.py create mode 100644 examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py create mode 100644 examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py create mode 100644 examples/speech_synthesis/preprocessing/vad/__init__.py diff --git a/examples/speech_synthesis/README.md b/examples/speech_synthesis/README.md new file mode 100644 index 0000000000..4a3ae54b85 --- /dev/null +++ b/examples/speech_synthesis/README.md @@ -0,0 +1,16 @@ +Speech Synthesis (S^2) +=== + +Speech synthesis with fairseq. + +- Autoregressive and non-autoregressive models +- Multi-speaker synthesis +- Audio preprocessing +- Automatic metrics +- Similar data configuration as [S2T](../speech_to_text/README.md) + + +## Examples +- [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md) +- [Multi-speaker synthesis on VCTK](docs/vctk_example.md) +- [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md) diff --git a/examples/speech_synthesis/__init__.py b/examples/speech_synthesis/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/examples/speech_synthesis/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/speech_synthesis/docs/common_voice_example.md b/examples/speech_synthesis/docs/common_voice_example.md new file mode 100644 index 0000000000..40e841b284 --- /dev/null +++ b/examples/speech_synthesis/docs/common_voice_example.md @@ -0,0 +1,56 @@ +[[Back]](..) + +# Common Voice + +[Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read +speech in 76 languages (the latest version 7.0). We provide examples for building +[Transformer](https://arxiv.org/abs/1809.08895) models on this dataset. + + +## Data preparation +[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`. +Create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \ + --data-root ${DATA_ROOT} \ + --lang ${LANG_ID} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav +``` + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --lang ${LANG_ID} +``` +where we use phoneme inputs (`--ipa-vocab`) as example. + +To denoise audio and trim leading/trailing silence using signal processing based VAD, run +```bash +for SPLIT in dev test train; do + python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-dir ${PROCESSED_DATA_ROOT} \ + --denoise --vad --vad-agg-level 2 +done +``` + + +## Training +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).) + + +## Inference +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).) + +## Automatic Evaluation +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).) + +## Results + +| Language | Speakers | --arch | Params | Test MCD | Model | +|---|---|---|---|---|---| +| English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) | + +[[Back]](..) diff --git a/examples/speech_synthesis/docs/ljspeech_example.md b/examples/speech_synthesis/docs/ljspeech_example.md new file mode 100644 index 0000000000..2b8d21abf9 --- /dev/null +++ b/examples/speech_synthesis/docs/ljspeech_example.md @@ -0,0 +1,138 @@ +[[Back]](..) + +# LJSpeech + +[LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS +corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building +[Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558) +models on this dataset. + + +## Data preparation + +Download data, create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \ + --output-data-root ${AUDIO_DATA_ROOT} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} +``` + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --use-g2p +``` +where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example. + +FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets. +Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from +phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via: +- `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one + [TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info. +- `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and + space-delimited pseudo-text unit sequence, respectively. + +For your convenience, we provide pre-computed +[force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from +[Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and +[pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from +[HuBERT](https://github.com/pytorch/fairseq/tree/master/examples/hubert). You can also generate them by yourself using +a different software or model. + + +## Training +#### Transformer +```bash +fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \ + --config-yaml config.yaml --train-subset train --valid-subset dev \ + --num-workers 4 --max-tokens 30000 --max-update 200000 \ + --task text_to_speech --criterion tacotron2 --arch tts_transformer \ + --clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \ + --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \ + --encoder-normalize-before --decoder-normalize-before \ + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss +``` +where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to +update it accordingly when using more than 1 GPU. + +#### FastSpeech2 +```bash +fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \ + --config-yaml config.yaml --train-subset train --valid-subset dev \ + --num-workers 4 --max-sentences 6 --max-update 200000 \ + --task text_to_speech --criterion fastspeech2 --arch fastspeech2 \ + --clip-norm 5.0 --n-frames-per-step 1 \ + --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \ + --encoder-normalize-before --decoder-normalize-before \ + --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss +``` + + +## Inference +Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder: +```bash +SPLIT=test +CHECKPOINT_NAME=avg_last_5 +CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt +python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \ + --num-epoch-checkpoints 5 \ + --output ${CHECKPOINT_PATH} + +python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \ + --config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \ + --path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \ + --dump-waveforms +``` +which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To +re-synthesize target waveforms for automatic evaluation, add `--dump-target`. + +## Automatic Evaluation +To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts. +```bash +python -m examples.speech_synthesis.evaluation.get_eval_manifest \ + --generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \ + --vocoder griffin_lim --sample-rate 22050 --audio-format flac \ + --use-resynthesized-target +``` +Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric, +you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and +use `--sample-rate 16000` for `get_eval_manifest.py`. + + +#### WER/CER metric +We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) +the model checkpoint and dictionary, then compute WER/CER with +```bash +python -m examples.speech_synthesis.evaluation.eval_asr \ + --audio-header syn --text-header text --err-unit char --split ${SPLIT} \ + --w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \ + --raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr +``` + +#### MCD/MSD metric +```bash +python -m examples.speech_synthesis.evaluation.eval_sp \ + ${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd +``` + +#### F0 metrics +```bash +python -m examples.speech_synthesis.evaluation.eval_f0 \ + ${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe +``` + + +## Results + +| --arch | Params | Test MCD | Model | +|---|---|---|---| +| tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) | +| fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) | + +[[Back]](..) diff --git a/examples/speech_synthesis/docs/vctk_example.md b/examples/speech_synthesis/docs/vctk_example.md new file mode 100644 index 0000000000..2ba78f3f73 --- /dev/null +++ b/examples/speech_synthesis/docs/vctk_example.md @@ -0,0 +1,51 @@ +[[Back]](..) + +# VCTK + +[VCTK](https://datashare.ed.ac.uk/handle/10283/3443) is an open English speech corpus. We provide examples +for building [Transformer](https://arxiv.org/abs/1809.08895) models on this dataset. + + +## Data preparation +Download data, create splits and generate audio manifests with +```bash +python -m examples.speech_synthesis.preprocessing.get_vctk_audio_manifest \ + --output-data-root ${AUDIO_DATA_ROOT} \ + --output-manifest-root ${AUDIO_MANIFEST_ROOT} +``` + +Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with +```bash +python -m examples.speech_synthesis.preprocessing.get_feature_manifest \ + --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \ + --output-root ${FEATURE_MANIFEST_ROOT} \ + --ipa-vocab --use-g2p +``` +where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example. + +To denoise audio and trim leading/trailing silence using signal processing based VAD, run +```bash +for SPLIT in dev test train; do + python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \ + --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \ + --output-dir ${PROCESSED_DATA_ROOT} \ + --denoise --vad --vad-agg-level 3 +done +``` + +## Training +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).) + +## Inference +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).) + +## Automatic Evaluation +(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).) + +## Results + +| --arch | Params | Test MCD | Model | +|---|---|---|---| +| tts_transformer | 54M | 3.4 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/vctk_transformer_phn.tar) | + +[[Back]](..) diff --git a/examples/speech_synthesis/evaluation/__init__.py b/examples/speech_synthesis/evaluation/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/examples/speech_synthesis/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/speech_synthesis/evaluation/eval_asr.py b/examples/speech_synthesis/evaluation/eval_asr.py new file mode 100644 index 0000000000..005a11bfb3 --- /dev/null +++ b/examples/speech_synthesis/evaluation/eval_asr.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import editdistance +import re +import shutil +import soundfile as sf +import subprocess +from pathlib import Path + +from examples.speech_to_text.data_utils import load_tsv_to_dicts + + +def preprocess_text(text): + text = "|".join(re.sub(r"[^A-Z' ]", " ", text.upper()).split()) + text = " ".join(text) + return text + + +def prepare_w2v_data( + dict_dir, sample_rate, label, audio_paths, texts, split, data_dir +): + data_dir.mkdir(parents=True, exist_ok=True) + shutil.copyfile( + dict_dir / f"dict.{label}.txt", + data_dir / f"dict.{label}.txt" + ) + with open(data_dir / f"{split}.tsv", "w") as f: + f.write("/\n") + for audio_path in audio_paths: + wav, sr = sf.read(audio_path) + assert sr == sample_rate, f"{sr} != sample_rate" + nsample = len(wav) + f.write(f"{audio_path}\t{nsample}\n") + with open(data_dir / f"{split}.{label}", "w") as f: + for text in texts: + text = preprocess_text(text) + f.write(f"{text}\n") + + +def run_asr(asr_dir, split, w2v_ckpt, w2v_label, res_dir): + """ + results will be saved at + {res_dir}/{ref,hypo}.word-{w2v_ckpt.filename}-{split}.txt + """ + cmd = ["python", "-m", "examples.speech_recognition.infer"] + cmd += [str(asr_dir.resolve())] + cmd += ["--task", "audio_finetuning", "--nbest", "1", "--quiet"] + cmd += ["--w2l-decoder", "viterbi", "--criterion", "ctc"] + cmd += ["--post-process", "letter", "--max-tokens", "4000000"] + cmd += ["--path", str(w2v_ckpt.resolve()), "--labels", w2v_label] + cmd += ["--gen-subset", split, "--results-path", str(res_dir.resolve())] + + print(f"running cmd:\n{' '.join(cmd)}") + subprocess.run(cmd, check=True) + + +def compute_error_rate(hyp_wrd_path, ref_wrd_path, unit="word"): + """each line is "<text> (None-<index>)" """ + tokenize_line = { + "word": lambda x: re.sub(r" \(.*\)$", "", x.rstrip()).split(), + "char": lambda x: list(re.sub(r" \(.*\)$", "", x.rstrip())) + }.get(unit) + if tokenize_line is None: + raise ValueError(f"{unit} not supported") + + inds = [int(re.sub(r"\D*(\d*)\D*", r"\1", line)) + for line in open(hyp_wrd_path)] + hyps = [tokenize_line(line) for line in open(hyp_wrd_path)] + refs = [tokenize_line(line) for line in open(ref_wrd_path)] + assert(len(hyps) == len(refs)) + err_rates = [ + editdistance.eval(hyp, ref) / len(ref) for hyp, ref in zip(hyps, refs) + ] + ind_to_err_rates = {i: e for i, e in zip(inds, err_rates)} + return ind_to_err_rates + + +def main(args): + samples = load_tsv_to_dicts(args.raw_manifest) + ids = [ + sample[args.id_header] if args.id_header else "" for sample in samples + ] + audio_paths = [sample[args.audio_header] for sample in samples] + texts = [sample[args.text_header] for sample in samples] + + prepare_w2v_data( + args.w2v_dict_dir, + args.w2v_sample_rate, + args.w2v_label, + audio_paths, + texts, + args.split, + args.asr_dir + ) + run_asr(args.asr_dir, args.split, args.w2v_ckpt, args.w2v_label, args.asr_dir) + ind_to_err_rates = compute_error_rate( + args.asr_dir / f"hypo.word-{args.w2v_ckpt.name}-{args.split}.txt", + args.asr_dir / f"ref.word-{args.w2v_ckpt.name}-{args.split}.txt", + args.err_unit, + ) + + uer_path = args.asr_dir / f"uer_{args.err_unit}.{args.split}.tsv" + with open(uer_path, "w") as f: + f.write("id\taudio\tuer\n") + for ind, (id_, audio_path) in enumerate(zip(ids, audio_paths)): + f.write(f"{id_}\t{audio_path}\t{ind_to_err_rates[ind]:.4f}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--raw-manifest", required=True, type=Path) + parser.add_argument("--asr-dir", required=True, type=Path) + parser.add_argument("--id-header", default="id", type=str) + parser.add_argument("--audio-header", default="audio", type=str) + parser.add_argument("--text-header", default="src_text", type=str) + parser.add_argument("--split", default="raw", type=str) + parser.add_argument("--w2v-ckpt", required=True, type=Path) + parser.add_argument("--w2v-dict-dir", required=True, type=Path) + parser.add_argument("--w2v-sample-rate", default=16000, type=int) + parser.add_argument("--w2v-label", default="ltr", type=str) + parser.add_argument("--err-unit", default="word", type=str) + args = parser.parse_args() + + main(args) diff --git a/examples/speech_synthesis/evaluation/eval_f0.py b/examples/speech_synthesis/evaluation/eval_f0.py new file mode 100644 index 0000000000..df721d6831 --- /dev/null +++ b/examples/speech_synthesis/evaluation/eval_f0.py @@ -0,0 +1,266 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Signal processing-based evaluation using waveforms +""" +import numpy as np +import os.path as op + +import torchaudio +import tqdm +from tabulate import tabulate + +from examples.speech_synthesis.utils import ( + gross_pitch_error, voicing_decision_error, f0_frame_error +) +from examples.speech_synthesis.evaluation.eval_sp import load_eval_spec + + +def difference_function(x, n, tau_max): + """ + Compute difference function of data x. This solution is implemented directly + with Numpy fft. + + + :param x: audio data + :param n: length of data + :param tau_max: integration window size + :return: difference function + :rtype: list + """ + + x = np.array(x, np.float64) + w = x.size + tau_max = min(tau_max, w) + x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum())) + size = w + tau_max + p2 = (size // 32).bit_length() + nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) + size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size) + fc = np.fft.rfft(x, size_pad) + conv = np.fft.irfft(fc * fc.conjugate())[:tau_max] + return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - \ + 2 * conv + + +def cumulative_mean_normalized_difference_function(df, n): + """ + Compute cumulative mean normalized difference function (CMND). + + :param df: Difference function + :param n: length of data + :return: cumulative mean normalized difference function + :rtype: list + """ + + # scipy method + cmn_df = df[1:] * range(1, n) / np.cumsum(df[1:]).astype(float) + return np.insert(cmn_df, 0, 1) + + +def get_pitch(cmdf, tau_min, tau_max, harmo_th=0.1): + """ + Return fundamental period of a frame based on CMND function. + + :param cmdf: Cumulative Mean Normalized Difference function + :param tau_min: minimum period for speech + :param tau_max: maximum period for speech + :param harmo_th: harmonicity threshold to determine if it is necessary to + compute pitch frequency + :return: fundamental period if there is values under threshold, 0 otherwise + :rtype: float + """ + tau = tau_min + while tau < tau_max: + if cmdf[tau] < harmo_th: + while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]: + tau += 1 + return tau + tau += 1 + + return 0 # if unvoiced + + +def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500, + harmo_thresh=0.1): + """ + + Compute the Yin Algorithm. Return fundamental frequency and harmonic rate. + + https://github.com/NVIDIA/mellotron adaption of + https://github.com/patriceguyot/Yin + + :param sig: Audio signal (list of float) + :param sr: sampling rate (int) + :param w_len: size of the analysis window (samples) + :param w_step: size of the lag between two consecutives windows (samples) + :param f0_min: Minimum fundamental frequency that can be detected (hertz) + :param f0_max: Maximum fundamental frequency that can be detected (hertz) + :param harmo_thresh: Threshold of detection. The yalgorithmù return the + first minimum of the CMND function below this threshold. + + :returns: + + * pitches: list of fundamental frequencies, + * harmonic_rates: list of harmonic rate values for each fundamental + frequency value (= confidence value) + * argmins: minimums of the Cumulative Mean Normalized DifferenceFunction + * times: list of time of each estimation + :rtype: tuple + """ + + tau_min = int(sr / f0_max) + tau_max = int(sr / f0_min) + + # time values for each analysis window + time_scale = range(0, len(sig) - w_len, w_step) + times = [t/float(sr) for t in time_scale] + frames = [sig[t:t + w_len] for t in time_scale] + + pitches = [0.0] * len(time_scale) + harmonic_rates = [0.0] * len(time_scale) + argmins = [0.0] * len(time_scale) + + for i, frame in enumerate(frames): + # Compute YIN + df = difference_function(frame, w_len, tau_max) + cm_df = cumulative_mean_normalized_difference_function(df, tau_max) + p = get_pitch(cm_df, tau_min, tau_max, harmo_thresh) + + # Get results + if np.argmin(cm_df) > tau_min: + argmins[i] = float(sr / np.argmin(cm_df)) + if p != 0: # A pitch was found + pitches[i] = float(sr / p) + harmonic_rates[i] = cm_df[p] + else: # No pitch, but we compute a value of the harmonic rate + harmonic_rates[i] = min(cm_df) + + return pitches, harmonic_rates, argmins, times + + +def extract_f0(samples): + f0_samples = [] + for sample in tqdm.tqdm(samples): + if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]): + f0_samples.append(None) + continue + + # assume single channel + yref, sr = torchaudio.load(sample["ref"]) + ysyn, _sr = torchaudio.load(sample["syn"]) + yref, ysyn = yref[0], ysyn[0] + assert sr == _sr, f"{sr} != {_sr}" + + yref_f0 = compute_yin(yref, sr) + ysyn_f0 = compute_yin(ysyn, sr) + + f0_samples += [ + { + "ref": yref_f0, + "syn": ysyn_f0 + } + ] + + return f0_samples + + +def eval_f0_error(samples, distortion_fn): + results = [] + for sample in tqdm.tqdm(samples): + if sample is None: + results.append(None) + continue + # assume single channel + yref_f, _, _, yref_t = sample["ref"] + ysyn_f, _, _, ysyn_t = sample["syn"] + + yref_f = np.array(yref_f) + yref_t = np.array(yref_t) + ysyn_f = np.array(ysyn_f) + ysyn_t = np.array(ysyn_t) + + distortion = distortion_fn(yref_t, yref_f, ysyn_t, ysyn_f) + results.append((distortion.item(), + len(yref_f), + len(ysyn_f) + )) + return results + + +def eval_gross_pitch_error(samples): + return eval_f0_error(samples, gross_pitch_error) + + +def eval_voicing_decision_error(samples): + return eval_f0_error(samples, voicing_decision_error) + + +def eval_f0_frame_error(samples): + return eval_f0_error(samples, f0_frame_error) + + +def print_results(results, show_bin): + results = np.array(list(filter(lambda x: x is not None, results))) + + np.set_printoptions(precision=3) + + def _print_result(results): + res = { + "nutt": len(results), + "error": results[:, 0].mean(), + "std": results[:, 0].std(), + "dur_ref": int(results[:, 1].sum()), + "dur_syn": int(results[:, 2].sum()), + } + print(tabulate([res.values()], res.keys(), floatfmt=".4f")) + + print(">>>> ALL") + _print_result(results) + + if show_bin: + edges = [0, 200, 400, 600, 800, 1000, 2000, 4000] + for i in range(1, len(edges)): + mask = np.logical_and(results[:, 1] >= edges[i-1], + results[:, 1] < edges[i]) + if not mask.any(): + continue + bin_results = results[mask] + print(f">>>> ({edges[i-1]}, {edges[i]})") + _print_result(bin_results) + + +def main(eval_f0, gpe, vde, ffe, show_bin): + samples = load_eval_spec(eval_f0) + if gpe or vde or ffe: + f0_samples = extract_f0(samples) + + if gpe: + print("===== Evaluate Gross Pitch Error =====") + results = eval_gross_pitch_error(f0_samples) + print_results(results, show_bin) + if vde: + print("===== Evaluate Voicing Decision Error =====") + results = eval_voicing_decision_error(f0_samples) + print_results(results, show_bin) + if ffe: + print("===== Evaluate F0 Frame Error =====") + results = eval_f0_frame_error(f0_samples) + print_results(results, show_bin) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("eval_f0") + parser.add_argument("--gpe", action="store_true") + parser.add_argument("--vde", action="store_true") + parser.add_argument("--ffe", action="store_true") + parser.add_argument("--show-bin", action="store_true") + args = parser.parse_args() + + main(args.eval_f0, args.gpe, args.vde, args.ffe, args.show_bin) diff --git a/examples/speech_synthesis/evaluation/eval_sp.py b/examples/speech_synthesis/evaluation/eval_sp.py new file mode 100644 index 0000000000..702c498038 --- /dev/null +++ b/examples/speech_synthesis/evaluation/eval_sp.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Signal processing-based evaluation using waveforms +""" + +import csv +import numpy as np +import os.path as op + +import torch +import tqdm +from tabulate import tabulate +import torchaudio + +from examples.speech_synthesis.utils import batch_mel_spectral_distortion +from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion + + +def load_eval_spec(path): + with open(path) as f: + reader = csv.DictReader(f, delimiter='\t') + samples = list(reader) + return samples + + +def eval_distortion(samples, distortion_fn, device="cuda"): + nmiss = 0 + results = [] + for sample in tqdm.tqdm(samples): + if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]): + nmiss += 1 + results.append(None) + continue + # assume single channel + yref, sr = torchaudio.load(sample["ref"]) + ysyn, _sr = torchaudio.load(sample["syn"]) + yref, ysyn = yref[0].to(device), ysyn[0].to(device) + assert sr == _sr, f"{sr} != {_sr}" + + distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0] + _, _, _, _, _, pathmap = extra + nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn + ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn + results.append( + (distortion.item(), # path distortion + pathmap.size(0), # yref num frames + pathmap.size(1), # ysyn num frames + pathmap.sum().item(), # path length + nins.item(), # insertion + ndel.item(), # deletion + ) + ) + return results + + +def eval_mel_cepstral_distortion(samples, device="cuda"): + return eval_distortion(samples, batch_mel_cepstral_distortion, device) + + +def eval_mel_spectral_distortion(samples, device="cuda"): + return eval_distortion(samples, batch_mel_spectral_distortion, device) + + +def print_results(results, show_bin): + results = np.array(list(filter(lambda x: x is not None, results))) + + np.set_printoptions(precision=3) + + def _print_result(results): + dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0) + res = { + "nutt": len(results), + "dist": dist, + "dur_ref": int(dur_ref), + "dur_syn": int(dur_syn), + "dur_ali": int(dur_ali), + "dist_per_ref_frm": dist/dur_ref, + "dist_per_syn_frm": dist/dur_syn, + "dist_per_ali_frm": dist/dur_ali, + "ins": nins/dur_ref, + "del": ndel/dur_ref, + } + print(tabulate( + [res.values()], + res.keys(), + floatfmt=".4f" + )) + + print(">>>> ALL") + _print_result(results) + + if show_bin: + edges = [0, 200, 400, 600, 800, 1000, 2000, 4000] + for i in range(1, len(edges)): + mask = np.logical_and(results[:, 1] >= edges[i-1], + results[:, 1] < edges[i]) + if not mask.any(): + continue + bin_results = results[mask] + print(f">>>> ({edges[i-1]}, {edges[i]})") + _print_result(bin_results) + + +def main(eval_spec, mcd, msd, show_bin): + samples = load_eval_spec(eval_spec) + device = "cpu" + if mcd: + print("===== Evaluate Mean Cepstral Distortion =====") + results = eval_mel_cepstral_distortion(samples, device) + print_results(results, show_bin) + if msd: + print("===== Evaluate Mean Spectral Distortion =====") + results = eval_mel_spectral_distortion(samples, device) + print_results(results, show_bin) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("eval_spec") + parser.add_argument("--mcd", action="store_true") + parser.add_argument("--msd", action="store_true") + parser.add_argument("--show-bin", action="store_true") + args = parser.parse_args() + + main(args.eval_spec, args.mcd, args.msd, args.show_bin) diff --git a/examples/speech_synthesis/evaluation/get_eval_manifest.py b/examples/speech_synthesis/evaluation/get_eval_manifest.py new file mode 100644 index 0000000000..a28cd607a0 --- /dev/null +++ b/examples/speech_synthesis/evaluation/get_eval_manifest.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import csv +from pathlib import Path + + +def main(args): + """ + `uid syn ref text` + """ + in_root = Path(args.generation_root).resolve() + ext = args.audio_format + with open(args.audio_manifest) as f, open(args.output_path, "w") as f_out: + reader = csv.DictReader( + f, delimiter="\t", quotechar=None, doublequote=False, + lineterminator="\n", quoting=csv.QUOTE_NONE + ) + header = ["id", "syn", "ref", "text", "speaker"] + f_out.write("\t".join(header) + "\n") + for row in reader: + dir_name = f"{ext}_{args.sample_rate}hz_{args.vocoder}" + id_ = row["id"] + syn = (in_root / dir_name / f"{id_}.{ext}").as_posix() + ref = row["audio"] + if args.use_resynthesized_target: + ref = (in_root / f"{dir_name}_tgt" / f"{id_}.{ext}").as_posix() + sample = [id_, syn, ref, row["tgt_text"], row["speaker"]] + f_out.write("\t".join(sample) + "\n") + print(f"wrote evaluation file to {args.output_path}") + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--generation-root", help="output directory for generate_waveform.py" + ) + parser.add_argument( + "--audio-manifest", + help="used to determine the original utterance ID and text" + ) + parser.add_argument( + "--output-path", help="path to output evaluation spec file" + ) + parser.add_argument( + "--use-resynthesized-target", action="store_true", + help="use resynthesized reference instead of the original audio" + ) + parser.add_argument("--vocoder", type=str, default="griffin_lim") + parser.add_argument("--sample-rate", type=int, default=22_050) + parser.add_argument("--audio-format", type=str, default="wav") + args = parser.parse_args() + + main(args) diff --git a/examples/speech_synthesis/preprocessing/__init__.py b/examples/speech_synthesis/preprocessing/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py b/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py new file mode 100644 index 0000000000..4e13b38a5d --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py @@ -0,0 +1,204 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import csv +import tempfile +from collections import defaultdict +from pathlib import Path + +import torchaudio +try: + import webrtcvad +except ImportError: + raise ImportError("Please install py-webrtcvad: pip install webrtcvad") +import pandas as pd +from tqdm import tqdm + +from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64 +import examples.speech_synthesis.preprocessing.denoiser.utils as utils +from examples.speech_synthesis.preprocessing.vad import ( + frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD, + SCALE +) +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +PATHS = ["after_denoise", "after_vad"] +MIN_T = 0.05 + + +def generate_tmp_filename(extension="txt"): + return tempfile._get_default_tempdir() + "/" + \ + next(tempfile._get_candidate_names()) + "." + extension + + +def convert_sr(inpath, sr, output_path=None): + if not output_path: + output_path = generate_tmp_filename("wav") + cmd = f"sox {inpath} -r {sr} {output_path}" + os.system(cmd) + return output_path + + +def apply_vad(vad, inpath): + audio, sample_rate = read_wave(inpath) + frames = frame_generator(FS_MS, audio, sample_rate) + frames = list(frames) + segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) + merge_segments = list() + timestamp_start = 0.0 + timestamp_end = 0.0 + # removing start, end, and long sequences of sils + for i, segment in enumerate(segments): + merge_segments.append(segment[0]) + if i and timestamp_start: + sil_duration = segment[1] - timestamp_end + if sil_duration > THRESHOLD: + merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00')) + else: + merge_segments.append(int((sil_duration / SCALE)) * (b'\x00')) + timestamp_start = segment[1] + timestamp_end = segment[2] + segment = b''.join(merge_segments) + return segment, sample_rate + + +def write(wav, filename, sr=16_000): + # Normalize audio if it prevents clipping + wav = wav / max(wav.abs().max().item(), 1) + torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S", + bits_per_sample=16) + + +def process(args): + # making sure we are requested either denoise or vad + if not args.denoise and not args.vad: + log.error("No denoise or vad is requested.") + return + + log.info("Creating out directories...") + if args.denoise: + out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0]) + out_denoise.mkdir(parents=True, exist_ok=True) + if args.vad: + out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1]) + out_vad.mkdir(parents=True, exist_ok=True) + + log.info("Loading pre-trained speech enhancement model...") + model = master64().to(args.device) + + log.info("Building the VAD model...") + vad = webrtcvad.Vad(int(args.vad_agg_level)) + + # preparing the output dict + output_dict = defaultdict(list) + + log.info(f"Parsing input manifest: {args.audio_manifest}") + with open(args.audio_manifest, "r") as f: + manifest_dict = csv.DictReader(f, delimiter="\t") + for row in tqdm(manifest_dict): + filename = str(row["audio"]) + + final_output = filename + keep_sample = True + n_frames = row["n_frames"] + snr = -1 + if args.denoise: + output_path_denoise = out_denoise.joinpath(Path(filename).name) + # convert to 16khz in case we use a differet sr + tmp_path = convert_sr(final_output, 16000) + + # loading audio file and generating the enhanced version + out, sr = torchaudio.load(tmp_path) + out = out.to(args.device) + estimate = model(out) + estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out + write(estimate[0], str(output_path_denoise), sr) + + snr = utils.cal_snr(out, estimate) + snr = snr.cpu().detach().numpy()[0][0] + final_output = str(output_path_denoise) + + if args.vad: + output_path_vad = out_vad.joinpath(Path(filename).name) + sr = torchaudio.info(final_output).sample_rate + if sr in [16000, 32000, 48000]: + tmp_path = final_output + elif sr < 16000: + tmp_path = convert_sr(final_output, 16000) + elif sr < 32000: + tmp_path = convert_sr(final_output, 32000) + else: + tmp_path = convert_sr(final_output, 48000) + # apply VAD + segment, sample_rate = apply_vad(vad, tmp_path) + if len(segment) < sample_rate * MIN_T: + keep_sample = False + print(( + f"WARNING: skip {filename} because it is too short " + f"after VAD ({len(segment) / sample_rate} < {MIN_T})" + )) + else: + if sample_rate != sr: + tmp_path = generate_tmp_filename("wav") + write_wave(tmp_path, segment, sample_rate) + convert_sr(tmp_path, sr, + output_path=str(output_path_vad)) + else: + write_wave(str(output_path_vad), segment, sample_rate) + final_output = str(output_path_vad) + segment, _ = torchaudio.load(final_output) + n_frames = segment.size(1) + + if keep_sample: + output_dict["id"].append(row["id"]) + output_dict["audio"].append(final_output) + output_dict["n_frames"].append(n_frames) + output_dict["tgt_text"].append(row["tgt_text"]) + output_dict["speaker"].append(row["speaker"]) + output_dict["src_text"].append(row["src_text"]) + output_dict["snr"].append(snr) + + out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name + log.info(f"Saving manifest to {out_tsv_path.as_posix()}") + save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--audio-manifest", "-i", required=True, + type=str, help="path to the input manifest.") + parser.add_argument( + "--output-dir", "-o", required=True, type=str, + help="path to the output dir. it will contain files after denoising and" + " vad" + ) + parser.add_argument("--vad-agg-level", "-a", type=int, default=2, + help="the aggresive level of the vad [0-3].") + parser.add_argument( + "--dry-wet", "-dw", type=float, default=0.01, + help="the level of linear interpolation between noisy and enhanced " + "files." + ) + parser.add_argument( + "--device", "-d", type=str, default="cpu", + help="the device to be used for the speech enhancement model: " + "cpu | cuda." + ) + parser.add_argument("--denoise", action="store_true", + help="apply a denoising") + parser.add_argument("--vad", action="store_true", help="apply a VAD") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/denoiser/__init__.py b/examples/speech_synthesis/preprocessing/denoiser/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoiser/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/speech_synthesis/preprocessing/denoiser/demucs.py b/examples/speech_synthesis/preprocessing/denoiser/demucs.py new file mode 100644 index 0000000000..3f70e73d6a --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoiser/demucs.py @@ -0,0 +1,473 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import math +import time + +import torch as th +from torch import nn +from torch.nn import functional as F + +from .resample import downsample2, upsample2 +from .utils import capture_init + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=2, bi=True): + super().__init__() + klass = nn.LSTM + self.lstm = klass( + bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim + ) + self.linear = None + if bi: + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x, hidden=None): + x, hidden = self.lstm(x, hidden) + if self.linear: + x = self.linear(x) + return x, hidden + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +class Demucs(nn.Module): + """ + Demucs speech enhancement model. + Args: + - chin (int): number of input channels. + - chout (int): number of output channels. + - hidden (int): number of initial hidden channels. + - depth (int): number of layers. + - kernel_size (int): kernel size for each layer. + - stride (int): stride for each layer. + - causal (bool): if false, uses BiLSTM instead of LSTM. + - resample (int): amount of resampling to apply to the input/output. + Can be one of 1, 2 or 4. + - growth (float): number of channels is multiplied by this for every layer. + - max_hidden (int): maximum number of channels. Can be useful to + control the size/speed of the model. + - normalize (bool): if true, normalize the input. + - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. + - rescale (float): controls custom weight initialization. + See https://arxiv.org/abs/1911.13254. + - floor (float): stability flooring when normalizing. + + """ + @capture_init + def __init__(self, + chin=1, + chout=1, + hidden=48, + depth=5, + kernel_size=8, + stride=4, + causal=True, + resample=4, + growth=2, + max_hidden=10_000, + normalize=True, + glu=True, + rescale=0.1, + floor=1e-3): + + super().__init__() + if resample not in [1, 2, 4]: + raise ValueError("Resample should be 1, 2 or 4.") + + self.chin = chin + self.chout = chout + self.hidden = hidden + self.depth = depth + self.kernel_size = kernel_size + self.stride = stride + self.causal = causal + self.floor = floor + self.resample = resample + self.normalize = normalize + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + activation = nn.GLU(1) if glu else nn.ReLU() + ch_scale = 2 if glu else 1 + + for index in range(depth): + encode = [] + encode += [ + nn.Conv1d(chin, hidden, kernel_size, stride), + nn.ReLU(), + nn.Conv1d(hidden, hidden * ch_scale, 1), activation, + ] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + decode += [ + nn.Conv1d(hidden, ch_scale * hidden, 1), activation, + nn.ConvTranspose1d(hidden, chout, kernel_size, stride), + ] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + chout = hidden + chin = hidden + hidden = min(int(growth * hidden), max_hidden) + + self.lstm = BLSTM(chin, bi=not causal) + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length. + """ + length = math.ceil(length * self.resample) + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(length, 1) + for _ in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + length = int(math.ceil(length / self.resample)) + return int(length) + + @property + def total_stride(self): + return self.stride ** self.depth // self.resample + + def forward(self, mix): + if mix.dim() == 2: + mix = mix.unsqueeze(1) + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + mix = mix / (self.floor + std) + else: + std = 1 + length = mix.shape[-1] + x = mix + x = F.pad(x, (0, self.valid_length(length) - length)) + if self.resample == 2: + x = upsample2(x) + elif self.resample == 4: + x = upsample2(x) + x = upsample2(x) + skips = [] + for encode in self.encoder: + x = encode(x) + skips.append(x) + x = x.permute(2, 0, 1) + x, _ = self.lstm(x) + x = x.permute(1, 2, 0) + for decode in self.decoder: + skip = skips.pop(-1) + x = x + skip[..., :x.shape[-1]] + x = decode(x) + if self.resample == 2: + x = downsample2(x) + elif self.resample == 4: + x = downsample2(x) + x = downsample2(x) + + x = x[..., :length] + return std * x + + +def fast_conv(conv, x): + """ + Faster convolution evaluation if either kernel size is 1 + or length of sequence is 1. + """ + batch, chin, length = x.shape + chout, chin, kernel = conv.weight.shape + assert batch == 1 + if kernel == 1: + x = x.view(chin, length) + out = th.addmm(conv.bias.view(-1, 1), + conv.weight.view(chout, chin), x) + elif length == kernel: + x = x.view(chin * kernel, 1) + out = th.addmm(conv.bias.view(-1, 1), + conv.weight.view(chout, chin * kernel), x) + else: + out = conv(x) + return out.view(batch, chout, -1) + + +class DemucsStreamer: + """ + Streaming implementation for Demucs. It supports being fed with any amount + of audio at a time. You will get back as much audio as possible at that + point. + + Args: + - demucs (Demucs): Demucs model. + - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum + noise removal, 1 just returns the input signal. Small values > 0 + allows to limit distortions. + - num_frames (int): number of frames to process at once. Higher values + will increase overall latency but improve the real time factor. + - resample_lookahead (int): extra lookahead used for the resampling. + - resample_buffer (int): size of the buffer of previous inputs/outputs + kept for resampling. + """ + def __init__(self, demucs, + dry=0, + num_frames=1, + resample_lookahead=64, + resample_buffer=256): + device = next(iter(demucs.parameters())).device + self.demucs = demucs + self.lstm_state = None + self.conv_state = None + self.dry = dry + self.resample_lookahead = resample_lookahead + resample_buffer = min(demucs.total_stride, resample_buffer) + self.resample_buffer = resample_buffer + self.frame_length = demucs.valid_length(1) + \ + demucs.total_stride * (num_frames - 1) + self.total_length = self.frame_length + self.resample_lookahead + self.stride = demucs.total_stride * num_frames + self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device) + self.resample_out = th.zeros( + demucs.chin, resample_buffer, device=device + ) + + self.frames = 0 + self.total_time = 0 + self.variance = 0 + self.pending = th.zeros(demucs.chin, 0, device=device) + + bias = demucs.decoder[0][2].bias + weight = demucs.decoder[0][2].weight + chin, chout, kernel = weight.shape + self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) + self._weight = weight.permute(1, 2, 0).contiguous() + + def reset_time_per_frame(self): + self.total_time = 0 + self.frames = 0 + + @property + def time_per_frame(self): + return self.total_time / self.frames + + def flush(self): + """ + Flush remaining audio by padding it with zero. Call this + when you have no more input and want to get back the last chunk of audio. + """ + pending_length = self.pending.shape[1] + padding = th.zeros( + self.demucs.chin, self.total_length, device=self.pending.device + ) + out = self.feed(padding) + return out[:, :pending_length] + + def feed(self, wav): + """ + Apply the model to mix using true real time evaluation. + Normalization is done online as is the resampling. + """ + begin = time.time() + demucs = self.demucs + resample_buffer = self.resample_buffer + stride = self.stride + resample = demucs.resample + + if wav.dim() != 2: + raise ValueError("input wav should be two dimensional.") + chin, _ = wav.shape + if chin != demucs.chin: + raise ValueError(f"Expected {demucs.chin} channels, got {chin}") + + self.pending = th.cat([self.pending, wav], dim=1) + outs = [] + while self.pending.shape[1] >= self.total_length: + self.frames += 1 + frame = self.pending[:, :self.total_length] + dry_signal = frame[:, :stride] + if demucs.normalize: + mono = frame.mean(0) + variance = (mono**2).mean() + self.variance = variance / self.frames + \ + (1 - 1 / self.frames) * self.variance + frame = frame / (demucs.floor + math.sqrt(self.variance)) + frame = th.cat([self.resample_in, frame], dim=-1) + self.resample_in[:] = frame[:, stride - resample_buffer:stride] + + if resample == 4: + frame = upsample2(upsample2(frame)) + elif resample == 2: + frame = upsample2(frame) + # remove pre sampling buffer + frame = frame[:, resample * resample_buffer:] + # remove extra samples after window + frame = frame[:, :resample * self.frame_length] + + out, extra = self._separate_frame(frame) + padded_out = th.cat([self.resample_out, out, extra], 1) + self.resample_out[:] = out[:, -resample_buffer:] + if resample == 4: + out = downsample2(downsample2(padded_out)) + elif resample == 2: + out = downsample2(padded_out) + else: + out = padded_out + + out = out[:, resample_buffer // resample:] + out = out[:, :stride] + + if demucs.normalize: + out *= math.sqrt(self.variance) + out = self.dry * dry_signal + (1 - self.dry) * out + outs.append(out) + self.pending = self.pending[:, stride:] + + self.total_time += time.time() - begin + if outs: + out = th.cat(outs, 1) + else: + out = th.zeros(chin, 0, device=wav.device) + return out + + def _separate_frame(self, frame): + demucs = self.demucs + skips = [] + next_state = [] + first = self.conv_state is None + stride = self.stride * demucs.resample + x = frame[None] + for idx, encode in enumerate(demucs.encoder): + stride //= demucs.stride + length = x.shape[2] + if idx == demucs.depth - 1: + # This is sligthly faster for the last conv + x = fast_conv(encode[0], x) + x = encode[1](x) + x = fast_conv(encode[2], x) + x = encode[3](x) + else: + if not first: + prev = self.conv_state.pop(0) + prev = prev[..., stride:] + tgt = (length - demucs.kernel_size) // demucs.stride + 1 + missing = tgt - prev.shape[-1] + offset = length - demucs.kernel_size - \ + demucs.stride * (missing - 1) + x = x[..., offset:] + x = encode[1](encode[0](x)) + x = fast_conv(encode[2], x) + x = encode[3](x) + if not first: + x = th.cat([prev, x], -1) + next_state.append(x) + skips.append(x) + + x = x.permute(2, 0, 1) + x, self.lstm_state = demucs.lstm(x, self.lstm_state) + x = x.permute(1, 2, 0) + # In the following, x contains only correct samples, i.e. the one + # for which each time position is covered by two window of the upper + # layer. extra contains extra samples to the right, and is used only as + # a better padding for the online resampling. + extra = None + for idx, decode in enumerate(demucs.decoder): + skip = skips.pop(-1) + x += skip[..., :x.shape[-1]] + x = fast_conv(decode[0], x) + x = decode[1](x) + + if extra is not None: + skip = skip[..., x.shape[-1]:] + extra += skip[..., :extra.shape[-1]] + extra = decode[2](decode[1](decode[0](extra))) + x = decode[2](x) + next_state.append( + x[..., -demucs.stride:] - decode[2].bias.view(-1, 1) + ) + if extra is None: + extra = x[..., -demucs.stride:] + else: + extra[..., :demucs.stride] += next_state[-1] + x = x[..., :-demucs.stride] + + if not first: + prev = self.conv_state.pop(0) + x[..., :demucs.stride] += prev + if idx != demucs.depth - 1: + x = decode[3](x) + extra = decode[3](extra) + self.conv_state = next_state + return x[0], extra[0] + + +def test(): + import argparse + parser = argparse.ArgumentParser( + "denoiser.demucs", + description="Benchmark the streaming Demucs implementation, as well as " + "checking the delta with the offline implementation.") + parser.add_argument("--depth", default=5, type=int) + parser.add_argument("--resample", default=4, type=int) + parser.add_argument("--hidden", default=48, type=int) + parser.add_argument("--sample_rate", default=16000, type=float) + parser.add_argument("--device", default="cpu") + parser.add_argument("-t", "--num_threads", type=int) + parser.add_argument("-f", "--num_frames", type=int, default=1) + args = parser.parse_args() + if args.num_threads: + th.set_num_threads(args.num_threads) + sr = args.sample_rate + sr_ms = sr / 1000 + demucs = Demucs( + depth=args.depth, hidden=args.hidden, resample=args.resample + ).to(args.device) + x = th.randn(1, int(sr * 4)).to(args.device) + out = demucs(x[None])[0] + streamer = DemucsStreamer(demucs, num_frames=args.num_frames) + out_rt = [] + frame_size = streamer.total_length + with th.no_grad(): + while x.shape[1] > 0: + out_rt.append(streamer.feed(x[:, :frame_size])) + x = x[:, frame_size:] + frame_size = streamer.demucs.total_stride + out_rt.append(streamer.flush()) + out_rt = th.cat(out_rt, 1) + model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20 + initial_lag = streamer.total_length / sr_ms + tpf = 1000 * streamer.time_per_frame + print(f"model size: {model_size:.1f}MB, ", end='') + print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}") + print(f"initial lag: {initial_lag:.1f}ms, ", end='') + print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms") + print(f"time per frame: {tpf:.1f}ms, ", end='') + rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms) + print(f"RTF: {rtf:.2f}") + print(f"Total lag with computation: {initial_lag + tpf:.1f}ms") + + +if __name__ == "__main__": + test() diff --git a/examples/speech_synthesis/preprocessing/denoiser/pretrained.py b/examples/speech_synthesis/preprocessing/denoiser/pretrained.py new file mode 100644 index 0000000000..2fa846075b --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoiser/pretrained.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import logging + +import torch.hub + +from .demucs import Demucs +from .utils import deserialize_model + +logger = logging.getLogger(__name__) +ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/" +DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th" +DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th" +MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th" + + +def _demucs(pretrained, url, **kwargs): + model = Demucs(**kwargs) + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') + model.load_state_dict(state_dict) + return model + + +def dns48(pretrained=True): + return _demucs(pretrained, DNS_48_URL, hidden=48) + + +def dns64(pretrained=True): + return _demucs(pretrained, DNS_64_URL, hidden=64) + + +def master64(pretrained=True): + return _demucs(pretrained, MASTER_64_URL, hidden=64) + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument( + "-m", "--model_path", help="Path to local trained model." + ) + group.add_argument( + "--dns48", action="store_true", + help="Use pre-trained real time H=48 model trained on DNS." + ) + group.add_argument( + "--dns64", action="store_true", + help="Use pre-trained real time H=64 model trained on DNS." + ) + group.add_argument( + "--master64", action="store_true", + help="Use pre-trained real time H=64 model trained on DNS and Valentini." + ) + + +def get_model(args): + """ + Load local model package or torchhub pre-trained model. + """ + if args.model_path: + logger.info("Loading model from %s", args.model_path) + pkg = torch.load(args.model_path) + model = deserialize_model(pkg) + elif args.dns64: + logger.info("Loading pre-trained real time H=64 model trained on DNS.") + model = dns64() + elif args.master64: + logger.info( + "Loading pre-trained real time H=64 model trained on DNS and Valentini." + ) + model = master64() + else: + logger.info("Loading pre-trained real time H=48 model trained on DNS.") + model = dns48() + logger.debug(model) + return model diff --git a/examples/speech_synthesis/preprocessing/denoiser/resample.py b/examples/speech_synthesis/preprocessing/denoiser/resample.py new file mode 100644 index 0000000000..1222addc42 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoiser/resample.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import math + +import torch as th +from torch.nn import functional as F + + +def sinc(t): + """sinc. + + :param t: the input tensor + """ + return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), + th.sin(t) / t) + + +def kernel_upsample2(zeros=56): + """kernel_upsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t *= math.pi + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def upsample2(x, zeros=56): + """ + Upsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + *other, time = x.shape + kernel = kernel_upsample2(zeros).to(x) + out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view( + *other, time + ) + y = th.stack([x, out], dim=-1) + return y.view(*other, -1) + + +def kernel_downsample2(zeros=56): + """kernel_downsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t.mul_(math.pi) + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def downsample2(x, zeros=56): + """ + Downsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + if x.shape[-1] % 2 != 0: + x = F.pad(x, (0, 1)) + xeven = x[..., ::2] + xodd = x[..., 1::2] + *other, time = xodd.shape + kernel = kernel_downsample2(zeros).to(x) + out = xeven + F.conv1d( + xodd.view(-1, 1, time), kernel, padding=zeros + )[..., :-1].view(*other, time) + return out.view(*other, -1).mul(0.5) diff --git a/examples/speech_synthesis/preprocessing/denoiser/utils.py b/examples/speech_synthesis/preprocessing/denoiser/utils.py new file mode 100644 index 0000000000..734d047f1b --- /dev/null +++ b/examples/speech_synthesis/preprocessing/denoiser/utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import functools +import logging +from contextlib import contextmanager +import inspect +import time + +logger = logging.getLogger(__name__) + +EPS = 1e-8 + + +def capture_init(init): + """capture_init. + + Decorate `__init__` with this, and you can then + recover the *args and **kwargs passed to it in `self._init_args_kwargs` + """ + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ + + +def deserialize_model(package, strict=False): + """deserialize_model. + + """ + klass = package['class'] + if strict: + model = klass(*package['args'], **package['kwargs']) + else: + sig = inspect.signature(klass) + kw = package['kwargs'] + for key in list(kw): + if key not in sig.parameters: + logger.warning("Dropping inexistant parameter %s", key) + del kw[key] + model = klass(*package['args'], **kw) + model.load_state_dict(package['state']) + return model + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +def serialize_model(model): + args, kwargs = model._init_args_kwargs + state = copy_state(model.state_dict()) + return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state) + try: + yield + finally: + model.load_state_dict(old_state) + + +def pull_metric(history, name): + out = [] + for metrics in history: + if name in metrics: + out.append(metrics[name]) + return out + + +class LogProgress: + """ + Sort of like tqdm but using log lines and not as real time. + Args: + - logger: logger obtained from `logging.getLogger`, + - iterable: iterable object to wrap + - updates (int): number of lines that will be printed, e.g. + if `updates=5`, log every 1/5th of the total length. + - total (int): length of the iterable, in case it does not support + `len`. + - name (str): prefix to use in the log. + - level: logging level (like `logging.INFO`). + """ + def __init__(self, + logger, + iterable, + updates=5, + total=None, + name="LogProgress", + level=logging.INFO): + self.iterable = iterable + self.total = total or len(iterable) + self.updates = updates + self.name = name + self.logger = logger + self.level = level + + def update(self, **infos): + self._infos = infos + + def __iter__(self): + self._iterator = iter(self.iterable) + self._index = -1 + self._infos = {} + self._begin = time.time() + return self + + def __next__(self): + self._index += 1 + try: + value = next(self._iterator) + except StopIteration: + raise + else: + return value + finally: + log_every = max(1, self.total // self.updates) + # logging is delayed by 1 it, in order to have the metrics from update + if self._index >= 1 and self._index % log_every == 0: + self._log() + + def _log(self): + self._speed = (1 + self._index) / (time.time() - self._begin) + infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) + if self._speed < 1e-4: + speed = "oo sec/it" + elif self._speed < 0.1: + speed = f"{1/self._speed:.1f} sec/it" + else: + speed = f"{self._speed:.1f} it/sec" + out = f"{self.name} | {self._index}/{self.total} | {speed}" + if infos: + out += " | " + infos + self.logger.log(self.level, out) + + +def colorize(text, color): + """ + Display text with some ANSI color in the terminal. + """ + code = f"\033[{color}m" + restore = "\033[0m" + return "".join([code, text, restore]) + + +def bold(text): + """ + Display text in bold in the terminal. + """ + return colorize(text, "1") + + +def cal_snr(lbl, est): + import torch + y = 10.0 * torch.log10( + torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) + + EPS + ) + return y diff --git a/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py b/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py new file mode 100644 index 0000000000..a302546043 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +from collections import defaultdict +from typing import List, Dict, Tuple + +import pandas as pd +import numpy as np +import torchaudio +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def get_top_n( + root: Path, n_speakers: int = 10, min_n_tokens: int = 5 +) -> pd.DataFrame: + df = load_df_from_tsv(root / "validated.tsv") + df["n_tokens"] = [len(s.split()) for s in df["sentence"]] + df = df[df["n_tokens"] >= min_n_tokens] + df["n_frames"] = [ + torchaudio.info((root / "clips" / p).as_posix()).num_frames + for p in tqdm(df["path"]) + ] + df["id"] = [Path(p).stem for p in df["path"]] + total_duration_ms = df.groupby("client_id")["n_frames"].agg(["sum"]) + total_duration_ms = total_duration_ms.sort_values("sum", ascending=False) + + top_n_total_duration_ms = total_duration_ms.head(n_speakers) + top_n_client_ids = set(top_n_total_duration_ms.index.tolist()) + df_top_n = df[df["client_id"].isin(top_n_client_ids)] + return df_top_n + + +def get_splits( + df, train_split_ratio=0.99, speaker_in_all_splits=False, rand_seed=0 +) -> Tuple[Dict[str, str], List[str]]: + np.random.seed(rand_seed) + dev_split_ratio = (1. - train_split_ratio) / 3 + grouped = list(df.groupby("client_id")) + id_to_split = {} + for _, cur_df in tqdm(grouped): + cur_n_examples = len(cur_df) + if speaker_in_all_splits and cur_n_examples < 3: + continue + cur_n_train = int(cur_n_examples * train_split_ratio) + cur_n_dev = int(cur_n_examples * dev_split_ratio) + cur_n_test = cur_n_examples - cur_n_dev - cur_n_train + if speaker_in_all_splits and cur_n_dev * cur_n_test == 0: + cur_n_dev, cur_n_test = 1, 1 + cur_n_train = cur_n_examples - cur_n_dev - cur_n_test + cur_indices = cur_df.index.tolist() + cur_shuffled_indices = np.random.permutation(cur_n_examples) + cur_shuffled_indices = [cur_indices[i] for i in cur_shuffled_indices] + cur_indices_by_split = { + "train": cur_shuffled_indices[:cur_n_train], + "dev": cur_shuffled_indices[cur_n_train: cur_n_train + cur_n_dev], + "test": cur_shuffled_indices[cur_n_train + cur_n_dev:] + } + for split in SPLITS: + for i in cur_indices_by_split[split]: + id_ = df["id"].loc[i] + id_to_split[id_] = split + return id_to_split, sorted(df["client_id"].unique()) + + +def convert_to_wav(root: Path, filenames: List[str], target_sr=16_000): + out_root = root / "wav" + out_root.mkdir(exist_ok=True, parents=True) + print("Converting to WAV...") + for n in tqdm(filenames): + in_path = (root / "clips" / n).as_posix() + waveform, sr = torchaudio.load(in_path) + converted, converted_sr = torchaudio.sox_effects.apply_effects_tensor( + waveform, sr, [["rate", str(target_sr)], ["channels", "1"]] + ) + out_path = (out_root / Path(n).with_suffix(".wav").name).as_posix() + torchaudio.save(out_path, converted, converted_sr, encoding="PCM_S", + bits_per_sample=16) + + +def process(args): + data_root = Path(args.data_root).absolute() / args.lang + + # Generate TSV manifest + print("Generating manifest...") + + df_top_n = get_top_n(data_root) + id_to_split, speakers = get_splits(df_top_n) + + if args.convert_to_wav: + convert_to_wav(data_root, df_top_n["path"].tolist()) + + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + for sample in tqdm(df_top_n.to_dict(orient="index").values()): + sample_id = sample["id"] + split = id_to_split[sample_id] + manifest_by_split[split]["id"].append(sample_id) + if args.convert_to_wav: + audio_path = data_root / "wav" / f"{sample_id}.wav" + else: + audio_path = data_root / "clips" / f"{sample_id}.mp3" + manifest_by_split[split]["audio"].append(audio_path.as_posix()) + manifest_by_split[split]["n_frames"].append(sample["n_frames"]) + manifest_by_split[split]["tgt_text"].append(sample["sentence"]) + manifest_by_split[split]["speaker"].append(sample["client_id"]) + manifest_by_split[split]["src_text"].append(sample["sentence"]) + + output_root = Path(args.output_manifest_root).absolute() + output_root.mkdir(parents=True, exist_ok=True) + for split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + output_root / f"{split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + parser.add_argument("--lang", "-l", required=True, type=str) + parser.add_argument("--convert-to-wav", action="store_true") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/get_feature_manifest.py b/examples/speech_synthesis/preprocessing/get_feature_manifest.py new file mode 100644 index 0000000000..516f2cc469 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/get_feature_manifest.py @@ -0,0 +1,233 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +import shutil +from tempfile import NamedTemporaryFile +from collections import Counter, defaultdict + +import pandas as pd +import torchaudio +from tqdm import tqdm + +from fairseq.data.audio.audio_utils import convert_waveform +from examples.speech_to_text.data_utils import ( + create_zip, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + load_tsv_to_dicts, + save_df_to_tsv +) +from examples.speech_synthesis.data_utils import ( + extract_logmel_spectrogram, extract_pitch, extract_energy, get_global_cmvn, + ipa_phonemize, get_mfa_alignment, get_unit_alignment +) + + +log = logging.getLogger(__name__) + + +def process(args): + assert "train" in args.splits + out_root = Path(args.output_root).absolute() + out_root.mkdir(exist_ok=True) + + print("Fetching data...") + audio_manifest_root = Path(args.audio_manifest_root).absolute() + samples = [] + for s in args.splits: + for e in load_tsv_to_dicts(audio_manifest_root / f"{s}.audio.tsv"): + e["split"] = s + samples.append(e) + sample_ids = [s["id"] for s in samples] + + # Get alignment info + id_to_alignment = None + if args.textgrid_zip is not None: + assert args.id_to_units_tsv is None + id_to_alignment = get_mfa_alignment( + args.textgrid_zip, sample_ids, args.sample_rate, args.hop_length + ) + elif args.id_to_units_tsv is not None: + # assume identical hop length on the unit sequence + id_to_alignment = get_unit_alignment(args.id_to_units_tsv, sample_ids) + + # Extract features and pack features into ZIP + feature_name = "logmelspec80" + zip_path = out_root / f"{feature_name}.zip" + pitch_zip_path = out_root / "pitch.zip" + energy_zip_path = out_root / "energy.zip" + gcmvn_npz_path = out_root / "gcmvn_stats.npz" + if zip_path.exists() and gcmvn_npz_path.exists(): + print(f"{zip_path} and {gcmvn_npz_path} exist.") + else: + feature_root = out_root / feature_name + feature_root.mkdir(exist_ok=True) + pitch_root = out_root / "pitch" + energy_root = out_root / "energy" + if args.add_fastspeech_targets: + pitch_root.mkdir(exist_ok=True) + energy_root.mkdir(exist_ok=True) + print("Extracting Mel spectrogram features...") + for sample in tqdm(samples): + waveform, sample_rate = torchaudio.load(sample["audio"]) + waveform, sample_rate = convert_waveform( + waveform, sample_rate, normalize_volume=args.normalize_volume, + to_sample_rate=args.sample_rate + ) + sample_id = sample["id"] + target_length = None + if id_to_alignment is not None: + a = id_to_alignment[sample_id] + target_length = sum(a.frame_durations) + if a.start_sec is not None and a.end_sec is not None: + start_frame = int(a.start_sec * sample_rate) + end_frame = int(a.end_sec * sample_rate) + waveform = waveform[:, start_frame: end_frame] + extract_logmel_spectrogram( + waveform, sample_rate, feature_root / f"{sample_id}.npy", + win_length=args.win_length, hop_length=args.hop_length, + n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min, + f_max=args.f_max, target_length=target_length + ) + if args.add_fastspeech_targets: + assert id_to_alignment is not None + extract_pitch( + waveform, sample_rate, pitch_root / f"{sample_id}.npy", + hop_length=args.hop_length, log_scale=True, + phoneme_durations=id_to_alignment[sample_id].frame_durations + ) + extract_energy( + waveform, energy_root / f"{sample_id}.npy", + hop_length=args.hop_length, n_fft=args.n_fft, + log_scale=True, + phoneme_durations=id_to_alignment[sample_id].frame_durations + ) + print("ZIPing features...") + create_zip(feature_root, zip_path) + get_global_cmvn(feature_root, gcmvn_npz_path) + shutil.rmtree(feature_root) + if args.add_fastspeech_targets: + create_zip(pitch_root, pitch_zip_path) + shutil.rmtree(pitch_root) + create_zip(energy_root, energy_zip_path) + shutil.rmtree(energy_root) + + print("Fetching ZIP manifest...") + audio_paths, audio_lengths = get_zip_manifest(zip_path) + pitch_paths, pitch_lengths, energy_paths, energy_lengths = [None] * 4 + if args.add_fastspeech_targets: + pitch_paths, pitch_lengths = get_zip_manifest(pitch_zip_path) + energy_paths, energy_lengths = get_zip_manifest(energy_zip_path) + # Generate TSV manifest + print("Generating manifest...") + manifest_by_split = {split: defaultdict(list) for split in args.splits} + for sample in tqdm(samples): + sample_id, split = sample["id"], sample["split"] + normalized_utt = sample["tgt_text"] + if id_to_alignment is not None: + normalized_utt = " ".join(id_to_alignment[sample_id].tokens) + elif args.ipa_vocab: + normalized_utt = ipa_phonemize( + normalized_utt, lang=args.lang, use_g2p=args.use_g2p + ) + manifest_by_split[split]["id"].append(sample_id) + manifest_by_split[split]["audio"].append(audio_paths[sample_id]) + manifest_by_split[split]["n_frames"].append(audio_lengths[sample_id]) + manifest_by_split[split]["tgt_text"].append(normalized_utt) + manifest_by_split[split]["speaker"].append(sample["speaker"]) + manifest_by_split[split]["src_text"].append(sample["src_text"]) + if args.add_fastspeech_targets: + assert id_to_alignment is not None + duration = " ".join( + str(d) for d in id_to_alignment[sample_id].frame_durations + ) + manifest_by_split[split]["duration"].append(duration) + manifest_by_split[split]["pitch"].append(pitch_paths[sample_id]) + manifest_by_split[split]["energy"].append(energy_paths[sample_id]) + for split in args.splits: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + out_root / f"{split}.tsv" + ) + # Generate vocab + vocab_name, spm_filename = None, None + if id_to_alignment is not None or args.ipa_vocab: + vocab = Counter() + for t in manifest_by_split["train"]["tgt_text"]: + vocab.update(t.split(" ")) + vocab_name = "vocab.txt" + with open(out_root / vocab_name, "w") as f: + for s, c in vocab.most_common(): + f.write(f"{s} {c}\n") + else: + spm_filename_prefix = "spm_char" + spm_filename = f"{spm_filename_prefix}.model" + with NamedTemporaryFile(mode="w") as f: + for t in manifest_by_split["train"]["tgt_text"]: + f.write(t + "\n") + f.flush() # needed to ensure gen_vocab sees dumped text + gen_vocab(Path(f.name), out_root / spm_filename_prefix, "char") + # Generate speaker list + speakers = sorted({sample["speaker"] for sample in samples}) + speakers_path = out_root / "speakers.txt" + with open(speakers_path, "w") as f: + for speaker in speakers: + f.write(f"{speaker}\n") + # Generate config YAML + win_len_t = args.win_length / args.sample_rate + hop_len_t = args.hop_length / args.sample_rate + extra = { + "sample_rate": args.sample_rate, + "features": { + "type": "spectrogram+melscale+log", + "eps": 1e-2, "n_mels": args.n_mels, "n_fft": args.n_fft, + "window_fn": "hann", "win_length": args.win_length, + "hop_length": args.hop_length, "sample_rate": args.sample_rate, + "win_len_t": win_len_t, "hop_len_t": hop_len_t, + "f_min": args.f_min, "f_max": args.f_max, + "n_stft": args.n_fft // 2 + 1 + } + } + if len(speakers) > 1: + extra["speaker_set_filename"] = "speakers.txt" + gen_config_yaml( + out_root, spm_filename=spm_filename, vocab_name=vocab_name, + audio_root=out_root.as_posix(), input_channels=None, + input_feat_per_channel=None, specaugment_policy=None, + cmvn_type="global", gcmvn_path=gcmvn_npz_path, extra=extra + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--audio-manifest-root", "-m", required=True, type=str) + parser.add_argument("--output-root", "-o", required=True, type=str) + parser.add_argument("--splits", "-s", type=str, nargs="+", + default=["train", "dev", "test"]) + parser.add_argument("--ipa-vocab", action="store_true") + parser.add_argument("--use-g2p", action="store_true") + parser.add_argument("--lang", type=str, default="en-us") + parser.add_argument("--win-length", type=int, default=1024) + parser.add_argument("--hop-length", type=int, default=256) + parser.add_argument("--n-fft", type=int, default=1024) + parser.add_argument("--n-mels", type=int, default=80) + parser.add_argument("--f-min", type=int, default=20) + parser.add_argument("--f-max", type=int, default=8000) + parser.add_argument("--sample-rate", type=int, default=22050) + parser.add_argument("--normalize-volume", "-n", action="store_true") + parser.add_argument("--textgrid-zip", type=str, default=None) + parser.add_argument("--id-to-units-tsv", type=str, default=None) + parser.add_argument("--add-fastspeech-targets", action="store_true") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py b/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py new file mode 100644 index 0000000000..7ec1fb7521 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from pathlib import Path +from collections import defaultdict + +import pandas as pd +from torchaudio.datasets import LJSPEECH +from tqdm import tqdm + +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def process(args): + out_root = Path(args.output_data_root).absolute() + out_root.mkdir(parents=True, exist_ok=True) + + # Generate TSV manifest + print("Generating manifest...") + # following FastSpeech's splits + dataset = LJSPEECH(out_root.as_posix(), download=True) + id_to_split = {} + for x in dataset._flist: + id_ = x[0] + speaker = id_.split("-")[0] + id_to_split[id_] = { + "LJ001": "test", "LJ002": "test", "LJ003": "dev" + }.get(speaker, "train") + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + progress = tqdm(enumerate(dataset), total=len(dataset)) + for i, (waveform, _, utt, normalized_utt) in progress: + sample_id = dataset._flist[i][0] + split = id_to_split[sample_id] + manifest_by_split[split]["id"].append(sample_id) + audio_path = f"{dataset._path}/{sample_id}.wav" + manifest_by_split[split]["audio"].append(audio_path) + manifest_by_split[split]["n_frames"].append(len(waveform[0])) + manifest_by_split[split]["tgt_text"].append(normalized_utt) + manifest_by_split[split]["speaker"].append("ljspeech") + manifest_by_split[split]["src_text"].append(utt) + + manifest_root = Path(args.output_manifest_root).absolute() + manifest_root.mkdir(parents=True, exist_ok=True) + for split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[split]), + manifest_root / f"{split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/get_speaker_embedding.py b/examples/speech_synthesis/preprocessing/get_speaker_embedding.py new file mode 100644 index 0000000000..0e3e4c5cd7 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/get_speaker_embedding.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +from collections import defaultdict +from itertools import chain +from pathlib import Path + +import numpy as np +import torchaudio +import torchaudio.sox_effects as ta_sox +import yaml +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_tsv_to_dicts +from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder + + +def extract_embedding(audio_path, embedder): + wav, sr = torchaudio.load(audio_path) # 2D + if sr != embedder.RATE: + wav, sr = ta_sox.apply_effects_tensor( + wav, sr, [["rate", str(embedder.RATE)]] + ) + try: + emb = embedder([wav[0].cuda().float()]).cpu().numpy() + except RuntimeError: + emb = None + return emb + + +def process(args): + print("Fetching data...") + raw_manifest_root = Path(args.raw_manifest_root).absolute() + samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv")) + for s in args.splits] + samples = list(chain(*samples)) + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f: + speaker_to_id = {r.strip(): i for i, r in enumerate(f)} + + embedder = SpkrEmbedder(args.ckpt).cuda() + speaker_to_cnt = defaultdict(float) + speaker_to_emb = defaultdict(float) + for sample in tqdm(samples, desc="extract emb"): + emb = extract_embedding(sample["audio"], embedder) + if emb is not None: + speaker_to_cnt[sample["speaker"]] += 1 + speaker_to_emb[sample["speaker"]] += emb + if len(speaker_to_emb) != len(speaker_to_id): + missed = set(speaker_to_id) - set(speaker_to_emb.keys()) + print( + f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}" + ) + speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float) + for speaker in speaker_to_emb: + idx = speaker_to_id[speaker] + emb = speaker_to_emb[speaker] + cnt = speaker_to_cnt[speaker] + speaker_emb_mat[idx, :] = emb / cnt + speaker_emb_name = "speaker_emb.npy" + speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}" + np.save(speaker_emb_path, speaker_emb_mat) + config["speaker_emb_filename"] = speaker_emb_name + + with open(args.new_config, "w") as f: + yaml.dump(config, f) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--raw-manifest-root", "-m", required=True, type=str) + parser.add_argument("--splits", "-s", type=str, nargs="+", + default=["train"]) + parser.add_argument("--config", "-c", required=True, type=str) + parser.add_argument("--new-config", "-n", required=True, type=str) + parser.add_argument("--ckpt", required=True, type=str, + help="speaker embedder checkpoint") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py b/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py new file mode 100644 index 0000000000..7afa40fcd1 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import numpy as np +import re +from pathlib import Path +from collections import defaultdict + +import pandas as pd +from torchaudio.datasets import VCTK +from tqdm import tqdm + +from examples.speech_to_text.data_utils import save_df_to_tsv + + +log = logging.getLogger(__name__) + +SPLITS = ["train", "dev", "test"] + + +def normalize_text(text): + return re.sub(r"[^a-zA-Z.?!,'\- ]", '', text) + + +def process(args): + out_root = Path(args.output_data_root).absolute() + out_root.mkdir(parents=True, exist_ok=True) + + # Generate TSV manifest + print("Generating manifest...") + dataset = VCTK(out_root.as_posix(), download=False) + ids = list(dataset._walker) + np.random.seed(args.seed) + np.random.shuffle(ids) + n_train = len(ids) - args.n_dev - args.n_test + _split = ["train"] * n_train + ["dev"] * args.n_dev + ["test"] * args.n_test + id_to_split = dict(zip(ids, _split)) + manifest_by_split = {split: defaultdict(list) for split in SPLITS} + progress = tqdm(enumerate(dataset), total=len(dataset)) + for i, (waveform, _, text, speaker_id, _) in progress: + sample_id = dataset._walker[i] + _split = id_to_split[sample_id] + audio_dir = Path(dataset._path) / dataset._folder_audio / speaker_id + audio_path = audio_dir / f"{sample_id}.wav" + text = normalize_text(text) + manifest_by_split[_split]["id"].append(sample_id) + manifest_by_split[_split]["audio"].append(audio_path.as_posix()) + manifest_by_split[_split]["n_frames"].append(len(waveform[0])) + manifest_by_split[_split]["tgt_text"].append(text) + manifest_by_split[_split]["speaker"].append(speaker_id) + manifest_by_split[_split]["src_text"].append(text) + + manifest_root = Path(args.output_manifest_root).absolute() + manifest_root.mkdir(parents=True, exist_ok=True) + for _split in SPLITS: + save_df_to_tsv( + pd.DataFrame.from_dict(manifest_by_split[_split]), + manifest_root / f"{_split}.audio.tsv" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-data-root", "-d", required=True, type=str) + parser.add_argument("--output-manifest-root", "-m", required=True, type=str) + parser.add_argument("--n-dev", default=50, type=int) + parser.add_argument("--n-test", default=100, type=int) + parser.add_argument("--seed", "-s", default=1234, type=int) + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py b/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py new file mode 100644 index 0000000000..3b178676ba --- /dev/null +++ b/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import librosa +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import torchaudio + + +EMBEDDER_PARAMS = { + 'num_mels': 40, + 'n_fft': 512, + 'emb_dim': 256, + 'lstm_hidden': 768, + 'lstm_layers': 3, + 'window': 80, + 'stride': 40, +} + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary + computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +class LinearNorm(nn.Module): + def __init__(self, hp): + super(LinearNorm, self).__init__() + self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"]) + + def forward(self, x): + return self.linear_layer(x) + + +class SpeechEmbedder(nn.Module): + def __init__(self, hp): + super(SpeechEmbedder, self).__init__() + self.lstm = nn.LSTM(hp["num_mels"], + hp["lstm_hidden"], + num_layers=hp["lstm_layers"], + batch_first=True) + self.proj = LinearNorm(hp) + self.hp = hp + + def forward(self, mel): + # (num_mels, T) -> (num_mels, T', window) + mels = mel.unfold(1, self.hp["window"], self.hp["stride"]) + mels = mels.permute(1, 2, 0) # (T', window, num_mels) + x, _ = self.lstm(mels) # (T', window, lstm_hidden) + x = x[:, -1, :] # (T', lstm_hidden), use last frame only + x = self.proj(x) # (T', emb_dim) + x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim) + + x = x.mean(dim=0) + if x.norm(p=2) != 0: + x = x / x.norm(p=2) + return x + + +class SpkrEmbedder(nn.Module): + RATE = 16000 + + def __init__( + self, + embedder_path, + embedder_params=EMBEDDER_PARAMS, + rate=16000, + hop_length=160, + win_length=400, + pad=False, + ): + super(SpkrEmbedder, self).__init__() + embedder_pt = torch.load(embedder_path, map_location="cpu") + self.embedder = SpeechEmbedder(embedder_params) + self.embedder.load_state_dict(embedder_pt) + self.embedder.eval() + set_requires_grad(self.embedder, requires_grad=False) + self.embedder_params = embedder_params + + self.register_buffer('mel_basis', torch.from_numpy( + librosa.filters.mel( + sr=self.RATE, + n_fft=self.embedder_params["n_fft"], + n_mels=self.embedder_params["num_mels"]) + ) + ) + + self.resample = None + if rate != self.RATE: + self.resample = torchaudio.transforms.Resample(rate, self.RATE) + self.hop_length = hop_length + self.win_length = win_length + self.pad = pad + + def get_mel(self, y): + if self.pad and y.shape[-1] < 14000: + y = F.pad(y, (0, 14000 - y.shape[-1])) + + window = torch.hann_window(self.win_length).to(y) + y = torch.stft(y, n_fft=self.embedder_params["n_fft"], + hop_length=self.hop_length, + win_length=self.win_length, + window=window) + magnitudes = torch.norm(y, dim=-1, p=2) ** 2 + mel = torch.log10(self.mel_basis @ magnitudes + 1e-6) + return mel + + def forward(self, inputs): + dvecs = [] + for wav in inputs: + mel = self.get_mel(wav) + if mel.dim() == 3: + mel = mel.squeeze(0) + dvecs += [self.embedder(mel)] + dvecs = torch.stack(dvecs) + + dvec = torch.mean(dvecs, dim=0) + dvec = dvec / torch.norm(dvec) + + return dvec diff --git a/examples/speech_synthesis/preprocessing/vad/__init__.py b/examples/speech_synthesis/preprocessing/vad/__init__.py new file mode 100644 index 0000000000..9cf121081f --- /dev/null +++ b/examples/speech_synthesis/preprocessing/vad/__init__.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +import contextlib +import wave + +try: + import webrtcvad +except ImportError: + raise ImportError("Please install py-webrtcvad: pip install webrtcvad") +import argparse +import os +import logging +from tqdm import tqdm + +AUDIO_SUFFIX = '.wav' +FS_MS = 30 +SCALE = 6e-5 +THRESHOLD = 0.3 + + +def read_wave(path): + """Reads a .wav file. + Takes the path, and returns (PCM audio data, sample rate). + """ + with contextlib.closing(wave.open(path, 'rb')) as wf: + num_channels = wf.getnchannels() + assert num_channels == 1 + sample_width = wf.getsampwidth() + assert sample_width == 2 + sample_rate = wf.getframerate() + assert sample_rate in (8000, 16000, 32000, 48000) + pcm_data = wf.readframes(wf.getnframes()) + return pcm_data, sample_rate + + +def write_wave(path, audio, sample_rate): + """Writes a .wav file. + Takes path, PCM audio data, and sample rate. + """ + with contextlib.closing(wave.open(path, 'wb')) as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio) + + +class Frame(object): + """Represents a "frame" of audio data.""" + def __init__(self, bytes, timestamp, duration): + self.bytes = bytes + self.timestamp = timestamp + self.duration = duration + + +def frame_generator(frame_duration_ms, audio, sample_rate): + """Generates audio frames from PCM audio data. + Takes the desired frame duration in milliseconds, the PCM data, and + the sample rate. + Yields Frames of the requested duration. + """ + n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) + offset = 0 + timestamp = 0.0 + duration = (float(n) / sample_rate) / 2.0 + while offset + n < len(audio): + yield Frame(audio[offset:offset + n], timestamp, duration) + timestamp += duration + offset += n + + +def vad_collector(sample_rate, frame_duration_ms, + padding_duration_ms, vad, frames): + """Filters out non-voiced audio frames. + Given a webrtcvad.Vad and a source of audio frames, yields only + the voiced audio. + Uses a padded, sliding window algorithm over the audio frames. + When more than 90% of the frames in the window are voiced (as + reported by the VAD), the collector triggers and begins yielding + audio frames. Then the collector waits until 90% of the frames in + the window are unvoiced to detrigger. + The window is padded at the front and back to provide a small + amount of silence or the beginnings/endings of speech around the + voiced frames. + Arguments: + sample_rate - The audio sample rate, in Hz. + frame_duration_ms - The frame duration in milliseconds. + padding_duration_ms - The amount to pad the window, in milliseconds. + vad - An instance of webrtcvad.Vad. + frames - a source of audio frames (sequence or generator). + Returns: A generator that yields PCM audio data. + """ + num_padding_frames = int(padding_duration_ms / frame_duration_ms) + # We use a deque for our sliding window/ring buffer. + ring_buffer = collections.deque(maxlen=num_padding_frames) + # We have two states: TRIGGERED and NOTTRIGGERED. We start in the + # NOTTRIGGERED state. + triggered = False + + voiced_frames = [] + for frame in frames: + is_speech = vad.is_speech(frame.bytes, sample_rate) + + # sys.stdout.write('1' if is_speech else '0') + if not triggered: + ring_buffer.append((frame, is_speech)) + num_voiced = len([f for f, speech in ring_buffer if speech]) + # If we're NOTTRIGGERED and more than 90% of the frames in + # the ring buffer are voiced frames, then enter the + # TRIGGERED state. + if num_voiced > 0.9 * ring_buffer.maxlen: + triggered = True + # We want to yield all the audio we see from now until + # we are NOTTRIGGERED, but we have to start with the + # audio that's already in the ring buffer. + for f, _ in ring_buffer: + voiced_frames.append(f) + ring_buffer.clear() + else: + # We're in the TRIGGERED state, so collect the audio data + # and add it to the ring buffer. + voiced_frames.append(frame) + ring_buffer.append((frame, is_speech)) + num_unvoiced = len([f for f, speech in ring_buffer if not speech]) + # If more than 90% of the frames in the ring buffer are + # unvoiced, then enter NOTTRIGGERED and yield whatever + # audio we've collected. + if num_unvoiced > 0.9 * ring_buffer.maxlen: + triggered = False + yield [b''.join([f.bytes for f in voiced_frames]), + voiced_frames[0].timestamp, voiced_frames[-1].timestamp] + ring_buffer.clear() + voiced_frames = [] + # If we have any leftover voiced audio when we run out of input, + # yield it. + if voiced_frames: + yield [b''.join([f.bytes for f in voiced_frames]), + voiced_frames[0].timestamp, voiced_frames[-1].timestamp] + + +def main(args): + # create output folder + try: + cmd = f"mkdir -p {args.out_path}" + os.system(cmd) + except Exception: + logging.error("Can not create output folder") + exit(-1) + + # build vad object + vad = webrtcvad.Vad(int(args.agg)) + # iterating over wavs in dir + for file in tqdm(os.listdir(args.in_path)): + if file.endswith(AUDIO_SUFFIX): + audio_inpath = os.path.join(args.in_path, file) + audio_outpath = os.path.join(args.out_path, file) + audio, sample_rate = read_wave(audio_inpath) + frames = frame_generator(FS_MS, audio, sample_rate) + frames = list(frames) + segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) + merge_segments = list() + timestamp_start = 0.0 + timestamp_end = 0.0 + # removing start, end, and long sequences of sils + for i, segment in enumerate(segments): + merge_segments.append(segment[0]) + if i and timestamp_start: + sil_duration = segment[1] - timestamp_end + if sil_duration > THRESHOLD: + merge_segments.append(int(THRESHOLD / SCALE)*(b'\x00')) + else: + merge_segments.append(int((sil_duration / SCALE))*(b'\x00')) + timestamp_start = segment[1] + timestamp_end = segment[2] + segment = b''.join(merge_segments) + write_wave(audio_outpath, segment, sample_rate) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Apply vad to a file of fils.') + parser.add_argument('in_path', type=str, help='Path to the input files') + parser.add_argument('out_path', type=str, + help='Path to save the processed files') + parser.add_argument('--agg', type=int, default=3, + help='The level of aggressiveness of the VAD: [0-3]') + args = parser.parse_args() + + main(args) From d974c709bf57cf494738a824a1597e1886bebb7a Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Sun, 12 Sep 2021 22:21:09 -0700 Subject: [PATCH 698/707] update S2T Summary: [fairseq-py] update S2T Reviewed By: wnhsu Differential Revision: D30720434 fbshipit-source-id: dc4e46b0cc3dec24943baeabe59424dabd5be38f --- examples/speech_to_text/data_utils.py | 115 +++++++---- examples/speech_to_text/prep_covost_data.py | 11 +- .../speech_to_text/prep_librispeech_data.py | 13 +- examples/speech_to_text/prep_mtedx_data.py | 103 ++++++---- examples/speech_to_text/prep_mustc_data.py | 137 ++++++++----- fairseq/data/audio/audio_utils.py | 172 ++++++++++++---- fairseq/data/audio/data_cfg.py | 139 +++++++++++++ fairseq/data/audio/speech_to_text_dataset.py | 188 ++++++------------ fairseq_cli/generate.py | 2 +- 9 files changed, 571 insertions(+), 309 deletions(-) create mode 100644 fairseq/data/audio/data_cfg.py diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 2bcff046f7..41afac0bf8 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -10,14 +9,17 @@ from functools import reduce from multiprocessing import cpu_count from typing import Any, Dict, List, Optional, Union +import io import numpy as np import pandas as pd import sentencepiece as sp from fairseq.data.audio.audio_utils import ( - _convert_to_mono, _get_kaldi_fbank, _get_torchaudio_fbank + convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data, + is_sf_audio_data ) import torch +import soundfile as sf from tqdm import tqdm @@ -78,8 +80,9 @@ def extract_fbank_features( if output_path is not None and output_path.is_file() and not overwrite: return - _waveform = _convert_to_mono(waveform, sample_rate) - _waveform = _waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + _waveform = convert_waveform(waveform, sample_rate, to_mono=True) + # Kaldi compliance: 16-bit signed integers + _waveform = _waveform * (2 ** 15) _waveform = _waveform.numpy() features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) @@ -92,8 +95,7 @@ def extract_fbank_features( if output_path is not None: np.save(output_path.as_posix(), features) - else: - return features + return features def create_zip(data_root: Path, zip_path: Path): @@ -103,42 +105,58 @@ def create_zip(data_root: Path, zip_path: Path): f.write(path, arcname=path.name) -def is_npy_data(data: bytes) -> bool: - return data[0] == 147 and data[1] == 78 - - -def get_zip_manifest(zip_path: Path, zip_root: Optional[Path] = None): - _zip_path = zip_path if zip_root is None else Path.joinpath(zip_root, zip_path) +def get_zip_manifest( + zip_path: Path, zip_root: Optional[Path] = None, is_audio=False +): + _zip_path = Path.joinpath(zip_root or Path(""), zip_path) with zipfile.ZipFile(_zip_path, mode="r") as f: info = f.infolist() - manifest = {} + paths, lengths = {}, {} for i in tqdm(info): utt_id = Path(i.filename).stem offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size - manifest[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}" + paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}" with open(_zip_path, "rb") as f: f.seek(offset) - data = f.read(file_size) - assert len(data) > 1 and is_npy_data(data) - return manifest + byte_data = f.read(file_size) + assert len(byte_data) > 1 + if is_audio: + assert is_sf_audio_data(byte_data), i + else: + assert is_npy_data(byte_data), i + byte_data_fp = io.BytesIO(byte_data) + if is_audio: + lengths[utt_id] = sf.info(byte_data_fp).frames + else: + lengths[utt_id] = np.load(byte_data_fp).shape[0] + return paths, lengths def gen_config_yaml( manifest_root: Path, - spm_filename: str, + spm_filename: Optional[str] = None, + vocab_name: Optional[str] = None, yaml_filename: str = "config.yaml", - specaugment_policy: str = "lb", + specaugment_policy: Optional[str] = "lb", prepend_tgt_lang_tag: bool = False, - sampling_alpha: float = 1.0, + sampling_alpha: Optional[float] = None, + input_channels: Optional[int] = 1, + input_feat_per_channel: Optional[int] = 80, audio_root: str = "", cmvn_type: str = "utterance", gcmvn_path: Optional[Path] = None, + extra=None ): manifest_root = manifest_root.absolute() writer = S2TDataConfigWriter(manifest_root / yaml_filename) - writer.set_vocab_filename(spm_filename.replace(".model", ".txt")) - writer.set_input_channels(1) - writer.set_input_feat_per_channel(80) + assert spm_filename is not None or vocab_name is not None + vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \ + else vocab_name + writer.set_vocab_filename(vocab_name) + if input_channels is not None: + writer.set_input_channels(input_channels) + if input_feat_per_channel is not None: + writer.set_input_feat_per_channel(input_feat_per_channel) specaugment_setters = { "lb": writer.set_specaugment_lb_policy, "ld": writer.set_specaugment_ld_policy, @@ -148,34 +166,42 @@ def gen_config_yaml( specaugment_setter = specaugment_setters.get(specaugment_policy, None) if specaugment_setter is not None: specaugment_setter() - writer.set_bpe_tokenizer( - { - "bpe": "sentencepiece", - "sentencepiece_model": (manifest_root / spm_filename).as_posix(), - } - ) + if spm_filename is not None: + writer.set_bpe_tokenizer( + { + "bpe": "sentencepiece", + "sentencepiece_model": (manifest_root / spm_filename).as_posix(), + } + ) if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) - writer.set_sampling_alpha(sampling_alpha) + if sampling_alpha is not None: + writer.set_sampling_alpha(sampling_alpha) if cmvn_type not in ["global", "utterance"]: raise NotImplementedError - writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"]) + if specaugment_policy is not None: + writer.set_feature_transforms( + "_train", [f"{cmvn_type}_cmvn", "specaugment"] + ) writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) if cmvn_type == "global": - assert gcmvn_path is not None, ( - 'Please provide path of global cmvn file.' - ) - writer.set_global_cmvn(str(gcmvn_path)) + if gcmvn_path is None: + raise ValueError("Please provide path of global cmvn file.") + else: + writer.set_global_cmvn(gcmvn_path.as_posix()) if len(audio_root) > 0: writer.set_audio_root(audio_root) + + if extra is not None: + writer.set_extra(extra) writer.flush() -def load_df_from_tsv(path: Union[str, Path]): +def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame: _path = path if isinstance(path, str) else path.as_posix() return pd.read_csv( _path, @@ -201,6 +227,20 @@ def save_df_to_tsv(dataframe, path: Union[str, Path]): ) +def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]: + with open(path, "r") as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + rows = [dict(e) for e in reader] + return rows + + def filter_manifest_df( df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000 ): @@ -337,3 +377,6 @@ def set_prepend_tgt_lang_tag(self, flag: bool = True): def set_sampling_alpha(self, sampling_alpha: float = 1.0): self.config["sampling_alpha"] = sampling_alpha + + def set_extra(self, data): + self.config.update(data) diff --git a/examples/speech_to_text/prep_covost_data.py b/examples/speech_to_text/prep_covost_data.py index af1d3fc6b8..411e9b5515 100644 --- a/examples/speech_to_text/prep_covost_data.py +++ b/examples/speech_to_text/prep_covost_data.py @@ -209,7 +209,7 @@ def process(args): print("ZIPing features...") create_zip(feature_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(zip_path) + audio_paths, audio_lengths = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] @@ -219,11 +219,10 @@ def process(args): for split in CoVoST.SPLITS: manifest = {c: [] for c in MANIFEST_COLUMNS} dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) - for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): + for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): manifest["id"].append(utt_id) - manifest["audio"].append(zip_manifest[utt_id]) - duration_ms = int(wav.size(1) / sr * 1000) - manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["audio"].append(audio_paths[utt_id]) + manifest["n_frames"].append(audio_lengths[utt_id]) manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt) manifest["speaker"].append(speaker_id) is_train_split = split.startswith("train") @@ -247,7 +246,7 @@ def process(args): # Generate config YAML gen_config_yaml( root, - spm_filename_prefix + ".model", + spm_filename=spm_filename_prefix + ".model", yaml_filename=f"config_{task}.yaml", specaugment_policy="lb", ) diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 7b08447190..f379fa7bf1 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -58,19 +58,18 @@ def process(args): print("ZIPing features...") create_zip(feature_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(zip_path) + audio_paths, audio_lengths = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] for split in SPLITS: manifest = {c: [] for c in MANIFEST_COLUMNS} dataset = LIBRISPEECH(out_root.as_posix(), url=split) - for wav, sample_rate, utt, spk_id, chapter_no, utt_no in tqdm(dataset): + for _, _, utt, spk_id, chapter_no, utt_no in tqdm(dataset): sample_id = f"{spk_id}-{chapter_no}-{utt_no}" manifest["id"].append(sample_id) - manifest["audio"].append(zip_manifest[sample_id]) - duration_ms = int(wav.size(1) / sample_rate * 1000) - manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["audio"].append(audio_paths[sample_id]) + manifest["n_frames"].append(audio_lengths[sample_id]) manifest["tgt_text"].append(utt.lower()) manifest["speaker"].append(spk_id) save_df_to_tsv( @@ -92,7 +91,9 @@ def process(args): ) # Generate config YAML gen_config_yaml( - out_root, spm_filename_prefix + ".model", specaugment_policy="ld" + out_root, + spm_filename=spm_filename_prefix + ".model", + specaugment_policy="ld" ) # Clean up shutil.rmtree(feature_root) diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py index 34b1c398c8..2dfd631763 100644 --- a/examples/speech_to_text/prep_mtedx_data.py +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -29,13 +29,15 @@ from torch.utils.data import Dataset from tqdm import tqdm -from fairseq.data.audio.audio_utils import get_waveform +from fairseq.data.audio.audio_utils import get_waveform, convert_waveform log = logging.getLogger(__name__) -MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"] +MANIFEST_COLUMNS = [ + "id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang" +] class mTEDx(Dataset): @@ -46,9 +48,9 @@ class mTEDx(Dataset): """ SPLITS = ["train", "valid", "test"] - LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", "de-de", - "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", "fr-pt", - "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] + LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", + "de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", + "fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] def __init__(self, root: str, lang: str, split: str) -> None: assert split in self.SPLITS and lang in self.LANGPAIRS @@ -59,7 +61,9 @@ def __init__(self, root: str, lang: str, split: str) -> None: try: import yaml except ImportError: - print("Please install PyYAML to load the Multilingual TEDx YAML files") + print( + "Please install PyYAML to load the Multilingual TEDx YAML files" + ) with open(txt_root / f"{split}.yaml") as f: segments = yaml.load(f, Loader=yaml.BaseLoader) # Load source and target utterances @@ -95,8 +99,11 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str, str]: - wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] + def __getitem__( + self, n: int + ) -> Tuple[torch.Tensor, int, str, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \ + utt_id = self.data[n] waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id @@ -113,36 +120,50 @@ def process(args): print(f"{cur_root.as_posix()} does not exist. Skipped.") continue # Extract features - feature_root = cur_root / "fbank80" - feature_root.mkdir(exist_ok=True) + audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80") + audio_root.mkdir(exist_ok=True) for split in mTEDx.SPLITS: print(f"Fetching split {split}...") dataset = mTEDx(root.as_posix(), lang, split) - print("Extracting log mel filter bank features...") - for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features( - waveform, sample_rate, feature_root / f"{utt_id}.npy" - ) + if args.use_audio_input: + print("Converting audios...") + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): + tgt_sample_rate = 16_000 + _wavform, _ = convert_waveform( + waveform, sample_rate, to_mono=True, + to_sample_rate=tgt_sample_rate + ) + sf.write( + (audio_root / f"{utt_id}.flac").as_posix(), + _wavform.numpy(), tgt_sample_rate + ) + else: + print("Extracting log mel filter bank features...") + for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features( + waveform, sample_rate, audio_root / f"{utt_id}.npy" + ) # Pack features into ZIP - zip_path = cur_root / "fbank80.zip" - print("ZIPing features...") - create_zip(feature_root, zip_path) + zip_path = cur_root / f"{audio_root.name}.zip" + print("ZIPing audios/features...") + create_zip(audio_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(zip_path) + audio_paths, audio_lengths = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] for split in mTEDx.SPLITS: is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} - dataset = mTEDx(args.data_root, lang, split) - for wav, sr, src_utt, tgt_utt, speaker_id, tgt_lang, utt_id in tqdm(dataset): + ds = mTEDx(args.data_root, lang, split) + for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds): manifest["id"].append(utt_id) - manifest["audio"].append(zip_manifest[utt_id]) - duration_ms = int(wav.size(1) / sr * 1000) - manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) - manifest["speaker"].append(speaker_id) + manifest["audio"].append(audio_paths[utt_id]) + manifest["n_frames"].append(audio_lengths[utt_id]) + manifest["tgt_text"].append( + src_utt if args.task == "asr" else tgt_utt + ) + manifest["speaker"].append(spk_id) manifest["tgt_lang"].append(tgt_lang) if is_train_split: train_text.extend(manifest["tgt_text"]) @@ -162,14 +183,23 @@ def process(args): args.vocab_size, ) # Generate config YAML - gen_config_yaml( - cur_root, - spm_filename_prefix + ".model", - yaml_filename=f"config_{args.task}.yaml", - specaugment_policy="lb", - ) + if args.use_audio_input: + gen_config_yaml( + cur_root, + spm_filename=spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy=None, + extra={"use_audio_input": True} + ) + else: + gen_config_yaml( + cur_root, + spm_filename=spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + ) # Clean up - shutil.rmtree(feature_root) + shutil.rmtree(audio_root) def process_joint(args): @@ -188,7 +218,9 @@ def process_joint(args): special_symbols = None if args.joint: # Add tgt_lang tags to dict - special_symbols = list({f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS}) + special_symbols = list( + {f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS} + ) gen_vocab( Path(f.name), cur_root / spm_filename_prefix, @@ -199,7 +231,7 @@ def process_joint(args): # Generate config YAML gen_config_yaml( cur_root, - spm_filename_prefix + ".model", + spm_filename=spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="ld", prepend_tgt_lang_tag=(args.joint), @@ -226,6 +258,7 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") + parser.add_argument("--use-audio-input", action="store_true") args = parser.parse_args() if args.joint: diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 0ee204e651..3f0d3fcbd9 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -31,7 +31,7 @@ from torch.utils.data import Dataset from tqdm import tqdm -from fairseq.data.audio.audio_utils import get_waveform +from fairseq.data.audio.audio_utils import get_waveform, convert_waveform log = logging.getLogger(__name__) @@ -92,8 +92,11 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str]: - wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] + def __getitem__( + self, n: int + ) -> Tuple[torch.Tensor, int, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, \ + utt_id = self.data[n] waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, utt_id @@ -110,40 +113,50 @@ def process(args): print(f"{cur_root.as_posix()} does not exist. Skipped.") continue # Extract features - feature_root = cur_root / "fbank80" - feature_root.mkdir(exist_ok=True) + audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80") + audio_root.mkdir(exist_ok=True) + for split in MUSTC.SPLITS: print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split) - print("Extracting log mel filter bank features...") - if split == 'train' and args.cmvn_type == "global": - print("And estimating cepstral mean and variance stats...") + if args.use_audio_input: + print("Converting audios...") + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): + tgt_sample_rate = 16_000 + _wavform, _ = convert_waveform( + waveform, sample_rate, to_mono=True, + to_sample_rate=tgt_sample_rate + ) + sf.write( + (audio_root / f"{utt_id}.flac").as_posix(), + _wavform.numpy(), tgt_sample_rate + ) + else: + print("Extracting log mel filter bank features...") gcmvn_feature_list = [] + if split == 'train' and args.cmvn_type == "global": + print("And estimating cepstral mean and variance stats...") - for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - features = extract_fbank_features(waveform, sample_rate) - - np.save( - (feature_root / f"{utt_id}.npy").as_posix(), - features - ) + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): + features = extract_fbank_features( + waveform, sample_rate, audio_root / f"{utt_id}.npy" + ) + if split == 'train' and args.cmvn_type == "global": + if len(gcmvn_feature_list) < args.gcmvn_max_num: + gcmvn_feature_list.append(features) if split == 'train' and args.cmvn_type == "global": - if len(gcmvn_feature_list) < args.gcmvn_max_num: - gcmvn_feature_list.append(features) - - if split == 'train' and args.cmvn_type == "global": - # Estimate and save cmv - stats = cal_gcmvn_stats(gcmvn_feature_list) - with open(cur_root / "gcmvn.npz", "wb") as f: - np.savez(f, mean=stats["mean"], std=stats["std"]) + # Estimate and save cmv + stats = cal_gcmvn_stats(gcmvn_feature_list) + with open(cur_root / "gcmvn.npz", "wb") as f: + np.savez(f, mean=stats["mean"], std=stats["std"]) # Pack features into ZIP - zip_path = cur_root / "fbank80.zip" - print("ZIPing features...") - create_zip(feature_root, zip_path) + zip_path = cur_root / f"{audio_root.name}.zip" + print("ZIPing audios/features...") + create_zip(audio_root, zip_path) print("Fetching ZIP manifest...") - zip_manifest = get_zip_manifest(zip_path) + audio_paths, audio_lengths = get_zip_manifest(zip_path) # Generate TSV manifest print("Generating manifest...") train_text = [] @@ -151,12 +164,13 @@ def process(args): is_train_split = split.startswith("train") manifest = {c: [] for c in MANIFEST_COLUMNS} dataset = MUSTC(args.data_root, lang, split) - for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): + for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): manifest["id"].append(utt_id) - manifest["audio"].append(zip_manifest[utt_id]) - duration_ms = int(wav.size(1) / sr * 1000) - manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) + manifest["audio"].append(audio_paths[utt_id]) + manifest["n_frames"].append(audio_lengths[utt_id]) + manifest["tgt_text"].append( + src_utt if args.task == "asr" else tgt_utt + ) manifest["speaker"].append(speaker_id) if is_train_split: train_text.extend(manifest["tgt_text"]) @@ -176,25 +190,35 @@ def process(args): args.vocab_size, ) # Generate config YAML - gen_config_yaml( - cur_root, - spm_filename_prefix + ".model", - yaml_filename=f"config_{args.task}.yaml", - specaugment_policy="lb", - cmvn_type=args.cmvn_type, - gcmvn_path=( - cur_root / "gcmvn.npz" if args.cmvn_type == "global" - else None - ), - ) + if args.use_audio_input: + gen_config_yaml( + cur_root, + spm_filename=spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy=None, + extra={"use_audio_input": True} + ) + else: + gen_config_yaml( + cur_root, + spm_filename=spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + cmvn_type=args.cmvn_type, + gcmvn_path=( + cur_root / "gcmvn.npz" if args.cmvn_type == "global" + else None + ), + ) # Clean up - shutil.rmtree(feature_root) + shutil.rmtree(audio_root) def process_joint(args): cur_root = Path(args.data_root) - assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \ - "do not have downloaded data available for all 8 languages" + assert all( + (cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES + ), "do not have downloaded data available for all 8 languages" # Generate vocab vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" @@ -217,7 +241,7 @@ def process_joint(args): # Generate config YAML gen_config_yaml( cur_root, - spm_filename_prefix + ".model", + spm_filename=spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="ld", prepend_tgt_lang_tag=(args.task == "st"), @@ -244,14 +268,17 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") - parser.add_argument("--cmvn-type", default="utterance", - choices=["global", "utterance"], - help="The type of cepstral mean and variance normalization") - parser.add_argument("--gcmvn-max-num", default=150000, type=int, - help=( - "Maximum number of sentences to use to estimate" - "global mean and variance" - )) + parser.add_argument( + "--cmvn-type", default="utterance", + choices=["global", "utterance"], + help="The type of cepstral mean and variance normalization" + ) + parser.add_argument( + "--gcmvn-max-num", default=150000, type=int, + help="Maximum number of sentences to use to estimate global mean and " + "variance" + ) + parser.add_argument("--use-audio-input", action="store_true") args = parser.parse_args() if args.joint: diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index 7c2638dc0c..b9444cb8d0 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,65 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + from pathlib import Path from typing import BinaryIO, Optional, Tuple, Union, List import numpy as np import torch +import torch.nn.functional as F SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} -def update_sample_rate( - waveform: np.ndarray, - sample_rate: int, - tgt_sample_rate: int, -) -> np.ndarray: - if tgt_sample_rate > 0 and tgt_sample_rate != sample_rate: - _waveform = torch.from_numpy(waveform) - effects = [["rate", f"{tgt_sample_rate}"]] - return _sox_convert(_waveform, sample_rate, effects).numpy() - return waveform - - -def _sox_convert( - waveform: torch.FloatTensor, - sample_rate: int, - effects: List[List[str]], -) -> torch.FloatTensor: +def convert_waveform( + waveform: Union[np.ndarray, torch.Tensor], sample_rate: int, + normalize_volume: bool = False, to_mono: bool = False, + to_sample_rate: Optional[int] = None +) -> Tuple[Union[np.ndarray, torch.Tensor], int]: + """convert a waveform: + - to a target sample rate + - from multi-channel to mono channel + - volume normalization + + Args: + waveform (numpy.ndarray or torch.Tensor): 2D original waveform + (channels x length) + sample_rate (int): original sample rate + normalize_volume (bool): perform volume normalization + to_mono (bool): convert to mono channel if having multiple channels + to_sample_rate (Optional[int]): target sample rate + Returns: + waveform (numpy.ndarray): converted 2D waveform (channels x length) + sample_rate (float): target sample rate + """ try: import torchaudio.sox_effects as ta_sox except ImportError: - raise ImportError("Please install torchaudio to convert audios") - return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] - - -def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray: - if waveform.shape[0] > 1: - _waveform = torch.from_numpy(waveform) - effects = [["channels", "1"]] - return _sox_convert(_waveform, sample_rate, effects).numpy() - return waveform + raise ImportError("Please install torchaudio: pip install torchaudio") + + effects = [] + if normalize_volume: + effects.append(["gain", "-n"]) + if to_sample_rate is not None and to_sample_rate != sample_rate: + effects.append(["rate", f"{to_sample_rate}"]) + if to_mono and waveform.shape[0] > 1: + effects.append(["channels", "1"]) + if len(effects) > 0: + is_np_input = isinstance(waveform, np.ndarray) + _waveform = torch.from_numpy(waveform) if is_np_input else waveform + converted, converted_sample_rate = ta_sox.apply_effects_tensor( + _waveform, sample_rate, effects + ) + if is_np_input: + converted = converted.numpy() + return converted, converted_sample_rate + return waveform, sample_rate def get_waveform( - path_or_fp: Union[str, BinaryIO], - normalization=True, - mono=True, - frames=-1, - start=0, - always_2d=True, - output_sample_rate=-1, + path_or_fp: Union[str, BinaryIO], normalization: bool = True, + mono: bool = True, frames: int = -1, start: int = 0, + always_2d: bool = True, output_sample_rate: Optional[int] = None, + normalize_volume: bool = False ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. Args: path_or_fp (str or BinaryIO): the path or file-like object - normalization (bool): Normalize values to [-1, 1] (Default: True) + normalization (bool): normalize values to [-1, 1] (Default: True) mono (bool): convert multi-channel audio to mono-channel one frames (int): the number of frames to read. (-1 for reading all) start (int): Where to start reading. A negative value counts from the end. always_2d (bool): always return 2D array even for mono-channel audios - output_sample_rate (int): output sample rate, -1 using default + output_sample_rate (Optional[int]): output sample rate + normalize_volume (bool): normalize volume Returns: waveform (numpy.ndarray): 1D or 2D waveform (channels x length) sample_rate (float): sample rate @@ -72,17 +90,17 @@ def get_waveform( try: import soundfile as sf except ImportError: - raise ImportError("Please install soundfile to load WAV/FLAC/OGG Vorbis audios") + raise ImportError("Please install soundfile: pip install soundfile") waveform, sample_rate = sf.read( path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start ) waveform = waveform.T # T x C -> C x T - if mono and waveform.shape[0] > 1: - waveform = convert_to_mono(waveform, sample_rate) - if output_sample_rate > 0: - waveform = update_sample_rate(waveform, sample_rate, output_sample_rate) - sample_rate = output_sample_rate + waveform, sample_rate = convert_waveform( + waveform, sample_rate, normalize_volume=normalize_volume, to_mono=mono, + to_sample_rate=output_sample_rate + ) + if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers if not always_2d: @@ -190,3 +208,73 @@ def parse_path(path: str) -> Tuple[str, List[int]]: assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}" slice_ptr = [int(i) for i in slice_ptr] return _path, slice_ptr + + +def get_window( + window_fn: callable, n_fft: int, win_length: int +) -> torch.Tensor: + padding = n_fft - win_length + assert padding >= 0 + return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2)) + + +def get_fourier_basis(n_fft: int) -> torch.Tensor: + basis = np.fft.fft(np.eye(n_fft)) + basis = np.vstack( + [np.real(basis[:n_fft // 2 + 1, :]), np.imag(basis[:n_fft // 2 + 1, :])] + ) + return torch.from_numpy(basis).float() + + +def get_mel_filters( + sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float +) -> torch.Tensor: + try: + import librosa + except ImportError: + raise ImportError("Please install librosa: pip install librosa") + basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max) + return torch.from_numpy(basis).float() + + +class TTSSpectrogram(torch.nn.Module): + def __init__( + self, n_fft: int, win_length: int, hop_length: int, + window_fn: callable = torch.hann_window, return_phase: bool = False + ) -> None: + super(TTSSpectrogram, self).__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.return_phase = return_phase + + basis = get_fourier_basis(n_fft).unsqueeze(1) + basis *= get_window(window_fn, n_fft, win_length) + self.register_buffer('basis', basis) + + def forward( + self, waveform: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + padding = (self.n_fft // 2, self.n_fft // 2) + x = F.pad(waveform.unsqueeze(1), padding, mode='reflect') + x = F.conv1d(x, self.basis, stride=self.hop_length) + real_part = x[:, :self.n_fft // 2 + 1, :] + imag_part = x[:, self.n_fft // 2 + 1:, :] + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + if self.return_phase: + phase = torch.atan2(imag_part, real_part) + return magnitude, phase + return magnitude + + +class TTSMelScale(torch.nn.Module): + def __init__( + self, n_mels: int, sample_rate: int, f_min: float, f_max: float, + n_stft: int + ) -> None: + super(TTSMelScale, self).__init__() + basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, + f_max) + self.register_buffer('basis', basis) + + def forward(self, specgram: torch.Tensor) -> torch.Tensor: + return torch.matmul(self.basis, specgram) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py new file mode 100644 index 0000000000..95b403ad9c --- /dev/null +++ b/fairseq/data/audio/data_cfg.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Dict, Optional + + +class S2TDataConfig(object): + """Wrapper class for data config YAML""" + + def __init__(self, yaml_path: Path): + try: + import yaml + except ImportError: + print("Please install PyYAML: pip install PyYAML") + self.config = {} + if yaml_path.is_file(): + try: + with open(yaml_path) as f: + self.config = yaml.load(f, Loader=yaml.FullLoader) + except Exception as e: + raise Exception( + f"Failed to load config from {yaml_path.as_posix()}: {e}" + ) + else: + raise FileNotFoundError(f"{yaml_path.as_posix()} not found") + self.root = yaml_path.parent + + def _auto_convert_to_abs_path(self, x): + if isinstance(x, str): + if not Path(x).exists() and (self.root / x).exists(): + return (self.root / x).as_posix() + elif isinstance(x, dict): + return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()} + return x + + @property + def vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("vocab_filename", "dict.txt") + + @property + def speaker_set_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("speaker_set_filename", None) + + @property + def shuffle(self) -> bool: + """Shuffle dataset samples before batching""" + return self.config.get("shuffle", False) + + @property + def pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None}) + return self._auto_convert_to_abs_path(tokenizer) + + @property + def bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply after pre-tokenization. Returning + a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + tokenizer = self.config.get("bpe_tokenizer", {"bpe": None}) + return self._auto_convert_to_abs_path(tokenizer) + + @property + def prepend_tgt_lang_tag(self) -> bool: + """Prepend target lang ID token as the target BOS (e.g. for to-many + multilingual setting). During inference, this requires `--prefix-size 1` + to force BOS to be lang ID token.""" + return self.config.get("prepend_tgt_lang_tag", False) + + @property + def input_feat_per_channel(self): + """The dimension of input features (per audio channel)""" + return self.config.get("input_feat_per_channel", 80) + + @property + def input_channels(self): + """The number of channels in the input audio""" + return self.config.get("input_channels", 1) + + @property + def sample_rate(self): + return self.config.get("sample_rate", 16_000) + + @property + def sampling_alpha(self): + """Hyper-parameter alpha = 1/T for temperature-based resampling. + (alpha = 1 for no resampling)""" + return self.config.get("sampling_alpha", 1.0) + + @property + def use_audio_input(self): + """Needed by the dataset loader to see if the model requires + raw audio as inputs.""" + return self.config.get("use_audio_input", False) + + @property + def use_sample_rate(self): + """Needed by the dataset loader to see if the model requires + raw audio with specific sample rate as inputs.""" + return self.config.get("use_sample_rate", 16000) + + @property + def audio_root(self): + """Audio paths in the manifest TSV can be relative and this provides + the root path. Set this to empty string when using absolute paths.""" + return self.config.get("audio_root", "") + + def get_feature_transforms(self, split, is_train): + """Split-specific feature transforms. Allowing train set + wildcard `_train`, evaluation set wildcard `_eval` and general + wildcard `*` for matching.""" + from copy import deepcopy + + cfg = deepcopy(self.config) + _cur = cfg.get("transforms", {}) + cur = _cur.get(split) + cur = _cur.get("_train") if cur is None and is_train else cur + cur = _cur.get("_eval") if cur is None and not is_train else cur + cur = _cur.get("*") if cur is None else cur + cfg["transforms"] = cur + return cfg + + @property + def global_cmvn_stats_npz(self) -> Optional[str]: + path = self.config.get("global_cmvn", {}).get("stats_npz_path", None) + return self._auto_convert_to_abs_path(path) + + @property + def vocoder(self) -> Optional[Dict[str, str]]: + return self.config.get("vocoder", None) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index ba6c28632e..164bf413e4 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -9,7 +9,8 @@ import re from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, NamedTuple +from typing import Dict, List, Optional +from dataclasses import dataclass import numpy as np import torch @@ -30,113 +31,12 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS, ) from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform +from fairseq.data.audio.data_cfg import S2TDataConfig logger = logging.getLogger(__name__) -class S2TDataConfig(object): - """Wrapper class for data config YAML""" - - def __init__(self, yaml_path: Path): - try: - import yaml - except ImportError: - print("Please install PyYAML to load YAML files for S2T data config") - self.config = {} - if yaml_path.is_file(): - try: - with open(yaml_path) as f: - self.config = yaml.load(f, Loader=yaml.FullLoader) - except Exception as e: - raise Exception( - f"Failed to load config from {yaml_path.as_posix()}: {e}" - ) - else: - raise FileNotFoundError(f"{yaml_path.as_posix()} not found") - - @property - def vocab_filename(self): - """fairseq vocabulary file under data root""" - return self.config.get("vocab_filename", "dict.txt") - - @property - def shuffle(self) -> bool: - """Shuffle dataset samples before batching""" - return self.config.get("shuffle", False) - - @property - def pre_tokenizer(self) -> Dict: - """Pre-tokenizer to apply before subword tokenization. Returning - a dictionary with `tokenizer` providing the tokenizer name and - the other items providing the tokenizer-specific arguments. - Tokenizers are defined in `fairseq.data.encoders.*`""" - return self.config.get("pre_tokenizer", {"tokenizer": None}) - - @property - def bpe_tokenizer(self) -> Dict: - """Subword tokenizer to apply after pre-tokenization. Returning - a dictionary with `bpe` providing the tokenizer name and - the other items providing the tokenizer-specific arguments. - Tokenizers are defined in `fairseq.data.encoders.*`""" - return self.config.get("bpe_tokenizer", {"bpe": None}) - - @property - def prepend_tgt_lang_tag(self) -> bool: - """Prepend target lang ID token as the target BOS (e.g. for to-many - multilingual setting). During inference, this requires `--prefix-size 1` - to force BOS to be lang ID token.""" - return self.config.get("prepend_tgt_lang_tag", False) - - @property - def input_feat_per_channel(self): - """The dimension of input features (per audio channel)""" - return self.config.get("input_feat_per_channel", 80) - - @property - def input_channels(self): - """The number of channels in the input audio""" - return self.config.get("input_channels", 1) - - @property - def sampling_alpha(self): - """Hyper-parameter alpha = 1/T for temperature-based resampling. - (alpha = 1 for no resampling)""" - return self.config.get("sampling_alpha", 1.0) - - @property - def use_audio_input(self): - """Needed by the dataset loader to see if the model requires - raw audio as inputs.""" - return self.config.get("use_audio_input", False) - - @property - def use_sample_rate(self): - """Needed by the dataset loader to see if the model requires - raw audio with specific sample rate as inputs.""" - return self.config.get("use_sample_rate", 16000) - - @property - def audio_root(self): - """Audio paths in the manifest TSV can be relative and this provides - the root path. Set this to empty string when using absolute paths.""" - return self.config.get("audio_root", "") - - def get_feature_transforms(self, split, is_train): - """Split-specific feature transforms. Allowing train set wildcard `_train`, - evaluation set wildcard `_eval` and general wildcard `*` for matching.""" - from copy import deepcopy - - cfg = deepcopy(self.config) - _cur = cfg.get("transforms", {}) - cur = _cur.get(split) - cur = _cur.get("_train") if cur is None and is_train else cur - cur = _cur.get("_eval") if cur is None and not is_train else cur - cur = _cur.get("*") if cur is None else cur - cfg["transforms"] = cur - return cfg - - def get_features_from_npy_or_audio(path): ext = Path(path).suffix if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: @@ -145,11 +45,7 @@ def get_features_from_npy_or_audio(path): def get_features_or_waveform_from_stored_zip( - path, - byte_offset, - byte_size, - need_waveform=False, - use_sample_rate=-1, + path, byte_offset, byte_size, need_waveform=False, use_sample_rate=None, ): assert path.endswith(".zip") data = read_from_stored_zip(path, byte_offset, byte_size) @@ -157,17 +53,18 @@ def get_features_or_waveform_from_stored_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_sf_audio_data(data): - features_or_waveform = ( - get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0] - if need_waveform - else get_fbank(f) - ) + features_or_waveform = \ + get_waveform( + f, always_2d=False, output_sample_rate=use_sample_rate + )[0] if need_waveform else get_fbank(f) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform -def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=-1): +def get_features_or_waveform( + path: str, need_waveform=False, use_sample_rate=None +): """Get speech features from .npy file or waveform from .wav/.flac file. The file may be inside an uncompressed ZIP file and is accessed via byte offset and length. @@ -190,11 +87,8 @@ def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=-1) return get_features_from_npy_or_audio(_path) elif len(slice_ptr) == 2: features_or_waveform = get_features_or_waveform_from_stored_zip( - _path, - slice_ptr[0], - slice_ptr[1], - need_waveform=need_waveform, - use_sample_rate=use_sample_rate, + _path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform, + use_sample_rate=use_sample_rate ) else: raise ValueError(f"Invalid path: {path}") @@ -223,10 +117,12 @@ def _collate_frames( return out -class SpeechToTextDatasetItem(NamedTuple): +@dataclass +class SpeechToTextDatasetItem(object): index: int source: torch.Tensor target: Optional[torch.Tensor] = None + speaker_id: Optional[int] = None class SpeechToTextDataset(FairseqDataset): @@ -248,6 +144,8 @@ def __init__( tgt_dict: Optional[Dictionary] = None, pre_tokenizer=None, bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None ): self.split, self.is_train_split = split, is_train_split self.cfg = cfg @@ -265,6 +163,7 @@ def __init__( ) self.src_texts, self.tgt_texts = src_texts, tgt_texts self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.speakers = speakers self.tgt_dict = tgt_dict self.check_tgt_lang_tag() self.ids = ids @@ -276,6 +175,8 @@ def __init__( self.pre_tokenizer = pre_tokenizer self.bpe_tokenizer = bpe_tokenizer + self.n_frames_per_step = n_frames_per_step + self.speaker_to_id = speaker_to_id self.tgt_lens = self.get_tgt_lens_and_check_oov() @@ -302,9 +203,10 @@ def get_tgt_lens_and_check_oov(self): def __repr__(self): return ( self.__class__.__name__ - + f'(split="{self.split}", n_samples={self.n_samples}, ' + + f'(split="{self.split}", n_samples={self.n_samples:_}, ' f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " - f"shuffle={self.shuffle}, transforms={self.feature_transforms})" + f"shuffle={self.shuffle}, transforms={self.feature_transforms}, " + f"n_frames_per_step={self.n_frames_per_step}" ) @classmethod @@ -329,6 +231,13 @@ def get_tokenized_tgt_text(self, index: int): text = self.tokenize(self.bpe_tokenizer, text) return text + def pack_frames(self, feature: torch.Tensor): + if self.n_frames_per_step == 1: + return feature + n_packed_frames = feature.shape[0] // self.n_frames_per_step + feature = feature[:self.n_frames_per_step * n_packed_frames] + return feature.reshape(n_packed_frames, -1) + @classmethod def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) @@ -345,6 +254,7 @@ def __getitem__(self, index: int) -> SpeechToTextDatasetItem: assert not self.cfg.use_audio_input source = self.feature_transforms(source) source = torch.from_numpy(source).float() + source = self.pack_frames(source) target = None if self.tgt_texts is not None: @@ -358,7 +268,12 @@ def __getitem__(self, index: int) -> SpeechToTextDatasetItem: ) target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) - return SpeechToTextDatasetItem(index=index, source=source, target=target) + speaker_id = None + if self.speaker_to_id is not None: + speaker_id = self.speaker_to_id[self.speakers[index]] + return SpeechToTextDatasetItem( + index=index, source=source, target=target, speaker_id=speaker_id + ) def __len__(self): return self.n_samples @@ -371,7 +286,7 @@ def collater( indices = torch.tensor([x.index for x in samples], dtype=torch.long) frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input) # sort samples by descending number of frames - n_frames = torch.tensor([x.source.size()[0] for x in samples], dtype=torch.long) + n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) @@ -389,7 +304,7 @@ def collater( ) target = target.index_select(0, order) target_lengths = torch.tensor( - [x.target.size()[0] for x in samples], dtype=torch.long + [x.target.size(0) for x in samples], dtype=torch.long ).index_select(0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( [x.target for x in samples], @@ -399,7 +314,13 @@ def collater( move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) - ntokens = sum(x.target.size()[0] for x in samples) + ntokens = sum(x.target.size(0) for x in samples) + + speaker = None + if self.speaker_to_id is not None: + speaker = torch.tensor( + [s.speaker_id for s in samples], dtype=torch.long + ).index_select(0, order).view(-1, 1) net_input = { "src_tokens": frames, @@ -409,6 +330,7 @@ def collater( out = { "id": indices, "net_input": net_input, + "speaker": speaker, "target": target, "target_lengths": target_lengths, "ntokens": ntokens, @@ -465,6 +387,8 @@ def _from_list( tgt_dict, pre_tokenizer, bpe_tokenizer, + n_frames_per_step, + speaker_to_id ) -> SpeechToTextDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] @@ -490,6 +414,8 @@ def _from_list( tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id ) @classmethod @@ -554,10 +480,13 @@ def _from_tsv( is_train_split: bool, pre_tokenizer, bpe_tokenizer, + n_frames_per_step, + speaker_to_id ) -> SpeechToTextDataset: samples = cls._load_samples_from_tsv(root, split) return cls._from_list( - split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, bpe_tokenizer + split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, + bpe_tokenizer, n_frames_per_step, speaker_to_id ) @classmethod @@ -572,10 +501,13 @@ def from_tsv( is_train_split: bool, epoch: int, seed: int, + n_frames_per_step: int = 1, + speaker_to_id=None ) -> SpeechToTextDataset: datasets = [ cls._from_tsv( - root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer + root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, + bpe_tokenizer, n_frames_per_step, speaker_to_id ) for split in splits.split(",") ] diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index c9ea52493d..7e887e8864 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -402,7 +402,7 @@ def cli_main(): parser = options.get_generation_parser() # TODO: replace this workaround with refactoring of `AudioPretraining` parser.add_argument( - '--arch', '-a', metavar='ARCH', default="transformer", + '--arch', '-a', metavar='ARCH', default="wav2vec2", help='Model architecture. For constructing tasks that rely on ' 'model args (e.g. `AudioPretraining`)' ) From e679327497702f52e4c6e5c2ab29b4d576c44ec4 Mon Sep 17 00:00:00 2001 From: "Yuan Shangguan (June)" <yuansg@fb.com> Date: Mon, 13 Sep 2021 13:19:35 -0700 Subject: [PATCH 699/707] Back out "Fairseq needs to store and load metadata from model state_dict" (#3861) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3861 backout fairseq changes. fix with a suggested, more optimal changes in checkopint utils. Reviewed By: zhengwy888 Differential Revision: D30886481 fbshipit-source-id: 12b6dd4d5107ab4371b73a58d9a044a17c733260 --- fairseq/checkpoint_utils.py | 37 +++------------------------------ fairseq/models/fairseq_model.py | 32 ---------------------------- fairseq/trainer.py | 9 +------- 3 files changed, 4 insertions(+), 74 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 7a494356ac..ef5d4c9022 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss, save_metadata=False): +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir @@ -114,7 +114,7 @@ def is_better(a, b): os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: - trainer.save_checkpoint(checkpoints[0], extra_state, save_metadata) + trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous @@ -455,35 +455,8 @@ def load_model_ensemble_and_task( else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) - new_state_model = state["model"] - - '''=====The following if-else statement is a work-around ===== - # the current metadata loading/saving of pytorch. - # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True) - # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be - # in orderedDict format, which has new_state_model._metadata attribute but not as key. - # TODO yuansg@ This issue should be fixed in pytorch ideally. - ''' - if new_state_model.get("_metadata", None) is not None: - new_metadata = new_state_model.get("_metadata", None) - del state["model"]["_metadata"] - else: - new_metadata = None - # Construct state dict content. - contents = OrderedDict(new_state_model) - # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models. - # calling state["model"] in fairseq will not invoke metadata storage. - if new_metadata is None: - logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====") - logger.warning("===Jit: we will be filling in with current model's meta-data instead.") - # For models trained before this diff, we do the following to be backward compatible. - contents.__setattr__("_metadata", model.state_dict()._metadata) - else: - contents.__setattr__("_metadata", new_metadata) - '''====End of work-around logic=====''' - model.load_state_dict( - contents, strict=strict, model_cfg=cfg.model + state["model"], strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble @@ -710,7 +683,6 @@ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): It's called by functions that load models from checkpoints and does not need to be called directly. """ - state_meta_data = state_dict.get("_metadata", None) arch = None if model_cfg is not None: arch = ( @@ -790,9 +762,6 @@ def create_pruning_pass(layers_to_keep, layer_name): if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None - # Ensure metadata is stored. - if state_meta_data is not None: - new_state_dict["_metadata"] = state_meta_data return new_state_dict diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 0645208efe..e55c7ba1ad 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -22,7 +22,6 @@ from fairseq.models import FairseqDecoder, FairseqEncoder from omegaconf import DictConfig from torch import Tensor -from collections import OrderedDict logger = logging.getLogger(__name__) @@ -123,15 +122,6 @@ def load_state_dict( from fairseq.checkpoint_utils import prune_state_dict new_state_dict = prune_state_dict(state_dict, model_cfg) - # The pytorch assumption of module is that it is an OrderedDict. - # Pytorch also assumes module._metadata exists in the state_dict, - # not as dictionary keys, rather as an attribute of the state dict. - new_state_dict = OrderedDict(new_state_dict) - metadata = new_state_dict.get("_metadata", None) - - if metadata: - del new_state_dict["_metadata"] - new_state_dict.__setattr__("_metadata", metadata) return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): @@ -161,28 +151,6 @@ def do_upgrade(m, prefix): do_upgrade(self, name) - def update_metadata(self, model_meta): - """ The model.state_dict()._metadata is stored in a collective location in - state_dict["model"]["_metadata"]. - A pytorch module's _metadata contains the torch modules' versions, which is important - for versionsetting functions. - - During model loading time, we load the model state_dict, but we don't load the state_dict metadata. - This function helps to update the model according to the state_dict["model"]["_metadata"] dump. - InputArgs: - update_metadata: Dict; key is module names, value is {"version", 1} or other metadata. - """ - # Do nothing if the model level metadata is empty. - if model_meta is None: - return - assert isinstance(model_meta, Dict), \ - "Input model_meta from state_dict should be a dictionary. Check state dict." - for key, val in model_meta.items(): - if key is None: # First level set up - self._metadata = val - else: # Subsequent levels of the model - self.get_submodule(key)._metadata = val - def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" for m in self.modules(): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a10ec97aa7..e46ccfe0b8 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -424,18 +424,12 @@ def state_dict(self): state_dict["fsdp_metadata"] = self.model.local_metadata_dict() return state_dict - def save_checkpoint(self, filename, extra_state, save_metadata=False): + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" logger.info(f"Saving checkpoint to {filename}") # call state_dict on all ranks in case it needs internal communication state_dict = utils.move_to_cpu(self.state_dict()) state_dict["extra_state"].update(extra_state) - # This should be added because model versions are stored as metadata. - if save_metadata and getattr(self.model.state_dict(), "_metadata", None) is not None: - logger.warning("Trainer: _metadata is inside model.state_dict(). ") - state_dict["model"]["_metadata"] = self.model.state_dict()._metadata - else: - logger.warning("Trainer: _metadata is not saved inside model.state_dict(). ") if self.should_save_checkpoint_on_current_rank: checkpoint_utils.torch_persistent_save( state_dict, @@ -508,7 +502,6 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) - self.model.update_metadata(getattr(state["model"], "_metadata", None)) # save memory for later steps del state["model"] if utils.has_parameters(self.get_criterion()): From 32b31173aa30e9b1555c4048917e8aa9f6227e18 Mon Sep 17 00:00:00 2001 From: "Yuan Shangguan (June)" <yuansg@fb.com> Date: Mon, 13 Sep 2021 13:19:35 -0700 Subject: [PATCH 700/707] Allow attributes of OrderedDict to be saved (#3862) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3862 We resolved a bug for missing "_metadata" attribute for pytorch models during checkpoing saving and loading using forced state["model"]["_metadata"], but it's not an efficient solution due to expensive model.state_dict() invocation. This diff offers an alternative solution. Reviewed By: zhengwy888 Differential Revision: D30857147 fbshipit-source-id: 5daa978e2a558ad4159e2da55470253950151bde --- fairseq/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fairseq/utils.py b/fairseq/utils.py index d1ec9a274c..623ff87a8c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F from torch import Tensor +import collections if TYPE_CHECKING: from fairseq.modules.multihead_attention import MultiheadAttention @@ -82,6 +83,11 @@ def apply_to_sample(f, sample): def _apply(x): if torch.is_tensor(x): return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict((key, _apply(value)) for key, value in x.items()) + od.__dict__ = x.__dict__ + return od elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): From 0ac3f3270c90e6d62284272b28ce076f61fb14eb Mon Sep 17 00:00:00 2001 From: Changhan Wang <changhan@fb.com> Date: Mon, 13 Sep 2021 18:12:38 -0700 Subject: [PATCH 701/707] add TTS Summary: [fairseq-py] add TTS Reviewed By: wnhsu Differential Revision: D30720666 fbshipit-source-id: b5288acec72bea1d3a9f3884a4ed51b616c7a403 --- examples/speech_synthesis/data_utils.py | 320 ++++++++++++ .../speech_synthesis/generate_waveform.py | 191 +++++++ examples/speech_synthesis/utils.py | 101 ++++ fairseq/criterions/fastspeech2_loss.py | 125 +++++ fairseq/criterions/tacotron2_loss.py | 210 ++++++++ .../data/audio/frm_text_to_speech_dataset.py | 207 ++++++++ fairseq/data/audio/text_to_speech_dataset.py | 215 ++++++++ fairseq/data/dictionary.py | 3 +- fairseq/models/speech_to_text/__init__.py | 4 +- fairseq/models/speech_to_text/utils.py | 1 - fairseq/models/text_to_speech/__init__.py | 8 + fairseq/models/text_to_speech/fastspeech2.py | 352 +++++++++++++ fairseq/models/text_to_speech/hifigan.py | 173 +++++++ fairseq/models/text_to_speech/tacotron2.py | 350 +++++++++++++ .../models/text_to_speech/tts_transformer.py | 371 ++++++++++++++ fairseq/models/text_to_speech/vocoder.py | 197 ++++++++ fairseq/modules/__init__.py | 4 + fairseq/modules/location_attention.py | 72 +++ fairseq/modules/lstm_cell_with_zoneout.py | 37 ++ .../optim/lr_scheduler/step_lr_scheduler.py | 86 ++++ fairseq/options.py | 18 + fairseq/speech_generator.py | 219 ++++++++ fairseq/tasks/frm_text_to_speech.py | 56 +++ fairseq/tasks/speech_to_text.py | 21 +- fairseq/tasks/text_to_speech.py | 467 ++++++++++++++++++ setup.py | 1 + 26 files changed, 3801 insertions(+), 8 deletions(-) create mode 100644 examples/speech_synthesis/data_utils.py create mode 100644 examples/speech_synthesis/generate_waveform.py create mode 100644 examples/speech_synthesis/utils.py create mode 100644 fairseq/criterions/fastspeech2_loss.py create mode 100644 fairseq/criterions/tacotron2_loss.py create mode 100644 fairseq/data/audio/frm_text_to_speech_dataset.py create mode 100644 fairseq/data/audio/text_to_speech_dataset.py create mode 100644 fairseq/models/text_to_speech/__init__.py create mode 100644 fairseq/models/text_to_speech/fastspeech2.py create mode 100644 fairseq/models/text_to_speech/hifigan.py create mode 100644 fairseq/models/text_to_speech/tacotron2.py create mode 100644 fairseq/models/text_to_speech/tts_transformer.py create mode 100644 fairseq/models/text_to_speech/vocoder.py create mode 100644 fairseq/modules/location_attention.py create mode 100644 fairseq/modules/lstm_cell_with_zoneout.py create mode 100644 fairseq/optim/lr_scheduler/step_lr_scheduler.py create mode 100644 fairseq/speech_generator.py create mode 100644 fairseq/tasks/frm_text_to_speech.py create mode 100644 fairseq/tasks/text_to_speech.py diff --git a/examples/speech_synthesis/data_utils.py b/examples/speech_synthesis/data_utils.py new file mode 100644 index 0000000000..f43a4a9004 --- /dev/null +++ b/examples/speech_synthesis/data_utils.py @@ -0,0 +1,320 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path +from typing import Optional, List, Dict +import zipfile +import tempfile +from dataclasses import dataclass +from itertools import groupby + +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm + +from examples.speech_to_text.data_utils import load_tsv_to_dicts +from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale + + +def trim_or_pad_to_target_length( + data_1d_or_2d: np.ndarray, target_length: int +) -> np.ndarray: + assert len(data_1d_or_2d.shape) in {1, 2} + delta = data_1d_or_2d.shape[0] - target_length + if delta >= 0: # trim if being longer + data_1d_or_2d = data_1d_or_2d[: target_length] + else: # pad if being shorter + if len(data_1d_or_2d.shape) == 1: + data_1d_or_2d = np.concatenate( + [data_1d_or_2d, np.zeros(-delta)], axis=0 + ) + else: + data_1d_or_2d = np.concatenate( + [data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))], + axis=0 + ) + return data_1d_or_2d + + +def extract_logmel_spectrogram( + waveform: torch.Tensor, sample_rate: int, + output_path: Optional[Path] = None, win_length: int = 1024, + hop_length: int = 256, n_fft: int = 1024, + win_fn: callable = torch.hann_window, n_mels: int = 80, + f_min: float = 0., f_max: float = 8000, eps: float = 1e-5, + overwrite: bool = False, target_length: Optional[int] = None +): + if output_path is not None and output_path.is_file() and not overwrite: + return + + spectrogram_transform = TTSSpectrogram( + n_fft=n_fft, win_length=win_length, hop_length=hop_length, + window_fn=win_fn + ) + mel_scale_transform = TTSMelScale( + n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, + n_stft=n_fft // 2 + 1 + ) + spectrogram = spectrogram_transform(waveform) + mel_spec = mel_scale_transform(spectrogram) + logmel_spec = torch.clamp(mel_spec, min=eps).log() + assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1 + logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D + if target_length is not None: + trim_or_pad_to_target_length(logmel_spec, target_length) + + if output_path is not None: + np.save(output_path.as_posix(), logmel_spec) + else: + return logmel_spec + + +def extract_pitch( + waveform: torch.Tensor, sample_rate: int, + output_path: Optional[Path] = None, hop_length: int = 256, + log_scale: bool = True, phoneme_durations: Optional[List[int]] = None +): + if output_path is not None and output_path.is_file(): + return + + try: + import pyworld + except ImportError: + raise ImportError("Please install PyWORLD: pip install pyworld") + + _waveform = waveform.squeeze(0).double().numpy() + pitch, t = pyworld.dio( + _waveform, sample_rate, frame_period=hop_length / sample_rate * 1000 + ) + pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate) + + if phoneme_durations is not None: + pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations)) + try: + from scipy.interpolate import interp1d + except ImportError: + raise ImportError("Please install SciPy: pip install scipy") + nonzero_ids = np.where(pitch != 0)[0] + interp_fn = interp1d( + nonzero_ids, + pitch[nonzero_ids], + fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), + bounds_error=False, + ) + pitch = interp_fn(np.arange(0, len(pitch))) + d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) + pitch = np.array( + [ + np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]]) + for i in range(1, len(d_cumsum)) + ] + ) + assert len(pitch) == len(phoneme_durations) + + if log_scale: + pitch = np.log(pitch + 1) + + if output_path is not None: + np.save(output_path.as_posix(), pitch) + else: + return pitch + + +def extract_energy( + waveform: torch.Tensor, output_path: Optional[Path] = None, + hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True, + phoneme_durations: Optional[List[int]] = None +): + if output_path is not None and output_path.is_file(): + return + + assert len(waveform.shape) == 2 and waveform.shape[0] == 1 + waveform = waveform.view(1, 1, waveform.shape[1]) + waveform = F.pad( + waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0], + mode="reflect" + ) + waveform = waveform.squeeze(1) + + fourier_basis = np.fft.fft(np.eye(n_fft)) + cutoff = int((n_fft / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + forward_transform = F.conv1d( + waveform, forward_basis, stride=hop_length, padding=0 + ) + + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + energy = torch.norm(magnitude, dim=1).squeeze(0).numpy() + + if phoneme_durations is not None: + energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations)) + d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) + energy = np.array( + [ + np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]]) + for i in range(1, len(d_cumsum)) + ] + ) + assert len(energy) == len(phoneme_durations) + + if log_scale: + energy = np.log(energy + 1) + + if output_path is not None: + np.save(output_path.as_posix(), energy) + else: + return energy + + +def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None): + mean_x, mean_x2, n_frames = None, None, 0 + feature_paths = feature_root.glob("*.npy") + for p in tqdm(feature_paths): + with open(p, 'rb') as f: + frames = np.load(f).squeeze() + + n_frames += frames.shape[0] + + cur_mean_x = frames.sum(axis=0) + if mean_x is None: + mean_x = cur_mean_x + else: + mean_x += cur_mean_x + + cur_mean_x2 = (frames ** 2).sum(axis=0) + if mean_x2 is None: + mean_x2 = cur_mean_x2 + else: + mean_x2 += cur_mean_x2 + + mean_x /= n_frames + mean_x2 /= n_frames + var_x = mean_x2 - mean_x ** 2 + std_x = np.sqrt(np.maximum(var_x, 1e-10)) + + if output_path is not None: + with open(output_path, 'wb') as f: + np.savez(f, mean=mean_x, std=std_x) + else: + return {"mean": mean_x, "std": std_x} + + +def ipa_phonemize(text, lang="en-us", use_g2p=False): + if use_g2p: + assert lang == "en-us", "g2pE phonemizer only works for en-us" + try: + from g2p_en import G2p + g2p = G2p() + return " ".join("|" if p == " " else p for p in g2p(text)) + except ImportError: + raise ImportError( + "Please install phonemizer: pip install g2p_en" + ) + else: + try: + from phonemizer import phonemize + from phonemizer.separator import Separator + return phonemize( + text, backend='espeak', language=lang, + separator=Separator(word="| ", phone=" ") + ) + except ImportError: + raise ImportError( + "Please install phonemizer: pip install phonemizer" + ) + + +@dataclass +class ForceAlignmentInfo(object): + tokens: List[str] + frame_durations: List[int] + start_sec: Optional[float] + end_sec: Optional[float] + + +def get_mfa_alignment_by_sample_id( + textgrid_zip_path: str, sample_id: str, sample_rate: int, + hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn") +) -> ForceAlignmentInfo: + try: + import tgt + except ImportError: + raise ImportError("Please install TextGridTools: pip install tgt") + + filename = f"{sample_id}.TextGrid" + out_root = Path(tempfile.gettempdir()) + tgt_path = out_root / filename + with zipfile.ZipFile(textgrid_zip_path) as f_zip: + f_zip.extract(filename, path=out_root) + textgrid = tgt.io.read_textgrid(tgt_path.as_posix()) + os.remove(tgt_path) + + phones, frame_durations = [], [] + start_sec, end_sec, end_idx = 0, 0, 0 + for t in textgrid.get_tier_by_name("phones")._objects: + s, e, p = t.start_time, t.end_time, t.text + # Trim leading silences + if len(phones) == 0: + if p in silence_phones: + continue + else: + start_sec = s + phones.append(p) + if p not in silence_phones: + end_sec = e + end_idx = len(phones) + r = sample_rate / hop_length + frame_durations.append(int(np.round(e * r) - np.round(s * r))) + # Trim tailing silences + phones = phones[:end_idx] + frame_durations = frame_durations[:end_idx] + + return ForceAlignmentInfo( + tokens=phones, frame_durations=frame_durations, start_sec=start_sec, + end_sec=end_sec + ) + + +def get_mfa_alignment( + textgrid_zip_path: str, sample_ids: List[str], sample_rate: int, + hop_length: int +) -> Dict[str, ForceAlignmentInfo]: + return { + i: get_mfa_alignment_by_sample_id( + textgrid_zip_path, i, sample_rate, hop_length + ) for i in tqdm(sample_ids) + } + + +def get_unit_alignment( + id_to_unit_tsv_path: str, sample_ids: List[str] +) -> Dict[str, ForceAlignmentInfo]: + id_to_units = { + e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path) + } + id_to_units = {i: id_to_units[i].split() for i in sample_ids} + id_to_units_collapsed = { + i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items() + } + id_to_durations = { + i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items() + } + + return { + i: ForceAlignmentInfo( + tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i], + start_sec=None, end_sec=None + ) + for i in sample_ids + } diff --git a/examples/speech_synthesis/generate_waveform.py b/examples/speech_synthesis/generate_waveform.py new file mode 100644 index 0000000000..bfc2ef8eb3 --- /dev/null +++ b/examples/speech_synthesis/generate_waveform.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +import soundfile as sf +import sys +import torch +import torchaudio + +from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.logging import progress_bar +from fairseq.tasks.text_to_speech import plot_tts_output +from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset + + +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def make_parser(): + parser = options.get_speech_generation_parser() + parser.add_argument("--dump-features", action="store_true") + parser.add_argument("--dump-waveforms", action="store_true") + parser.add_argument("--dump-attentions", action="store_true") + parser.add_argument("--dump-eos-probs", action="store_true") + parser.add_argument("--dump-plots", action="store_true") + parser.add_argument("--dump-target", action="store_true") + parser.add_argument("--output-sample-rate", default=22050, type=int) + parser.add_argument("--teacher-forcing", action="store_true") + parser.add_argument( + "--audio-format", type=str, default="wav", choices=["wav", "flac"] + ) + return parser + + +def postprocess_results( + dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target +): + def to_np(x): + return None if x is None else x.detach().cpu().numpy() + + sample_ids = [dataset.ids[i] for i in sample["id"].tolist()] + texts = sample["src_texts"] + attns = [to_np(hypo["attn"]) for hypo in hypos] + eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos] + feat_preds = [to_np(hypo["feature"]) for hypo in hypos] + wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos] + if dump_target: + feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos] + wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos] + else: + feat_targs = [None for _ in hypos] + wave_targs = [None for _ in hypos] + + return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds, + feat_targs, wave_targs) + + +def dump_result( + is_na_model, + args, + vocoder, + sample_id, + text, + attn, + eos_prob, + feat_pred, + wave_pred, + feat_targ, + wave_targ, +): + sample_rate = args.output_sample_rate + out_root = Path(args.results_path) + if args.dump_features: + feat_dir = out_root / "feat" + feat_dir.mkdir(exist_ok=True, parents=True) + np.save(feat_dir / f"{sample_id}.npy", feat_pred) + if args.dump_target: + feat_tgt_dir = out_root / "feat_tgt" + feat_tgt_dir.mkdir(exist_ok=True, parents=True) + np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ) + if args.dump_attentions: + attn_dir = out_root / "attn" + attn_dir.mkdir(exist_ok=True, parents=True) + np.save(attn_dir / f"{sample_id}.npy", attn.numpy()) + if args.dump_eos_probs and not is_na_model: + eos_dir = out_root / "eos" + eos_dir.mkdir(exist_ok=True, parents=True) + np.save(eos_dir / f"{sample_id}.npy", eos_prob) + + if args.dump_plots: + images = [feat_pred.T] if is_na_model else [feat_pred.T, attn] + names = ["output"] if is_na_model else ["output", "alignment"] + if feat_targ is not None: + images = [feat_targ.T] + images + names = [f"target (idx={sample_id})"] + names + if is_na_model: + plot_tts_output(images, names, attn, "alignment", suptitle=text) + else: + plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text) + plot_dir = out_root / "plot" + plot_dir.mkdir(exist_ok=True, parents=True) + plt.savefig(plot_dir / f"{sample_id}.png") + plt.close() + + if args.dump_waveforms: + ext = args.audio_format + if wave_pred is not None: + wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}" + wav_dir.mkdir(exist_ok=True, parents=True) + sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate) + if args.dump_target and wave_targ is not None: + wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt" + wav_tgt_dir.mkdir(exist_ok=True, parents=True) + sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate) + + +def main(args): + assert(args.dump_features or args.dump_waveforms or args.dump_attentions + or args.dump_eos_probs or args.dump_plots) + if args.max_tokens is None and args.batch_size is None: + args.max_tokens = 8000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + task = tasks.setup_task(args) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [args.path], + task=task, + ) + model = models[0].cuda() if use_cuda else models[0] + # use the original n_frames_per_step + task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step + task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) + + data_cfg = task.data_cfg + sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050) + resample_fn = { + False: lambda x: x, + True: lambda x: torchaudio.sox_effects.apply_effects_tensor( + x.detach().cpu().unsqueeze(0), sample_rate, + [['rate', str(args.output_sample_rate)]] + )[0].squeeze(0) + }.get(args.output_sample_rate != sample_rate) + if args.output_sample_rate != sample_rate: + logger.info(f"resampling to {args.output_sample_rate}Hz") + + generator = task.build_generator([model], args) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + Path(args.results_path).mkdir(exist_ok=True, parents=True) + is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False) + dataset = task.dataset(args.gen_subset) + vocoder = task.args.vocoder + with progress_bar.build_progress_bar(args, itr) as t: + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + hypos = generator.generate(model, sample, has_targ=args.dump_target) + for result in postprocess_results( + dataset, sample, hypos, resample_fn, args.dump_target + ): + dump_result(is_na_model, args, vocoder, *result) + + +def cli_main(): + parser = make_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/speech_synthesis/utils.py b/examples/speech_synthesis/utils.py new file mode 100644 index 0000000000..2c7b03733d --- /dev/null +++ b/examples/speech_synthesis/utils.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from scipy.interpolate import interp1d +import torchaudio + +from fairseq.tasks.text_to_speech import ( + batch_compute_distortion, compute_rms_dist +) + + +def batch_mel_spectral_distortion( + y1, y2, sr, normalize_type="path", mel_fn=None +): + """ + https://arxiv.org/pdf/2011.03568.pdf + + Same as Mel Cepstral Distortion, but computed on log-mel spectrograms. + """ + if mel_fn is None or mel_fn.sample_rate != sr: + mel_fn = torchaudio.transforms.MelSpectrogram( + sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr), + hop_length=int(0.0125 * sr), f_min=20, n_mels=80, + window_fn=torch.hann_window + ).to(y1[0].device) + offset = 1e-6 + return batch_compute_distortion( + y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2), + compute_rms_dist, normalize_type + ) + + +# This code is based on +# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py" +def _same_t_in_true_and_est(func): + def new_func(true_t, true_f, est_t, est_f): + assert type(true_t) is np.ndarray + assert type(true_f) is np.ndarray + assert type(est_t) is np.ndarray + assert type(est_f) is np.ndarray + + interpolated_f = interp1d( + est_t, est_f, bounds_error=False, kind='nearest', fill_value=0 + )(true_t) + return func(true_t, true_f, true_t, interpolated_f) + + return new_func + + +@_same_t_in_true_and_est +def gross_pitch_error(true_t, true_f, est_t, est_f): + """The relative frequency in percent of pitch estimates that are + outside a threshold around the true pitch. Only frames that are + considered pitched by both the ground truth and the estimator (if + applicable) are considered. + """ + + correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) + gross_pitch_error_frames = _gross_pitch_error_frames( + true_t, true_f, est_t, est_f + ) + return np.sum(gross_pitch_error_frames) / np.sum(correct_frames) + + +def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8): + voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) + true_f_p_eps = [x + eps for x in true_f] + pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2 + return voiced_frames & pitch_error_frames + + +def _true_voiced_frames(true_t, true_f, est_t, est_f): + return (est_f != 0) & (true_f != 0) + + +def _voicing_decision_error_frames(true_t, true_f, est_t, est_f): + return (est_f != 0) != (true_f != 0) + + +@_same_t_in_true_and_est +def f0_frame_error(true_t, true_f, est_t, est_f): + gross_pitch_error_frames = _gross_pitch_error_frames( + true_t, true_f, est_t, est_f + ) + voicing_decision_error_frames = _voicing_decision_error_frames( + true_t, true_f, est_t, est_f + ) + return (np.sum(gross_pitch_error_frames) + + np.sum(voicing_decision_error_frames)) / (len(true_t)) + + +@_same_t_in_true_and_est +def voicing_decision_error(true_t, true_f, est_t, est_f): + voicing_decision_error_frames = _voicing_decision_error_frames( + true_t, true_f, est_t, est_f + ) + return np.sum(voicing_decision_error_frames) / (len(true_t)) diff --git a/fairseq/criterions/fastspeech2_loss.py b/fairseq/criterions/fastspeech2_loss.py new file mode 100644 index 0000000000..085d5628d4 --- /dev/null +++ b/fairseq/criterions/fastspeech2_loss.py @@ -0,0 +1,125 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from typing import List, Dict, Any +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import lengths_to_mask +from fairseq.models.fairseq_model import FairseqEncoderModel + + +@dataclass +class FastSpeech2CriterionConfig(FairseqDataclass): + ctc_weight: float = field( + default=0.0, metadata={"help": "weight for CTC loss"} + ) + + +@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) +class FastSpeech2Loss(FairseqCriterion): + def __init__(self, task, ctc_weight): + super().__init__(task) + self.ctc_weight = ctc_weight + + def forward(self, model: FairseqEncoderModel, sample, reduction="mean"): + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + _feat_out, _, log_dur_out, pitch_out, energy_out = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + durations=sample["durations"], + pitches=sample["pitches"], + energies=sample["energies"] + ) + + src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) + tgt_mask = lengths_to_mask(sample["target_lengths"]) + + pitches, energies = sample["pitches"], sample["energies"] + pitch_out, pitches = pitch_out[src_mask], pitches[src_mask] + energy_out, energies = energy_out[src_mask], energies[src_mask] + + feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] + l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) + + pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) + energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) + + log_dur_out = log_dur_out[src_mask] + dur = sample["durations"].float() + dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur + log_dur = torch.log(dur + 1)[src_mask] + dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) + + ctc_loss = torch.tensor(0.).type_as(l1_loss) + if self.ctc_weight > 0.: + lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = F.ctc_loss( + lprobs, src_tokens_flat, tgt_lens, src_lens, + reduction=reduction, zero_infinity=True + ) * self.ctc_weight + + loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss + + sample_size = sample["nsentences"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "dur_loss": utils.item(dur_loss.data), + "pitch_loss": utils.item(pitch_loss.data), + "energy_loss": utils.item(energy_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in [ + "loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss", + "ctc_loss" + ]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/criterions/tacotron2_loss.py b/fairseq/criterions/tacotron2_loss.py new file mode 100644 index 0000000000..8c7b655c8c --- /dev/null +++ b/fairseq/criterions/tacotron2_loss.py @@ -0,0 +1,210 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +from typing import Any, Dict, List +from functools import lru_cache +from dataclasses import dataclass, field + +import torch +from omegaconf import II + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import lengths_to_mask +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +@dataclass +class Tacotron2CriterionConfig(FairseqDataclass): + bce_pos_weight: float = field( + default=1.0, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + n_frames_per_step: int = field( + default=0, + metadata={"help": "Number of frames per decoding step"}, + ) + use_guided_attention_loss: bool = field( + default=False, + metadata={"help": "use guided attention loss"}, + ) + guided_attention_loss_sigma: float = field( + default=0.4, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + ctc_weight: float = field( + default=0.0, metadata={"help": "weight for CTC loss"} + ) + sentence_avg: bool = II("optimization.sentence_avg") + + +class GuidedAttentionLoss(torch.nn.Module): + """ + Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention (https://arxiv.org/abs/1710.08969) + """ + + def __init__(self, sigma): + super().__init__() + self.sigma = sigma + + @staticmethod + @lru_cache(maxsize=8) + def _get_weight(s_len, t_len, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len)) + grid_x = grid_x.to(s_len.device) + grid_y = grid_y.to(s_len.device) + w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2 + return 1.0 - torch.exp(-w / (2 * (sigma ** 2))) + + def _get_weights(self, src_lens, tgt_lens): + bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) + weights = torch.zeros((bsz, max_t_len, max_s_len)) + for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): + weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, + self.sigma) + return weights + + @staticmethod + def _get_masks(src_lens, tgt_lens): + in_masks = lengths_to_mask(src_lens) + out_masks = lengths_to_mask(tgt_lens) + return out_masks.unsqueeze(2) & in_masks.unsqueeze(1) + + def forward(self, attn, src_lens, tgt_lens, reduction="mean"): + weights = self._get_weights(src_lens, tgt_lens).to(attn.device) + masks = self._get_masks(src_lens, tgt_lens).to(attn.device) + loss = (weights * attn.transpose(1, 2)).masked_select(masks) + loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss) + return loss + + +@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig) +class Tacotron2Criterion(FairseqCriterion): + def __init__(self, task, sentence_avg, n_frames_per_step, + use_guided_attention_loss, guided_attention_loss_sigma, + bce_pos_weight, ctc_weight): + super().__init__(task) + self.sentence_avg = sentence_avg + self.n_frames_per_step = n_frames_per_step + self.bce_pos_weight = bce_pos_weight + + self.guided_attn = None + if use_guided_attention_loss: + self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma) + self.ctc_weight = ctc_weight + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + + feat_out, eos_out, extra = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"] + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt, + tgt_lens, reduction, + ) + attn_loss = torch.tensor(0.).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction) + ctc_loss = torch.tensor(0.).type_as(l1_loss) + if self.ctc_weight > 0.: + net_output = (feat_out, eos_out, extra) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = F.ctc_loss( + lprobs, src_tokens_flat, tgt_lens, src_lens, + reduction=reduction, zero_infinity=True + ) * self.ctc_weight + loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss + + sample_size = sample["nsentences"] if self.sentence_avg \ + else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt, + eos_tgt, tgt_lens, reduction="mean"): + mask = lengths_to_mask(tgt_lens) + _eos_out = eos_out[mask].squeeze() + _eos_tgt = eos_tgt[mask] + _feat_tgt = feat_tgt[mask] + _feat_out = feat_out[mask] + _feat_out_post = feat_out_post[mask] + + l1_loss = ( + F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + + F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction) + ) + mse_loss = ( + F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + + F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction) + ) + eos_loss = F.binary_cross_entropy_with_logits( + _eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight), + reduction=reduction + ) + return l1_loss, mse_loss, eos_loss + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/data/audio/frm_text_to_speech_dataset.py b/fairseq/data/audio/frm_text_to_speech_dataset.py new file mode 100644 index 0000000000..125b1fc0c0 --- /dev/null +++ b/fairseq/data/audio/frm_text_to_speech_dataset.py @@ -0,0 +1,207 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory.abs + +import csv +import logging +import os.path as op +from typing import List, Optional + +import numpy as np +import torch +from fairseq.data import Dictionary +from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig +) +from fairseq.data.audio.text_to_speech_dataset import ( + TextToSpeechDataset, TextToSpeechDatasetCreator +) + +logger = logging.getLogger(__name__) + + +class FrmTextToSpeechDataset(TextToSpeechDataset): + def __init__( + self, + split: str, + is_train_split: bool, + data_cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + do_chunk=False, + chunk_bound=-1, + chunk_init=50, + chunk_incr=5, + add_eos=True, + dedup=True, + ref_fpu=-1 + ): + # It assumes texts are encoded at a fixed frame-rate + super().__init__( + split=split, + is_train_split=is_train_split, + data_cfg=data_cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id + ) + + self.do_chunk = do_chunk + self.chunk_bound = chunk_bound + self.chunk_init = chunk_init + self.chunk_incr = chunk_incr + self.add_eos = add_eos + self.dedup = dedup + self.ref_fpu = ref_fpu + + self.chunk_size = -1 + + if do_chunk: + assert self.chunk_incr >= 0 + assert self.pre_tokenizer is None + + def __getitem__(self, index): + index, source, target, speaker_id, _, _, _ = super().__getitem__(index) + if target[-1].item() == self.tgt_dict.eos_index: + target = target[:-1] + + fpu = source.size(0) / target.size(0) # frame-per-unit + fps = self.n_frames_per_step + assert ( + self.ref_fpu == -1 or + abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1 + ), f"{fpu*fps} != {self.ref_fpu}" + + # only chunk training split + if self.is_train_split and self.do_chunk and self.chunk_size > 0: + lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)] + text = target[int(self.data_cfg.prepend_tgt_lang_tag):] + size = len(text) + chunk_size = min(self.chunk_size, size) + chunk_start = np.random.randint(size - chunk_size + 1) + text = text[chunk_start:chunk_start+chunk_size] + target = torch.cat((lang, text), 0) + + f_size = int(np.floor(chunk_size * fpu)) + f_start = int(np.floor(chunk_start * fpu)) + assert(f_size > 0) + source = source[f_start:f_start+f_size, :] + + if self.dedup: + target = torch.unique_consecutive(target) + + if self.add_eos: + eos_idx = self.tgt_dict.eos_index + target = torch.cat((target, torch.LongTensor([eos_idx])), 0) + + return index, source, target, speaker_id + + def set_epoch(self, epoch): + if self.is_train_split and self.do_chunk: + old = self.chunk_size + self.chunk_size = self.chunk_init + epoch * self.chunk_incr + if self.chunk_bound > 0: + self.chunk_size = min(self.chunk_size, self.chunk_bound) + logger.info(( + f"{self.split}: setting chunk size " + f"from {old} to {self.chunk_size}" + )) + + +class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): + # inherit for key names + @classmethod + def from_tsv( + cls, + root: str, + data_cfg: S2TDataConfig, + split: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + n_frames_per_step: int, + speaker_to_id, + do_chunk: bool = False, + chunk_bound: int = -1, + chunk_init: int = 50, + chunk_incr: int = 5, + add_eos: bool = True, + dedup: bool = True, + ref_fpu: float = -1 + ) -> FrmTextToSpeechDataset: + tsv_path = op.join(root, f"{split}.tsv") + if not op.isfile(tsv_path): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + s = [dict(e) for e in reader] + assert len(s) > 0 + + ids = [ss[cls.KEY_ID] for ss in s] + audio_paths = [ + op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s + ] + n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s] + tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s] + src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] + speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s] + src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s] + tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s] + + return FrmTextToSpeechDataset( + split=split, + is_train_split=is_train_split, + data_cfg=data_cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + do_chunk=do_chunk, + chunk_bound=chunk_bound, + chunk_init=chunk_init, + chunk_incr=chunk_incr, + add_eos=add_eos, + dedup=dedup, + ref_fpu=ref_fpu + ) diff --git a/fairseq/data/audio/text_to_speech_dataset.py b/fairseq/data/audio/text_to_speech_dataset.py new file mode 100644 index 0000000000..abfcb2be40 --- /dev/null +++ b/fairseq/data/audio/text_to_speech_dataset.py @@ -0,0 +1,215 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory.abs + +from pathlib import Path +from typing import List, Dict, Optional, Any +from dataclasses import dataclass + +import numpy as np +import torch + +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig, + _collate_frames, get_features_or_waveform +) +from fairseq.data import Dictionary, data_utils as fairseq_data_utils + + +@dataclass +class TextToSpeechDatasetItem(object): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + speaker_id: Optional[int] = None + duration: Optional[torch.Tensor] = None + pitch: Optional[torch.Tensor] = None + energy: Optional[torch.Tensor] = None + + +class TextToSpeechDataset(SpeechToTextDataset): + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + durations: Optional[List[List[int]]] = None, + pitches: Optional[List[str]] = None, + energies: Optional[List[str]] = None + ): + super(TextToSpeechDataset, self).__init__( + split, is_train_split, cfg, audio_paths, n_frames, + src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, + src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, + tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id + ) + self.durations = durations + self.pitches = pitches + self.energies = energies + + def __getitem__(self, index: int) -> TextToSpeechDatasetItem: + s2t_item = super().__getitem__(index) + + duration, pitch, energy = None, None, None + if self.durations is not None: + duration = torch.tensor( + self.durations[index] + [0], dtype=torch.long # pad 0 for EOS + ) + if self.pitches is not None: + pitch = get_features_or_waveform(self.pitches[index]) + pitch = torch.from_numpy( + np.concatenate((pitch, [0])) # pad 0 for EOS + ).float() + if self.energies is not None: + energy = get_features_or_waveform(self.energies[index]) + energy = torch.from_numpy( + np.concatenate((energy, [0])) # pad 0 for EOS + ).float() + return TextToSpeechDatasetItem( + index=index, source=s2t_item.source, target=s2t_item.target, + speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch, + energy=energy + ) + + def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: + if len(samples) == 0: + return {} + + src_lengths, order = torch.tensor( + [s.target.shape[0] for s in samples], dtype=torch.long + ).sort(descending=True) + id_ = torch.tensor([s.index for s in samples], + dtype=torch.long).index_select(0, order) + feat = _collate_frames( + [s.source for s in samples], self.cfg.use_audio_input + ).index_select(0, order) + target_lengths = torch.tensor( + [s.source.shape[0] for s in samples], dtype=torch.long + ).index_select(0, order) + + src_tokens = fairseq_data_utils.collate_tokens( + [s.target for s in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ).index_select(0, order) + + speaker = None + if self.speaker_to_id is not None: + speaker = torch.tensor( + [s.speaker_id for s in samples], dtype=torch.long + ).index_select(0, order).view(-1, 1) + + bsz, _, d = feat.size() + prev_output_tokens = torch.cat( + (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1 + ) + + durations, pitches, energies = None, None, None + if self.durations is not None: + durations = fairseq_data_utils.collate_tokens( + [s.duration for s in samples], 0 + ).index_select(0, order) + assert src_tokens.shape[1] == durations.shape[1] + if self.pitches is not None: + pitches = _collate_frames([s.pitch for s in samples], True) + pitches = pitches.index_select(0, order) + assert src_tokens.shape[1] == pitches.shape[1] + if self.energies is not None: + energies = _collate_frames([s.energy for s in samples], True) + energies = energies.index_select(0, order) + assert src_tokens.shape[1] == energies.shape[1] + src_texts = [self.tgt_dict.string(samples[i].target) for i in order] + + return { + "id": id_, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "prev_output_tokens": prev_output_tokens, + }, + "speaker": speaker, + "target": feat, + "durations": durations, + "pitches": pitches, + "energies": energies, + "target_lengths": target_lengths, + "ntokens": sum(target_lengths).item(), + "nsentences": len(samples), + "src_texts": src_texts, + } + + +class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator): + KEY_DURATION = "duration" + KEY_PITCH = "pitch" + KEY_ENERGY = "energy" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id + ) -> TextToSpeechDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + + durations = [s.get(cls.KEY_DURATION, None) for s in samples] + durations = [ + None if dd is None else [int(d) for d in dd.split(" ")] + for dd in durations + ] + durations = None if any(dd is None for dd in durations) else durations + + pitches = [s.get(cls.KEY_PITCH, None) for s in samples] + pitches = [ + None if pp is None else (audio_root / pp).as_posix() + for pp in pitches + ] + pitches = None if any(pp is None for pp in pitches) else pitches + + energies = [s.get(cls.KEY_ENERGY, None) for s in samples] + energies = [ + None if ee is None else (audio_root / ee).as_posix() + for ee in energies] + energies = None if any(ee is None for ee in energies) else energies + + return TextToSpeechDataset( + split_name, is_train_split, cfg, audio_paths, n_frames, + src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, + pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, + durations, pitches, energies + ) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index d3ef0f9896..d6495389f0 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -92,7 +92,8 @@ def string( ) extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) - extra_symbols_to_ignore.add(self.eos()) + if not include_eos: + extra_symbols_to_ignore.add(self.eos()) def token_string(i): if i == self.unk(): diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index cac365cbb8..1c5189c0f7 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -5,5 +5,5 @@ from .berard import * # noqa from .convtransformer import * # noqa -from .s2t_transformer import * # noqa -from .xm_transformer import * # noqa +from .s2t_transformer import * # noqa +from .xm_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/utils.py b/fairseq/models/speech_to_text/utils.py index 573f8537c9..168b8bf13b 100644 --- a/fairseq/models/speech_to_text/utils.py +++ b/fairseq/models/speech_to_text/utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # diff --git a/fairseq/models/text_to_speech/__init__.py b/fairseq/models/text_to_speech/__init__.py new file mode 100644 index 0000000000..652fee0d68 --- /dev/null +++ b/fairseq/models/text_to_speech/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .tacotron2 import * # noqa +from .tts_transformer import * # noqa +from .fastspeech2 import * # noqa diff --git a/fairseq/models/text_to_speech/fastspeech2.py b/fairseq/models/text_to_speech/fastspeech2.py new file mode 100644 index 0000000000..9c38d0917d --- /dev/null +++ b/fairseq/models/text_to_speech/fastspeech2.py @@ -0,0 +1,352 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch import nn + +from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model, + register_model_architecture) +from fairseq.modules import ( + LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention +) +from fairseq import utils +from fairseq.data.data_utils import lengths_to_padding_mask + + +logger = logging.getLogger(__name__) + + +def model_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + return m + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, in_dim, hidden_dim, kernel_size, dropout): + super().__init__() + self.ffn = nn.Sequential( + nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size, + padding=(kernel_size - 1) // 2), + nn.ReLU(), + nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size, + padding=(kernel_size - 1) // 2) + ) + self.layer_norm = LayerNorm(in_dim) + self.dropout = self.dropout_module = FairseqDropout( + p=dropout, module_name=self.__class__.__name__ + ) + + def forward(self, x): + # B x T x C + residual = x + x = self.ffn(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout(x) + return self.layer_norm(x + residual) + + +class FFTLayer(torch.nn.Module): + def __init__( + self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, + attention_dropout + ): + super().__init__() + self.self_attn = MultiheadAttention( + embed_dim, n_heads, dropout=attention_dropout, self_attention=True + ) + self.layer_norm = LayerNorm(embed_dim) + self.ffn = PositionwiseFeedForward( + embed_dim, hidden_dim, kernel_size, dropout=dropout + ) + + def forward(self, x, padding_mask=None): + # B x T x C + residual = x + x = x.transpose(0, 1) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=padding_mask, + need_weights=False + ) + x = x.transpose(0, 1) + x = self.layer_norm(x + residual) + return self.ffn(x) + + +class LengthRegulator(nn.Module): + def forward(self, x, durations): + # x: B x T x C + out_lens = durations.sum(dim=1) + max_len = out_lens.max() + bsz, seq_len, dim = x.size() + out = x.new_zeros((bsz, max_len, dim)) + + for b in range(bsz): + indices = [] + for t in range(seq_len): + indices.extend([t] * utils.item(durations[b, t])) + indices = torch.tensor(indices, dtype=torch.long).to(x.device) + out_len = utils.item(out_lens[b]) + out[b, :out_len] = x[b].index_select(0, indices) + + return out, out_lens + + +class VariancePredictor(nn.Module): + def __init__(self, args): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv1d( + args.encoder_embed_dim, args.var_pred_hidden_dim, + kernel_size=args.var_pred_kernel_size, + padding=(args.var_pred_kernel_size - 1) // 2 + ), + nn.ReLU() + ) + self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim) + self.dropout_module = FairseqDropout( + p=args.var_pred_dropout, module_name=self.__class__.__name__ + ) + self.conv2 = nn.Sequential( + nn.Conv1d( + args.var_pred_hidden_dim, args.var_pred_hidden_dim, + kernel_size=args.var_pred_kernel_size, padding=1 + ), + nn.ReLU() + ) + self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim) + self.proj = nn.Linear(args.var_pred_hidden_dim, 1) + + def forward(self, x): + # Input: B x T x C; Output: B x T + x = self.conv1(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout_module(self.ln1(x)) + x = self.conv2(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout_module(self.ln2(x)) + return self.proj(x).squeeze(dim=2) + + +class VarianceAdaptor(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.length_regulator = LengthRegulator() + self.duration_predictor = VariancePredictor(args) + self.pitch_predictor = VariancePredictor(args) + self.energy_predictor = VariancePredictor(args) + + n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1 + self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps) + self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim) + self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps) + self.embed_energy = Embedding(n_bins, args.encoder_embed_dim) + + def get_pitch_emb(self, x, tgt=None, factor=1.0): + out = self.pitch_predictor(x) + bins = self.pitch_bins.to(x.device) + if tgt is None: + out = out * factor + emb = self.embed_pitch(torch.bucketize(out, bins)) + else: + emb = self.embed_pitch(torch.bucketize(tgt, bins)) + return out, emb + + def get_energy_emb(self, x, tgt=None, factor=1.0): + out = self.energy_predictor(x) + bins = self.energy_bins.to(x.device) + if tgt is None: + out = out * factor + emb = self.embed_energy(torch.bucketize(out, bins)) + else: + emb = self.embed_energy(torch.bucketize(tgt, bins)) + return out, emb + + def forward( + self, x, padding_mask, durations=None, pitches=None, energies=None, + d_factor=1.0, p_factor=1.0, e_factor=1.0 + ): + # x: B x T x C + log_dur_out = self.duration_predictor(x) + dur_out = torch.clamp( + torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0 + ) + dur_out.masked_fill_(padding_mask, 0) + + pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor) + x = x + pitch_emb + energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor) + x = x + energy_emb + + x, out_lens = self.length_regulator( + x, dur_out if durations is None else durations + ) + + return x, out_lens, log_dur_out, pitch_out, energy_out + + +class FastSpeech2Encoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.args = args + self.padding_idx = src_dict.pad() + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, + args.encoder_embed_dim + ) + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_tokens = Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) + + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1)) + + self.encoder_fft_layers = nn.ModuleList( + FFTLayer( + args.encoder_embed_dim, args.encoder_attention_heads, + args.fft_hidden_dim, args.fft_kernel_size, + dropout=args.dropout, attention_dropout=args.attention_dropout + ) + for _ in range(args.encoder_layers) + ) + + self.var_adaptor = VarianceAdaptor(args) + + self.decoder_fft_layers = nn.ModuleList( + FFTLayer( + args.decoder_embed_dim, args.decoder_attention_heads, + args.fft_hidden_dim, args.fft_kernel_size, + dropout=args.dropout, attention_dropout=args.attention_dropout + ) + for _ in range(args.decoder_layers) + ) + + self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) + + self.apply(model_init) + + def forward(self, src_tokens, src_lengths=None, speaker=None, + durations=None, pitches=None, energies=None, **kwargs): + x = self.embed_tokens(src_tokens) + + enc_padding_mask = src_tokens.eq(self.padding_idx) + x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask) + x = self.dropout_module(x) + + for layer in self.encoder_fft_layers: + x = layer(x, enc_padding_mask) + + if self.embed_speaker is not None: + bsz, seq_len, _ = x.size() + emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + x, out_lens, log_dur_out, pitch_out, energy_out = \ + self.var_adaptor(x, enc_padding_mask, durations, pitches, energies) + + dec_padding_mask = lengths_to_padding_mask(out_lens) + x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask) + for layer in self.decoder_fft_layers: + x = layer(x, dec_padding_mask) + + x = self.out_proj(x) + + return x, out_lens, log_dur_out, pitch_out, energy_out + + +@register_model("fastspeech2") +class FastSpeech2Model(FairseqEncoderModel): + """ + Implementation for https://arxiv.org/abs/2006.04558 + """ + + NON_AUTOREGRESSIVE = True + + @staticmethod + def add_args(parser): + parser.add_argument("--dropout", type=float) + parser.add_argument("--output-frame-dim", type=int) + parser.add_argument("--speaker-embed-dim", type=int) + # FFT blocks + parser.add_argument("--fft-hidden-dim", type=int) + parser.add_argument("--fft-kernel-size", type=int) + parser.add_argument("--attention-dropout", type=float) + parser.add_argument("--encoder-layers", type=int) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-attention-heads", type=int) + parser.add_argument("--decoder-layers", type=int) + parser.add_argument("--decoder-embed-dim", type=int) + parser.add_argument("--decoder-attention-heads", type=int) + # variance predictor + parser.add_argument("--var-pred-n-bins", type=int) + parser.add_argument("--var-pred-hidden-dim", type=int) + parser.add_argument("--var-pred-kernel-size", type=int) + parser.add_argument("--var-pred-dropout", type=float) + + def __init__(self, encoder, args, src_dict): + super().__init__(encoder) + self._num_updates = 0 + + out_dim = args.output_frame_dim * args.n_frames_per_step + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.) > 0.: + self.ctc_proj = nn.Linear(out_dim, len(src_dict)) + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker) + return cls(encoder, args, task.src_dict) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + def get_normalized_probs(self, net_output, log_probs, sample=None): + logits = self.ctc_proj(net_output[0]) + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + +@register_model_architecture("fastspeech2", "fastspeech2") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.2) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) + args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64) + # FFT blocks + args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024) + args.fft_kernel_size = getattr(args, "fft_kernel_size", 9) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.encoder_layers = getattr(args, "encoder_layers", 4) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) + # variance predictor + args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256) + args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256) + args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3) + args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5) diff --git a/fairseq/models/text_to_speech/hifigan.py b/fairseq/models/text_to_speech/hifigan.py new file mode 100644 index 0000000000..edc7db6015 --- /dev/null +++ b/fairseq/models/text_to_speech/hifigan.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +class ResBlock(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + remove_weight_norm(layer) + for layer in self.convs2: + remove_weight_norm(layer) + + +class Generator(torch.nn.Module): + def __init__(self, cfg): + super(Generator, self).__init__() + self.num_kernels = len(cfg["resblock_kernel_sizes"]) + self.num_upsamples = len(cfg["upsample_rates"]) + self.conv_pre = weight_norm( + Conv1d(80, cfg["upsample_initial_channel"], 7, 1, padding=3) + ) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) + ): + self.ups.append( + weight_norm( + ConvTranspose1d( + cfg["upsample_initial_channel"] // (2 ** i), + cfg["upsample_initial_channel"] // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) + for k, d in zip( + cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] + ): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/fairseq/models/text_to_speech/tacotron2.py b/fairseq/models/text_to_speech/tacotron2.py new file mode 100644 index 0000000000..bb327e81e7 --- /dev/null +++ b/fairseq/models/text_to_speech/tacotron2.py @@ -0,0 +1,350 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch import nn +from torch.nn import functional as F + +from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, + register_model_architecture) +from fairseq.modules import LSTMCellWithZoneOut, LocationAttention + + +logger = logging.getLogger(__name__) + + +def encoder_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +class Tacotron2Encoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.padding_idx = src_dict.pad() + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, + args.encoder_embed_dim + ) + + self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, + padding_idx=self.padding_idx) + + assert(args.encoder_conv_kernel_size % 2 == 1) + self.convolutions = nn.ModuleList( + nn.Sequential( + nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2)), + nn.BatchNorm1d(args.encoder_embed_dim), + nn.ReLU(), + nn.Dropout(args.encoder_dropout) + ) + for _ in range(args.encoder_conv_layers) + ) + + self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2, + num_layers=args.encoder_lstm_layers, + batch_first=True, bidirectional=True) + + self.apply(encoder_init) + + def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs): + x = self.embed_tokens(src_tokens) + x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T + for conv in self.convolutions: + x = conv(x) + x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C + + src_lengths = src_lengths.cpu().long() + x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) + x = self.lstm(x)[0] + x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] + + encoder_padding_mask = src_tokens.eq(self.padding_idx) + + if self.embed_speaker is not None: + seq_len, bsz, _ = x.size() + emb = self.embed_speaker(speaker).expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + return { + "encoder_out": [x], # B x T x C + "encoder_padding_mask": encoder_padding_mask, # B x T + } + + +class Prenet(nn.Module): + def __init__(self, in_dim, n_layers, n_units, dropout): + super().__init__() + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), + nn.ReLU()) + for i in range(n_layers) + ) + self.dropout = dropout + + def forward(self, x): + for layer in self.layers: + x = F.dropout(layer(x), p=self.dropout) # always applies dropout + return x + + +class Postnet(nn.Module): + def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + assert(kernel_size % 2 == 1) + for i in range(n_layers): + cur_layers = [ + nn.Conv1d(in_dim if i == 0 else n_channels, + n_channels if i < n_layers - 1 else in_dim, + kernel_size=kernel_size, + padding=((kernel_size - 1) // 2)), + nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim) + ] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)] + nn.init.xavier_uniform_( + cur_layers[0].weight, + torch.nn.init.calculate_gain( + "tanh" if i < n_layers - 1 else "linear" + ) + ) + self.convolutions.append(nn.Sequential(*cur_layers)) + + def forward(self, x): + x = x.transpose(1, 2) # B x T x C -> B x C x T + for conv in self.convolutions: + x = conv(x) + return x.transpose(1, 2) + + +def decoder_init(m): + if isinstance(m, torch.nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) + + +class Tacotron2Decoder(FairseqIncrementalDecoder): + def __init__(self, args, src_dict): + super().__init__(None) + self.args = args + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, + args.prenet_dropout) + + # take prev_context, prev_frame, (speaker embedding) as input + self.attention_lstm = LSTMCellWithZoneOut( + args.zoneout, + args.prenet_dim + args.encoder_embed_dim, + args.decoder_lstm_dim + ) + + # take attention_lstm output, attention_state, encoder_out as input + self.attention = LocationAttention( + args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim, + (1 + int(args.attention_use_cumprob)), + args.attention_conv_dim, args.attention_conv_kernel_size + ) + + # take attention_lstm output, context, (gated_latent) as input + self.lstm = nn.ModuleList( + LSTMCellWithZoneOut( + args.zoneout, + args.encoder_embed_dim + args.decoder_lstm_dim, + args.decoder_lstm_dim + ) + for i in range(args.decoder_lstm_layers) + ) + + proj_in_dim = args.encoder_embed_dim + args.decoder_lstm_dim + self.feat_proj = nn.Linear(proj_in_dim, self.out_dim) + self.eos_proj = nn.Linear(proj_in_dim, 1) + + self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, args.postnet_dropout) + + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.) > 0.: + self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) + + self.apply(decoder_init) + + def _get_states(self, incremental_state, enc_out): + bsz, in_len, _ = enc_out.size() + alstm_h = self.get_incremental_state(incremental_state, "alstm_h") + if alstm_h is None: + alstm_h = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + alstm_c = self.get_incremental_state(incremental_state, "alstm_c") + if alstm_c is None: + alstm_c = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + + lstm_h = self.get_incremental_state(incremental_state, "lstm_h") + if lstm_h is None: + lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers)] + lstm_c = self.get_incremental_state(incremental_state, "lstm_c") + if lstm_c is None: + lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers)] + + attn_w = self.get_incremental_state(incremental_state, "attn_w") + if attn_w is None: + attn_w = enc_out.new_zeros(bsz, in_len) + attn_w_cum = self.get_incremental_state(incremental_state, "attn_w_cum") + if attn_w_cum is None: + attn_w_cum = enc_out.new_zeros(bsz, in_len) + return alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum + + def _get_init_attn_c(self, enc_out, enc_mask): + bsz = enc_out.size(0) + if self.args.init_attn_c == "zero": + return enc_out.new_zeros(bsz, self.args.encoder_embed_dim) + elif self.args.init_attn_c == "avg": + enc_w = (~enc_mask).type(enc_out.type()) + enc_w = enc_w / enc_w.sum(dim=1, keepdim=True) + return torch.sum(enc_out * enc_w.unsqueeze(2), dim=1) + else: + raise ValueError(f"{self.args.init_attn_c} not supported") + + def forward(self, prev_output_tokens, encoder_out=None, + incremental_state=None, target_lengths=None, **kwargs): + enc_mask = encoder_out["encoder_padding_mask"] + enc_out = encoder_out["encoder_out"][0] + in_len = enc_out.size(1) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:, :] + bsz, out_len, _ = prev_output_tokens.size() + + prenet_out = self.prenet(prev_output_tokens) + (alstm_h, alstm_c, lstm_h, lstm_c, + attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out) + attn_ctx = self._get_init_attn_c(enc_out, enc_mask) + + attn_out = enc_out.new_zeros(bsz, in_len, out_len) + feat_out = enc_out.new_zeros(bsz, out_len, self.out_dim) + eos_out = enc_out.new_zeros(bsz, out_len) + for t in range(out_len): + alstm_in = torch.cat((attn_ctx, prenet_out[:, t, :]), dim=1) + alstm_h, alstm_c = self.attention_lstm(alstm_in, (alstm_h, alstm_c)) + + attn_state = attn_w.unsqueeze(1) + if self.args.attention_use_cumprob: + attn_state = torch.stack((attn_w, attn_w_cum), dim=1) + attn_ctx, attn_w = self.attention( + enc_out, enc_mask, alstm_h, attn_state + ) + attn_w_cum = attn_w_cum + attn_w + attn_out[:, :, t] = attn_w + + for i, cur_lstm in enumerate(self.lstm): + if i == 0: + lstm_in = torch.cat((attn_ctx, alstm_h), dim=1) + else: + lstm_in = torch.cat((attn_ctx, lstm_h[i - 1]), dim=1) + lstm_h[i], lstm_c[i] = cur_lstm(lstm_in, (lstm_h[i], lstm_c[i])) + + proj_in = torch.cat((attn_ctx, lstm_h[-1]), dim=1) + feat_out[:, t, :] = self.feat_proj(proj_in) + eos_out[:, t] = self.eos_proj(proj_in).squeeze(1) + self.attention.clear_cache() + + self.set_incremental_state(incremental_state, "alstm_h", alstm_h) + self.set_incremental_state(incremental_state, "alstm_c", alstm_c) + self.set_incremental_state(incremental_state, "lstm_h", lstm_h) + self.set_incremental_state(incremental_state, "lstm_c", lstm_c) + self.set_incremental_state(incremental_state, "attn_w", attn_w) + self.set_incremental_state(incremental_state, "attn_w_cum", attn_w_cum) + + post_feat_out = feat_out + self.postnet(feat_out) + eos_out = eos_out.view(bsz, out_len, 1) + return post_feat_out, eos_out, {"attn": attn_out, "feature_out": feat_out} + + +@register_model("tacotron_2") +class Tacotron2Model(FairseqEncoderDecoderModel): + """ + Implementation for https://arxiv.org/pdf/1712.05884.pdf + """ + + @staticmethod + def add_args(parser): + # encoder + parser.add_argument("--encoder-dropout", type=float) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-conv-layers", type=int) + parser.add_argument("--encoder-conv-kernel-size", type=int) + parser.add_argument("--encoder-lstm-layers", type=int) + # decoder + parser.add_argument("--attention-dim", type=int) + parser.add_argument("--attention-conv-dim", type=int) + parser.add_argument("--attention-conv-kernel-size", type=int) + parser.add_argument("--prenet-dropout", type=float) + parser.add_argument("--prenet-layers", type=int) + parser.add_argument("--prenet-dim", type=int) + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + parser.add_argument("--init-attn-c", type=str) + parser.add_argument("--attention-use-cumprob", action='store_true') + parser.add_argument("--zoneout", type=float) + parser.add_argument("--decoder-lstm-layers", type=int) + parser.add_argument("--decoder-lstm-dim", type=int) + parser.add_argument("--output-frame-dim", type=int) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_updates = 0 + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = Tacotron2Encoder(args, task.src_dict, embed_speaker) + decoder = Tacotron2Decoder(args, task.src_dict) + return cls(encoder, decoder) + + def forward_encoder(self, src_tokens, src_lengths, **kwargs): + return self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + +@register_model_architecture("tacotron_2", "tacotron_2") +def base_architecture(args): + # encoder + args.encoder_dropout = getattr(args, "encoder_dropout", 0.5) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3) + args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5) + args.encoder_lstm_layers = getattr(args, "encoder_lstm_layers", 1) + # decoder + args.attention_dim = getattr(args, "attention_dim", 128) + args.attention_conv_dim = getattr(args, "attention_conv_dim", 32) + args.attention_conv_kernel_size = getattr(args, + "attention_conv_kernel_size", 15) + args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) + args.prenet_layers = getattr(args, "prenet_layers", 2) + args.prenet_dim = getattr(args, "prenet_dim", 256) + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) + args.init_attn_c = getattr(args, "init_attn_c", "zero") + args.attention_use_cumprob = getattr(args, "attention_use_cumprob", True) + args.zoneout = getattr(args, "zoneout", 0.1) + args.decoder_lstm_layers = getattr(args, "decoder_lstm_layers", 2) + args.decoder_lstm_dim = getattr(args, "decoder_lstm_dim", 1024) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) diff --git a/fairseq/models/text_to_speech/tts_transformer.py b/fairseq/models/text_to_speech/tts_transformer.py new file mode 100644 index 0000000000..ff7af78bd4 --- /dev/null +++ b/fairseq/models/text_to_speech/tts_transformer.py @@ -0,0 +1,371 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional + +import torch +from torch import nn + +from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, + register_model_architecture) +from fairseq.modules import ( + TransformerEncoderLayer, TransformerDecoderLayer +) +from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet +from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq import utils + +logger = logging.getLogger(__name__) + + +def encoder_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +def Embedding(num_embeddings, embedding_dim): + m = nn.Embedding(num_embeddings, embedding_dim) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + return m + + +class TTSTransformerEncoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.padding_idx = src_dict.pad() + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, + args.encoder_embed_dim + ) + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, + padding_idx=self.padding_idx) + assert(args.encoder_conv_kernel_size % 2 == 1) + self.prenet = nn.ModuleList( + nn.Sequential( + nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2)), + nn.BatchNorm1d(args.encoder_embed_dim), + nn.ReLU(), + nn.Dropout(args.encoder_dropout), + ) + for _ in range(args.encoder_conv_layers) + ) + self.prenet_proj = nn.Linear( + args.encoder_embed_dim, args.encoder_embed_dim + ) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + + self.transformer_layers = nn.ModuleList( + TransformerEncoderLayer(args) + for _ in range(args.encoder_transformer_layers) + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + self.apply(encoder_init) + + def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs): + x = self.embed_tokens(src_tokens) + x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T + for conv in self.prenet: + x = conv(x) + x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C + x = self.prenet_proj(x) + + padding_mask = src_tokens.eq(self.padding_idx) + positions = self.embed_positions(padding_mask) + x += self.pos_emb_alpha * positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + for layer in self.transformer_layers: + x = layer(x, padding_mask) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + if self.embed_speaker is not None: + seq_len, bsz, _ = x.size() + emb = self.embed_speaker(speaker).transpose(0, 1) + emb = emb.expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + +def decoder_init(m): + if isinstance(m, torch.nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) + + +class TTSTransformerDecoder(FairseqIncrementalDecoder): + def __init__(self, args, src_dict): + super().__init__(None) + self._future_mask = torch.empty(0) + + self.args = args + self.padding_idx = src_dict.pad() + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.embed_positions = PositionalEmbedding( + args.max_target_positions, args.decoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + self.prenet = nn.Sequential( + Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, + args.prenet_dropout), + nn.Linear(args.prenet_dim, args.decoder_embed_dim), + ) + + self.n_transformer_layers = args.decoder_transformer_layers + self.transformer_layers = nn.ModuleList( + TransformerDecoderLayer(args) + for _ in range(self.n_transformer_layers) + ) + if args.decoder_normalize_before: + self.layer_norm = LayerNorm(args.decoder_embed_dim) + else: + self.layer_norm = None + + self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) + self.eos_proj = nn.Linear(args.decoder_embed_dim, 1) + + self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, args.postnet_dropout) + + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.) > 0.: + self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) + + self.apply(decoder_init) + + def extract_features( + self, prev_outputs, encoder_out=None, incremental_state=None, + target_lengths=None, speaker=None, **kwargs + ): + alignment_layer = self.n_transformer_layers - 1 + self_attn_padding_mask = lengths_to_padding_mask(target_lengths) + positions = self.embed_positions( + self_attn_padding_mask, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_outputs = prev_outputs[:, -1:, :] + self_attn_padding_mask = self_attn_padding_mask[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + x = self.prenet(prev_outputs) + x += self.pos_emb_alpha * positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + if not self_attn_padding_mask.any(): + self_attn_padding_mask = None + + attn: Optional[torch.Tensor] = None + inner_states: List[Optional[torch.Tensor]] = [x] + for idx, transformer_layer in enumerate(self.transformer_layers): + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = transformer_layer( + x, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + # average probabilities over heads, transpose to + # (B, src_len, tgt_len) + attn = attn.mean(dim=0).transpose(2, 1) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def forward(self, prev_output_tokens, encoder_out=None, + incremental_state=None, target_lengths=None, speaker=None, + **kwargs): + x, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, + incremental_state=incremental_state, target_lengths=target_lengths, + speaker=speaker, **kwargs + ) + attn = extra["attn"] + feat_out = self.feat_proj(x) + bsz, seq_len, _ = x.size() + eos_out = self.eos_proj(x) + post_feat_out = feat_out + self.postnet(feat_out) + return post_feat_out, eos_out, {"attn": attn, "feature_out": feat_out} + + def get_normalized_probs(self, net_output, log_probs, sample): + logits = self.ctc_proj(net_output[2]["feature_out"]) + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + +@register_model("tts_transformer") +class TTSTransformerModel(FairseqEncoderDecoderModel): + """ + Implementation for https://arxiv.org/pdf/1809.08895.pdf + """ + + @staticmethod + def add_args(parser): + parser.add_argument("--dropout", type=float) + parser.add_argument("--output-frame-dim", type=int) + parser.add_argument("--speaker-embed-dim", type=int) + # encoder prenet + parser.add_argument("--encoder-dropout", type=float) + parser.add_argument("--encoder-conv-layers", type=int) + parser.add_argument("--encoder-conv-kernel-size", type=int) + # encoder transformer layers + parser.add_argument("--encoder-transformer-layers", type=int) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-ffn-embed-dim", type=int) + parser.add_argument("--encoder-normalize-before", action="store_true") + parser.add_argument("--encoder-attention-heads", type=int) + parser.add_argument("--attention-dropout", type=float) + parser.add_argument("--activation-dropout", "--relu-dropout", type=float) + parser.add_argument("--activation-fn", type=str, default="relu") + # decoder prenet + parser.add_argument("--prenet-dropout", type=float) + parser.add_argument("--prenet-layers", type=int) + parser.add_argument("--prenet-dim", type=int) + # decoder postnet + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + # decoder transformer layers + parser.add_argument("--decoder-transformer-layers", type=int) + parser.add_argument("--decoder-embed-dim", type=int) + parser.add_argument("--decoder-ffn-embed-dim", type=int) + parser.add_argument("--decoder-normalize-before", action="store_true") + parser.add_argument("--decoder-attention-heads", type=int) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_updates = 0 + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker) + decoder = TTSTransformerDecoder(args, task.src_dict) + return cls(encoder, decoder) + + def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): + return self.encoder(src_tokens, src_lengths=src_lengths, + speaker=speaker, **kwargs) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + +@register_model_architecture("tts_transformer", "tts_transformer") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) + args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64) + # encoder prenet + args.encoder_dropout = getattr(args, "encoder_dropout", 0.5) + args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3) + args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5) + # encoder transformer layers + args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + # decoder prenet + args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) + args.prenet_layers = getattr(args, "prenet_layers", 2) + args.prenet_dim = getattr(args, "prenet_dim", 256) + # decoder postnet + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) + # decoder transformer layers + args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/text_to_speech/vocoder.py b/fairseq/models/text_to_speech/vocoder.py new file mode 100644 index 0000000000..65d9f9f06b --- /dev/null +++ b/fairseq/models/text_to_speech/vocoder.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import json +from typing import Dict + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from fairseq.data.audio.audio_utils import ( + get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram +) +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig +from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel + +logger = logging.getLogger(__name__) + + +class PseudoInverseMelScale(torch.nn.Module): + def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None: + super(PseudoInverseMelScale, self).__init__() + self.n_mels = n_mels + basis = get_mel_filters( + sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max + ) + basis = torch.pinverse(basis) # F x F_mel + self.register_buffer('basis', basis) + + def forward(self, melspec: torch.Tensor) -> torch.Tensor: + # pack batch + shape = melspec.shape # B_1 x ... x B_K x F_mel x T + n_mels, time = shape[-2], shape[-1] + melspec = melspec.view(-1, n_mels, time) + + freq, _ = self.basis.size() # F x F_mel + assert self.n_mels == n_mels, (self.n_mels, n_mels) + specgram = self.basis.matmul(melspec).clamp(min=0) + + # unpack batch + specgram = specgram.view(shape[:-2] + (freq, time)) + return specgram + + +class GriffinLim(torch.nn.Module): + def __init__( + self, n_fft: int, win_length: int, hop_length: int, n_iter: int, + window_fn=torch.hann_window + ): + super(GriffinLim, self).__init__() + self.transform = TTSSpectrogram( + n_fft, win_length, hop_length, return_phase=True + ) + + basis = get_fourier_basis(n_fft) + basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] + basis *= get_window(window_fn, n_fft, win_length) + self.register_buffer('basis', basis) + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_iter = n_iter + + self.tiny = 1.1754944e-38 + + @classmethod + def get_window_sum_square( + cls, n_frames, hop_length, win_length, n_fft, + window_fn=torch.hann_window + ) -> torch.Tensor: + w_sq = get_window(window_fn, n_fft, win_length) ** 2 + n = n_fft + hop_length * (n_frames - 1) + x = torch.zeros(n, dtype=torch.float32) + for i in range(n_frames): + ofst = i * hop_length + x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))] + return x + + def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor: + x = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], + dim=1 + ) + x = F.conv_transpose1d(x, self.basis, stride=self.hop_length) + win_sum_sq = self.get_window_sum_square( + magnitude.shape[-1], hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.n_fft + ).to(magnitude.device) + # remove modulation effects + approx_nonzero_indices = win_sum_sq > self.tiny + x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices] + x *= self.n_fft / self.hop_length + x = x[:, :, self.n_fft // 2:] + x = x[:, :, :-self.n_fft // 2:] + return x + + def forward(self, specgram: torch.Tensor) -> torch.Tensor: + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*specgram.shape))) + angles = torch.from_numpy(angles).to(specgram) + _specgram = specgram.view(-1, specgram.shape[-2], specgram.shape[-1]) + waveform = self.inverse(_specgram, angles).squeeze(1) + for _ in range(self.n_iter): + _, angles = self.transform(waveform) + waveform = self.inverse(_specgram, angles).squeeze(1) + return waveform.squeeze(0) + + +class GriffinLimVocoder(nn.Module): + def __init__(self, sample_rate, win_size, hop_size, n_fft, + n_mels, f_min, f_max, window_fn, + spec_bwd_max_iter=32, + fp16=False): + super().__init__() + self.inv_mel_transform = PseudoInverseMelScale( + n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate, + f_min=f_min, f_max=f_max + ) + self.gl_transform = GriffinLim( + n_fft=n_fft, win_length=win_size, hop_length=hop_size, + window_fn=window_fn, n_iter=spec_bwd_max_iter + ) + if fp16: + self.half() + self.inv_mel_transform.half() + self.gl_transform.half() + else: + self.float() + self.inv_mel_transform.float() + self.gl_transform.float() + + def forward(self, x): + # x: (B x) T x D -> (B x) 1 x T + # NOTE: batched forward produces noisier waveform. recommend running + # one utterance at a time + self.eval() + x = x.exp().transpose(-1, -2) + x = self.inv_mel_transform(x) + x = self.gl_transform(x) + return x + + @classmethod + def from_data_cfg(cls, args, data_cfg: S2TDataConfig): + feat_cfg = data_cfg.config["features"] + window_fn = getattr(torch, feat_cfg["window_fn"] + "_window") + return cls( + sample_rate=feat_cfg["sample_rate"], + win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]), + hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]), + n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"], + f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"], + window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter, + fp16=args.fp16 + ) + + +class HiFiGANVocoder(nn.Module): + def __init__( + self, checkpoint_path: str, model_cfg: Dict[str, str], + fp16: bool = False + ) -> None: + super().__init__() + self.model = HiFiGANModel(model_cfg) + state_dict = torch.load(checkpoint_path) + self.model.load_state_dict(state_dict["generator"]) + if fp16: + self.model.half() + logger.info(f"loaded HiFiGAN checkpoint from {checkpoint_path}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # (B x) T x D -> (B x) 1 x T + model = self.model.eval() + if len(x.shape) == 2: + return model(x.unsqueeze(0).transpose(1, 2)).detach().squeeze(0) + else: + return model(x.transpose(-1, -2)).detach() + + @classmethod + def from_data_cfg(cls, args, data_cfg: S2TDataConfig): + vocoder_cfg = data_cfg.vocoder + assert vocoder_cfg.get("type", "griffin_lim") == "hifigan" + with open(vocoder_cfg["config"]) as f: + model_cfg = json.load(f) + return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16) + + +def get_vocoder(args, data_cfg: S2TDataConfig): + if args.vocoder == "griffin_lim": + return GriffinLimVocoder.from_data_cfg(args, data_cfg) + elif args.vocoder == "hifigan": + return HiFiGANVocoder.from_data_cfg(args, data_cfg) + else: + raise ValueError("Unknown vocoder") diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 81930aa71c..d7a030e2b5 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -25,6 +25,8 @@ from .learned_positional_embedding import LearnedPositionalEmbedding from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .linearized_convolution import LinearizedConvolution +from .location_attention import LocationAttention +from .lstm_cell_with_zoneout import LSTMCellWithZoneOut from .multihead_attention import MultiheadAttention from .positional_embedding import PositionalEmbedding from .same_pad import SamePad @@ -63,6 +65,8 @@ "LightweightConv1dTBC", "LightweightConv", "LinearizedConvolution", + "LocationAttention", + "LSTMCellWithZoneOut", "MultiheadAttention", "PositionalEmbedding", "SamePad", diff --git a/fairseq/modules/location_attention.py b/fairseq/modules/location_attention.py new file mode 100644 index 0000000000..a970876bba --- /dev/null +++ b/fairseq/modules/location_attention.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class LocationAttention(nn.Module): + """ + Attention-Based Models for Speech Recognition + https://arxiv.org/pdf/1506.07503.pdf + + :param int encoder_dim: # projection-units of encoder + :param int decoder_dim: # units of decoder + :param int attn_dim: attention dimension + :param int conv_dim: # channels of attention convolution + :param int conv_kernel_size: filter size of attention convolution + """ + + def __init__(self, attn_dim, encoder_dim, decoder_dim, + attn_state_kernel_size, conv_dim, conv_kernel_size, + scaling=2.0): + super(LocationAttention, self).__init__() + self.attn_dim = attn_dim + self.decoder_dim = decoder_dim + self.scaling = scaling + self.proj_enc = nn.Linear(encoder_dim, attn_dim) + self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) + self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) + self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim, + 2 * conv_kernel_size + 1, + padding=conv_kernel_size, bias=False) + self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) + + self.proj_enc_out = None # cache + + def clear_cache(self): + self.proj_enc_out = None + + def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state): + """ + :param torch.Tensor encoder_out: padded encoder hidden state B x T x D + :param torch.Tensor encoder_padding_mask: encoder padding mask + :param torch.Tensor decoder_h: decoder hidden state B x D + :param torch.Tensor attn_prev: previous attention weight B x K x T + :return: attention weighted encoder state (B, D) + :rtype: torch.Tensor + :return: previous attention weights (B x T) + :rtype: torch.Tensor + """ + bsz, seq_len, _ = encoder_out.size() + if self.proj_enc_out is None: + self.proj_enc_out = self.proj_enc(encoder_out) + + # B x K x T -> B x C x T + attn = self.conv(attn_state) + # B x C x T -> B x T x C -> B x T x D + attn = self.proj_attn(attn.transpose(1, 2)) + + if decoder_h is None: + decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim) + dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim) + + out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2) + out.masked_fill_(encoder_padding_mask, -float("inf")) + + w = F.softmax(self.scaling * out, dim=1) + c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1) + return c, w diff --git a/fairseq/modules/lstm_cell_with_zoneout.py b/fairseq/modules/lstm_cell_with_zoneout.py new file mode 100644 index 0000000000..f04e5db255 --- /dev/null +++ b/fairseq/modules/lstm_cell_with_zoneout.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + + +class LSTMCellWithZoneOut(nn.Module): + """ + Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations + https://arxiv.org/abs/1606.01305 + """ + + def __init__(self, prob: float, input_size: int, hidden_size: int, + bias: bool = True): + super(LSTMCellWithZoneOut, self).__init__() + self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) + self.prob = prob + if prob > 1.0 or prob < 0.0: + raise ValueError("zoneout probability must be in the range from " + "0.0 to 1.0.") + + def zoneout(self, h, next_h, prob): + if isinstance(h, tuple): + return tuple( + [self.zoneout(h[i], next_h[i], prob) for i in range(len(h))] + ) + + if self.training: + mask = h.new_zeros(*h.size()).bernoulli_(prob) + return mask * h + (1 - mask) * next_h + + return prob * h + (1 - prob) * next_h + + def forward(self, x, h): + return self.zoneout(h, self.lstm_cell(x, h), self.prob) diff --git a/fairseq/optim/lr_scheduler/step_lr_scheduler.py b/fairseq/optim/lr_scheduler/step_lr_scheduler.py new file mode 100644 index 0000000000..8cb2006860 --- /dev/null +++ b/fairseq/optim/lr_scheduler/step_lr_scheduler.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class StepLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, + ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) + lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"}) + lr_decay: float = field(default=0.5, metadata={"help": "decay factor"}) + + +@register_lr_scheduler("step", dataclass=StepLRScheduleConfig) +class StepLRSchedule(FairseqLRScheduler): + """Decay learning rate every k updates by a fixed factor + """ + + def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + self.min_lr = cfg.min_lr + self.lr_deacy_period = cfg.lr_deacy_period + self.lr_decay = cfg.lr_decay + self.warmup_updates = cfg.warmup_updates + self.warmup_init_lr = ( + cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr + ) + + assert(self.lr_deacy_period > 0) + assert(self.lr_decay <= 1) + assert(self.min_lr >= 0) + assert(self.max_lr > self.min_lr) + + if cfg.warmup_updates > 0: + # linearly warmup for the first cfg.warmup_updates + self.warmup_lr_step = ( + (self.max_lr - self.warmup_init_lr) / self.warmup_updates + ) + else: + self.warmup_lr_step = 1 + + # initial learning rate + self.lr = self.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period) + self.lr = max(self.max_lr * lr_mult, self.min_lr) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/options.py b/fairseq/options.py index 03883fc561..b4d350f902 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -56,6 +56,14 @@ def get_generation_parser(interactive=False, default_task="translation"): return parser +def get_speech_generation_parser(default_task="text_to_speech"): + parser = get_parser("Speech Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_speech_generation_args(parser) + return parser + + def get_interactive_generation_parser(default_task="translation"): return get_generation_parser(interactive=True, default_task=default_task) @@ -344,6 +352,16 @@ def add_generation_args(parser): return group +def add_speech_generation_args(parser): + group = parser.add_argument_group("Speech Generation") + add_common_eval_args(group) # NOTE: remove_bpe is not needed + # fmt: off + group.add_argument('--eos_prob_threshold', default=0.5, type=float, + help='terminate when eos probability exceeds this') + # fmt: on + return group + + def add_interactive_args(parser): group = parser.add_argument_group("Interactive") gen_parser_from_dataclass(group, InteractiveConfig()) diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py new file mode 100644 index 0000000000..8086e34d2b --- /dev/null +++ b/fairseq/speech_generator.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np + +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig + + +class SpeechGenerator(object): + def __init__(self, model, vocoder, data_cfg: S2TDataConfig): + self.model = model + self.vocoder = vocoder + stats_npz_path = data_cfg.global_cmvn_stats_npz + self.gcmvn_stats = None + if stats_npz_path is not None: + self.gcmvn_stats = np.load(stats_npz_path) + + def gcmvn_denormalize(self, x): + # x: B x T x C + if self.gcmvn_stats is None: + return x + mean = torch.from_numpy(self.gcmvn_stats["mean"]).to(x) + std = torch.from_numpy(self.gcmvn_stats["std"]).to(x) + assert len(x.shape) == 3 and mean.shape[0] == std.shape[0] == x.shape[2] + x = x * std.view(1, 1, -1).expand_as(x) + return x + mean.view(1, 1, -1).expand_as(x) + + def get_waveform(self, feat): + # T x C -> T + return None if self.vocoder is None else self.vocoder(feat).squeeze(0) + + +class AutoRegressiveSpeechGenerator(SpeechGenerator): + def __init__( + self, model, vocoder, data_cfg, max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + ): + super().__init__(model, vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size() + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder(src_tokens, src_lengths, + speaker=sample["speaker"]) + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra['feature_out']) + attn.append(cur_extra['attn']) + eos_prob.append(cur_eos_prob) + + cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold) + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra['feature_out'] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + 'feature': feat[b, :out_len], + 'eos_prob': eos_prob[b, :out_len], + 'attn': attn[b, :, :out_len], + 'alignment': alignment[b, :out_len], + 'waveform': self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class NonAutoregressiveSpeechGenerator(SpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + bsz, max_src_len = sample["net_input"]["src_tokens"].size() + n_frames_per_step = model.encoder.n_frames_per_step + out_dim = model.encoder.out_dim + raw_dim = out_dim // n_frames_per_step + + feat, out_lens, log_dur_out, _, _ = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=sample["target_lengths"], + speaker=sample["speaker"] + ) + + feat = feat.view(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + dur_out = torch.clamp( + torch.round(torch.exp(log_dur_out) - 1).long(), min=0 + ) + + def get_dur_plot_data(d): + r = [] + for i, dd in enumerate(d): + r += [i + 1] * dd.item() + return r + + out_lens = out_lens * n_frames_per_step + finalized = [ + { + 'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), + 'waveform': self.get_waveform( + feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]) + ), + 'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])), + } + for b, l in zip(range(bsz), out_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + prev_out_tokens = sample["net_input"]["prev_output_tokens"] + tgt_lens = sample["target_lengths"] + n_frames_per_step = model.decoder.n_frames_per_step + raw_dim = model.decoder.out_dim // n_frames_per_step + bsz = src_tokens.shape[0] + + feat, eos_prob, extra = model( + src_tokens, src_lens, prev_out_tokens, incremental_state=None, + target_lengths=tgt_lens, speaker=sample["speaker"] + ) + + attn = extra["attn"] # B x T_s x T_t + alignment = attn.max(dim=1)[1] + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + tgt_lens = sample["target_lengths"] * n_frames_per_step + + finalized = [ + { + 'feature': feat[b, :tgt_len], + 'eos_prob': eos_prob[b, :tgt_len], + 'attn': attn[b, :, :tgt_len], + 'alignment': alignment[b, :tgt_len], + 'waveform': self.get_waveform(feat[b, :tgt_len]), + } + for b, tgt_len in zip(range(bsz), tgt_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized diff --git a/fairseq/tasks/frm_text_to_speech.py b/fairseq/tasks/frm_text_to_speech.py new file mode 100644 index 0000000000..1fa9b0f83e --- /dev/null +++ b/fairseq/tasks/frm_text_to_speech.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.text_to_speech import TextToSpeechTask + + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO +) +logger = logging.getLogger(__name__) + + +@register_task('frm_text_to_speech') +class FrmTextToSpeechTask(TextToSpeechTask): + @staticmethod + def add_args(parser): + TextToSpeechTask.add_args(parser) + parser.add_argument( + "--do_chunk", action="store_true", help="train on chunks" + ) + parser.add_argument("--chunk_bound", default=-1, type=int) + parser.add_argument("--chunk_init", default=50, type=int) + parser.add_argument("--chunk_incr", default=5, type=int) + parser.add_argument("--add_eos", action="store_true") + parser.add_argument("--dedup", action="store_true") + parser.add_argument("--ref_fpu", default=-1, type=float) + + def load_dataset(self, split, **unused_kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id, + do_chunk=self.args.do_chunk, + chunk_bound=self.args.chunk_bound, + chunk_init=self.args.chunk_init, + chunk_incr=self.args.chunk_incr, + add_eos=self.args.add_eos, + dedup=self.args.dedup, + ref_fpu=self.args.ref_fpu + ) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 5795c04bf7..06e292103e 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -50,6 +50,16 @@ def __init__(self, args, tgt_dict): super().__init__(args) self.tgt_dict = tgt_dict self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + self.speaker_to_id = self._get_speaker_to_id() + + def _get_speaker_to_id(self): + speaker_to_id = None + speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") + if speaker_set_filename is not None: + speaker_set_path = Path(self.args.data) / speaker_set_filename + with open(speaker_set_path) as f: + speaker_to_id = {r.strip(): i for i, r in enumerate(f)} + return speaker_to_id @classmethod def setup_task(cls, args, **kwargs): @@ -91,6 +101,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): is_train_split=is_train_split, epoch=epoch, seed=self.args.seed, + speaker_to_id=self.speaker_to_id ) @property @@ -107,6 +118,7 @@ def max_positions(self): def build_model(self, args): args.input_feat_per_channel = self.data_cfg.input_feat_per_channel args.input_channels = self.data_cfg.input_channels + args.speaker_to_id = self.speaker_to_id return super(SpeechToTextTask, self).build_model(args) def build_generator( @@ -126,12 +138,13 @@ def build_generator( for s, i in self.tgt_dict.indices.items() if SpeechToTextDataset.is_lang_tag(s) } + if extra_gen_cls_kwargs is None: - extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} - else: - extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids return super().build_generator( - models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, args, seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_tokenizer(self, args): diff --git a/fairseq/tasks/text_to_speech.py b/fairseq/tasks/text_to_speech.py new file mode 100644 index 0000000000..5646e41d39 --- /dev/null +++ b/fairseq/tasks/text_to_speech.py @@ -0,0 +1,467 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import os.path as op + +import torch +import torch.nn.functional as F +import numpy as np + +from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.speech_generator import ( + AutoRegressiveSpeechGenerator, NonAutoregressiveSpeechGenerator, + TeacherForcingAutoRegressiveSpeechGenerator +) + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO +) +logger = logging.getLogger(__name__) + + +try: + from tensorboardX import SummaryWriter +except ImportError: + logger.info("Please install tensorboardX: pip install tensorboardX") + SummaryWriter = None + + +@register_task('text_to_speech') +class TextToSpeechTask(SpeechToTextTask): + @staticmethod + def add_args(parser): + parser.add_argument('data', help='manifest root path') + parser.add_argument( + '--config-yaml', type=str, default='config.yaml', + help='Configuration YAML filename (under manifest root)' + ) + parser.add_argument('--max-source-positions', default=1024, type=int, + metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1200, type=int, + metavar='N', + help='max number of tokens in the target sequence') + parser.add_argument("--n-frames-per-step", type=int, default=1) + parser.add_argument("--eos-prob-threshold", type=float, default=0.5) + parser.add_argument("--eval-inference", action="store_true") + parser.add_argument("--eval-tb-nsample", type=int, default=8) + parser.add_argument("--vocoder", type=str, default="griffin_lim") + parser.add_argument("--spec-bwd-max-iter", type=int, default=8) + + def __init__(self, args, src_dict): + super().__init__(args, src_dict) + self.src_dict = src_dict + self.sr = self.data_cfg.config.get("features").get("sample_rate") + + self.tensorboard_writer = None + self.tensorboard_dir = "" + if args.tensorboard_logdir and SummaryWriter is not None: + self.tensorboard_dir = os.path.join(args.tensorboard_logdir, + "valid_extra") + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith('train') + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( + self.args.data, self.data_cfg, split, self.src_dict, + pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split, + epoch=epoch, seed=self.args.seed, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id + ) + + @property + def target_dictionary(self): + return None + + @property + def source_dictionary(self): + return self.src_dict + + def get_speaker_embeddings_path(self): + speaker_emb_path = None + if self.data_cfg.config.get("speaker_emb_filename") is not None: + speaker_emb_path = op.join( + self.args.data, self.data_cfg.config.get("speaker_emb_filename") + ) + return speaker_emb_path + + @classmethod + def get_speaker_embeddings(cls, args): + embed_speaker = None + if args.speaker_to_id is not None: + if args.speaker_emb_path is None: + embed_speaker = torch.nn.Embedding( + len(args.speaker_to_id), args.speaker_embed_dim + ) + else: + speaker_emb_mat = np.load(args.speaker_emb_path) + assert speaker_emb_mat.shape[1] == args.speaker_embed_dim + embed_speaker = torch.nn.Embedding.from_pretrained( + torch.from_numpy(speaker_emb_mat), freeze=True, + ) + logger.info( + f"load speaker embeddings from {args.speaker_emb_path}. " + f"train embedding? {embed_speaker.weight.requires_grad}\n" + f"embeddings:\n{speaker_emb_mat}" + ) + return embed_speaker + + def build_model(self, cfg): + cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) + cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) + cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) + cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) + cfg.speaker_emb_path = self.get_speaker_embeddings_path() + model = super().build_model(cfg) + self.generator = None + if getattr(cfg, "eval_inference", False): + self.generator = self.build_generator([model], cfg) + return model + + def build_generator(self, models, cfg, vocoder=None, **unused): + if vocoder is None: + vocoder = self.build_default_vocoder() + model = models[0] + if getattr(model, "NON_AUTOREGRESSIVE", False): + return NonAutoregressiveSpeechGenerator( + model, vocoder, self.data_cfg + ) + else: + generator = AutoRegressiveSpeechGenerator + if getattr(cfg, "teacher_forcing", False): + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + return generator( + model, vocoder, self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold + ) + + def build_default_vocoder(self): + from fairseq.models.text_to_speech.vocoder import get_vocoder + vocoder = get_vocoder(self.args, self.data_cfg) + if torch.cuda.is_available() and not self.args.cpu: + vocoder = vocoder.cuda() + else: + vocoder = vocoder.cpu() + return vocoder + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step( + sample, model, criterion + ) + + if getattr(self.args, "eval_inference", False): + hypos, inference_losses = self.valid_step_with_inference( + sample, model, self.generator + ) + for k, v in inference_losses.items(): + assert(k not in logging_output) + logging_output[k] = v + + picked_id = 0 + if self.tensorboard_dir and (sample["id"] == picked_id).any(): + self.log_tensorboard( + sample, + hypos[:self.args.eval_tb_nsample], + model._num_updates, + is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False) + ) + return loss, sample_size, logging_output + + def valid_step_with_inference(self, sample, model, generator): + hypos = generator.generate(model, sample, has_targ=True) + + losses = { + "mcd_loss": 0., + "targ_frames": 0., + "pred_frames": 0., + "nins": 0., + "ndel": 0., + } + rets = batch_mel_cepstral_distortion( + [hypo["targ_waveform"] for hypo in hypos], + [hypo["waveform"] for hypo in hypos], + self.sr, + normalize_type=None + ) + for d, extra in rets: + pathmap = extra[-1] + losses["mcd_loss"] += d.item() + losses["targ_frames"] += pathmap.size(0) + losses["pred_frames"] += pathmap.size(1) + losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() + losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() + + return hypos, losses + + def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): + if self.tensorboard_writer is None: + self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) + tb_writer = self.tensorboard_writer + for b in range(len(hypos)): + idx = sample["id"][b] + text = sample["src_texts"][b] + targ = hypos[b]["targ_feature"] + pred = hypos[b]["feature"] + attn = hypos[b]["attn"] + + if is_na_model: + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1)], + [f"target (idx={idx})", "output"], attn, + "alignment", ret_np=True, suptitle=text, + ) + else: + eos_prob = hypos[b]["eos_prob"] + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1), attn], + [f"target (idx={idx})", "output", "alignment"], eos_prob, + "eos prob", ret_np=True, suptitle=text, + ) + + tb_writer.add_image( + f"inference_sample_{b}", data, num_updates, + dataformats="HWC" + ) + + if hypos[b]["waveform"] is not None: + targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() + pred_wave = hypos[b]["waveform"].detach().cpu().float() + tb_writer.add_audio( + f"inference_targ_{b}", + targ_wave, + num_updates, + sample_rate=self.sr + ) + tb_writer.add_audio( + f"inference_pred_{b}", + pred_wave, + num_updates, + sample_rate=self.sr + ) + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +DEFAULT_V_MIN = np.log(1e-5) + + +def plot_tts_output( + data_2d, title_2d, data_1d, title_1d, figsize=(24, 4), + v_min=DEFAULT_V_MIN, v_max=3, ret_np=False, suptitle="" +): + try: + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable + except ImportError: + raise ImportError("Please install Matplotlib: pip install matplotlib") + + data_2d = [ + x.detach().cpu().float().numpy() + if isinstance(x, torch.Tensor) else x for x in data_2d + ] + fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) + if suptitle: + fig.suptitle(suptitle[:400]) # capped at 400 chars + axes = [axes] if len(data_2d) == 0 else axes + for ax, x, name in zip(axes, data_2d, title_2d): + ax.set_title(name) + divider = make_axes_locatable(ax) + cax = divider.append_axes('right', size='5%', pad=0.05) + im = ax.imshow( + x, origin="lower", aspect="auto", vmin=max(x.min(), v_min), + vmax=min(x.max(), v_max) + ) + fig.colorbar(im, cax=cax, orientation='vertical') + + if isinstance(data_1d, torch.Tensor): + data_1d = data_1d.detach().cpu().numpy() + axes[-1].plot(data_1d) + axes[-1].set_title(title_1d) + plt.tight_layout() + + if ret_np: + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close(fig) + return data + + +def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): + """ + for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs + + offset=2 (1, 1), + offset=3 (2, 1), (1, 2) + offset=4 (2, 2), (1, 3) + offset=5 (2, 3) + + constraints: + i + j = offset + min_j <= j < max_j + min_i <= offset - j < max_i + """ + if max_i is None: + max_i = offset + 1 + if max_j is None: + max_j = offset + 1 + min_j = max(min_j, offset - max_i + 1, 0) + max_j = min(max_j, offset - min_i + 1, offset + 1) + j = torch.arange(min_j, max_j) + i = offset - j + return torch.stack([i, j]) + + +def batch_dynamic_time_warping(distance, shapes=None): + """full batched DTW without any constraints + + distance: (batchsize, max_M, max_N) matrix + shapes: (batchsize,) vector specifying (M, N) for each entry + """ + # ptr: 0=left, 1=up-left, 2=up + ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} + + bsz, m, n = distance.size() + cumdist = torch.zeros_like(distance) + backptr = torch.zeros_like(distance).type(torch.int32) - 1 + + # initialize + cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) + cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) + backptr[:, 0, :] = 0 + backptr[:, :, 0] = 2 + + # DP with optimized anti-diagonal parallelization, O(M+N) steps + for offset in range(2, m + n - 1): + ind = antidiag_indices(offset, 1, m, 1, n) + c = torch.stack( + [cumdist[:, ind[0], ind[1] - 1], cumdist[:, ind[0] - 1, ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1]], ], + dim=2 + ) + v, b = c.min(axis=-1) + backptr[:, ind[0], ind[1]] = b.int() + cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] + + # backtrace + pathmap = torch.zeros_like(backptr) + for b in range(bsz): + i = m - 1 if shapes is None else (shapes[b][0] - 1).item() + j = n - 1 if shapes is None else (shapes[b][1] - 1).item() + dtwpath = [(i, j)] + while (i != 0 or j != 0) and len(dtwpath) < 10000: + assert (i >= 0 and j >= 0) + di, dj = ptr2dij[backptr[b, i, j].item()] + i, j = i + di, j + dj + dtwpath.append((i, j)) + dtwpath = dtwpath[::-1] + indices = torch.from_numpy(np.array(dtwpath)) + pathmap[b, indices[:, 0], indices[:, 1]] = 1 + + return cumdist, backptr, pathmap + + +def compute_l2_dist(x1, x2): + """compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" + return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) + + +def compute_rms_dist(x1, x2): + l2_dist = compute_l2_dist(x1, x2) + return (l2_dist / x1.size(1)).pow(0.5) + + +def get_divisor(pathmap, normalize_type): + if normalize_type is None: + return 1 + elif normalize_type == "len1": + return pathmap.size(0) + elif normalize_type == "len2": + return pathmap.size(1) + elif normalize_type == "path": + return pathmap.sum().item() + else: + raise ValueError(f"normalize_type {normalize_type} not supported") + + +def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): + d, s, x1, x2 = [], [], [], [] + for cur_y1, cur_y2 in zip(y1, y2): + assert (cur_y1.ndim == 1 and cur_y2.ndim == 1) + cur_x1 = feat_fn(cur_y1) + cur_x2 = feat_fn(cur_y2) + x1.append(cur_x1) + x2.append(cur_x2) + + cur_d = dist_fn(cur_x1, cur_x2) + d.append(cur_d) + s.append(d[-1].size()) + max_m = max(ss[0] for ss in s) + max_n = max(ss[1] for ss in s) + d = torch.stack( + [F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] + ) + s = torch.LongTensor(s).to(d.device) + cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) + + rets = [] + itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) + for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: + cumdist = cumdist[:m, :n] + backptr = backptr[:m, :n] + pathmap = pathmap[:m, :n] + divisor = get_divisor(pathmap, normalize_type) + + distortion = cumdist[-1, -1] / divisor + ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) + rets.append(ret) + return rets + + +def batch_mel_cepstral_distortion( + y1, y2, sr, normalize_type="path", mfcc_fn=None +): + """ + https://arxiv.org/pdf/2011.03568.pdf + + The root mean squared error computed on 13-dimensional MFCC using DTW for + alignment. MFCC features are computed from an 80-channel log-mel + spectrogram using a 50ms Hann window and hop of 12.5ms. + + y1: list of waveforms + y2: list of waveforms + sr: sampling rate + """ + + try: + import torchaudio + except ImportError: + raise ImportError("Please install torchaudio: pip install torchaudio") + + if mfcc_fn is None or mfcc_fn.sample_rate != sr: + melkwargs = { + "n_fft": int(0.05 * sr), "win_length": int(0.05 * sr), + "hop_length": int(0.0125 * sr), "f_min": 20, + "n_mels": 80, "window_fn": torch.hann_window + } + mfcc_fn = torchaudio.transforms.MFCC( + sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs + ).to(y1[0].device) + return batch_compute_distortion( + y1, y2, sr, lambda y: mfcc_fn(y).transpose(-1, -2), compute_rms_dist, + normalize_type + ) diff --git a/setup.py b/setup.py index c699936a99..1f3998de80 100644 --- a/setup.py +++ b/setup.py @@ -210,6 +210,7 @@ def do_setup(package_data): "torch", "tqdm", "bitarray", + "torchaudio>=0.8.0", ], dependency_links=dependency_links, packages=find_packages( From 8adff65ab30dd5f3a3589315bbc1fafad52943e7 Mon Sep 17 00:00:00 2001 From: Vimal Manohar <vimalmanohar@fb.com> Date: Tue, 14 Sep 2021 21:42:39 -0700 Subject: [PATCH 702/707] Use batch_by_size in dataset in aligned training task Summary: Aligned training was not using batch_by_size in the dataset. Due to this, it was not possible to use batch sampling in MultiCorpusDataset with different transforms and collators for different datasets. Reviewed By: xiaoxiao26 Differential Revision: D30889985 fbshipit-source-id: 224ad55d2337681a06a82caf19900e5a241a3d6a --- fairseq/data/multi_corpus_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 1bd61c32eb..746155e515 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -170,7 +170,12 @@ def collater(self, samples): return None if "full_id" in samples[0]: _, key = self._map_index(samples[0]["full_id"]) - return self.datasets[key].collater(samples) + try: + batch = self.datasets[key].collater(samples) + except Exception: + print(f"Collating failed for key {key}", flush=True) + raise + return batch else: # Subclasses may override __getitem__ to not specify full_id return list(self.datasets.values())[0].collater(samples) From 98d638c70cdbe751153c10fc571c34beac228347 Mon Sep 17 00:00:00 2001 From: Xutai Ma <xutaima@gmail.com> Date: Wed, 15 Sep 2021 01:48:57 -0700 Subject: [PATCH 703/707] Mma refactor (#2087) Summary: Fixing issues ([3546](https://github.com/pytorch/fairseq/issues/3546)) with latency augmented training for mma due to the change of fairseq APIs Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2087 Reviewed By: hygong-fb Differential Revision: D29851286 Pulled By: xutaima fbshipit-source-id: 6c3077db06b89c23b312b28527d7395a725f3b3a --- .../models/transformer_monotonic_attention.py | 4 + .../modules/monotonic_multihead_attention.py | 4 +- .../utils/data_utils.py | 100 ---- .../utils/functions.py | 5 +- .../simultaneous_translation/utils/latency.py | 451 ------------------ .../utils/monotonic_attention.py | 2 + ...moothed_cross_entropy_latency_augmented.py | 275 ++++++++--- fairseq/modules/transformer_layer.py | 5 +- 8 files changed, 215 insertions(+), 631 deletions(-) delete mode 100644 examples/simultaneous_translation/utils/data_utils.py delete mode 100644 examples/simultaneous_translation/utils/latency.py diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index b0cdc43483..7b9414b0eb 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -100,6 +100,10 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ] ) self.policy_criterion = getattr(args, "policy_criterion", "any") + self.num_updates = None + + def set_num_updates(self, num_updates): + self.num_updates = num_updates def pre_attention( self, diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 2b8a48b1de..11ef60c945 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -131,7 +131,7 @@ def energy_from_qk( return energy - def p_choose_from_qk(self, query, key, key_padding_mask): + def p_choose_from_qk(self, query, key, key_padding_mask, incremental_states=None): monotonic_energy = self.energy_from_qk( query, key, @@ -148,7 +148,7 @@ def p_choose_from_qk(self, query, key, key_padding_mask): ) return p_choose - def p_choose(self, query, key, key_padding_mask): + def p_choose(self, query, key, key_padding_mask, incremental_states=None): return self.p_choose_from_qk(self, query, key, key_padding_mask) def monotonic_attention_process_infer( diff --git a/examples/simultaneous_translation/utils/data_utils.py b/examples/simultaneous_translation/utils/data_utils.py deleted file mode 100644 index a763ea6686..0000000000 --- a/examples/simultaneous_translation/utils/data_utils.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -def calc_mean_invstddev(feature): - if len(feature.size()) != 2: - raise ValueError("We expect the input feature to be 2-D tensor") - mean = feature.mean(0) - var = feature.var(0) - # avoid division by ~zero - eps = 1e-8 - if (var < eps).any(): - return mean, 1.0 / (torch.sqrt(var) + eps) - return mean, 1.0 / torch.sqrt(var) - - -def apply_mv_norm(features): - # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) - # and normalization is not possible, so return the item as it is - if features.size(0) < 2: - return features - mean, invstddev = calc_mean_invstddev(features) - res = (features - mean) * invstddev - return res - - -def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): - """ - convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor - - Args: - lengths: a (B, )-shaped tensor - - Return: - max_length: maximum length of B sequences - encoder_padding_mask: a (max_length, B) binary mask, where - [t, b] = 0 for t < lengths[b] and 1 otherwise - - TODO: - kernelize this function if benchmarking shows this function is slow - """ - max_lengths = torch.max(lengths).item() - bsz = lengths.size(0) - encoder_padding_mask = torch.arange( - max_lengths - ).to( # a (T, ) tensor with [0, ..., T-1] - lengths.device - ).view( # move to the right device - 1, max_lengths - ).expand( # reshape to (1, T)-shaped tensor - bsz, -1 - ) >= lengths.view( # expand to (B, T)-shaped tensor - bsz, 1 - ).expand( - -1, max_lengths - ) - if not batch_first: - return encoder_padding_mask.t(), max_lengths - else: - return encoder_padding_mask, max_lengths - - -def encoder_padding_mask_to_lengths( - encoder_padding_mask, max_lengths, batch_size, device -): - """ - convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor - - Conventionally, encoder output contains a encoder_padding_mask, which is - a 2-D mask in a shape (T, B), whose (t, b) element indicate whether - encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we - need to convert this mask tensor to a 1-D tensor in shape (B, ), where - [b] denotes the valid length of b-th sequence - - Args: - encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, - indicating all are valid - Return: - seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the - number of valid elements of b-th sequence - - max_lengths: maximum length of all sequence, if encoder_padding_mask is - not None, max_lengths must equal to encoder_padding_mask.size(0) - - batch_size: batch size; if encoder_padding_mask is - not None, max_lengths must equal to encoder_padding_mask.size(1) - - device: which device to put the result on - """ - if encoder_padding_mask is None: - return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) - - assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" - assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" - - return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/examples/simultaneous_translation/utils/functions.py b/examples/simultaneous_translation/utils/functions.py index 0ced35a9d5..590a6c11ce 100644 --- a/examples/simultaneous_translation/utils/functions.py +++ b/examples/simultaneous_translation/utils/functions.py @@ -6,11 +6,12 @@ import torch -def prob_check(tensor): +def prob_check(tensor, eps=1e-10): assert not torch.isnan(tensor).any(), ( "Nan in a probability tensor." ) - assert tensor.le(1.0).all() and tensor.ge(0.0).all(), ( + # Add the eps here to prevent errors introduced by precision + assert tensor.le(1.0 + eps).all() and tensor.ge(0.0 - eps).all(), ( "Incorrect values in a probability tensor" ", 0.0 <= tensor <= 1.0" ) diff --git a/examples/simultaneous_translation/utils/latency.py b/examples/simultaneous_translation/utils/latency.py deleted file mode 100644 index 5d800a5d9e..0000000000 --- a/examples/simultaneous_translation/utils/latency.py +++ /dev/null @@ -1,451 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -class LatencyMetric(object): - @staticmethod - def length_from_padding_mask(padding_mask, batch_first: bool = False): - dim = 1 if batch_first else 0 - return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True) - - def prepare_latency_metric( - self, - delays, - src_lens, - target_padding_mask=None, - batch_first: bool = False, - start_from_zero: bool = True, - ): - assert len(delays.size()) == 2 - assert len(src_lens.size()) == 2 - - if start_from_zero: - delays = delays + 1 - - if batch_first: - # convert to batch_last - delays = delays.t() - src_lens = src_lens.t() - tgt_len, bsz = delays.size() - _, bsz_1 = src_lens.size() - - if target_padding_mask is not None: - target_padding_mask = target_padding_mask.t() - tgt_len_1, bsz_2 = target_padding_mask.size() - assert tgt_len == tgt_len_1 - assert bsz == bsz_2 - - assert bsz == bsz_1 - - if target_padding_mask is None: - tgt_lens = tgt_len * delays.new_ones([1, bsz]).float() - else: - # 1, batch_size - tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float() - delays = delays.masked_fill(target_padding_mask, 0) - - return delays, src_lens, tgt_lens, target_padding_mask - - def __call__( - self, - delays, - src_lens, - target_padding_mask=None, - batch_first: bool = False, - start_from_zero: bool = True, - ): - delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric( - delays, src_lens, target_padding_mask, batch_first, start_from_zero - ) - return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask) - - @staticmethod - def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): - """ - Expected sizes: - delays: tgt_len, batch_size - src_lens: 1, batch_size - target_padding_mask: tgt_len, batch_size - """ - raise NotImplementedError - - -class AverageProportion(LatencyMetric): - """ - Function to calculate Average Proportion from - Can neural machine translation do simultaneous translation? - (https://arxiv.org/abs/1606.02012) - - Delays are monotonic steps, range from 1 to src_len. - Give src x tgt y, AP is calculated as: - - AP = 1 / (|x||y]) sum_i^|Y| deleys_i - """ - - @staticmethod - def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): - if target_padding_mask is not None: - AP = torch.sum( - delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True - ) - else: - AP = torch.sum(delays, dim=0, keepdim=True) - - AP = AP / (src_lens * tgt_lens) - return AP - - -class AverageLagging(LatencyMetric): - """ - Function to calculate Average Lagging from - STACL: Simultaneous Translation with Implicit Anticipation - and Controllable Latency using Prefix-to-Prefix Framework - (https://arxiv.org/abs/1810.08398) - - Delays are monotonic steps, range from 1 to src_len. - Give src x tgt y, AP is calculated as: - - AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma - - Where - gamma = |y| / |x| - tau = argmin_i(delays_i = |x|) - """ - - @staticmethod - def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): - # tau = argmin_i(delays_i = |x|) - tgt_len, bsz = delays.size() - lagging_padding_mask = delays >= src_lens - lagging_padding_mask = torch.nn.functional.pad( - lagging_padding_mask.t(), (1, 0) - ).t()[:-1, :] - gamma = tgt_lens / src_lens - lagging = ( - delays - - torch.arange(delays.size(0)) - .unsqueeze(1) - .type_as(delays) - .expand_as(delays) - / gamma - ) - lagging.masked_fill_(lagging_padding_mask, 0) - tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True) - AL = lagging.sum(dim=0, keepdim=True) / tau - - return AL - - -class DifferentiableAverageLagging(LatencyMetric): - """ - Function to calculate Differentiable Average Lagging from - Monotonic Infinite Lookback Attention for Simultaneous Machine Translation - (https://arxiv.org/abs/1906.05218) - - Delays are monotonic steps, range from 0 to src_len-1. - (In the original paper thery are from 1 to src_len) - Give src x tgt y, AP is calculated as: - - DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma - - Where - delays'_i = - 1. delays_i if i == 1 - 2. max(delays_i, delays'_{i-1} + 1 / gamma) - - """ - - @staticmethod - def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): - tgt_len, bsz = delays.size() - - gamma = tgt_lens / src_lens - new_delays = torch.zeros_like(delays) - - for i in range(delays.size(0)): - if i == 0: - new_delays[i] = delays[i] - else: - new_delays[i] = torch.cat( - [ - new_delays[i - 1].unsqueeze(0) + 1 / gamma, - delays[i].unsqueeze(0), - ], - dim=0, - ).max(dim=0)[0] - - DAL = ( - new_delays - - torch.arange(delays.size(0)) - .unsqueeze(1) - .type_as(delays) - .expand_as(delays) - / gamma - ) - if target_padding_mask is not None: - DAL = DAL.masked_fill(target_padding_mask, 0) - - DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens - - return DAL - - -class LatencyMetricVariance(LatencyMetric): - def prepare_latency_metric( - self, - delays, - src_lens, - target_padding_mask=None, - batch_first: bool = True, - start_from_zero: bool = True, - ): - assert batch_first - assert len(delays.size()) == 3 - assert len(src_lens.size()) == 2 - - if start_from_zero: - delays = delays + 1 - - # convert to batch_last - bsz, num_heads_x_layers, tgt_len = delays.size() - bsz_1, _ = src_lens.size() - assert bsz == bsz_1 - - if target_padding_mask is not None: - bsz_2, tgt_len_1 = target_padding_mask.size() - assert tgt_len == tgt_len_1 - assert bsz == bsz_2 - - if target_padding_mask is None: - tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float() - else: - # batch_size, 1 - tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float() - delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0) - - return delays, src_lens, tgt_lens, target_padding_mask - - -class VarianceDelay(LatencyMetricVariance): - @staticmethod - def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): - """ - delays : bsz, num_heads_x_layers, tgt_len - src_lens : bsz, 1 - target_lens : bsz, 1 - target_padding_mask: bsz, tgt_len or None - """ - if delays.size(1) == 1: - return delays.new_zeros([1]) - - variance_delays = delays.var(dim=1) - - if target_padding_mask is not None: - variance_delays.masked_fill_(target_padding_mask, 0) - - return variance_delays.sum(dim=1, keepdim=True) / tgt_lens - - -class LatencyInference(object): - def __init__(self, start_from_zero=True): - self.metric_calculator = { - "differentiable_average_lagging": DifferentiableAverageLagging(), - "average_lagging": AverageLagging(), - "average_proportion": AverageProportion(), - } - - self.start_from_zero = start_from_zero - - def __call__(self, monotonic_step, src_lens): - """ - monotonic_step range from 0 to src_len. src_len means eos - delays: bsz, tgt_len - src_lens: bsz, 1 - """ - if not self.start_from_zero: - monotonic_step -= 1 - - src_lens = src_lens - - delays = monotonic_step.view( - monotonic_step.size(0), -1, monotonic_step.size(-1) - ).max(dim=1)[0] - - delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as( - delays - ).masked_fill(delays < src_lens, 0) - return_dict = {} - for key, func in self.metric_calculator.items(): - return_dict[key] = func( - delays.float(), - src_lens.float(), - target_padding_mask=None, - batch_first=True, - start_from_zero=True, - ).t() - - return return_dict - - -class LatencyTraining(object): - def __init__( - self, - avg_weight, - var_weight, - avg_type, - var_type, - stay_on_last_token, - average_method, - ): - self.avg_weight = avg_weight - self.var_weight = var_weight - self.avg_type = avg_type - self.var_type = var_type - self.stay_on_last_token = stay_on_last_token - self.average_method = average_method - - self.metric_calculator = { - "differentiable_average_lagging": DifferentiableAverageLagging(), - "average_lagging": AverageLagging(), - "average_proportion": AverageProportion(), - } - - self.variance_calculator = { - "variance_delay": VarianceDelay(), - } - - def expected_delays_from_attention( - self, attention, source_padding_mask=None, target_padding_mask=None - ): - if type(attention) == list: - # bsz, num_heads, tgt_len, src_len - bsz, num_heads, tgt_len, src_len = attention[0].size() - attention = torch.cat(attention, dim=1) - bsz, num_heads_x_layers, tgt_len, src_len = attention.size() - # bsz * num_heads * num_layers, tgt_len, src_len - attention = attention.view(-1, tgt_len, src_len) - else: - # bsz * num_heads * num_layers, tgt_len, src_len - bsz, tgt_len, src_len = attention.size() - num_heads_x_layers = 1 - attention = attention.view(-1, tgt_len, src_len) - - if not self.stay_on_last_token: - residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True) - attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2) - - # bsz * num_heads_x_num_layers, tgt_len, src_len for MMA - steps = ( - torch.arange(1, 1 + src_len) - .unsqueeze(0) - .unsqueeze(1) - .expand_as(attention) - .type_as(attention) - ) - - if source_padding_mask is not None: - src_offset = ( - source_padding_mask.type_as(attention) - .sum(dim=1, keepdim=True) - .expand(bsz, num_heads_x_layers) - .contiguous() - .view(-1, 1) - ) - src_lens = src_len - src_offset - if source_padding_mask[:, 0].any(): - # Pad left - src_offset = src_offset.view(-1, 1, 1) - steps = steps - src_offset - steps = steps.masked_fill(steps <= 0, 0) - else: - src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len - src_lens = src_lens.view(-1, 1) - - # bsz * num_heads_num_layers, tgt_len, src_len - expected_delays = ( - (steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len) - ) - - if target_padding_mask is not None: - expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0) - - return expected_delays, src_lens - - def avg_loss(self, expected_delays, src_lens, target_padding_mask): - - bsz, num_heads_x_layers, tgt_len = expected_delays.size() - target_padding_mask = ( - target_padding_mask.unsqueeze(1) - .expand_as(expected_delays) - .contiguous() - .view(-1, tgt_len) - ) - - if self.average_method == "average": - # bsz * tgt_len - expected_delays = expected_delays.mean(dim=1) - elif self.average_method == "weighted_average": - weights = torch.nn.functional.softmax(expected_delays, dim=1) - expected_delays = torch.sum(expected_delays * weights, dim=1) - elif self.average_method == "max": - # bsz * num_heads_x_num_layers, tgt_len - expected_delays = expected_delays.max(dim=1)[0] - else: - raise RuntimeError(f"{self.average_method} is not supported") - - src_lens = src_lens.view(bsz, -1)[:, :1] - target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0] - - if self.avg_weight > 0.0: - if self.avg_type in self.metric_calculator: - average_delays = self.metric_calculator[self.avg_type]( - expected_delays, - src_lens, - target_padding_mask, - batch_first=True, - start_from_zero=False, - ) - else: - raise RuntimeError(f"{self.avg_type} is not supported.") - - # bsz * num_heads_x_num_layers, 1 - return self.avg_weight * average_delays.sum() - else: - return 0.0 - - def var_loss(self, expected_delays, src_lens, target_padding_mask): - src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[ - :, :1 - ] - if self.var_weight > 0.0: - if self.var_type in self.variance_calculator: - variance_delays = self.variance_calculator[self.var_type]( - expected_delays, - src_lens, - target_padding_mask, - batch_first=True, - start_from_zero=False, - ) - else: - raise RuntimeError(f"{self.var_type} is not supported.") - - return self.var_weight * variance_delays.sum() - else: - return 0.0 - - def loss(self, attention, source_padding_mask=None, target_padding_mask=None): - expected_delays, src_lens = self.expected_delays_from_attention( - attention, source_padding_mask, target_padding_mask - ) - - latency_loss = 0 - - latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask) - - latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask) - - return latency_loss diff --git a/examples/simultaneous_translation/utils/monotonic_attention.py b/examples/simultaneous_translation/utils/monotonic_attention.py index fd45137735..61dbb112bf 100644 --- a/examples/simultaneous_translation/utils/monotonic_attention.py +++ b/examples/simultaneous_translation/utils/monotonic_attention.py @@ -145,6 +145,8 @@ def expected_soft_attention( # Mix precision to prevent overflow for fp16 beta = beta.type(dtype) + beta = beta.clamp(0, 1) + prob_check(beta) return beta diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index 051785238f..223a16f740 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -3,13 +3,63 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +import torch +from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig ) +try: + from simuleval.metrics.latency import ( + AverageLagging, + AverageProportion, + DifferentiableAverageLagging + ) + LATENCY_METRICS = { + "average_lagging": AverageLagging, + "average_proportion": AverageProportion, + "differentiable_average_lagging": DifferentiableAverageLagging, + } +except ImportError: + LATENCY_METRICS = None -@register_criterion("latency_augmented_label_smoothed_cross_entropy") + +@dataclass +class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + latency_avg_weight: float = field( + default=0.0, + metadata={"help": "weight fot average latency loss."}, + ) + latency_var_weight: float = field( + default=0.0, + metadata={"help": "weight fot variance latency loss."}, + ) + latency_avg_type: str = field( + default="differentiable_average_lagging", + metadata={"help": "latency type for average loss"}, + ) + latency_var_type: str = field( + default="variance_delay", + metadata={"help": "latency typ for variance loss"}, + ) + latency_gather_method: str = field( + default="weighted_average", + metadata={"help": "method to gather latency loss for all heads"}, + ) + latency_update_after: int = field( + default=0, + metadata={"help": "Add latency loss after certain steps"}, + ) + +@register_criterion( + "latency_augmented_label_smoothed_cross_entropy", + dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig +) class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion ): @@ -20,89 +70,164 @@ def __init__( label_smoothing, ignore_prefix_size, report_accuracy, - latency_weight_avg, - latency_weight_avg_type, - latency_weight_var, - latency_weight_var_type, - mass_preservation, - average_method, + latency_avg_weight, + latency_var_weight, + latency_avg_type, + latency_var_type, + latency_gather_method, + latency_update_after, ): super().__init__( task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy ) - from examples.simultaneous_translation.utils.latency import LatencyTraining - self.eps = label_smoothing - self.latency_weight_avg = latency_weight_avg - self.latency_weight_avg_type = latency_weight_avg_type - self.latency_weight_var = latency_weight_var - self.latency_weight_var_type = latency_weight_var_type - self.mass_preservation = mass_preservation - self.average_method = average_method - self.latency_train = LatencyTraining( - self.latency_weight_avg, - self.latency_weight_var, - self.latency_weight_avg_type, - self.latency_weight_var_type, - self.mass_preservation, - self.average_method, + assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed." + + self.latency_avg_weight = latency_avg_weight + self.latency_var_weight = latency_var_weight + self.latency_avg_type = latency_avg_type + self.latency_var_type = latency_var_type + self.latency_gather_method = latency_gather_method + self.latency_update_after = latency_update_after + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + # 1. Compute cross entropy loss + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + + # 2. Compute cross latency loss + latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss( + model, sample, net_output + ) + + if self.latency_update_after > 0: + num_updates = getattr(model.decoder, "num_updates", None) + assert num_updates is not None, ( + "model.decoder doesn't have attribute 'num_updates'" + ) + if num_updates <= self.latency_update_after: + latency_loss = 0 + + loss += latency_loss + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "latency": expected_latency, + "delays_var": expected_delays_var, + "latency_loss": latency_loss, + } + + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def compute_latency_loss(self, model, sample, net_output): + assert ( + net_output[-1].encoder_padding_mask is None + or not net_output[-1].encoder_padding_mask[:, 0].any() + ), ( + "Only right padding on source is supported." + ) + # 1. Obtain the expected alignment + alpha_list = [item["alpha"] for item in net_output[1].attn_list] + num_layers = len(alpha_list) + bsz, num_heads, tgt_len, src_len = alpha_list[0].size() + + # bsz * num_layers * num_heads, tgt_len, src_len + alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len) + + # 2 compute expected delays + # bsz * num_heads * num_layers, tgt_len, src_len for MMA + steps = ( + torch.arange(1, 1 + src_len) + .unsqueeze(0) + .unsqueeze(1) + .expand_as(alpha_all) + .type_as(alpha_all) ) - @staticmethod - def add_args(parser): - super( - LatencyAugmentedLabelSmoothedCrossEntropyCriterion, - LatencyAugmentedLabelSmoothedCrossEntropyCriterion, - ).add_args(parser) - # fmt: off - - """Add criterion-specific arguments to the parser.""" - parser.add_argument( - "--label-smoothing", - default=0.0, - type=float, - metavar="D", - help="epsilon for label smoothing, 0 means no label smoothing", + expected_delays = torch.sum(steps * alpha_all, dim=-1) + + target_padding_mask = ( + model.get_targets(sample, net_output) + .eq(self.padding_idx) + .unsqueeze(1) + .expand(bsz, num_layers * num_heads, tgt_len) + .contiguous() + .view(-1, tgt_len) ) - parser.add_argument( - "--ignore_prefix_size", - default=0, - type=int, - help="ignore first N tokens", + + src_lengths = ( + sample["net_input"]["src_lengths"] + .unsqueeze(1) + .expand(bsz, num_layers * num_heads) + .contiguous() + .view(-1) ) - parser.add_argument( - "--report-accuracy", - default=False, - type=bool, - help="report accuracy metric", + expected_latency = LATENCY_METRICS[self.latency_avg_type]( + expected_delays, src_lengths, None, + target_padding_mask=target_padding_mask ) - parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', - help="Average loss weight") - parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', - help="Variance loss weight") - parser.add_argument("--latency-weight-avg-type", default="differentiable_average_lagging", - help="Statistics for Average loss type") - parser.add_argument("--latency-weight-var-type", default="variance_delay", - help="Statistics for variance loss type") - parser.add_argument("--average-method", default="weighted_average", - help="Average loss type") - # fmt: on - - def compute_loss(self, model, net_output, sample, reduce=True): - # Compute cross entropy loss first - loss, nll_loss = super().compute_loss(model, net_output, sample, reduce) - - # Obtain the expected alignment - attn_list = [item["alpha"] for item in net_output[-1]["attn_list"]] - - target_padding_mask = model.get_targets(sample, net_output).eq(self.padding_idx) - - source_padding_mask = net_output[-1].get("encoder_padding_mask", None) - - # Get latency loss - latency_loss = self.latency_train.loss( - attn_list, source_padding_mask, target_padding_mask + + # 2.1 average expected latency of heads + # bsz, num_layers * num_heads + expected_latency = expected_latency.view(bsz, -1) + if self.latency_gather_method == "average": + # bsz * tgt_len + expected_latency = expected_delays.mean(dim=1) + elif self.latency_gather_method == "weighted_average": + weights = torch.nn.functional.softmax(expected_latency, dim=1) + expected_latency = torch.sum(expected_latency * weights, dim=1) + elif self.latency_gather_method == "max": + expected_latency = expected_latency.max(dim=1)[0] + else: + raise NotImplementedError + + expected_latency = expected_latency.sum() + avg_loss = self.latency_avg_weight * expected_latency + + # 2.2 variance of expected delays + expected_delays_var = ( + expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1) ) + expected_delays_var = expected_delays_var.sum() + var_loss = self.latency_avg_weight * expected_delays_var - loss += latency_loss + # 3. Final loss + latency_loss = avg_loss + var_loss + + return latency_loss, expected_latency, expected_delays_var - return loss, nll_loss + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + latency = sum( + log.get("latency", 0) for log in logging_outputs + ) + delays_var = sum( + log.get("delays_var", 0) for log in logging_outputs + ) + latency_loss = sum( + log.get("latency_loss", 0) for log in logging_outputs + ) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar( + "latency", latency.float() / nsentences, nsentences, round=3 + ) + metrics.log_scalar( + "delays_var", delays_var / nsentences, + nsentences, round=3 + ) + metrics.log_scalar( + "latency_loss", latency_loss / nsentences, + nsentences, round=3 + ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index de25de6564..347b8118da 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -131,7 +131,10 @@ def forward( # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters if attn_mask is not None: - attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + attn_mask = attn_mask.masked_fill( + attn_mask.to(torch.bool), + -1e8 if x.dtype == torch.float32 else -1e4 + ) residual = x if self.normalize_before: From f6abcc2a67328bee8b15c596bb626ce2d720aae6 Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Thu, 16 Sep 2021 10:01:11 -0700 Subject: [PATCH 704/707] update on branch renaming (#3879) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3879 Reviewed By: myleott Differential Revision: D30969142 Pulled By: dianaml0 fbshipit-source-id: 902154c03fd68ae6645d3e0ac07b7d729dfc7934 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index cd9654cf31..3316c963ce 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming). * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) From 5adfeaccf9a70cf8ad25eb0d3a0826a6665ac8d2 Mon Sep 17 00:00:00 2001 From: Diana Liskovich <dianaml@devfair0471.h2.fair> Date: Mon, 20 Sep 2021 08:04:06 -0700 Subject: [PATCH 705/707] Rename references from master -> main in preparation for branch name change (#2297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2297 Reviewed By: alexeib Differential Revision: D30906090 Pulled By: dianaml0 fbshipit-source-id: 941d30db7f766c9077a1b5bb2a04680f57e2e070 --- .github/ISSUE_TEMPLATE/bug_report.md | 4 ++-- .github/ISSUE_TEMPLATE/how-to-question.md | 10 +++++----- .github/PULL_REQUEST_TEMPLATE.md | 10 +++++----- .github/workflows/build.yml | 4 ++-- CONTRIBUTING.md | 2 +- README.md | 6 +++--- docs/conf.py | 2 +- examples/adaptive_span/README.md | 2 +- examples/constrained_decoding/README.md | 2 +- .../discriminative_reranking_nmt/README.md | 2 +- examples/fast_noisy_channel/README.md | 4 ++-- examples/layerdrop/README.md | 6 +++--- examples/m2m_100/README.md | 2 +- examples/multilingual/README.md | 6 +++--- examples/quant_noise/README.md | 20 +++++++++---------- examples/roberta/README.md | 8 ++++---- examples/roberta/commonsense_qa/README.md | 2 +- examples/shuffled_word_order/README.md | 6 +++--- .../speech_synthesis/docs/ljspeech_example.md | 4 ++-- examples/textless_nlp/gslm/README.md | 4 ++-- examples/wav2vec/unsupervised/README.md | 4 ++-- fairseq/models/bart/hub_interface.py | 2 +- fairseq/models/roberta/hub_interface.py | 2 +- 23 files changed, 57 insertions(+), 57 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index a7f4f0a902..aa15123d8e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -19,7 +19,7 @@ Steps to reproduce the behavior (**always include the command you ran**): #### Code sample -<!-- Ideally attach a minimal code sample to reproduce the decried issue. +<!-- Ideally attach a minimal code sample to reproduce the decried issue. Minimal means having the shortest code but still preserving the bug. --> ### Expected behavior @@ -28,7 +28,7 @@ Minimal means having the shortest code but still preserving the bug. --> ### Environment - - fairseq Version (e.g., 1.0 or master): + - fairseq Version (e.g., 1.0 or main): - PyTorch Version (e.g., 1.0) - OS (e.g., Linux): - How you installed fairseq (`pip`, source): diff --git a/.github/ISSUE_TEMPLATE/how-to-question.md b/.github/ISSUE_TEMPLATE/how-to-question.md index 4beb180dbf..04f3f15d3e 100644 --- a/.github/ISSUE_TEMPLATE/how-to-question.md +++ b/.github/ISSUE_TEMPLATE/how-to-question.md @@ -6,9 +6,9 @@ labels: 'question, needs triage' ## ❓ Questions and Help -### Before asking: -1. search the issues. -2. search the docs. +### Before asking: +1. search the issues. +2. search the docs. <!-- If you still can't find what you need: --> @@ -16,13 +16,13 @@ labels: 'question, needs triage' #### Code -<!-- Please paste a code snippet if your question requires it! --> +<!-- Please paste a code snippet if your question requires it! --> #### What have you tried? #### What's your environment? - - fairseq Version (e.g., 1.0 or master): + - fairseq Version (e.g., 1.0 or main): - PyTorch Version (e.g., 1.0) - OS (e.g., Linux): - How you installed fairseq (`pip`, source): diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b28ff98e7b..d005e2df4f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,15 +1,15 @@ # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) -- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? -- [ ] Did you make sure to update the docs? -- [ ] Did you write any new necessary tests? +- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? +- [ ] Did you make sure to update the docs? +- [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). -## PR review -Anyone in the community is free to review the PR once the tests have passed. +## PR review +Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 105c42a503..f493f91f0d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,10 +1,10 @@ name: build on: - # Trigger the workflow on push to master or any pull request + # Trigger the workflow on push to main or any pull request push: branches: - - master + - main pull_request: jobs: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d7ca6a98e..3930c46196 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ possible. ## Pull Requests We actively welcome your pull requests. -1. Fork the repo and create your branch from `master`. +1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. diff --git a/README.md b/README.md index 3316c963ce..dd68717480 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ <img src="docs/fairseq_logo.png" width="150"> <br /> <br /> - <a href="https://github.com/pytorch/fairseq/blob/master/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a> + <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a> <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a> <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a> <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a> @@ -48,7 +48,7 @@ We provide reference implementations of various sequence modeling papers: + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) - + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) * **Non-autoregressive Transformers** @@ -93,7 +93,7 @@ We provide reference implementations of various sequence modeling papers: * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) * February 2020: [mBART model and code released](examples/mbart/README.md) -* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german) * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) * November 2019: [CamemBERT model and code released](examples/camembert/README.md) diff --git a/docs/conf.py b/docs/conf.py index 440784bfae..87b0db98c7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,7 +55,7 @@ copyright = "Facebook AI Research (FAIR)" author = "Facebook AI Research (FAIR)" -github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/" +github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/examples/adaptive_span/README.md b/examples/adaptive_span/README.md index 913a873386..d5224fb289 100644 --- a/examples/adaptive_span/README.md +++ b/examples/adaptive_span/README.md @@ -4,7 +4,7 @@ Adaptive Span is a novel self-attention mechanism that can learn its optimal attention span. This allows us to extend significantly the maximum context size used in Transformer, while maintaining control over their memory footprint and computational time. It uses the Truncated BPTT technique for training, -as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md). +as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md). Adaptive Span was introduced by paper: [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799), diff --git a/examples/constrained_decoding/README.md b/examples/constrained_decoding/README.md index cfca9c91fd..e04b8b6a01 100644 --- a/examples/constrained_decoding/README.md +++ b/examples/constrained_decoding/README.md @@ -12,7 +12,7 @@ Constrained search is enabled by adding the command-line argument `--constraints Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens) is a separate field. -The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md), +The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/main/examples/wmt19/README.md), translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints "hard" and "to influence". diff --git a/examples/discriminative_reranking_nmt/README.md b/examples/discriminative_reranking_nmt/README.md index e6f42b1278..b155e855f2 100644 --- a/examples/discriminative_reranking_nmt/README.md +++ b/examples/discriminative_reranking_nmt/README.md @@ -38,7 +38,7 @@ source_sentence_L_hypo_1 source_sentence_L_hypo_N ``` -2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/master/examples/xlmr#pre-trained-models). +2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/main/examples/xlmr#pre-trained-models). ``` wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz tar zxvf xlmr.base.tar.gz diff --git a/examples/fast_noisy_channel/README.md b/examples/fast_noisy_channel/README.md index a04151a796..f2631a8c34 100644 --- a/examples/fast_noisy_channel/README.md +++ b/examples/fast_noisy_channel/README.md @@ -29,9 +29,9 @@ This framework provides a great way to utlize strong target language models trai ### Training Translation Models and Language Models -For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/translation) +For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/translation) -For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) +For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model) ### Generation with Language Model for German-English translation with fairseq diff --git a/examples/layerdrop/README.md b/examples/layerdrop/README.md index 394e710b0f..4d48ee9615 100644 --- a/examples/layerdrop/README.md +++ b/examples/layerdrop/README.md @@ -126,9 +126,9 @@ This model override command overrides the training parameters and updates the mo Looking to reproduce the results in the paper? -1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md) -2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta) -3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) +1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/main/examples/scaling_nmt/README.md) +2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/roberta) +3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model) ## Tips diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index 05801584d6..02a68a5f09 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -82,7 +82,7 @@ fairseq-preprocess \ 3. **Training Scripts** -To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/master/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale). +To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/main/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale). 4. **Generation** diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md index 0076f5e8f0..46ff9c351b 100644 --- a/examples/multilingual/README.md +++ b/examples/multilingual/README.md @@ -17,9 +17,9 @@ This work is for training multilingual translation models with multiple bitext d - --finetune-from-model to specify the path from which to load the pretrained model ## Preprocessing data -Multilingual training requires a joint BPE vocab. Please follow [mBART's preprocessing steps](https://github.com/pytorch/fairseq/tree/master/examples/mbart#bpe-data) to reuse our pretrained sentence-piece model. +Multilingual training requires a joint BPE vocab. Please follow [mBART's preprocessing steps](https://github.com/pytorch/fairseq/tree/main/examples/mbart#bpe-data) to reuse our pretrained sentence-piece model. -You can also train a joint BPE model on your own dataset and then follow the steps in [[link]](https://github.com/pytorch/fairseq/tree/master/examples/translation#multilingual-translation). +You can also train a joint BPE model on your own dataset and then follow the steps in [[link]](https://github.com/pytorch/fairseq/tree/main/examples/translation#multilingual-translation). ## Training @@ -49,7 +49,7 @@ fairseq-train $path_2_data \ ``` ## Finetuning -We can also finetune multilingual models from a monolingual pretrained models, e.g. [mMBART](https://github.com/pytorch/fairseq/tree/master/examples/mbart). +We can also finetune multilingual models from a monolingual pretrained models, e.g. [mMBART](https://github.com/pytorch/fairseq/tree/main/examples/mbart). ```bash lang_pairs=<language pairs to be trained, e.g. "en-cs,cs-en"> path_2_data=<set to data path> diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 539c3d5af9..a04d7e4e8a 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -33,7 +33,7 @@ Unlike the section [Iterative Product Quantization](#iterative-product-quantizat #### Training -Scalar quantization with Quant-Noise consists in randomly quantizing a proportion `p` of the weights during training. Scalar quantization is implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar) under the form of Fake Quantization, meaning that we emulate int8 on GPU by quantizing and de-quantizing both the weights and the activations. We rely on PyTorch's [quantization primitives](https://github.com/pytorch/pytorch/tree/master/torch/quantization). +Scalar quantization with Quant-Noise consists in randomly quantizing a proportion `p` of the weights during training. Scalar quantization is implemented [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/scalar) under the form of Fake Quantization, meaning that we emulate int8 on GPU by quantizing and de-quantizing both the weights and the activations. We rely on PyTorch's [quantization primitives](https://github.com/pytorch/pytorch/tree/master/torch/quantization). To train a model with Quant-Noise, add the following flag: ``` @@ -49,7 +49,7 @@ When evaluating a network, all quantized modules and activation hooks automatica #### Integration with your own code Looking to quantize your own models with Quant-Noise + Scalar Quantization? -- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. +- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. - Then, perform your training as usual. Note that in `eval()` mode, the network is always fully quantized (weights and activations) by default (`p=1`). @@ -66,12 +66,12 @@ To train a model with Quant-Noise, add the following flags: --quant-noise-pq 0.1 --quant-noise-pq-block-size 8 ``` `quant-noise-pq` controls how much dropout is applied to the blocks of the weight matrix. `quant-noise-pq-block-size` controls the size of the weight matrix blocks. -We recommend training with 0.05 to 0.2 Quant-Noise, a value that worked well in our experiments. For the block-size, we recommend training with block-size of 8. Note that the block size must be a multiple of `input_features`, see the size checks [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py). Large block sizes result in higher compression ratio but may induce a loss in accuracy. +We recommend training with 0.05 to 0.2 Quant-Noise, a value that worked well in our experiments. For the block-size, we recommend training with block-size of 8. Note that the block size must be a multiple of `input_features`, see the size checks [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py). Large block sizes result in higher compression ratio but may induce a loss in accuracy. -We currently support training Transformer based models, such as sequence-to-sequence, language models, and BERT architectures. The `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py) wraps a module. It splits a weight matrix into blocks and applies random dropout to these blocks. +We currently support training Transformer based models, such as sequence-to-sequence, language models, and BERT architectures. The `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py) wraps a module. It splits a weight matrix into blocks and applies random dropout to these blocks. In the Transformer architectures, quant-noise is applied to the input and output embeddings, the attention, and the FFN. -Quant-Noise can also be combined with **LayerDrop** (see [here](https://github.com/pytorch/fairseq/tree/master/examples/layerdrop)) to add its pruning effect to the quantized model and make the model even smaller. We recommend training with LayerDrop 0.1 or 0.2. +Quant-Noise can also be combined with **LayerDrop** (see [here](https://github.com/pytorch/fairseq/tree/main/examples/layerdrop)) to add its pruning effect to the quantized model and make the model even smaller. We recommend training with LayerDrop 0.1 or 0.2. #### Quantization @@ -84,8 +84,8 @@ For the particular case of PQ, quantization is made sequentially. We recommend f #### Integration with your own code Looking to quantize your own models with Quant-Noise + iPQ? -- First wrap your modules with the `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py), which is module-agnostic and train your favorite model. -- Then, quantize your trained model using the code [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/pq). This can be done *without any changes to your training loop*. Below is an example code for integration. +- First wrap your modules with the `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py), which is module-agnostic and train your favorite model. +- Then, quantize your trained model using the code [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/pq). This can be done *without any changes to your training loop*. Below is an example code for integration. Note that we tried our approach only on Transformers and various Convolutional Models such as EfficientNets. ```python @@ -128,7 +128,7 @@ We detail below how to reproduce the state-of-the-art results in reported in the ### Training with Quant-Noise -To **train** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta). +To **train** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/roberta). The following command can be used to train a RoBERTa Base + QuantNoise model: ```bash @@ -158,7 +158,7 @@ fairseq-train $DATA_DIR \ --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 --untie-weights-roberta ``` -To **finetune** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md). +To **finetune** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.glue.md). The following command can be used to finetune a RoBERTa Base + QuantNoise model on the RTE dataset: ```bash @@ -193,7 +193,7 @@ fairseq-train /path/to/rte/data/ \ --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 ``` -To **train** Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model). +To **train** Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model). The following command can be used to train a Transformer + QuantNoise model on Wikitext-103: ```bash diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 58091b2c7d..ed4d5df52c 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -8,13 +8,13 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l ### What's New: -- December 2020: German model (GottBERT) is available: [GottBERT](https://github.com/pytorch/fairseq/tree/master/examples/gottbert). +- December 2020: German model (GottBERT) is available: [GottBERT](https://github.com/pytorch/fairseq/tree/main/examples/gottbert). - January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto). -- November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/master/examples/camembert). -- November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/master/examples/xlmr). +- November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/main/examples/camembert). +- November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/main/examples/xlmr). - September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers). - August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers). -- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset). +- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/main/examples/roberta/wsc#roberta-training-on-winogrande-dataset). - August 2019: Added [tutorial for pretraining RoBERTa using your own data](README.pretraining.md). ## Pre-trained models diff --git a/examples/roberta/commonsense_qa/README.md b/examples/roberta/commonsense_qa/README.md index 05c6f841a8..7f386decd8 100644 --- a/examples/roberta/commonsense_qa/README.md +++ b/examples/roberta/commonsense_qa/README.md @@ -96,4 +96,4 @@ print('Accuracy: ' + str(ncorrect / float(nsamples))) ``` The above snippet is not batched, which makes it quite slow. See [instructions -for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta#batched-prediction). +for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta#batched-prediction). diff --git a/examples/shuffled_word_order/README.md b/examples/shuffled_word_order/README.md index 14c240cb56..f20483849a 100644 --- a/examples/shuffled_word_order/README.md +++ b/examples/shuffled_word_order/README.md @@ -40,7 +40,7 @@ For more results on probing tasks, please refer to [our paper](https://arxiv.org ## Example Usage -Follow the same usage as in [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) to load and test your models: +Follow the same usage as in [RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta) to load and test your models: ```python # Download roberta.base.shuffle.n1 model @@ -53,11 +53,11 @@ roberta = RoBERTaModel.from_pretrained('/path/to/roberta.base.shuffle.n1', check roberta.eval() # disable dropout (or leave in train mode to finetune) ``` -**Note**: The model trained without positional embeddings (`roberta.base.nopos`) is a modified `RoBERTa` model, where the positional embeddings are not used. Thus, the typical `from_pretrained` method on fairseq version of RoBERTa will not be able to load the above model weights. To do so, construct a new `RoBERTaModel` object by setting the flag `use_positional_embeddings` to `False` (or [in the latest code](https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/model.py#L543), set `no_token_positional_embeddings` to `True`), and then load the individual weights. +**Note**: The model trained without positional embeddings (`roberta.base.nopos`) is a modified `RoBERTa` model, where the positional embeddings are not used. Thus, the typical `from_pretrained` method on fairseq version of RoBERTa will not be able to load the above model weights. To do so, construct a new `RoBERTaModel` object by setting the flag `use_positional_embeddings` to `False` (or [in the latest code](https://github.com/pytorch/fairseq/blob/main/fairseq/models/roberta/model.py#L543), set `no_token_positional_embeddings` to `True`), and then load the individual weights. ## Fine-tuning Evaluation -We provide the trained fine-tuned models on MNLI here for each model above for quick evaluation (1 seed for each model). Please refer to [finetuning details](README.finetuning.md) for the parameters of these models. Follow [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) instructions to evaluate these models. +We provide the trained fine-tuned models on MNLI here for each model above for quick evaluation (1 seed for each model). Please refer to [finetuning details](README.finetuning.md) for the parameters of these models. Follow [RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta) instructions to evaluate these models. | Model | MNLI M Dev Accuracy | Link | | :----------------------------------------- | :------------------ | :--------------------------------------------------------------------------------------------------------------- | diff --git a/examples/speech_synthesis/docs/ljspeech_example.md b/examples/speech_synthesis/docs/ljspeech_example.md index 2b8d21abf9..90c524fac8 100644 --- a/examples/speech_synthesis/docs/ljspeech_example.md +++ b/examples/speech_synthesis/docs/ljspeech_example.md @@ -38,7 +38,7 @@ For your convenience, we provide pre-computed [force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from [Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and [pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from -[HuBERT](https://github.com/pytorch/fairseq/tree/master/examples/hubert). You can also generate them by yourself using +[HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using a different software or model. @@ -106,7 +106,7 @@ use `--sample-rate 16000` for `get_eval_manifest.py`. #### WER/CER metric -We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) +We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) the model checkpoint and dictionary, then compute WER/CER with ```bash python -m examples.speech_synthesis.evaluation.eval_asr \ diff --git a/examples/textless_nlp/gslm/README.md b/examples/textless_nlp/gslm/README.md index 79de55d96e..7a76ffd57c 100644 --- a/examples/textless_nlp/gslm/README.md +++ b/examples/textless_nlp/gslm/README.md @@ -3,7 +3,7 @@ * [Paper](https://arxiv.org/abs/2102.01192) * [Demo](https://speechbot.github.io/gslm/index.html) -We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/master/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below. +We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/main/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below. ## Speech to Unit Model (speech2unit) Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit) @@ -18,4 +18,4 @@ Unit to speech model is used for synthesizing speech from discrete speech units. We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics). ## Tools -We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools) \ No newline at end of file +We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools) diff --git a/examples/wav2vec/unsupervised/README.md b/examples/wav2vec/unsupervised/README.md index 046202e01c..0b213fd202 100644 --- a/examples/wav2vec/unsupervised/README.md +++ b/examples/wav2vec/unsupervised/README.md @@ -1,6 +1,6 @@ # wav2vec Unsupervised (wav2vec-U) -Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec)) as well as unlabeled speech and text data. +Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition systems without any labeled training data as described in [Unsupervised Speech Recognition (Baevski et al., 2021)](https://ai.facebook.com/research/publications/unsupervised-speech-recognition). The model takes as input wav2vec 2.0 or XLSR representations (see [pretrained models](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec)) as well as unlabeled speech and text data. The wav2vec-U training procedure consists of three consecutive main steps: * Preparation of speech representations and text data @@ -8,7 +8,7 @@ Wav2vec Unsupervised (wav2vec-U) is a framework for building speech recognition * Iterative self-training + Kaldi LM-decoding ## Preparation of speech and text data -Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}. +Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md), data folders contain {train,valid,test}.{tsv,wrd,phn} files, where audio paths are stored in tsv files, and word, letter or phoneme transcriptions are stored in .{wrd,ltr,phn}. In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD. diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 9afe385b9d..4d47d97518 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -23,7 +23,7 @@ class BARTHubInterface(GeneratorHubInterface): """A simple PyTorch Hub interface to BART. - Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart + Usage: https://github.com/pytorch/fairseq/tree/main/examples/bart """ def __init__(self, cfg, task, model): diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index c9af434bde..ba298d63ba 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -14,7 +14,7 @@ class RobertaHubInterface(nn.Module): """A simple PyTorch Hub interface to RoBERTa. - Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta + Usage: https://github.com/pytorch/fairseq/tree/main/examples/roberta """ def __init__(self, cfg, task, model): From 3dd70d8c0d17ef3268b22706805622826df7b6d3 Mon Sep 17 00:00:00 2001 From: freewym <freewym@gmail.com> Date: Mon, 20 Sep 2021 11:52:04 -0700 Subject: [PATCH 706/707] =?UTF-8?q?fix=20the=20problem=20that=20command=20?= =?UTF-8?q?line=20args=20for=20Transformer=20model=20do=20not=20o=E2=80=A6?= =?UTF-8?q?=20(#3773)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …verride the defaults # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3761. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3773 Reviewed By: yuntang Differential Revision: D30310383 Pulled By: kahne fbshipit-source-id: cbfcbc032dbf53490a25ffdebe57f65c42d52e71 --- fairseq/models/transformer/transformer_base.py | 4 ++-- fairseq/models/transformer/transformer_legacy.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py index 810c9b98db..b4d5604dbb 100644 --- a/fairseq/models/transformer/transformer_base.py +++ b/fairseq/models/transformer/transformer_base.py @@ -41,8 +41,8 @@ def __init__(self, cfg, encoder, decoder): self.cfg = cfg self.supports_align_args = True - @staticmethod - def add_args(parser): + @classmethod + def add_args(cls, parser): """Add model-specific arguments to the parser.""" # we want to build the args recursively in this case. gen_parser_from_dataclass( diff --git a/fairseq/models/transformer/transformer_legacy.py b/fairseq/models/transformer/transformer_legacy.py index 9534e400b5..af9646740a 100644 --- a/fairseq/models/transformer/transformer_legacy.py +++ b/fairseq/models/transformer/transformer_legacy.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.models import ( register_model, register_model_architecture, @@ -78,6 +79,15 @@ def __init__(self, args, encoder, decoder): super().__init__(cfg, encoder, decoder) self.args = args + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + # we want to build the args recursively in this case. + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass( + parser, TransformerConfig(), delete_default=True, with_prefix="" + ) + @classmethod def build_model(cls, args, task): """Build a new model instance.""" From fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1 Mon Sep 17 00:00:00 2001 From: Diana Liskovich <dianaml@fb.com> Date: Mon, 20 Sep 2021 14:30:04 -0700 Subject: [PATCH 707/707] Update reference from master to main elsewhere in fbcode Summary: Update reference from master to main elsewhere in fbcode Reviewed By: alexeib Differential Revision: D30938472 fbshipit-source-id: 243b98550207f241c9d3265bf3d4060350aaf0a8 --- examples/fully_sharded_data_parallel/README.md | 2 +- examples/speech_text_joint_to_text/docs/ende-mustc.md | 4 ++-- examples/speech_text_joint_to_text/docs/iwslt2021.md | 2 +- examples/textless_nlp/gslm/metrics/asr_metrics/README.md | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md index d620f0e4f1..b9e44fef48 100644 --- a/examples/fully_sharded_data_parallel/README.md +++ b/examples/fully_sharded_data_parallel/README.md @@ -48,7 +48,7 @@ CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs. These examples use the WikiText-103 dataset for demonstration purposes, but in practice a much larger dataset will be needed to achieve good results. -Follow the [instructions here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data) +Follow the [instructions here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.pretraining.md#1-preprocess-the-data) to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary. ### 13B params on 1 V100 GPU (with CPU offloading) diff --git a/examples/speech_text_joint_to_text/docs/ende-mustc.md b/examples/speech_text_joint_to_text/docs/ende-mustc.md index 3487af6671..2897c4e27b 100644 --- a/examples/speech_text_joint_to_text/docs/ende-mustc.md +++ b/examples/speech_text_joint_to_text/docs/ende-mustc.md @@ -12,7 +12,7 @@ Enhanced Joint Training: the joint training is enhanced with pre-trained models, - Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt) - config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml) #### Prepare MuST-C data set -- [Please follow the data preparation in the S2T example](https://github.com/pytorch/fairseq/blob/master/examples/speech_to_text/docs/mustc_example.md) +- [Please follow the data preparation in the S2T example](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mustc_example.md) - Append src_text in the tsv file with phoneme representation. ```bash python examples/speech_text_joint_to_text/scripts/g2p_encode.py \ @@ -24,7 +24,7 @@ Enhanced Joint Training: the joint training is enhanced with pre-trained models, - Update tsv data with src_text generated above and save to $MANIFEST_ROOT - Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt) #### Prepare WMT text data -- [Download wmt data](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2de.sh) +- [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh) - Convert source text (English) into phoneme representation as above - Generate binary parallel file for training (as translation example) and save data in $parallel_text_data diff --git a/examples/speech_text_joint_to_text/docs/iwslt2021.md b/examples/speech_text_joint_to_text/docs/iwslt2021.md index 37a07c4a05..920ff271c2 100644 --- a/examples/speech_text_joint_to_text/docs/iwslt2021.md +++ b/examples/speech_text_joint_to_text/docs/iwslt2021.md @@ -11,7 +11,7 @@ This directory contains the code from paper ["FST: the FAIR Speech Translation S - Config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/config.yaml) #### Prepare -- [Please follow the data preparation in speech-to-text](https://github.com/pytorch/fairseq/blob/master/examples/speech_to_text/docs/mtedx_example.md) +- [Please follow the data preparation in speech-to-text](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mtedx_example.md) diff --git a/examples/textless_nlp/gslm/metrics/asr_metrics/README.md b/examples/textless_nlp/gslm/metrics/asr_metrics/README.md index d05bc73d0d..90741f42b0 100644 --- a/examples/textless_nlp/gslm/metrics/asr_metrics/README.md +++ b/examples/textless_nlp/gslm/metrics/asr_metrics/README.md @@ -29,7 +29,7 @@ Here `ground_truth_continuation_dev.json` is a json file with ground-truth text ## Running ASR We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint -[[here]](https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md)). To run ASR, you would also need to +[[here]](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md)). To run ASR, you would also need to install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model. ```bash